Skip to main content

koda_core/providers/
mod.rs

1//! LLM provider abstraction layer.
2//!
3//! Defines a common `Provider` trait for all backends and re-exports
4//! the concrete implementations.
5//!
6//! ## Supported providers
7//!
8//! | Provider | Module | API style | Local? |
9//! |---|---|---|---|
10//! | Anthropic Claude | `anthropic` | Native | No |
11//! | Google Gemini | `gemini` | Native | No |
12//! | OpenAI / GPT | `openai_compat` | OpenAI-compat | No |
13//! | LM Studio | `openai_compat` | OpenAI-compat | Yes |
14//! | Ollama | `openai_compat` | OpenAI-compat | Yes |
15//! | Groq | `openai_compat` | OpenAI-compat | No |
16//! | Grok (xAI) | `openai_compat` | OpenAI-compat | No |
17//! | DeepSeek | `openai_compat` | OpenAI-compat | No |
18//! | OpenRouter | `openai_compat` | OpenAI-compat | No |
19//! | Together | `openai_compat` | OpenAI-compat | No |
20//! | Mistral | `openai_compat` | OpenAI-compat | No |
21//! | Cerebras | `openai_compat` | OpenAI-compat | No |
22//! | Fireworks | `openai_compat` | OpenAI-compat | No |
23//! | Custom | `openai_compat` | OpenAI-compat | Varies |
24//!
25//! All OpenAI-compatible providers share the same module with different
26//! base URLs. Use `--base-url` to point at any compatible endpoint.
27//!
28//! ## Design (DESIGN.md)
29//!
30//! - **Any model, any provider (P1)**: No vendor lock-in.
31//!   The tool serves the person, not the platform.
32//! - **Context Window Auto-Detection (P1, P3)**: Capabilities are queried
33//!   from the provider API at startup. Hardcoded lookup is the fallback,
34//!   not the primary source.
35
36/// Anthropic Claude API provider.
37pub mod anthropic;
38/// Google Gemini API provider.
39pub mod gemini;
40/// OpenAI-compatible provider (LM Studio, Ollama, vLLM, OpenRouter, etc.).
41pub mod openai_compat;
42/// Shared SSE stream collector for all providers.
43pub mod stream_collector;
44/// Streaming XML tag filter for think/reasoning tags.
45pub mod stream_tag_filter;
46
47/// Mock provider for deterministic testing.
48#[cfg(any(test, feature = "test-support"))]
49pub mod mock;
50
51use anyhow::Result;
52use async_trait::async_trait;
53use serde::{Deserialize, Serialize};
54
55/// A tool call requested by the LLM.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ToolCall {
58    /// Provider-assigned call ID (echoed back in tool results).
59    pub id: String,
60    /// Name of the tool to invoke.
61    pub function_name: String,
62    /// Raw JSON string of tool arguments.
63    pub arguments: String,
64    /// Gemini-specific: thought signature that must be echoed back in history.
65    #[serde(skip_serializing_if = "Option::is_none", default)]
66    pub thought_signature: Option<String>,
67}
68
69/// Token usage from an LLM response.
70#[derive(Debug, Clone, Default)]
71pub struct TokenUsage {
72    /// Input tokens sent to the model.
73    pub prompt_tokens: i64,
74    /// Output tokens generated by the model.
75    pub completion_tokens: i64,
76    /// Tokens read from provider cache (e.g. Anthropic prompt caching, Gemini cached content).
77    pub cache_read_tokens: i64,
78    /// Tokens written to provider cache on this request.
79    pub cache_creation_tokens: i64,
80    /// Tokens used for reasoning/thinking (e.g. OpenAI reasoning_tokens, Anthropic thinking).
81    pub thinking_tokens: i64,
82    /// Why the model stopped: "end_turn", "max_tokens", "stop_sequence", etc.
83    /// Empty string means unknown (provider didn't report it).
84    pub stop_reason: String,
85}
86
87/// The LLM's response: either text, tool calls, or both.
88#[derive(Debug, Clone)]
89pub struct LlmResponse {
90    /// Text content of the response (may be `None` if tool-calls only).
91    pub content: Option<String>,
92    /// Tool calls requested by the model.
93    pub tool_calls: Vec<ToolCall>,
94    /// Token usage statistics.
95    pub usage: TokenUsage,
96}
97
98/// Base64-encoded image data for multi-modal messages.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ImageData {
101    /// MIME type (e.g. "image/png", "image/jpeg").
102    pub media_type: String,
103    /// Base64-encoded image bytes.
104    pub base64: String,
105}
106
107/// A single message in the conversation.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ChatMessage {
110    /// Message role: `"user"`, `"assistant"`, or `"tool"`.
111    pub role: String,
112    /// Text content (may be `None` for tool-call-only messages).
113    pub content: Option<String>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    /// Tool calls requested by the assistant.
116    pub tool_calls: Option<Vec<ToolCall>>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    /// ID of the tool call this message responds to.
119    pub tool_call_id: Option<String>,
120    /// Attached images (only used in-flight, not persisted to DB).
121    #[serde(skip_serializing_if = "Option::is_none", default)]
122    pub images: Option<Vec<ImageData>>,
123}
124
125impl ChatMessage {
126    /// Create a simple text message (convenience for the common case).
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use koda_core::providers::ChatMessage;
132    ///
133    /// let msg = ChatMessage::text("user", "Hello!");
134    /// assert_eq!(msg.role, "user");
135    /// assert_eq!(msg.content.as_deref(), Some("Hello!"));
136    /// assert!(msg.tool_calls.is_none());
137    /// ```
138    pub fn text(role: &str, content: &str) -> Self {
139        Self {
140            role: role.to_string(),
141            content: Some(content.to_string()),
142            tool_calls: None,
143            tool_call_id: None,
144            images: None,
145        }
146    }
147}
148
149/// Tool definition sent to the LLM.
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ToolDefinition {
152    /// Tool name (e.g. `"Read"`, `"Bash"`).
153    pub name: String,
154    /// Human-readable description for the LLM.
155    pub description: String,
156    /// JSON Schema for the tool's parameters.
157    pub parameters: serde_json::Value,
158}
159
160/// A discovered model from a provider.
161#[derive(Debug, Clone)]
162pub struct ModelInfo {
163    /// Model identifier (e.g. `"claude-3-5-sonnet-20241022"`).
164    pub id: String,
165    /// Provider/organization that owns the model.
166    #[allow(dead_code)]
167    pub owned_by: Option<String>,
168}
169
170/// Model capabilities queried from the provider API.
171#[derive(Debug, Clone, Default)]
172pub struct ModelCapabilities {
173    /// Maximum context window in tokens (input + output).
174    pub context_window: Option<usize>,
175    /// Maximum output tokens the model supports.
176    pub max_output_tokens: Option<usize>,
177}
178
179/// Is this URL pointing to a local address?
180fn is_localhost_url(url: &str) -> bool {
181    let lower = url.to_lowercase();
182    lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
183}
184
185/// Build a reqwest client with proper proxy configuration.
186///
187/// - Reads HTTPS_PROXY / HTTP_PROXY from env
188/// - Supports proxy auth via URL (http://user:pass@proxy:port)
189/// - Supports separate PROXY_USER / PROXY_PASS env vars
190/// - Bypasses proxy for localhost (LM Studio)
191/// - Applies a connect timeout (default 30s, env: `KODA_CONNECT_TIMEOUT_SECS`)
192///   and a read timeout (default 180s, env: `KODA_READ_TIMEOUT_SECS`).
193///   We deliberately avoid the total-request `.timeout()` because it would
194///   kill long-running SSE streams during slow tool/agent turns.
195pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
196    let mut builder = reqwest::Client::builder();
197
198    // ── Timeouts ──────────────────────────────────────────────────────────
199    //
200    // connect_timeout: time to establish the TCP+TLS connection. A stuck
201    // SYN or hung TLS handshake aborts after this. Always safe to apply
202    // because it only governs connection setup, not response reading.
203    //
204    // read_timeout: maximum idle time between successive reads from the
205    // socket. An actively-streaming SSE response keeps resetting this on
206    // every chunk, so it doesn't penalize long agent turns; but a server
207    // that goes silent (or a half-open connection a NAT box has dropped)
208    // will fail fast instead of hanging the agent forever.
209    let connect_timeout = crate::runtime_env::get("KODA_CONNECT_TIMEOUT_SECS")
210        .and_then(|v| v.parse::<u64>().ok())
211        .unwrap_or(30);
212    let read_timeout = crate::runtime_env::get("KODA_READ_TIMEOUT_SECS")
213        .and_then(|v| v.parse::<u64>().ok())
214        .unwrap_or(180);
215    builder = builder
216        .connect_timeout(std::time::Duration::from_secs(connect_timeout))
217        .read_timeout(std::time::Duration::from_secs(read_timeout));
218
219    let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
220        .or_else(|| crate::runtime_env::get("HTTP_PROXY"))
221        .or_else(|| crate::runtime_env::get("https_proxy"))
222        .or_else(|| crate::runtime_env::get("http_proxy"));
223
224    if let Some(ref url) = proxy_url
225        && !url.is_empty()
226    {
227        match reqwest::Proxy::all(url) {
228            Ok(mut proxy) => {
229                // Bypass proxy for local addresses
230                proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
231
232                // If URL doesn't contain creds, check env vars
233                if !url.contains('@') {
234                    let user = crate::runtime_env::get("PROXY_USER");
235                    let pass = crate::runtime_env::get("PROXY_PASS");
236                    if let (Some(u), Some(p)) = (user, pass) {
237                        proxy = proxy.basic_auth(&u, &p);
238                        tracing::debug!("Using proxy with basic auth (credentials redacted)");
239                    }
240                }
241
242                builder = builder.proxy(proxy);
243                tracing::debug!("Using proxy: {}", redact_url_credentials(url));
244            }
245            Err(e) => {
246                tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
247            }
248        }
249    }
250
251    // Accept self-signed certs only for localhost (LM Studio, Ollama, vLLM).
252    // The env var is still required, but it's now scoped to local addresses.
253    let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
254        .map(|v| v == "1" || v == "true")
255        .unwrap_or(false);
256    let is_local = base_url.is_some_and(is_localhost_url);
257    if wants_skip_tls && is_local {
258        tracing::info!("TLS certificate validation disabled for local provider.");
259        builder = builder.danger_accept_invalid_certs(true);
260    } else if wants_skip_tls {
261        tracing::warn!(
262            "KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
263             TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
264        );
265    }
266
267    builder.build().unwrap_or_else(|_| reqwest::Client::new())
268}
269
270/// Redact embedded credentials from a URL.
271///
272/// `http://user:pass@proxy:8080` → `http://***:***@proxy:8080`
273fn redact_url_credentials(url: &str) -> String {
274    // Pattern: scheme://user:pass@host...
275    if let Some(at_pos) = url.find('@')
276        && let Some(scheme_end) = url.find("://")
277    {
278        let prefix = &url[..scheme_end + 3]; // "http://"
279        let host_part = &url[at_pos..]; // "@proxy:8080/..."
280        return format!("{prefix}***:***{host_part}");
281    }
282    url.to_string()
283}
284
285/// A streaming chunk from the LLM.
286#[derive(Debug, Clone)]
287pub enum StreamChunk {
288    /// A text delta (partial content).
289    TextDelta(String),
290    /// A thinking/reasoning delta from native API (Anthropic extended thinking, OpenAI reasoning).
291    ThinkingDelta(String),
292    /// A single tool call whose arguments finished streaming.
293    ///
294    /// Emitted by providers that support per-block completion events (Anthropic
295    /// `content_block_stop`). Enables eager execution of read-only tools while
296    /// subsequent tool calls are still being streamed.
297    ///
298    /// Providers that don't support per-block events (OpenAI, Gemini) never
299    /// emit this — they only emit `ToolCalls` at stream end.
300    ToolCallReady(ToolCall),
301    /// All tool calls from the response (batch, emitted at stream end).
302    ///
303    /// For Anthropic, this only contains tool calls NOT already emitted via
304    /// `ToolCallReady`. For other providers, this contains all tool calls.
305    ToolCalls(Vec<ToolCall>),
306    /// Stream finished with usage info.
307    Done(TokenUsage),
308    /// The underlying HTTP connection was dropped before the stream completed.
309    ///
310    /// Distinct from `Done` (clean finish) and user-initiated cancellation
311    /// (Ctrl+C). The partial response MUST be discarded — it is incomplete
312    /// and storing it would corrupt the session history on resume.
313    NetworkError(String),
314}
315
316/// Trait for LLM provider backends.
317#[async_trait]
318pub trait LlmProvider: Send + Sync {
319    /// Send a chat completion request (non-streaming).
320    async fn chat(
321        &self,
322        messages: &[ChatMessage],
323        tools: &[ToolDefinition],
324        settings: &crate::config::ModelSettings,
325    ) -> Result<LlmResponse>;
326
327    /// Send a streaming chat completion request.
328    /// Returns an [`stream_collector::SseCollector`] with the chunk receiver and a task handle
329    /// that can be aborted to immediately kill the HTTP read (#825).
330    async fn chat_stream(
331        &self,
332        messages: &[ChatMessage],
333        tools: &[ToolDefinition],
334        settings: &crate::config::ModelSettings,
335    ) -> Result<stream_collector::SseCollector>;
336
337    /// List available models from the provider.
338    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
339
340    /// Query model capabilities (context window, max output tokens) from the API.
341    ///
342    /// Returns `Ok(caps)` with whatever the provider reports. Fields are `None`
343    /// when the API doesn't expose them. Callers should fall back to the
344    /// hardcoded lookup table for any `None` fields.
345    async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
346        Ok(ModelCapabilities::default())
347    }
348
349    /// Provider display name (for UI).
350    fn provider_name(&self) -> &str;
351}
352
353// ── Provider factory ──────────────────────────────────────────
354
355use crate::config::{KodaConfig, ProviderType};
356
357/// Create an LLM provider from the given configuration.
358pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider + Send + Sync> {
359    let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
360    match config.provider_type {
361        ProviderType::Anthropic => {
362            let key = api_key.unwrap_or_else(|| {
363                tracing::warn!("No ANTHROPIC_API_KEY set");
364                String::new()
365            });
366            Box::new(anthropic::AnthropicProvider::new(
367                key,
368                Some(&config.base_url),
369            ))
370        }
371        ProviderType::Gemini => {
372            let key = api_key.unwrap_or_else(|| {
373                tracing::warn!("No GEMINI_API_KEY set");
374                String::new()
375            });
376            Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
377        }
378        #[cfg(any(test, feature = "test-support"))]
379        ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
380        _ => Box::new(openai_compat::OpenAiCompatProvider::new(
381            &config.base_url,
382            api_key,
383        )),
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    // ── is_localhost_url ────────────────────────────────────────────────
392
393    #[test]
394    fn test_is_localhost_url_localhost() {
395        assert!(is_localhost_url("http://localhost:1234/v1"));
396        assert!(is_localhost_url("HTTP://LOCALHOST:11434/api"));
397    }
398
399    #[test]
400    fn test_is_localhost_url_127() {
401        assert!(is_localhost_url("http://127.0.0.1:8000/v1"));
402    }
403
404    #[test]
405    fn test_is_localhost_url_ipv6() {
406        assert!(is_localhost_url("http://[::1]:1234/v1"));
407    }
408
409    #[test]
410    fn test_is_localhost_url_remote() {
411        assert!(!is_localhost_url("https://api.openai.com/v1"));
412        assert!(!is_localhost_url("https://api.anthropic.com/v1"));
413    }
414
415    // ── redact_url_credentials ─────────────────────────────────────────
416
417    #[test]
418    fn test_redact_with_credentials() {
419        let result = redact_url_credentials("http://user:secret@proxy.corp.com:8080");
420        assert!(
421            !result.contains("secret"),
422            "credentials should be redacted: {result}"
423        );
424        assert!(
425            result.contains("***:***"),
426            "should have redacted placeholder: {result}"
427        );
428        assert!(
429            result.contains("proxy.corp.com"),
430            "host should be preserved: {result}"
431        );
432    }
433
434    #[test]
435    fn test_redact_without_credentials() {
436        let url = "https://proxy.corp.com:8080";
437        assert_eq!(redact_url_credentials(url), url);
438    }
439
440    #[test]
441    fn test_redact_empty_url() {
442        assert_eq!(redact_url_credentials(""), "");
443    }
444
445    // ── ChatMessage::text ─────────────────────────────────────────────
446
447    #[test]
448    fn test_chat_message_text_builder() {
449        let msg = ChatMessage::text("user", "hello world");
450        assert_eq!(msg.role, "user");
451        assert_eq!(msg.content.as_deref(), Some("hello world"));
452        assert!(msg.tool_calls.is_none());
453        assert!(msg.tool_call_id.is_none());
454        assert!(msg.images.is_none());
455    }
456
457    #[test]
458    fn test_chat_message_text_assistant() {
459        let msg = ChatMessage::text("assistant", "I can help with that.");
460        assert_eq!(msg.role, "assistant");
461        assert_eq!(msg.content.as_deref(), Some("I can help with that."));
462    }
463
464    // ── TokenUsage defaults ────────────────────────────────────────────
465
466    #[test]
467    fn test_token_usage_default() {
468        let usage = TokenUsage::default();
469        assert_eq!(usage.prompt_tokens, 0);
470        assert_eq!(usage.completion_tokens, 0);
471        assert!(
472            usage.stop_reason.is_empty(),
473            "default stop_reason should be empty"
474        );
475    }
476}