Skip to main content

koda_core/providers/
mod.rs

1//! LLM provider abstraction layer.
2//!
3//! Defines a common `Provider` trait for all backends and re-exports
4//! the concrete implementations.
5//!
6//! ## Supported providers
7//!
8//! | Provider | Module | API style | Local? |
9//! |---|---|---|---|
10//! | Anthropic Claude | `anthropic` | Native | No |
11//! | Google Gemini | `gemini` | Native | No |
12//! | OpenAI / GPT | `openai_compat` | OpenAI-compat | No |
13//! | LM Studio | `openai_compat` | OpenAI-compat | Yes |
14//! | Ollama | `openai_compat` | OpenAI-compat | Yes |
15//! | Groq | `openai_compat` | OpenAI-compat | No |
16//! | Grok (xAI) | `openai_compat` | OpenAI-compat | No |
17//! | DeepSeek | `openai_compat` | OpenAI-compat | No |
18//! | OpenRouter | `openai_compat` | OpenAI-compat | No |
19//! | Together | `openai_compat` | OpenAI-compat | No |
20//! | Mistral | `openai_compat` | OpenAI-compat | No |
21//! | Cerebras | `openai_compat` | OpenAI-compat | No |
22//! | Fireworks | `openai_compat` | OpenAI-compat | No |
23//! | Custom | `openai_compat` | OpenAI-compat | Varies |
24//!
25//! All OpenAI-compatible providers share the same module with different
26//! base URLs. Use `--base-url` to point at any compatible endpoint.
27//!
28//! ## Design (DESIGN.md)
29//!
30//! - **Any model, any provider (P1)**: No vendor lock-in.
31//!   The tool serves the person, not the platform.
32//! - **Context Window Auto-Detection (P1, P3)**: Capabilities are queried
33//!   from the provider API at startup. Hardcoded lookup is the fallback,
34//!   not the primary source.
35
36/// Anthropic Claude API provider.
37pub mod anthropic;
38/// Google Gemini API provider.
39pub mod gemini;
40/// OpenAI-compatible provider (LM Studio, Ollama, vLLM, OpenRouter, etc.).
41pub mod openai_compat;
42/// Shared SSE stream collector for all providers.
43pub mod stream_collector;
44/// Streaming XML tag filter for think/reasoning tags.
45pub mod stream_tag_filter;
46
47/// Mock provider for deterministic testing.
48#[cfg(any(test, feature = "test-support"))]
49pub mod mock;
50
51use anyhow::Result;
52use async_trait::async_trait;
53use serde::{Deserialize, Serialize};
54
55/// A tool call requested by the LLM.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ToolCall {
58    /// Provider-assigned call ID (echoed back in tool results).
59    pub id: String,
60    /// Name of the tool to invoke.
61    pub function_name: String,
62    /// Raw JSON string of tool arguments.
63    pub arguments: String,
64    /// Gemini-specific: thought signature that must be echoed back in history.
65    #[serde(skip_serializing_if = "Option::is_none", default)]
66    pub thought_signature: Option<String>,
67}
68
69/// Token usage from an LLM response.
70#[derive(Debug, Clone, Default)]
71pub struct TokenUsage {
72    /// Input tokens sent to the model.
73    pub prompt_tokens: i64,
74    /// Output tokens generated by the model.
75    pub completion_tokens: i64,
76    /// Tokens read from provider cache (e.g. Anthropic prompt caching, Gemini cached content).
77    pub cache_read_tokens: i64,
78    /// Tokens written to provider cache on this request.
79    pub cache_creation_tokens: i64,
80    /// Tokens used for reasoning/thinking (e.g. OpenAI reasoning_tokens, Anthropic thinking).
81    pub thinking_tokens: i64,
82    /// Why the model stopped: "end_turn", "max_tokens", "stop_sequence", etc.
83    /// Empty string means unknown (provider didn't report it).
84    pub stop_reason: String,
85}
86
87/// The LLM's response: either text, tool calls, or both.
88#[derive(Debug, Clone)]
89pub struct LlmResponse {
90    /// Text content of the response (may be `None` if tool-calls only).
91    pub content: Option<String>,
92    /// Tool calls requested by the model.
93    pub tool_calls: Vec<ToolCall>,
94    /// Token usage statistics.
95    pub usage: TokenUsage,
96}
97
98/// Base64-encoded image data for multi-modal messages.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ImageData {
101    /// MIME type (e.g. "image/png", "image/jpeg").
102    pub media_type: String,
103    /// Base64-encoded image bytes.
104    pub base64: String,
105}
106
107/// A single message in the conversation.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ChatMessage {
110    /// Message role: `"user"`, `"assistant"`, or `"tool"`.
111    pub role: String,
112    /// Text content (may be `None` for tool-call-only messages).
113    pub content: Option<String>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    /// Tool calls requested by the assistant.
116    pub tool_calls: Option<Vec<ToolCall>>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    /// ID of the tool call this message responds to.
119    pub tool_call_id: Option<String>,
120    /// Attached images (only used in-flight, not persisted to DB).
121    #[serde(skip_serializing_if = "Option::is_none", default)]
122    pub images: Option<Vec<ImageData>>,
123}
124
125impl ChatMessage {
126    /// Create a simple text message (convenience for the common case).
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use koda_core::providers::ChatMessage;
132    ///
133    /// let msg = ChatMessage::text("user", "Hello!");
134    /// assert_eq!(msg.role, "user");
135    /// assert_eq!(msg.content.as_deref(), Some("Hello!"));
136    /// assert!(msg.tool_calls.is_none());
137    /// ```
138    pub fn text(role: &str, content: &str) -> Self {
139        Self {
140            role: role.to_string(),
141            content: Some(content.to_string()),
142            tool_calls: None,
143            tool_call_id: None,
144            images: None,
145        }
146    }
147}
148
149/// Tool definition sent to the LLM.
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ToolDefinition {
152    /// Tool name (e.g. `"Read"`, `"Bash"`).
153    pub name: String,
154    /// Human-readable description for the LLM.
155    pub description: String,
156    /// JSON Schema for the tool's parameters.
157    pub parameters: serde_json::Value,
158}
159
160/// A discovered model from a provider.
161#[derive(Debug, Clone)]
162pub struct ModelInfo {
163    /// Model identifier (e.g. `"claude-3-5-sonnet-20241022"`).
164    pub id: String,
165    /// Provider/organization that owns the model.
166    #[allow(dead_code)]
167    pub owned_by: Option<String>,
168}
169
170/// Model capabilities queried from the provider API.
171#[derive(Debug, Clone, Default)]
172pub struct ModelCapabilities {
173    /// Maximum context window in tokens (input + output).
174    pub context_window: Option<usize>,
175    /// Maximum output tokens the model supports.
176    pub max_output_tokens: Option<usize>,
177}
178
179/// Is this URL pointing to a local address?
180fn is_localhost_url(url: &str) -> bool {
181    let lower = url.to_lowercase();
182    lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
183}
184
185/// Build a reqwest client with proper proxy configuration.
186///
187/// - Reads HTTPS_PROXY / HTTP_PROXY from env
188/// - Supports proxy auth via URL (http://user:pass@proxy:port)
189/// - Supports separate PROXY_USER / PROXY_PASS env vars
190/// - Bypasses proxy for localhost (LM Studio)
191pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
192    let mut builder = reqwest::Client::builder();
193
194    let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
195        .or_else(|| crate::runtime_env::get("HTTP_PROXY"))
196        .or_else(|| crate::runtime_env::get("https_proxy"))
197        .or_else(|| crate::runtime_env::get("http_proxy"));
198
199    if let Some(ref url) = proxy_url
200        && !url.is_empty()
201    {
202        match reqwest::Proxy::all(url) {
203            Ok(mut proxy) => {
204                // Bypass proxy for local addresses
205                proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
206
207                // If URL doesn't contain creds, check env vars
208                if !url.contains('@') {
209                    let user = crate::runtime_env::get("PROXY_USER");
210                    let pass = crate::runtime_env::get("PROXY_PASS");
211                    if let (Some(u), Some(p)) = (user, pass) {
212                        proxy = proxy.basic_auth(&u, &p);
213                        tracing::debug!("Using proxy with basic auth (credentials redacted)");
214                    }
215                }
216
217                builder = builder.proxy(proxy);
218                tracing::debug!("Using proxy: {}", redact_url_credentials(url));
219            }
220            Err(e) => {
221                tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
222            }
223        }
224    }
225
226    // Accept self-signed certs only for localhost (LM Studio, Ollama, vLLM).
227    // The env var is still required, but it's now scoped to local addresses.
228    let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
229        .map(|v| v == "1" || v == "true")
230        .unwrap_or(false);
231    let is_local = base_url.is_some_and(is_localhost_url);
232    if wants_skip_tls && is_local {
233        tracing::info!("TLS certificate validation disabled for local provider.");
234        builder = builder.danger_accept_invalid_certs(true);
235    } else if wants_skip_tls {
236        tracing::warn!(
237            "KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
238             TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
239        );
240    }
241
242    builder.build().unwrap_or_else(|_| reqwest::Client::new())
243}
244
245/// Redact embedded credentials from a URL.
246///
247/// `http://user:pass@proxy:8080` → `http://***:***@proxy:8080`
248fn redact_url_credentials(url: &str) -> String {
249    // Pattern: scheme://user:pass@host...
250    if let Some(at_pos) = url.find('@')
251        && let Some(scheme_end) = url.find("://")
252    {
253        let prefix = &url[..scheme_end + 3]; // "http://"
254        let host_part = &url[at_pos..]; // "@proxy:8080/..."
255        return format!("{prefix}***:***{host_part}");
256    }
257    url.to_string()
258}
259
260/// A streaming chunk from the LLM.
261#[derive(Debug, Clone)]
262pub enum StreamChunk {
263    /// A text delta (partial content).
264    TextDelta(String),
265    /// A thinking/reasoning delta from native API (Anthropic extended thinking, OpenAI reasoning).
266    ThinkingDelta(String),
267    /// A single tool call whose arguments finished streaming.
268    ///
269    /// Emitted by providers that support per-block completion events (Anthropic
270    /// `content_block_stop`). Enables eager execution of read-only tools while
271    /// subsequent tool calls are still being streamed.
272    ///
273    /// Providers that don't support per-block events (OpenAI, Gemini) never
274    /// emit this — they only emit `ToolCalls` at stream end.
275    ToolCallReady(ToolCall),
276    /// All tool calls from the response (batch, emitted at stream end).
277    ///
278    /// For Anthropic, this only contains tool calls NOT already emitted via
279    /// `ToolCallReady`. For other providers, this contains all tool calls.
280    ToolCalls(Vec<ToolCall>),
281    /// Stream finished with usage info.
282    Done(TokenUsage),
283    /// The underlying HTTP connection was dropped before the stream completed.
284    ///
285    /// Distinct from `Done` (clean finish) and user-initiated cancellation
286    /// (Ctrl+C). The partial response MUST be discarded — it is incomplete
287    /// and storing it would corrupt the session history on resume.
288    NetworkError(String),
289}
290
291/// Trait for LLM provider backends.
292#[async_trait]
293pub trait LlmProvider: Send + Sync {
294    /// Send a chat completion request (non-streaming).
295    async fn chat(
296        &self,
297        messages: &[ChatMessage],
298        tools: &[ToolDefinition],
299        settings: &crate::config::ModelSettings,
300    ) -> Result<LlmResponse>;
301
302    /// Send a streaming chat completion request.
303    /// Returns an [`stream_collector::SseCollector`] with the chunk receiver and a task handle
304    /// that can be aborted to immediately kill the HTTP read (#825).
305    async fn chat_stream(
306        &self,
307        messages: &[ChatMessage],
308        tools: &[ToolDefinition],
309        settings: &crate::config::ModelSettings,
310    ) -> Result<stream_collector::SseCollector>;
311
312    /// List available models from the provider.
313    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
314
315    /// Query model capabilities (context window, max output tokens) from the API.
316    ///
317    /// Returns `Ok(caps)` with whatever the provider reports. Fields are `None`
318    /// when the API doesn't expose them. Callers should fall back to the
319    /// hardcoded lookup table for any `None` fields.
320    async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
321        Ok(ModelCapabilities::default())
322    }
323
324    /// Provider display name (for UI).
325    fn provider_name(&self) -> &str;
326}
327
328// ── Provider factory ──────────────────────────────────────────
329
330use crate::config::{KodaConfig, ProviderType};
331
332/// Create an LLM provider from the given configuration.
333pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider> {
334    let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
335    match config.provider_type {
336        ProviderType::Anthropic => {
337            let key = api_key.unwrap_or_else(|| {
338                tracing::warn!("No ANTHROPIC_API_KEY set");
339                String::new()
340            });
341            Box::new(anthropic::AnthropicProvider::new(
342                key,
343                Some(&config.base_url),
344            ))
345        }
346        ProviderType::Gemini => {
347            let key = api_key.unwrap_or_else(|| {
348                tracing::warn!("No GEMINI_API_KEY set");
349                String::new()
350            });
351            Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
352        }
353        #[cfg(any(test, feature = "test-support"))]
354        ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
355        _ => Box::new(openai_compat::OpenAiCompatProvider::new(
356            &config.base_url,
357            api_key,
358        )),
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    // ── is_localhost_url ────────────────────────────────────────────────
367
368    #[test]
369    fn test_is_localhost_url_localhost() {
370        assert!(is_localhost_url("http://localhost:1234/v1"));
371        assert!(is_localhost_url("HTTP://LOCALHOST:11434/api"));
372    }
373
374    #[test]
375    fn test_is_localhost_url_127() {
376        assert!(is_localhost_url("http://127.0.0.1:8000/v1"));
377    }
378
379    #[test]
380    fn test_is_localhost_url_ipv6() {
381        assert!(is_localhost_url("http://[::1]:1234/v1"));
382    }
383
384    #[test]
385    fn test_is_localhost_url_remote() {
386        assert!(!is_localhost_url("https://api.openai.com/v1"));
387        assert!(!is_localhost_url("https://api.anthropic.com/v1"));
388    }
389
390    // ── redact_url_credentials ─────────────────────────────────────────
391
392    #[test]
393    fn test_redact_with_credentials() {
394        let result = redact_url_credentials("http://user:secret@proxy.corp.com:8080");
395        assert!(
396            !result.contains("secret"),
397            "credentials should be redacted: {result}"
398        );
399        assert!(
400            result.contains("***:***"),
401            "should have redacted placeholder: {result}"
402        );
403        assert!(
404            result.contains("proxy.corp.com"),
405            "host should be preserved: {result}"
406        );
407    }
408
409    #[test]
410    fn test_redact_without_credentials() {
411        let url = "https://proxy.corp.com:8080";
412        assert_eq!(redact_url_credentials(url), url);
413    }
414
415    #[test]
416    fn test_redact_empty_url() {
417        assert_eq!(redact_url_credentials(""), "");
418    }
419
420    // ── ChatMessage::text ─────────────────────────────────────────────
421
422    #[test]
423    fn test_chat_message_text_builder() {
424        let msg = ChatMessage::text("user", "hello world");
425        assert_eq!(msg.role, "user");
426        assert_eq!(msg.content.as_deref(), Some("hello world"));
427        assert!(msg.tool_calls.is_none());
428        assert!(msg.tool_call_id.is_none());
429        assert!(msg.images.is_none());
430    }
431
432    #[test]
433    fn test_chat_message_text_assistant() {
434        let msg = ChatMessage::text("assistant", "I can help with that.");
435        assert_eq!(msg.role, "assistant");
436        assert_eq!(msg.content.as_deref(), Some("I can help with that."));
437    }
438
439    // ── TokenUsage defaults ────────────────────────────────────────────
440
441    #[test]
442    fn test_token_usage_default() {
443        let usage = TokenUsage::default();
444        assert_eq!(usage.prompt_tokens, 0);
445        assert_eq!(usage.completion_tokens, 0);
446        assert!(
447            usage.stop_reason.is_empty(),
448            "default stop_reason should be empty"
449        );
450    }
451}