use std::pin::Pin;
use futures::Stream;
use serde::{Deserialize, Serialize};
use thiserror::Error;
mod ollama;
mod openai;
#[cfg(test)]
mod tests;
pub use ollama::OllamaProvider;
pub use openai::OpenAiProvider;
mod failover;
#[derive(Debug, Error)]
pub enum LlmError {
#[error("HTTP request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("API error: {status} - {message}")]
Api { status: u16, message: String },
#[error("Stream error: {0}")]
Stream(String),
#[error("Invalid response format: {0}")]
InvalidFormat(String),
#[error("Provider not available: {0}")]
ProviderUnavailable(String),
#[error("Rate limited")]
RateLimited,
#[error("Timeout")]
Timeout,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ProposedToolCall>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self::plain(Role::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::plain(Role::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::plain(Role::Assistant, content)
}
pub fn assistant_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ProposedToolCall>,
) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
tool_calls,
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: content.into(),
tool_calls: Vec::new(),
tool_call_id: Some(tool_call_id.into()),
}
}
fn plain(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
#[default]
User,
Assistant,
Tool,
}
impl Role {
pub fn as_wire_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
}
pub(crate) fn build_http_client(timeout: std::time::Duration) -> Result<reqwest::Client, LlmError> {
reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}")))
}
pub(crate) async fn ensure_ok(resp: reqwest::Response) -> Result<reqwest::Response, LlmError> {
if resp.status().is_success() {
return Ok(resp);
}
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
Err(LlmError::Api {
status: status.as_u16(),
message: body,
})
}
#[derive(Debug, Clone)]
pub struct ResponseChunk {
pub content: String,
pub is_done: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProposedToolCall {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Default)]
pub struct Response {
pub content: String,
pub usage: Option<Usage>,
pub tool_calls: Vec<ProposedToolCall>,
}
impl Response {
pub fn text(content: impl Into<String>, usage: Option<Usage>) -> Self {
Self {
content: content.into(),
usage,
tool_calls: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[async_trait::async_trait]
pub trait LlmProvider: Send + Sync {
async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError>;
async fn generate_with_tools(
&self,
messages: &[Message],
tools: &[ToolDef],
) -> Result<Response, LlmError> {
let _ = tools;
self.generate(messages).await
}
async fn generate_stream(
&self,
messages: &[Message],
) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError>;
async fn health_check(&self) -> bool;
fn name(&self) -> &str;
fn model(&self) -> &str;
async fn list_models(&self) -> Result<Vec<String>, LlmError>;
async fn fetch_context_window(&self) -> Option<usize> {
known_context_window(self.model())
}
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub provider: String,
pub base_url: String,
pub api_key: Option<String>,
pub model: String,
pub temperature: f64,
pub max_tokens: i32,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
provider: "ollama".to_string(),
base_url: "http://localhost:11434".to_string(),
api_key: None,
model: "qwen2.5-coder:7b".to_string(),
temperature: 0.7,
max_tokens: 4096,
}
}
}
pub fn create_provider(config: &ProviderConfig) -> Result<Box<dyn LlmProvider>, LlmError> {
if config.provider == "ollama" {
let provider = OllamaProvider::new(
&config.base_url,
&config.model,
config.temperature,
config.max_tokens,
)
.or_else(|e| {
tracing::error!(error = %e, "Failed to create Ollama provider, falling back to default");
OllamaProvider::default_config()
})?;
return Ok(Box::new(provider));
}
let preset_base = crate::presets::resolve(&config.provider).map(|p| p.base_url);
if config.provider == "openai_compat" || preset_base.is_some() {
let base_url = if !config.base_url.is_empty() {
config.base_url.as_str()
} else if let Some(b) = preset_base {
b
} else {
return Err(LlmError::ProviderUnavailable(format!(
"provider `{}` has no base_url configured",
config.provider
)));
};
return Ok(Box::new(OpenAiProvider::new(
base_url,
config.api_key.as_deref(),
&config.model,
config.temperature,
Some(config.max_tokens),
)?));
}
tracing::warn!(
provider = %config.provider,
"Unknown LLM provider, falling back to default Ollama"
);
Ok(Box::new(OllamaProvider::default_config()?))
}
fn provider_config_from_entry(
entry: &brain::ProviderEntry,
temperature: f64,
max_tokens: i32,
model_override: Option<&str>,
) -> ProviderConfig {
let api_key = match entry.api_key_file.as_ref() {
Some(path) => match std::fs::read_to_string(path) {
Ok(raw) => {
let trimmed = raw.trim().to_string();
if trimmed.is_empty() {
tracing::warn!(
provider = %entry.name,
path = %path.display(),
"llm.providers[].api_key_file is empty; falling back to inline api_key"
);
entry.api_key.trim().to_string()
} else {
trimmed
}
}
Err(e) => {
tracing::warn!(
provider = %entry.name,
path = %path.display(),
error = %e,
"llm.providers[].api_key_file unreadable; falling back to inline api_key"
);
entry.api_key.trim().to_string()
}
},
None => entry.api_key.trim().to_string(),
};
ProviderConfig {
provider: entry.kind.clone(),
base_url: entry.base_url.clone(),
api_key: if api_key.is_empty() {
None
} else {
Some(api_key)
},
model: model_override.unwrap_or(&entry.model).to_string(),
temperature,
max_tokens,
}
}
pub async fn select_provider(llm: &brain::LlmConfig) -> Result<Box<dyn LlmProvider>, LlmError> {
let entries = synthesise_entries(llm);
let max_tokens = llm.max_tokens as i32;
if entries.is_empty() {
return Err(LlmError::ProviderUnavailable(
"no LLM providers configured".into(),
));
}
for entry in &entries {
let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
let probe = match create_provider(&cfg) {
Ok(p) => p,
Err(e) => {
tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
continue;
}
};
match probe.list_models().await {
Ok(models) => {
let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
tracing::info!(
name = %entry.name,
kind = %entry.kind,
model = %chosen,
"LLM provider selected"
);
let cfg =
provider_config_from_entry(entry, llm.temperature, max_tokens, Some(&chosen));
return create_provider(&cfg);
}
Err(e) => {
tracing::warn!(
name = %entry.name,
error = %e,
"provider unreachable — trying next"
);
}
}
}
let first = &entries[0];
tracing::warn!(
name = %first.name,
"no provider answered list_models — falling back to first entry"
);
let cfg = provider_config_from_entry(first, llm.temperature, max_tokens, None);
create_provider(&cfg)
}
pub async fn build_failover_chain(
llm: &brain::LlmConfig,
) -> Result<failover::FailoverProvider, LlmError> {
let entries = synthesise_entries(llm);
let max_tokens = llm.max_tokens as i32;
if entries.is_empty() {
return Err(LlmError::ProviderUnavailable(
"no LLM providers configured".into(),
));
}
let mut primary_idx = None;
for (i, entry) in entries.iter().enumerate() {
let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
let probe = match create_provider(&cfg) {
Ok(p) => p,
Err(e) => {
tracing::warn!(name = %entry.name, error = %e, "skipping provider — construction failed");
continue;
}
};
match probe.list_models().await {
Ok(models) => {
let chosen = pick_model(&entry.preferred_models, &models, &entry.model);
tracing::info!(
name = %entry.name,
kind = %entry.kind,
model = %chosen,
"LLM provider selected"
);
primary_idx = Some((i, chosen));
break;
}
Err(e) => {
tracing::warn!(name = %entry.name, error = %e, "provider unreachable — trying next");
}
}
}
let (primary_i, model_override) = primary_idx.unwrap_or_else(|| {
tracing::warn!("no provider answered list_models — using first entry as primary");
(0, entries[0].model.clone())
});
let mut providers: Vec<Box<dyn LlmProvider>> = Vec::with_capacity(entries.len());
let primary_cfg = provider_config_from_entry(
&entries[primary_i],
llm.temperature,
max_tokens,
Some(&model_override),
);
providers.push(create_provider(&primary_cfg)?);
for (i, entry) in entries.iter().enumerate() {
if i == primary_i {
continue;
}
let cfg = provider_config_from_entry(entry, llm.temperature, max_tokens, None);
match create_provider(&cfg) {
Ok(p) => {
tracing::info!(name = %entry.name, "registered as fallback provider");
providers.push(p);
}
Err(e) => {
tracing::warn!(name = %entry.name, error = %e, "fallback provider construction failed — skipping");
}
}
}
Ok(failover::FailoverProvider::new(providers))
}
fn synthesise_entries(llm: &brain::LlmConfig) -> Vec<brain::ProviderEntry> {
if !llm.providers.is_empty() {
return llm.providers.clone();
}
#[allow(deprecated)]
let entry = brain::ProviderEntry {
name: "default".to_string(),
kind: llm.provider.clone(),
base_url: llm.base_url.clone(),
api_key: llm.api_key.clone(),
api_key_file: llm.api_key_file.clone(),
model: llm.model.clone(),
preferred_models: Vec::new(),
};
vec![entry]
}
pub(crate) fn known_context_window(model: &str) -> Option<usize> {
let lower = &model.to_ascii_lowercase();
if lower.contains("gemini") && !lower.contains("gemini-2.0-flash-lite") {
return Some(1_000_000);
}
if lower.contains("claude")
&& (lower.contains("sonnet") || lower.contains("opus") || lower.contains("haiku"))
{
return Some(200_000);
}
if lower.contains("claude") {
return Some(200_000);
}
if lower.contains("gpt-4o") || lower.contains("gpt-4.5") || lower.contains("gpt-4-turbo") {
return Some(128_000);
}
if lower.contains("gpt-3.5") {
return Some(16_000);
}
if lower.contains("gpt-4") {
return Some(32_000);
}
if lower.starts_with("o1") || lower.starts_with("o3") {
return Some(200_000);
}
if lower.contains("deepseek") {
return Some(128_000);
}
if lower.contains("qwen") {
return Some(128_000);
}
if lower.contains("llama") && lower.contains("3") {
return Some(128_000);
}
if lower.contains("llama") {
return Some(8_192);
}
if lower.contains("mistral") || lower.contains("mixtral") {
if lower.contains("large") || lower.contains("nemo") || lower.contains("codestral") {
return Some(128_000);
}
return Some(32_000);
}
if lower.contains("command-r") || lower.contains("command-r+") {
return Some(128_000);
}
if lower.contains("dbrx") || lower.contains("mpt") {
return Some(32_000);
}
if lower.contains("128k") || lower.contains("131k") || lower.contains("131072") {
return Some(131_072);
}
if lower.contains("200k") {
return Some(200_000);
}
if lower.contains("1m") || lower.contains("1000k") {
return Some(1_000_000);
}
if lower.contains("70b")
|| lower.contains("120b")
|| lower.contains("180b")
|| lower.contains("240b")
{
return Some(131_072);
}
if lower.contains("/oss") || lower.contains("oss-") || lower.contains("-oss") {
return Some(131_072);
}
None
}
fn pick_model(preferred: &[String], available: &[String], fallback: &str) -> String {
for want in preferred {
if available.iter().any(|m| m == want) {
return want.clone();
}
}
fallback.to_string()
}
pub fn extract_json_from_response<T: serde::de::DeserializeOwned>(raw: &str) -> Option<T> {
let trimmed = raw.trim();
if let Ok(parsed) = serde_json::from_str::<T>(trimmed) {
return Some(parsed);
}
let start = trimmed.find('{')?;
let end = trimmed.rfind('}')?;
serde_json::from_str::<T>(&trimmed[start..=end]).ok()
}