1use std::pin::Pin;
8
9use futures::Stream;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13mod ollama;
14mod openai;
15
16#[cfg(test)]
17mod tests;
18
19pub use ollama::OllamaProvider;
20pub use openai::OpenAiProvider;
21
22mod failover;
23
24#[derive(Debug, Error)]
28pub enum LlmError {
29 #[error("HTTP request failed: {0}")]
30 Http(#[from] reqwest::Error),
31
32 #[error("API error: {status} - {message}")]
33 Api { status: u16, message: String },
34
35 #[error("Stream error: {0}")]
36 Stream(String),
37
38 #[error("Invalid response format: {0}")]
39 InvalidFormat(String),
40
41 #[error("Provider not available: {0}")]
42 ProviderUnavailable(String),
43
44 #[error("Rate limited")]
45 RateLimited,
46
47 #[error("Timeout")]
48 Timeout,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct Message {
56 pub role: Role,
57 pub content: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62#[serde(rename_all = "lowercase")]
63pub enum Role {
64 System,
65 User,
66 Assistant,
67}
68
69#[derive(Debug, Clone)]
71pub struct ResponseChunk {
72 pub content: String,
73 pub is_done: bool,
74}
75
76#[derive(Debug, Clone)]
78pub struct Response {
79 pub content: String,
80 pub usage: Option<Usage>,
81}
82
83#[derive(Debug, Clone)]
85pub struct Usage {
86 pub prompt_tokens: u32,
87 pub completion_tokens: u32,
88 pub total_tokens: u32,
89}
90
91#[async_trait::async_trait]
95pub trait LlmProvider: Send + Sync {
96 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError>;
98
99 async fn generate_stream(
101 &self,
102 messages: &[Message],
103 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError>;
104
105 async fn health_check(&self) -> bool;
107
108 fn name(&self) -> &str;
110
111 fn model(&self) -> &str;
113
114 async fn list_models(&self) -> Result<Vec<String>, LlmError>;
117}
118
119#[derive(Debug, Clone)]
123pub struct ProviderConfig {
124 pub provider: String,
125 pub base_url: String,
126 pub api_key: Option<String>,
127 pub model: String,
128 pub temperature: f64,
129 pub max_tokens: i32,
130}
131
132impl Default for ProviderConfig {
133 fn default() -> Self {
134 Self {
135 provider: "ollama".to_string(),
136 base_url: "http://localhost:11434".to_string(),
137 api_key: None,
138 model: "qwen2.5-coder:7b".to_string(),
139 temperature: 0.7,
140 max_tokens: 4096,
141 }
142 }
143}
144
145pub fn create_provider(config: &ProviderConfig) -> Result<Box<dyn LlmProvider>, LlmError> {
154 if config.provider == "ollama" {
155 let provider = OllamaProvider::new(
156 &config.base_url,
157 &config.model,
158 config.temperature,
159 config.max_tokens,
160 )
161 .or_else(|e| {
162 tracing::error!(error = %e, "Failed to create Ollama provider, falling back to default");
163 OllamaProvider::default_config()
164 })?;
165 return Ok(Box::new(provider));
166 }
167
168 let preset_base = crate::presets::resolve(&config.provider).map(|p| p.base_url);
169
170 if config.provider == "openai_compat" || preset_base.is_some() {
171 let base_url = if !config.base_url.is_empty() {
172 config.base_url.as_str()
173 } else if let Some(b) = preset_base {
174 b
175 } else {
176 return Err(LlmError::ProviderUnavailable(format!(
177 "provider `{}` has no base_url configured",
178 config.provider
179 )));
180 };
181 return Ok(Box::new(OpenAiProvider::new(
182 base_url,
183 config.api_key.as_deref(),
184 &config.model,
185 config.temperature,
186 Some(config.max_tokens),
187 )?));
188 }
189
190 tracing::warn!(
191 provider = %config.provider,
192 "Unknown LLM provider, falling back to default Ollama"
193 );
194 Ok(Box::new(OllamaProvider::default_config()?))
195}
196
197fn provider_config_from_entry(
203 entry: &brain_core::ProviderEntry,
204 temperature: f64,
205 max_tokens: i32,
206 model_override: Option<&str>,
207) -> ProviderConfig {
208 let api_key = entry.api_key.trim();
209 ProviderConfig {
210 provider: entry.kind.clone(),
211 base_url: entry.base_url.clone(),
212 api_key: if api_key.is_empty() {
213 None
214 } else {
215 Some(api_key.to_string())
216 },
217 model: model_override.unwrap_or(&entry.model).to_string(),
218 temperature,
219 max_tokens,
220 }
221}
222
223pub async fn select_provider(
234 llm: &brain_core::LlmConfig,
235) -> Result<Box<dyn LlmProvider>, LlmError> {
236 let entries = synthesise_entries(llm);
237 let max_tokens = llm.max_tokens as i32;
238
239 if entries.is_empty() {
240 return Err(LlmError::ProviderUnavailable(
241 "no LLM providers configured".into(),
242 ));
243 }
244
245 for entry in &entries {
246 let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
247 let probe = match create_provider(&cfg) {
248 Ok(p) => p,
249 Err(e) => {
250 tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
251 continue;
252 }
253 };
254
255 match probe.list_models().await {
256 Ok(models) => {
257 let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
258 tracing::info!(
259 name = %entry.name,
260 kind = %entry.kind,
261 model = %chosen,
262 "LLM provider selected"
263 );
264 let cfg =
265 provider_config_from_entry(entry, llm.temperature, max_tokens, Some(&chosen));
266 return create_provider(&cfg);
267 }
268 Err(e) => {
269 tracing::warn!(
270 name = %entry.name,
271 error = %e,
272 "provider unreachable — trying next"
273 );
274 }
275 }
276 }
277
278 let first = &entries[0];
281 tracing::warn!(
282 name = %first.name,
283 "no provider answered list_models — falling back to first entry"
284 );
285 let cfg = provider_config_from_entry(first, llm.temperature, max_tokens, None);
286 create_provider(&cfg)
287}
288
289pub async fn build_failover_chain(
296 llm: &brain_core::LlmConfig,
297) -> Result<failover::FalloverProvider, LlmError> {
298 let entries = synthesise_entries(llm);
299 let max_tokens = llm.max_tokens as i32;
300
301 if entries.is_empty() {
302 return Err(LlmError::ProviderUnavailable(
303 "no LLM providers configured".into(),
304 ));
305 }
306
307 let mut primary_idx = None;
309 for (i, entry) in entries.iter().enumerate() {
310 let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
311 let probe = match create_provider(&cfg) {
312 Ok(p) => p,
313 Err(e) => {
314 tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
315 continue;
316 }
317 };
318 match probe.list_models().await {
319 Ok(models) => {
320 let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
321 tracing::info!(
322 name = %entry.name,
323 kind = %entry.kind,
324 model = %chosen,
325 "LLM provider selected"
326 );
327 primary_idx = Some((i, chosen));
328 break;
329 }
330 Err(e) => {
331 tracing::warn!(name = %entry.name, error = %e, "provider unreachable — trying next");
332 }
333 }
334 }
335
336 let (primary_i, model_override) = primary_idx.unwrap_or_else(|| {
338 tracing::warn!("no provider answered list_models — using first entry as primary");
339 (0, entries[0].model.clone())
340 });
341
342 let mut providers: Vec<Box<dyn LlmProvider>> = Vec::with_capacity(entries.len());
344 let primary_cfg = provider_config_from_entry(
345 &entries[primary_i],
346 llm.temperature,
347 max_tokens,
348 Some(&model_override),
349 );
350 providers.push(create_provider(&primary_cfg)?);
351
352 for (i, entry) in entries.iter().enumerate() {
353 if i == primary_i {
354 continue;
355 }
356 let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
357 match create_provider(&cfg) {
358 Ok(p) => {
359 tracing::info!(name = %entry.name, "registered as fallback provider");
360 providers.push(p);
361 }
362 Err(e) => {
363 tracing::warn!(name = %entry.name, error = %e, "fallback provider construction failed — skipping");
364 }
365 }
366 }
367
368 Ok(failover::FalloverProvider::new(providers))
369}
370
371fn synthesise_entries(llm: &brain_core::LlmConfig) -> Vec<brain_core::ProviderEntry> {
372 if !llm.providers.is_empty() {
373 return llm.providers.clone();
374 }
375 vec![brain_core::ProviderEntry {
376 name: "default".to_string(),
377 kind: llm.provider.clone(),
378 base_url: llm.base_url.clone(),
379 api_key: llm.api_key.clone(),
380 model: llm.model.clone(),
381 preferred_models: Vec::new(),
382 }]
383}
384
385fn pick_model(preferred: &[String], available: &[String], fallback: &str) -> String {
386 for want in preferred {
387 if available.iter().any(|m| m == want) {
388 return want.clone();
389 }
390 }
391 fallback.to_string()
392}
393
394pub fn extract_json_from_response<T: serde::de::DeserializeOwned>(raw: &str) -> Option<T> {
399 let trimmed = raw.trim();
400 if let Ok(parsed) = serde_json::from_str::<T>(trimmed) {
401 return Some(parsed);
402 }
403 let start = trimmed.find('{')?;
404 let end = trimmed.rfind('}')?;
405 serde_json::from_str::<T>(&trimmed[start..=end]).ok()
406}