1use std::pin::Pin;
8
9use futures::Stream;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13mod ollama;
14mod openai;
15
16#[cfg(test)]
17mod tests;
18
19pub use ollama::OllamaProvider;
20pub use openai::OpenAiProvider;
21
22mod failover;
23
24#[derive(Debug, Error)]
28pub enum LlmError {
29 #[error("HTTP request failed: {0}")]
30 Http(#[from] reqwest::Error),
31
32 #[error("API error: {status} - {message}")]
33 Api { status: u16, message: String },
34
35 #[error("Stream error: {0}")]
36 Stream(String),
37
38 #[error("Invalid response format: {0}")]
39 InvalidFormat(String),
40
41 #[error("Provider not available: {0}")]
42 ProviderUnavailable(String),
43
44 #[error("Rate limited")]
45 RateLimited,
46
47 #[error("Timeout")]
48 Timeout,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, Default)]
60pub struct Message {
61 pub role: Role,
62 pub content: String,
63 #[serde(default, skip_serializing_if = "Vec::is_empty")]
67 pub tool_calls: Vec<ProposedToolCall>,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
72 pub tool_call_id: Option<String>,
73}
74
75impl Message {
76 pub fn system(content: impl Into<String>) -> Self {
78 Self::plain(Role::System, content)
79 }
80
81 pub fn user(content: impl Into<String>) -> Self {
83 Self::plain(Role::User, content)
84 }
85
86 pub fn assistant(content: impl Into<String>) -> Self {
88 Self::plain(Role::Assistant, content)
89 }
90
91 pub fn assistant_with_tool_calls(
94 content: impl Into<String>,
95 tool_calls: Vec<ProposedToolCall>,
96 ) -> Self {
97 Self {
98 role: Role::Assistant,
99 content: content.into(),
100 tool_calls,
101 tool_call_id: None,
102 }
103 }
104
105 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
107 Self {
108 role: Role::Tool,
109 content: content.into(),
110 tool_calls: Vec::new(),
111 tool_call_id: Some(tool_call_id.into()),
112 }
113 }
114
115 fn plain(role: Role, content: impl Into<String>) -> Self {
116 Self {
117 role,
118 content: content.into(),
119 tool_calls: Vec::new(),
120 tool_call_id: None,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
127#[serde(rename_all = "lowercase")]
128pub enum Role {
129 System,
130 #[default]
131 User,
132 Assistant,
133 Tool,
136}
137
138impl Role {
139 pub fn as_wire_str(&self) -> &'static str {
143 match self {
144 Role::System => "system",
145 Role::User => "user",
146 Role::Assistant => "assistant",
147 Role::Tool => "tool",
148 }
149 }
150}
151
152pub(crate) fn build_http_client(timeout: std::time::Duration) -> Result<reqwest::Client, LlmError> {
157 reqwest::Client::builder()
158 .timeout(timeout)
159 .build()
160 .map_err(|e| LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}")))
161}
162
163pub(crate) async fn ensure_ok(resp: reqwest::Response) -> Result<reqwest::Response, LlmError> {
167 if resp.status().is_success() {
168 return Ok(resp);
169 }
170 let status = resp.status();
171 let body = resp.text().await.unwrap_or_default();
172 Err(LlmError::Api {
173 status: status.as_u16(),
174 message: body,
175 })
176}
177
178#[derive(Debug, Clone)]
180pub struct ResponseChunk {
181 pub content: String,
182 pub is_done: bool,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ToolDef {
193 pub name: String,
194 pub description: String,
195 pub parameters: serde_json::Value,
196}
197
198#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
202pub struct ProposedToolCall {
203 #[serde(default, skip_serializing_if = "Option::is_none")]
205 pub id: Option<String>,
206 pub name: String,
207 pub arguments: serde_json::Value,
210}
211
212#[derive(Debug, Clone, Default)]
214pub struct Response {
215 pub content: String,
216 pub usage: Option<Usage>,
217 pub tool_calls: Vec<ProposedToolCall>,
220}
221
222impl Response {
223 pub fn text(content: impl Into<String>, usage: Option<Usage>) -> Self {
227 Self {
228 content: content.into(),
229 usage,
230 tool_calls: Vec::new(),
231 }
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct Usage {
238 pub prompt_tokens: u32,
239 pub completion_tokens: u32,
240 pub total_tokens: u32,
241}
242
243#[async_trait::async_trait]
247pub trait LlmProvider: Send + Sync {
248 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError>;
250
251 async fn generate_with_tools(
258 &self,
259 messages: &[Message],
260 tools: &[ToolDef],
261 ) -> Result<Response, LlmError> {
262 let _ = tools;
263 self.generate(messages).await
264 }
265
266 async fn generate_stream(
268 &self,
269 messages: &[Message],
270 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError>;
271
272 async fn health_check(&self) -> bool;
274
275 fn name(&self) -> &str;
277
278 fn model(&self) -> &str;
280
281 async fn list_models(&self) -> Result<Vec<String>, LlmError>;
284
285 async fn fetch_context_window(&self) -> Option<usize> {
291 known_context_window(self.model())
292 }
293}
294
295#[derive(Debug, Clone)]
299pub struct ProviderConfig {
300 pub provider: String,
301 pub base_url: String,
302 pub api_key: Option<String>,
303 pub model: String,
304 pub temperature: f64,
305 pub max_tokens: i32,
306}
307
308impl Default for ProviderConfig {
309 fn default() -> Self {
310 Self {
311 provider: "ollama".to_string(),
312 base_url: "http://localhost:11434".to_string(),
313 api_key: None,
314 model: "qwen2.5-coder:7b".to_string(),
315 temperature: 0.7,
316 max_tokens: 4096,
317 }
318 }
319}
320
321pub fn create_provider(config: &ProviderConfig) -> Result<Box<dyn LlmProvider>, LlmError> {
330 if config.provider == "ollama" {
331 let provider = OllamaProvider::new(
332 &config.base_url,
333 &config.model,
334 config.temperature,
335 config.max_tokens,
336 )
337 .or_else(|e| {
338 tracing::error!(error = %e, "Failed to create Ollama provider, falling back to default");
339 OllamaProvider::default_config()
340 })?;
341 return Ok(Box::new(provider));
342 }
343
344 let preset_base = crate::presets::resolve(&config.provider).map(|p| p.base_url);
345
346 if config.provider == "openai_compat" || preset_base.is_some() {
347 let base_url = if !config.base_url.is_empty() {
348 config.base_url.as_str()
349 } else if let Some(b) = preset_base {
350 b
351 } else {
352 return Err(LlmError::ProviderUnavailable(format!(
353 "provider `{}` has no base_url configured",
354 config.provider
355 )));
356 };
357 return Ok(Box::new(OpenAiProvider::new(
358 base_url,
359 config.api_key.as_deref(),
360 &config.model,
361 config.temperature,
362 Some(config.max_tokens),
363 )?));
364 }
365
366 tracing::warn!(
367 provider = %config.provider,
368 "Unknown LLM provider, falling back to default Ollama"
369 );
370 Ok(Box::new(OllamaProvider::default_config()?))
371}
372
373fn provider_config_from_entry(
379 entry: &brain::ProviderEntry,
380 temperature: f64,
381 max_tokens: i32,
382 model_override: Option<&str>,
383) -> ProviderConfig {
384 let api_key = match entry.api_key_file.as_ref() {
390 Some(path) => match std::fs::read_to_string(path) {
391 Ok(raw) => {
392 let trimmed = raw.trim().to_string();
393 if trimmed.is_empty() {
394 tracing::warn!(
395 provider = %entry.name,
396 path = %path.display(),
397 "llm.providers[].api_key_file is empty; falling back to inline api_key"
398 );
399 entry.api_key.trim().to_string()
400 } else {
401 trimmed
402 }
403 }
404 Err(e) => {
405 tracing::warn!(
406 provider = %entry.name,
407 path = %path.display(),
408 error = %e,
409 "llm.providers[].api_key_file unreadable; falling back to inline api_key"
410 );
411 entry.api_key.trim().to_string()
412 }
413 },
414 None => entry.api_key.trim().to_string(),
415 };
416 ProviderConfig {
417 provider: entry.kind.clone(),
418 base_url: entry.base_url.clone(),
419 api_key: if api_key.is_empty() {
420 None
421 } else {
422 Some(api_key)
423 },
424 model: model_override.unwrap_or(&entry.model).to_string(),
425 temperature,
426 max_tokens,
427 }
428}
429
430pub async fn select_provider(llm: &brain::LlmConfig) -> Result<Box<dyn LlmProvider>, LlmError> {
441 let entries = synthesise_entries(llm);
442 let max_tokens = llm.max_tokens as i32;
443
444 if entries.is_empty() {
445 return Err(LlmError::ProviderUnavailable(
446 "no LLM providers configured".into(),
447 ));
448 }
449
450 for entry in &entries {
451 let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
452 let probe = match create_provider(&cfg) {
453 Ok(p) => p,
454 Err(e) => {
455 tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
456 continue;
457 }
458 };
459
460 match probe.list_models().await {
461 Ok(models) => {
462 let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
463 tracing::info!(
464 name = %entry.name,
465 kind = %entry.kind,
466 model = %chosen,
467 "LLM provider selected"
468 );
469 let cfg =
470 provider_config_from_entry(entry, llm.temperature, max_tokens, Some(&chosen));
471 return create_provider(&cfg);
472 }
473 Err(e) => {
474 tracing::warn!(
475 name = %entry.name,
476 error = %e,
477 "provider unreachable — trying next"
478 );
479 }
480 }
481 }
482
483 let first = &entries[0];
486 tracing::warn!(
487 name = %first.name,
488 "no provider answered list_models — falling back to first entry"
489 );
490 let cfg = provider_config_from_entry(first, llm.temperature, max_tokens, None);
491 create_provider(&cfg)
492}
493
494pub async fn build_failover_chain(
501 llm: &brain::LlmConfig,
502) -> Result<failover::FailoverProvider, LlmError> {
503 let entries = synthesise_entries(llm);
504 let max_tokens = llm.max_tokens as i32;
505
506 if entries.is_empty() {
507 return Err(LlmError::ProviderUnavailable(
508 "no LLM providers configured".into(),
509 ));
510 }
511
512 let mut primary_idx = None;
514 for (i, entry) in entries.iter().enumerate() {
515 let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
516 let probe = match create_provider(&cfg) {
517 Ok(p) => p,
518 Err(e) => {
519 tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
520 continue;
521 }
522 };
523 match probe.list_models().await {
524 Ok(models) => {
525 let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
526 tracing::info!(
527 name = %entry.name,
528 kind = %entry.kind,
529 model = %chosen,
530 "LLM provider selected"
531 );
532 primary_idx = Some((i, chosen));
533 break;
534 }
535 Err(e) => {
536 tracing::warn!(name = %entry.name, error = %e, "provider unreachable — trying next");
537 }
538 }
539 }
540
541 let (primary_i, model_override) = primary_idx.unwrap_or_else(|| {
543 tracing::warn!("no provider answered list_models — using first entry as primary");
544 (0, entries[0].model.clone())
545 });
546
547 let mut providers: Vec<Box<dyn LlmProvider>> = Vec::with_capacity(entries.len());
549 let primary_cfg = provider_config_from_entry(
550 &entries[primary_i],
551 llm.temperature,
552 max_tokens,
553 Some(&model_override),
554 );
555 providers.push(create_provider(&primary_cfg)?);
556
557 for (i, entry) in entries.iter().enumerate() {
558 if i == primary_i {
559 continue;
560 }
561 let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
562 match create_provider(&cfg) {
563 Ok(p) => {
564 tracing::info!(name = %entry.name, "registered as fallback provider");
565 providers.push(p);
566 }
567 Err(e) => {
568 tracing::warn!(name = %entry.name, error = %e, "fallback provider construction failed — skipping");
569 }
570 }
571 }
572
573 Ok(failover::FailoverProvider::new(providers))
574}
575
576fn synthesise_entries(llm: &brain::LlmConfig) -> Vec<brain::ProviderEntry> {
577 if !llm.providers.is_empty() {
578 return llm.providers.clone();
579 }
580 #[allow(deprecated)]
586 let entry = brain::ProviderEntry {
587 name: "default".to_string(),
588 kind: llm.provider.clone(),
589 base_url: llm.base_url.clone(),
590 api_key: llm.api_key.clone(),
591 api_key_file: llm.api_key_file.clone(),
592 model: llm.model.clone(),
593 preferred_models: Vec::new(),
594 };
595 vec![entry]
596}
597
598pub(crate) fn known_context_window(model: &str) -> Option<usize> {
603 let lower = &model.to_ascii_lowercase();
604
605 if lower.contains("gemini") && !lower.contains("gemini-2.0-flash-lite") {
607 return Some(1_000_000);
608 }
609
610 if lower.contains("claude")
612 && (lower.contains("sonnet") || lower.contains("opus") || lower.contains("haiku"))
613 {
614 return Some(200_000);
615 }
616 if lower.contains("claude") {
618 return Some(200_000);
619 }
620
621 if lower.contains("gpt-4o") || lower.contains("gpt-4.5") || lower.contains("gpt-4-turbo") {
624 return Some(128_000);
625 }
626 if lower.contains("gpt-3.5") {
628 return Some(16_000);
629 }
630 if lower.contains("gpt-4") {
634 return Some(32_000);
635 }
636 if lower.starts_with("o1") || lower.starts_with("o3") {
638 return Some(200_000);
639 }
640
641 if lower.contains("deepseek") {
643 return Some(128_000);
644 }
645
646 if lower.contains("qwen") {
648 return Some(128_000);
649 }
650
651 if lower.contains("llama") && lower.contains("3") {
653 return Some(128_000);
654 }
655 if lower.contains("llama") {
656 return Some(8_192);
657 }
658
659 if lower.contains("mistral") || lower.contains("mixtral") {
661 if lower.contains("large") || lower.contains("nemo") || lower.contains("codestral") {
662 return Some(128_000);
663 }
664 return Some(32_000);
665 }
666
667 if lower.contains("command-r") || lower.contains("command-r+") {
669 return Some(128_000);
670 }
671
672 if lower.contains("dbrx") || lower.contains("mpt") {
674 return Some(32_000);
675 }
676
677 if lower.contains("128k") || lower.contains("131k") || lower.contains("131072") {
682 return Some(131_072);
683 }
684 if lower.contains("200k") {
685 return Some(200_000);
686 }
687 if lower.contains("1m") || lower.contains("1000k") {
688 return Some(1_000_000);
689 }
690
691 if lower.contains("70b")
694 || lower.contains("120b")
695 || lower.contains("180b")
696 || lower.contains("240b")
697 {
698 return Some(131_072);
699 }
700
701 if lower.contains("/oss") || lower.contains("oss-") || lower.contains("-oss") {
703 return Some(131_072);
704 }
705
706 None
707}
708
709fn pick_model(preferred: &[String], available: &[String], fallback: &str) -> String {
710 for want in preferred {
711 if available.iter().any(|m| m == want) {
712 return want.clone();
713 }
714 }
715 fallback.to_string()
716}
717
718pub fn extract_json_from_response<T: serde::de::DeserializeOwned>(raw: &str) -> Option<T> {
723 let trimmed = raw.trim();
724 if let Ok(parsed) = serde_json::from_str::<T>(trimmed) {
725 return Some(parsed);
726 }
727 let start = trimmed.find('{')?;
728 let end = trimmed.rfind('}')?;
729 serde_json::from_str::<T>(&trimmed[start..=end]).ok()
730}