koda_core/providers/mod.rs
1//! LLM provider abstraction layer.
2//!
3//! Defines a common `Provider` trait for all backends and re-exports
4//! the concrete implementations.
5//!
6//! ## Supported providers
7//!
8//! | Provider | Module | API style | Local? |
9//! |---|---|---|---|
10//! | Anthropic Claude | `anthropic` | Native | No |
11//! | Google Gemini | `gemini` | Native | No |
12//! | OpenAI / GPT | `openai_compat` | OpenAI-compat | No |
13//! | LM Studio | `openai_compat` | OpenAI-compat | Yes |
14//! | Ollama | `openai_compat` | OpenAI-compat | Yes |
15//! | Groq | `openai_compat` | OpenAI-compat | No |
16//! | Grok (xAI) | `openai_compat` | OpenAI-compat | No |
17//! | DeepSeek | `openai_compat` | OpenAI-compat | No |
18//! | OpenRouter | `openai_compat` | OpenAI-compat | No |
19//! | Together | `openai_compat` | OpenAI-compat | No |
20//! | Mistral | `openai_compat` | OpenAI-compat | No |
21//! | Cerebras | `openai_compat` | OpenAI-compat | No |
22//! | Fireworks | `openai_compat` | OpenAI-compat | No |
23//! | Custom | `openai_compat` | OpenAI-compat | Varies |
24//!
25//! All OpenAI-compatible providers share the same module with different
26//! base URLs. Use `--base-url` to point at any compatible endpoint.
27//!
28//! ## Design (DESIGN.md)
29//!
30//! - **Any model, any provider (P1)**: No vendor lock-in.
31//! The tool serves the person, not the platform.
32//! - **Context Window Auto-Detection (P1, P3)**: Capabilities are queried
33//! from the provider API at startup. Hardcoded lookup is the fallback,
34//! not the primary source.
35
36/// Anthropic Claude API provider.
37pub mod anthropic;
38/// Google Gemini API provider.
39pub mod gemini;
40/// OpenAI-compatible provider (LM Studio, Ollama, vLLM, OpenRouter, etc.).
41pub mod openai_compat;
42/// Shared SSE stream collector for all providers.
43pub mod stream_collector;
44/// Streaming XML tag filter for think/reasoning tags.
45pub mod stream_tag_filter;
46
47/// Mock provider for deterministic testing.
48#[cfg(any(test, feature = "test-support"))]
49pub mod mock;
50
51use anyhow::Result;
52use async_trait::async_trait;
53use serde::{Deserialize, Serialize};
54
55/// A tool call requested by the LLM.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ToolCall {
58 /// Provider-assigned call ID (echoed back in tool results).
59 pub id: String,
60 /// Name of the tool to invoke.
61 pub function_name: String,
62 /// Raw JSON string of tool arguments.
63 pub arguments: String,
64 /// Gemini-specific: thought signature that must be echoed back in history.
65 #[serde(skip_serializing_if = "Option::is_none", default)]
66 pub thought_signature: Option<String>,
67}
68
69/// Token usage from an LLM response.
70#[derive(Debug, Clone, Default)]
71pub struct TokenUsage {
72 /// Input tokens sent to the model.
73 pub prompt_tokens: i64,
74 /// Output tokens generated by the model.
75 pub completion_tokens: i64,
76 /// Tokens read from provider cache (e.g. Anthropic prompt caching, Gemini cached content).
77 pub cache_read_tokens: i64,
78 /// Tokens written to provider cache on this request.
79 pub cache_creation_tokens: i64,
80 /// Tokens used for reasoning/thinking (e.g. OpenAI reasoning_tokens, Anthropic thinking).
81 pub thinking_tokens: i64,
82 /// Why the model stopped: "end_turn", "max_tokens", "stop_sequence", etc.
83 /// Empty string means unknown (provider didn't report it).
84 pub stop_reason: String,
85}
86
87/// The LLM's response: either text, tool calls, or both.
88#[derive(Debug, Clone)]
89pub struct LlmResponse {
90 /// Text content of the response (may be `None` if tool-calls only).
91 pub content: Option<String>,
92 /// Tool calls requested by the model.
93 pub tool_calls: Vec<ToolCall>,
94 /// Token usage statistics.
95 pub usage: TokenUsage,
96}
97
98/// Base64-encoded image data for multi-modal messages.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ImageData {
101 /// MIME type (e.g. "image/png", "image/jpeg").
102 pub media_type: String,
103 /// Base64-encoded image bytes.
104 pub base64: String,
105}
106
107/// A single message in the conversation.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ChatMessage {
110 /// Message role: `"user"`, `"assistant"`, or `"tool"`.
111 pub role: String,
112 /// Text content (may be `None` for tool-call-only messages).
113 pub content: Option<String>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 /// Tool calls requested by the assistant.
116 pub tool_calls: Option<Vec<ToolCall>>,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 /// ID of the tool call this message responds to.
119 pub tool_call_id: Option<String>,
120 /// Attached images (only used in-flight, not persisted to DB).
121 #[serde(skip_serializing_if = "Option::is_none", default)]
122 pub images: Option<Vec<ImageData>>,
123}
124
125impl ChatMessage {
126 /// Create a simple text message (convenience for the common case).
127 ///
128 /// # Examples
129 ///
130 /// ```
131 /// use koda_core::providers::ChatMessage;
132 ///
133 /// let msg = ChatMessage::text("user", "Hello!");
134 /// assert_eq!(msg.role, "user");
135 /// assert_eq!(msg.content.as_deref(), Some("Hello!"));
136 /// assert!(msg.tool_calls.is_none());
137 /// ```
138 pub fn text(role: &str, content: &str) -> Self {
139 Self {
140 role: role.to_string(),
141 content: Some(content.to_string()),
142 tool_calls: None,
143 tool_call_id: None,
144 images: None,
145 }
146 }
147}
148
149/// Tool definition sent to the LLM.
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ToolDefinition {
152 /// Tool name (e.g. `"Read"`, `"Bash"`).
153 pub name: String,
154 /// Human-readable description for the LLM.
155 pub description: String,
156 /// JSON Schema for the tool's parameters.
157 pub parameters: serde_json::Value,
158}
159
160/// A discovered model from a provider.
161#[derive(Debug, Clone)]
162pub struct ModelInfo {
163 /// Model identifier (e.g. `"claude-3-5-sonnet-20241022"`).
164 pub id: String,
165 /// Provider/organization that owns the model.
166 #[allow(dead_code)]
167 pub owned_by: Option<String>,
168}
169
170/// Model capabilities queried from the provider API.
171#[derive(Debug, Clone, Default)]
172pub struct ModelCapabilities {
173 /// Maximum context window in tokens (input + output).
174 pub context_window: Option<usize>,
175 /// Maximum output tokens the model supports.
176 pub max_output_tokens: Option<usize>,
177}
178
179/// Is this URL pointing to a local address?
180fn is_localhost_url(url: &str) -> bool {
181 let lower = url.to_lowercase();
182 lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
183}
184
185/// Build a reqwest client with proper proxy configuration.
186///
187/// - Reads HTTPS_PROXY / HTTP_PROXY from env
188/// - Supports proxy auth via URL (http://user:pass@proxy:port)
189/// - Supports separate PROXY_USER / PROXY_PASS env vars
190/// - Bypasses proxy for localhost (LM Studio)
191pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
192 let mut builder = reqwest::Client::builder();
193
194 let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
195 .or_else(|| crate::runtime_env::get("HTTP_PROXY"))
196 .or_else(|| crate::runtime_env::get("https_proxy"))
197 .or_else(|| crate::runtime_env::get("http_proxy"));
198
199 if let Some(ref url) = proxy_url
200 && !url.is_empty()
201 {
202 match reqwest::Proxy::all(url) {
203 Ok(mut proxy) => {
204 // Bypass proxy for local addresses
205 proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
206
207 // If URL doesn't contain creds, check env vars
208 if !url.contains('@') {
209 let user = crate::runtime_env::get("PROXY_USER");
210 let pass = crate::runtime_env::get("PROXY_PASS");
211 if let (Some(u), Some(p)) = (user, pass) {
212 proxy = proxy.basic_auth(&u, &p);
213 tracing::debug!("Using proxy with basic auth (credentials redacted)");
214 }
215 }
216
217 builder = builder.proxy(proxy);
218 tracing::debug!("Using proxy: {}", redact_url_credentials(url));
219 }
220 Err(e) => {
221 tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
222 }
223 }
224 }
225
226 // Accept self-signed certs only for localhost (LM Studio, Ollama, vLLM).
227 // The env var is still required, but it's now scoped to local addresses.
228 let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
229 .map(|v| v == "1" || v == "true")
230 .unwrap_or(false);
231 let is_local = base_url.is_some_and(is_localhost_url);
232 if wants_skip_tls && is_local {
233 tracing::info!("TLS certificate validation disabled for local provider.");
234 builder = builder.danger_accept_invalid_certs(true);
235 } else if wants_skip_tls {
236 tracing::warn!(
237 "KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
238 TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
239 );
240 }
241
242 builder.build().unwrap_or_else(|_| reqwest::Client::new())
243}
244
245/// Redact embedded credentials from a URL.
246///
247/// `http://user:pass@proxy:8080` → `http://***:***@proxy:8080`
248fn redact_url_credentials(url: &str) -> String {
249 // Pattern: scheme://user:pass@host...
250 if let Some(at_pos) = url.find('@')
251 && let Some(scheme_end) = url.find("://")
252 {
253 let prefix = &url[..scheme_end + 3]; // "http://"
254 let host_part = &url[at_pos..]; // "@proxy:8080/..."
255 return format!("{prefix}***:***{host_part}");
256 }
257 url.to_string()
258}
259
260/// A streaming chunk from the LLM.
261#[derive(Debug, Clone)]
262pub enum StreamChunk {
263 /// A text delta (partial content).
264 TextDelta(String),
265 /// A thinking/reasoning delta from native API (Anthropic extended thinking, OpenAI reasoning).
266 ThinkingDelta(String),
267 /// A single tool call whose arguments finished streaming.
268 ///
269 /// Emitted by providers that support per-block completion events (Anthropic
270 /// `content_block_stop`). Enables eager execution of read-only tools while
271 /// subsequent tool calls are still being streamed.
272 ///
273 /// Providers that don't support per-block events (OpenAI, Gemini) never
274 /// emit this — they only emit `ToolCalls` at stream end.
275 ToolCallReady(ToolCall),
276 /// All tool calls from the response (batch, emitted at stream end).
277 ///
278 /// For Anthropic, this only contains tool calls NOT already emitted via
279 /// `ToolCallReady`. For other providers, this contains all tool calls.
280 ToolCalls(Vec<ToolCall>),
281 /// Stream finished with usage info.
282 Done(TokenUsage),
283 /// The underlying HTTP connection was dropped before the stream completed.
284 ///
285 /// Distinct from `Done` (clean finish) and user-initiated cancellation
286 /// (Ctrl+C). The partial response MUST be discarded — it is incomplete
287 /// and storing it would corrupt the session history on resume.
288 NetworkError(String),
289}
290
291/// Trait for LLM provider backends.
292#[async_trait]
293pub trait LlmProvider: Send + Sync {
294 /// Send a chat completion request (non-streaming).
295 async fn chat(
296 &self,
297 messages: &[ChatMessage],
298 tools: &[ToolDefinition],
299 settings: &crate::config::ModelSettings,
300 ) -> Result<LlmResponse>;
301
302 /// Send a streaming chat completion request.
303 /// Returns an [`stream_collector::SseCollector`] with the chunk receiver and a task handle
304 /// that can be aborted to immediately kill the HTTP read (#825).
305 async fn chat_stream(
306 &self,
307 messages: &[ChatMessage],
308 tools: &[ToolDefinition],
309 settings: &crate::config::ModelSettings,
310 ) -> Result<stream_collector::SseCollector>;
311
312 /// List available models from the provider.
313 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
314
315 /// Query model capabilities (context window, max output tokens) from the API.
316 ///
317 /// Returns `Ok(caps)` with whatever the provider reports. Fields are `None`
318 /// when the API doesn't expose them. Callers should fall back to the
319 /// hardcoded lookup table for any `None` fields.
320 async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
321 Ok(ModelCapabilities::default())
322 }
323
324 /// Provider display name (for UI).
325 fn provider_name(&self) -> &str;
326}
327
328// ── Provider factory ──────────────────────────────────────────
329
330use crate::config::{KodaConfig, ProviderType};
331
332/// Create an LLM provider from the given configuration.
333pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider> {
334 let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
335 match config.provider_type {
336 ProviderType::Anthropic => {
337 let key = api_key.unwrap_or_else(|| {
338 tracing::warn!("No ANTHROPIC_API_KEY set");
339 String::new()
340 });
341 Box::new(anthropic::AnthropicProvider::new(
342 key,
343 Some(&config.base_url),
344 ))
345 }
346 ProviderType::Gemini => {
347 let key = api_key.unwrap_or_else(|| {
348 tracing::warn!("No GEMINI_API_KEY set");
349 String::new()
350 });
351 Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
352 }
353 #[cfg(any(test, feature = "test-support"))]
354 ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
355 _ => Box::new(openai_compat::OpenAiCompatProvider::new(
356 &config.base_url,
357 api_key,
358 )),
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 // ── is_localhost_url ────────────────────────────────────────────────
367
368 #[test]
369 fn test_is_localhost_url_localhost() {
370 assert!(is_localhost_url("http://localhost:1234/v1"));
371 assert!(is_localhost_url("HTTP://LOCALHOST:11434/api"));
372 }
373
374 #[test]
375 fn test_is_localhost_url_127() {
376 assert!(is_localhost_url("http://127.0.0.1:8000/v1"));
377 }
378
379 #[test]
380 fn test_is_localhost_url_ipv6() {
381 assert!(is_localhost_url("http://[::1]:1234/v1"));
382 }
383
384 #[test]
385 fn test_is_localhost_url_remote() {
386 assert!(!is_localhost_url("https://api.openai.com/v1"));
387 assert!(!is_localhost_url("https://api.anthropic.com/v1"));
388 }
389
390 // ── redact_url_credentials ─────────────────────────────────────────
391
392 #[test]
393 fn test_redact_with_credentials() {
394 let result = redact_url_credentials("http://user:secret@proxy.corp.com:8080");
395 assert!(
396 !result.contains("secret"),
397 "credentials should be redacted: {result}"
398 );
399 assert!(
400 result.contains("***:***"),
401 "should have redacted placeholder: {result}"
402 );
403 assert!(
404 result.contains("proxy.corp.com"),
405 "host should be preserved: {result}"
406 );
407 }
408
409 #[test]
410 fn test_redact_without_credentials() {
411 let url = "https://proxy.corp.com:8080";
412 assert_eq!(redact_url_credentials(url), url);
413 }
414
415 #[test]
416 fn test_redact_empty_url() {
417 assert_eq!(redact_url_credentials(""), "");
418 }
419
420 // ── ChatMessage::text ─────────────────────────────────────────────
421
422 #[test]
423 fn test_chat_message_text_builder() {
424 let msg = ChatMessage::text("user", "hello world");
425 assert_eq!(msg.role, "user");
426 assert_eq!(msg.content.as_deref(), Some("hello world"));
427 assert!(msg.tool_calls.is_none());
428 assert!(msg.tool_call_id.is_none());
429 assert!(msg.images.is_none());
430 }
431
432 #[test]
433 fn test_chat_message_text_assistant() {
434 let msg = ChatMessage::text("assistant", "I can help with that.");
435 assert_eq!(msg.role, "assistant");
436 assert_eq!(msg.content.as_deref(), Some("I can help with that."));
437 }
438
439 // ── TokenUsage defaults ────────────────────────────────────────────
440
441 #[test]
442 fn test_token_usage_default() {
443 let usage = TokenUsage::default();
444 assert_eq!(usage.prompt_tokens, 0);
445 assert_eq!(usage.completion_tokens, 0);
446 assert!(
447 usage.stop_reason.is_empty(),
448 "default stop_reason should be empty"
449 );
450 }
451}