koda_core/providers/mod.rs
1//! LLM provider abstraction layer.
2//!
3//! Defines a common `Provider` trait for all backends and re-exports
4//! the concrete implementations.
5//!
6//! ## Supported providers
7//!
8//! | Provider | Module | API style | Local? |
9//! |---|---|---|---|
10//! | Anthropic Claude | `anthropic` | Native | No |
11//! | Google Gemini | `gemini` | Native | No |
12//! | OpenAI / GPT | `openai_compat` | OpenAI-compat | No |
13//! | LM Studio | `openai_compat` | OpenAI-compat | Yes |
14//! | Ollama | `openai_compat` | OpenAI-compat | Yes |
15//! | Groq | `openai_compat` | OpenAI-compat | No |
16//! | Grok (xAI) | `openai_compat` | OpenAI-compat | No |
17//! | DeepSeek | `openai_compat` | OpenAI-compat | No |
18//! | OpenRouter | `openai_compat` | OpenAI-compat | No |
19//! | Together | `openai_compat` | OpenAI-compat | No |
20//! | Mistral | `openai_compat` | OpenAI-compat | No |
21//! | Cerebras | `openai_compat` | OpenAI-compat | No |
22//! | Fireworks | `openai_compat` | OpenAI-compat | No |
23//! | Custom | `openai_compat` | OpenAI-compat | Varies |
24//!
25//! All OpenAI-compatible providers share the same module with different
26//! base URLs. Use `--base-url` to point at any compatible endpoint.
27//!
28//! ## Design (DESIGN.md)
29//!
30//! - **Any model, any provider (P1)**: No vendor lock-in.
31//! The tool serves the person, not the platform.
32//! - **Context Window Auto-Detection (P1, P3)**: Capabilities are queried
33//! from the provider API at startup. Hardcoded lookup is the fallback,
34//! not the primary source.
35
36/// Anthropic Claude API provider.
37pub mod anthropic;
38/// Google Gemini API provider.
39pub mod gemini;
40/// OpenAI-compatible provider (LM Studio, Ollama, vLLM, OpenRouter, etc.).
41pub mod openai_compat;
42/// Shared SSE stream collector for all providers.
43pub mod stream_collector;
44/// Streaming XML tag filter for think/reasoning tags.
45pub mod stream_tag_filter;
46
47/// Mock provider for deterministic testing.
48pub mod mock;
49
50use anyhow::Result;
51use async_trait::async_trait;
52use serde::{Deserialize, Serialize};
53
54/// A tool call requested by the LLM.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ToolCall {
57 /// Provider-assigned call ID (echoed back in tool results).
58 pub id: String,
59 /// Name of the tool to invoke.
60 pub function_name: String,
61 /// Raw JSON string of tool arguments.
62 pub arguments: String,
63 /// Gemini-specific: thought signature that must be echoed back in history.
64 #[serde(skip_serializing_if = "Option::is_none", default)]
65 pub thought_signature: Option<String>,
66}
67
68/// Token usage from an LLM response.
69#[derive(Debug, Clone, Default)]
70pub struct TokenUsage {
71 /// Input tokens sent to the model.
72 pub prompt_tokens: i64,
73 /// Output tokens generated by the model.
74 pub completion_tokens: i64,
75 /// Tokens read from provider cache (e.g. Anthropic prompt caching, Gemini cached content).
76 pub cache_read_tokens: i64,
77 /// Tokens written to provider cache on this request.
78 pub cache_creation_tokens: i64,
79 /// Tokens used for reasoning/thinking (e.g. OpenAI reasoning_tokens, Anthropic thinking).
80 pub thinking_tokens: i64,
81 /// Why the model stopped: "end_turn", "max_tokens", "stop_sequence", etc.
82 /// Empty string means unknown (provider didn't report it).
83 pub stop_reason: String,
84}
85
86/// The LLM's response: either text, tool calls, or both.
87#[derive(Debug, Clone)]
88pub struct LlmResponse {
89 /// Text content of the response (may be `None` if tool-calls only).
90 pub content: Option<String>,
91 /// Tool calls requested by the model.
92 pub tool_calls: Vec<ToolCall>,
93 /// Token usage statistics.
94 pub usage: TokenUsage,
95}
96
97/// Base64-encoded image data for multi-modal messages.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ImageData {
100 /// MIME type (e.g. "image/png", "image/jpeg").
101 pub media_type: String,
102 /// Base64-encoded image bytes.
103 pub base64: String,
104}
105
106/// A single message in the conversation.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChatMessage {
109 /// Message role: `"user"`, `"assistant"`, or `"tool"`.
110 pub role: String,
111 /// Text content (may be `None` for tool-call-only messages).
112 pub content: Option<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 /// Tool calls requested by the assistant.
115 pub tool_calls: Option<Vec<ToolCall>>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 /// ID of the tool call this message responds to.
118 pub tool_call_id: Option<String>,
119 /// Attached images (only used in-flight, not persisted to DB).
120 #[serde(skip_serializing_if = "Option::is_none", default)]
121 pub images: Option<Vec<ImageData>>,
122}
123
124impl ChatMessage {
125 /// Create a simple text message (convenience for the common case).
126 pub fn text(role: &str, content: &str) -> Self {
127 Self {
128 role: role.to_string(),
129 content: Some(content.to_string()),
130 tool_calls: None,
131 tool_call_id: None,
132 images: None,
133 }
134 }
135}
136
137/// Tool definition sent to the LLM.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ToolDefinition {
140 /// Tool name (e.g. `"Read"`, `"Bash"`).
141 pub name: String,
142 /// Human-readable description for the LLM.
143 pub description: String,
144 /// JSON Schema for the tool's parameters.
145 pub parameters: serde_json::Value,
146}
147
148/// A discovered model from a provider.
149#[derive(Debug, Clone)]
150pub struct ModelInfo {
151 /// Model identifier (e.g. `"claude-3-5-sonnet-20241022"`).
152 pub id: String,
153 /// Provider/organization that owns the model.
154 #[allow(dead_code)]
155 pub owned_by: Option<String>,
156}
157
158/// Model capabilities queried from the provider API.
159#[derive(Debug, Clone, Default)]
160pub struct ModelCapabilities {
161 /// Maximum context window in tokens (input + output).
162 pub context_window: Option<usize>,
163 /// Maximum output tokens the model supports.
164 pub max_output_tokens: Option<usize>,
165}
166
167/// Is this URL pointing to a local address?
168fn is_localhost_url(url: &str) -> bool {
169 let lower = url.to_lowercase();
170 lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
171}
172
173/// Build a reqwest client with proper proxy configuration.
174///
175/// - Reads HTTPS_PROXY / HTTP_PROXY from env
176/// - Supports proxy auth via URL (http://user:pass@proxy:port)
177/// - Supports separate PROXY_USER / PROXY_PASS env vars
178/// - Bypasses proxy for localhost (LM Studio)
179pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
180 let mut builder = reqwest::Client::builder();
181
182 let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
183 .or_else(|| crate::runtime_env::get("HTTP_PROXY"))
184 .or_else(|| crate::runtime_env::get("https_proxy"))
185 .or_else(|| crate::runtime_env::get("http_proxy"));
186
187 if let Some(ref url) = proxy_url
188 && !url.is_empty()
189 {
190 match reqwest::Proxy::all(url) {
191 Ok(mut proxy) => {
192 // Bypass proxy for local addresses
193 proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
194
195 // If URL doesn't contain creds, check env vars
196 if !url.contains('@') {
197 let user = crate::runtime_env::get("PROXY_USER");
198 let pass = crate::runtime_env::get("PROXY_PASS");
199 if let (Some(u), Some(p)) = (user, pass) {
200 proxy = proxy.basic_auth(&u, &p);
201 tracing::debug!("Using proxy with basic auth (credentials redacted)");
202 }
203 }
204
205 builder = builder.proxy(proxy);
206 tracing::debug!("Using proxy: {}", redact_url_credentials(url));
207 }
208 Err(e) => {
209 tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
210 }
211 }
212 }
213
214 // Accept self-signed certs only for localhost (LM Studio, Ollama, vLLM).
215 // The env var is still required, but it's now scoped to local addresses.
216 let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
217 .map(|v| v == "1" || v == "true")
218 .unwrap_or(false);
219 let is_local = base_url.is_some_and(is_localhost_url);
220 if wants_skip_tls && is_local {
221 tracing::info!("TLS certificate validation disabled for local provider.");
222 builder = builder.danger_accept_invalid_certs(true);
223 } else if wants_skip_tls {
224 tracing::warn!(
225 "KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
226 TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
227 );
228 }
229
230 builder.build().unwrap_or_else(|_| reqwest::Client::new())
231}
232
233/// Redact embedded credentials from a URL.
234///
235/// `http://user:pass@proxy:8080` → `http://***:***@proxy:8080`
236fn redact_url_credentials(url: &str) -> String {
237 // Pattern: scheme://user:pass@host...
238 if let Some(at_pos) = url.find('@')
239 && let Some(scheme_end) = url.find("://")
240 {
241 let prefix = &url[..scheme_end + 3]; // "http://"
242 let host_part = &url[at_pos..]; // "@proxy:8080/..."
243 return format!("{prefix}***:***{host_part}");
244 }
245 url.to_string()
246}
247
248/// A streaming chunk from the LLM.
249#[derive(Debug, Clone)]
250pub enum StreamChunk {
251 /// A text delta (partial content).
252 TextDelta(String),
253 /// A thinking/reasoning delta from native API (Anthropic extended thinking, OpenAI reasoning).
254 ThinkingDelta(String),
255 /// A single tool call whose arguments finished streaming.
256 ///
257 /// Emitted by providers that support per-block completion events (Anthropic
258 /// `content_block_stop`). Enables eager execution of read-only tools while
259 /// subsequent tool calls are still being streamed.
260 ///
261 /// Providers that don't support per-block events (OpenAI, Gemini) never
262 /// emit this — they only emit `ToolCalls` at stream end.
263 ToolCallReady(ToolCall),
264 /// All tool calls from the response (batch, emitted at stream end).
265 ///
266 /// For Anthropic, this only contains tool calls NOT already emitted via
267 /// `ToolCallReady`. For other providers, this contains all tool calls.
268 ToolCalls(Vec<ToolCall>),
269 /// Stream finished with usage info.
270 Done(TokenUsage),
271 /// The underlying HTTP connection was dropped before the stream completed.
272 ///
273 /// Distinct from `Done` (clean finish) and user-initiated cancellation
274 /// (Ctrl+C). The partial response MUST be discarded — it is incomplete
275 /// and storing it would corrupt the session history on resume.
276 NetworkError(String),
277}
278
279/// Trait for LLM provider backends.
280#[async_trait]
281pub trait LlmProvider: Send + Sync {
282 /// Send a chat completion request (non-streaming).
283 async fn chat(
284 &self,
285 messages: &[ChatMessage],
286 tools: &[ToolDefinition],
287 settings: &crate::config::ModelSettings,
288 ) -> Result<LlmResponse>;
289
290 /// Send a streaming chat completion request.
291 /// Returns a channel receiver that yields chunks as they arrive.
292 async fn chat_stream(
293 &self,
294 messages: &[ChatMessage],
295 tools: &[ToolDefinition],
296 settings: &crate::config::ModelSettings,
297 ) -> Result<tokio::sync::mpsc::Receiver<StreamChunk>>;
298
299 /// List available models from the provider.
300 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
301
302 /// Query model capabilities (context window, max output tokens) from the API.
303 ///
304 /// Returns `Ok(caps)` with whatever the provider reports. Fields are `None`
305 /// when the API doesn't expose them. Callers should fall back to the
306 /// hardcoded lookup table for any `None` fields.
307 async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
308 Ok(ModelCapabilities::default())
309 }
310
311 /// Provider display name (for UI).
312 fn provider_name(&self) -> &str;
313}
314
315// ── Provider factory ──────────────────────────────────────────
316
317use crate::config::{KodaConfig, ProviderType};
318
319/// Create an LLM provider from the given configuration.
320pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider> {
321 let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
322 match config.provider_type {
323 ProviderType::Anthropic => {
324 let key = api_key.unwrap_or_else(|| {
325 tracing::warn!("No ANTHROPIC_API_KEY set");
326 String::new()
327 });
328 Box::new(anthropic::AnthropicProvider::new(
329 key,
330 Some(&config.base_url),
331 ))
332 }
333 ProviderType::Gemini => {
334 let key = api_key.unwrap_or_else(|| {
335 tracing::warn!("No GEMINI_API_KEY set");
336 String::new()
337 });
338 Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
339 }
340 ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
341 _ => Box::new(openai_compat::OpenAiCompatProvider::new(
342 &config.base_url,
343 api_key,
344 )),
345 }
346}