Skip to main content

brainos_cortex/
llm.rs

1//! LLM client — hybrid provider with trait-based adapter.
2//!
3//! `LlmProvider` trait with multiple implementations:
4//! - `OllamaProvider` — local Ollama server
5//! - `OpenAiProvider` — OpenAI compatible APIs
6
7use std::pin::Pin;
8
9use futures::Stream;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13mod ollama;
14mod openai;
15
16#[cfg(test)]
17mod tests;
18
19pub use ollama::OllamaProvider;
20pub use openai::OpenAiProvider;
21
22mod failover;
23
24// ─── Errors ─────────────────────────────────────────────────────────────────
25
26/// Errors from the LLM layer.
27#[derive(Debug, Error)]
28pub enum LlmError {
29    #[error("HTTP request failed: {0}")]
30    Http(#[from] reqwest::Error),
31
32    #[error("API error: {status} - {message}")]
33    Api { status: u16, message: String },
34
35    #[error("Stream error: {0}")]
36    Stream(String),
37
38    #[error("Invalid response format: {0}")]
39    InvalidFormat(String),
40
41    #[error("Provider not available: {0}")]
42    ProviderUnavailable(String),
43
44    #[error("Rate limited")]
45    RateLimited,
46
47    #[error("Timeout")]
48    Timeout,
49}
50
51// ─── Types ──────────────────────────────────────────────────────────────────
52
53/// A message in the conversation.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct Message {
56    pub role: Role,
57    pub content: String,
58}
59
60/// Message roles.
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62#[serde(rename_all = "lowercase")]
63pub enum Role {
64    System,
65    User,
66    Assistant,
67}
68
69/// LLM response chunk (for streaming).
70#[derive(Debug, Clone)]
71pub struct ResponseChunk {
72    pub content: String,
73    pub is_done: bool,
74}
75
76/// Complete LLM response.
77#[derive(Debug, Clone)]
78pub struct Response {
79    pub content: String,
80    pub usage: Option<Usage>,
81}
82
83/// Token usage statistics.
84#[derive(Debug, Clone)]
85pub struct Usage {
86    pub prompt_tokens: u32,
87    pub completion_tokens: u32,
88    pub total_tokens: u32,
89}
90
91// ─── Provider Trait ─────────────────────────────────────────────────────────
92
93/// Trait for LLM providers.
94#[async_trait::async_trait]
95pub trait LlmProvider: Send + Sync {
96    /// Generate a complete response (non-streaming).
97    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError>;
98
99    /// Generate a streaming response.
100    async fn generate_stream(
101        &self,
102        messages: &[Message],
103    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError>;
104
105    /// Check if the provider is available.
106    async fn health_check(&self) -> bool;
107
108    /// Get the provider name.
109    fn name(&self) -> &str;
110
111    /// Get the active model name.
112    fn model(&self) -> &str;
113
114    /// List models available from this provider. Used by `select_provider`
115    /// to probe reachability and match `preferred_models` during startup.
116    async fn list_models(&self) -> Result<Vec<String>, LlmError>;
117}
118
119// ─── Provider Factory ───────────────────────────────────────────────────────
120
121/// Configuration for LLM provider selection.
122#[derive(Debug, Clone)]
123pub struct ProviderConfig {
124    pub provider: String,
125    pub base_url: String,
126    pub api_key: Option<String>,
127    pub model: String,
128    pub temperature: f64,
129    pub max_tokens: i32,
130}
131
132impl Default for ProviderConfig {
133    fn default() -> Self {
134        Self {
135            provider: "ollama".to_string(),
136            base_url: "http://localhost:11434".to_string(),
137            api_key: None,
138            model: "qwen2.5-coder:7b".to_string(),
139            temperature: 0.7,
140            max_tokens: 4096,
141        }
142    }
143}
144
145/// Create an LLM provider from configuration.
146///
147/// Resolution order:
148/// 1. `ollama` → `OllamaProvider`.
149/// 2. `openai_compat` (or a built-in preset: openai, openrouter, groq,
150///    deepseek, together, gemini-compat) → OpenAI-compatible provider.
151///    An explicit non-empty `base_url` overrides the preset default.
152/// 3. Unknown provider → fall back to default Ollama with a warning.
153pub fn create_provider(config: &ProviderConfig) -> Result<Box<dyn LlmProvider>, LlmError> {
154    if config.provider == "ollama" {
155        let provider = OllamaProvider::new(
156            &config.base_url,
157            &config.model,
158            config.temperature,
159            config.max_tokens,
160        )
161        .or_else(|e| {
162            tracing::error!(error = %e, "Failed to create Ollama provider, falling back to default");
163            OllamaProvider::default_config()
164        })?;
165        return Ok(Box::new(provider));
166    }
167
168    let preset_base = crate::presets::resolve(&config.provider).map(|p| p.base_url);
169
170    if config.provider == "openai_compat" || preset_base.is_some() {
171        let base_url = if !config.base_url.is_empty() {
172            config.base_url.as_str()
173        } else if let Some(b) = preset_base {
174            b
175        } else {
176            return Err(LlmError::ProviderUnavailable(format!(
177                "provider `{}` has no base_url configured",
178                config.provider
179            )));
180        };
181        return Ok(Box::new(OpenAiProvider::new(
182            base_url,
183            config.api_key.as_deref(),
184            &config.model,
185            config.temperature,
186            Some(config.max_tokens),
187        )?));
188    }
189
190    tracing::warn!(
191        provider = %config.provider,
192        "Unknown LLM provider, falling back to default Ollama"
193    );
194    Ok(Box::new(OllamaProvider::default_config()?))
195}
196
197// ─── Multi-provider selection ───────────────────────────────────────────────
198
199/// Build a `ProviderConfig` from a `brain_core::ProviderEntry` and shared
200/// temperature/max_tokens. `model_override` lets `select_provider` swap in
201/// a preferred model discovered via `list_models`.
202fn provider_config_from_entry(
203    entry: &brain_core::ProviderEntry,
204    temperature: f64,
205    max_tokens: i32,
206    model_override: Option<&str>,
207) -> ProviderConfig {
208    let api_key = entry.api_key.trim();
209    ProviderConfig {
210        provider: entry.kind.clone(),
211        base_url: entry.base_url.clone(),
212        api_key: if api_key.is_empty() {
213            None
214        } else {
215            Some(api_key.to_string())
216        },
217        model: model_override.unwrap_or(&entry.model).to_string(),
218        temperature,
219        max_tokens,
220    }
221}
222
223/// Probe every configured provider, pick the first reachable one whose
224/// `preferred_models` intersects the live model list, and return it.
225///
226/// When `llm.providers` is empty we synthesise a single entry from the
227/// legacy `llm.provider`/`model`/`base_url`/`api_key` fields — so existing
228/// configs keep working unchanged.
229///
230/// Fail-safe: if no provider answers `list_models`, we still return the
231/// first entry as a best effort rather than erroring out (the underlying
232/// generate call will surface the real problem when used).
233pub async fn select_provider(
234    llm: &brain_core::LlmConfig,
235) -> Result<Box<dyn LlmProvider>, LlmError> {
236    let entries = synthesise_entries(llm);
237    let max_tokens = llm.max_tokens as i32;
238
239    if entries.is_empty() {
240        return Err(LlmError::ProviderUnavailable(
241            "no LLM providers configured".into(),
242        ));
243    }
244
245    for entry in &entries {
246        let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
247        let probe = match create_provider(&cfg) {
248            Ok(p) => p,
249            Err(e) => {
250                tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
251                continue;
252            }
253        };
254
255        match probe.list_models().await {
256            Ok(models) => {
257                let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
258                tracing::info!(
259                    name = %entry.name,
260                    kind = %entry.kind,
261                    model = %chosen,
262                    "LLM provider selected"
263                );
264                let cfg =
265                    provider_config_from_entry(entry, llm.temperature, max_tokens, Some(&chosen));
266                return create_provider(&cfg);
267            }
268            Err(e) => {
269                tracing::warn!(
270                    name = %entry.name,
271                    error = %e,
272                    "provider unreachable — trying next"
273                );
274            }
275        }
276    }
277
278    // All probes failed — fall back to the first entry so startup continues
279    // and the caller surfaces the real failure on first generate().
280    let first = &entries[0];
281    tracing::warn!(
282        name = %first.name,
283        "no provider answered list_models — falling back to first entry"
284    );
285    let cfg = provider_config_from_entry(first, llm.temperature, max_tokens, None);
286    create_provider(&cfg)
287}
288
289/// Build a failover chain from all configured providers.
290///
291/// The chain is ordered: the startup-probed winner goes first; the remaining
292/// entries (built without probing) follow as fallbacks. At request time the
293/// chain tries each in order whenever the current provider returns a retriable
294/// error (429 / 5xx / unavailable / timeout).
295pub async fn build_failover_chain(
296    llm: &brain_core::LlmConfig,
297) -> Result<failover::FalloverProvider, LlmError> {
298    let entries = synthesise_entries(llm);
299    let max_tokens = llm.max_tokens as i32;
300
301    if entries.is_empty() {
302        return Err(LlmError::ProviderUnavailable(
303            "no LLM providers configured".into(),
304        ));
305    }
306
307    // Find the primary via probing (same logic as select_provider).
308    let mut primary_idx = None;
309    for (i, entry) in entries.iter().enumerate() {
310        let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
311        let probe = match create_provider(&cfg) {
312            Ok(p) => p,
313            Err(e) => {
314                tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
315                continue;
316            }
317        };
318        match probe.list_models().await {
319            Ok(models) => {
320                let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
321                tracing::info!(
322                    name = %entry.name,
323                    kind = %entry.kind,
324                    model = %chosen,
325                    "LLM provider selected"
326                );
327                primary_idx = Some((i, chosen));
328                break;
329            }
330            Err(e) => {
331                tracing::warn!(name = %entry.name, error = %e, "provider unreachable — trying next");
332            }
333        }
334    }
335
336    // If no probe succeeded, fall back to index 0 (best-effort).
337    let (primary_i, model_override) = primary_idx.unwrap_or_else(|| {
338        tracing::warn!("no provider answered list_models — using first entry as primary");
339        (0, entries[0].model.clone())
340    });
341
342    // Build all providers: primary first, rest appended in config order.
343    let mut providers: Vec<Box<dyn LlmProvider>> = Vec::with_capacity(entries.len());
344    let primary_cfg = provider_config_from_entry(
345        &entries[primary_i],
346        llm.temperature,
347        max_tokens,
348        Some(&model_override),
349    );
350    providers.push(create_provider(&primary_cfg)?);
351
352    for (i, entry) in entries.iter().enumerate() {
353        if i == primary_i {
354            continue;
355        }
356        let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
357        match create_provider(&cfg) {
358            Ok(p) => {
359                tracing::info!(name = %entry.name, "registered as fallback provider");
360                providers.push(p);
361            }
362            Err(e) => {
363                tracing::warn!(name = %entry.name, error = %e, "fallback provider construction failed — skipping");
364            }
365        }
366    }
367
368    Ok(failover::FalloverProvider::new(providers))
369}
370
371fn synthesise_entries(llm: &brain_core::LlmConfig) -> Vec<brain_core::ProviderEntry> {
372    if !llm.providers.is_empty() {
373        return llm.providers.clone();
374    }
375    vec![brain_core::ProviderEntry {
376        name: "default".to_string(),
377        kind: llm.provider.clone(),
378        base_url: llm.base_url.clone(),
379        api_key: llm.api_key.clone(),
380        model: llm.model.clone(),
381        preferred_models: Vec::new(),
382    }]
383}
384
385fn pick_model(preferred: &[String], available: &[String], fallback: &str) -> String {
386    for want in preferred {
387        if available.iter().any(|m| m == want) {
388            return want.clone();
389        }
390    }
391    fallback.to_string()
392}
393
394/// Extract a JSON object from an LLM response string.
395///
396/// LLMs sometimes wrap JSON in markdown fences or explanatory text.
397/// This tries direct parse first, then finds the outermost `{...}`.
398pub fn extract_json_from_response<T: serde::de::DeserializeOwned>(raw: &str) -> Option<T> {
399    let trimmed = raw.trim();
400    if let Ok(parsed) = serde_json::from_str::<T>(trimmed) {
401        return Some(parsed);
402    }
403    let start = trimmed.find('{')?;
404    let end = trimmed.rfind('}')?;
405    serde_json::from_str::<T>(&trimmed[start..=end]).ok()
406}