Skip to main content

koda_core/providers/
mod.rs

1//! LLM provider abstraction layer.
2//!
3//! Defines a common `Provider` trait for all backends and re-exports
4//! the concrete implementations.
5//!
6//! ## Supported providers
7//!
8//! | Provider | Module | API style | Local? |
9//! |---|---|---|---|
10//! | Anthropic Claude | `anthropic` | Native | No |
11//! | Google Gemini | `gemini` | Native | No |
12//! | OpenAI / GPT | `openai_compat` | OpenAI-compat | No |
13//! | LM Studio | `openai_compat` | OpenAI-compat | Yes |
14//! | Ollama | `openai_compat` | OpenAI-compat | Yes |
15//! | Groq | `openai_compat` | OpenAI-compat | No |
16//! | Grok (xAI) | `openai_compat` | OpenAI-compat | No |
17//! | DeepSeek | `openai_compat` | OpenAI-compat | No |
18//! | OpenRouter | `openai_compat` | OpenAI-compat | No |
19//! | Together | `openai_compat` | OpenAI-compat | No |
20//! | Mistral | `openai_compat` | OpenAI-compat | No |
21//! | Cerebras | `openai_compat` | OpenAI-compat | No |
22//! | Fireworks | `openai_compat` | OpenAI-compat | No |
23//! | Custom | `openai_compat` | OpenAI-compat | Varies |
24//!
25//! All OpenAI-compatible providers share the same module with different
26//! base URLs. Use `--base-url` to point at any compatible endpoint.
27//!
28//! ## Design (DESIGN.md)
29//!
30//! - **Any model, any provider (P1)**: No vendor lock-in.
31//!   The tool serves the person, not the platform.
32//! - **Context Window Auto-Detection (P1, P3)**: Capabilities are queried
33//!   from the provider API at startup. Hardcoded lookup is the fallback,
34//!   not the primary source.
35
36/// Anthropic Claude API provider.
37pub mod anthropic;
38/// Google Gemini API provider.
39pub mod gemini;
40/// OpenAI-compatible provider (LM Studio, Ollama, vLLM, OpenRouter, etc.).
41pub mod openai_compat;
42/// Shared SSE stream collector for all providers.
43pub mod stream_collector;
44/// Streaming XML tag filter for think/reasoning tags.
45pub mod stream_tag_filter;
46
47/// Mock provider for deterministic testing.
48pub mod mock;
49
50use anyhow::Result;
51use async_trait::async_trait;
52use serde::{Deserialize, Serialize};
53
54/// A tool call requested by the LLM.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ToolCall {
57    /// Provider-assigned call ID (echoed back in tool results).
58    pub id: String,
59    /// Name of the tool to invoke.
60    pub function_name: String,
61    /// Raw JSON string of tool arguments.
62    pub arguments: String,
63    /// Gemini-specific: thought signature that must be echoed back in history.
64    #[serde(skip_serializing_if = "Option::is_none", default)]
65    pub thought_signature: Option<String>,
66}
67
68/// Token usage from an LLM response.
69#[derive(Debug, Clone, Default)]
70pub struct TokenUsage {
71    /// Input tokens sent to the model.
72    pub prompt_tokens: i64,
73    /// Output tokens generated by the model.
74    pub completion_tokens: i64,
75    /// Tokens read from provider cache (e.g. Anthropic prompt caching, Gemini cached content).
76    pub cache_read_tokens: i64,
77    /// Tokens written to provider cache on this request.
78    pub cache_creation_tokens: i64,
79    /// Tokens used for reasoning/thinking (e.g. OpenAI reasoning_tokens, Anthropic thinking).
80    pub thinking_tokens: i64,
81    /// Why the model stopped: "end_turn", "max_tokens", "stop_sequence", etc.
82    /// Empty string means unknown (provider didn't report it).
83    pub stop_reason: String,
84}
85
86/// The LLM's response: either text, tool calls, or both.
87#[derive(Debug, Clone)]
88pub struct LlmResponse {
89    /// Text content of the response (may be `None` if tool-calls only).
90    pub content: Option<String>,
91    /// Tool calls requested by the model.
92    pub tool_calls: Vec<ToolCall>,
93    /// Token usage statistics.
94    pub usage: TokenUsage,
95}
96
97/// Base64-encoded image data for multi-modal messages.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ImageData {
100    /// MIME type (e.g. "image/png", "image/jpeg").
101    pub media_type: String,
102    /// Base64-encoded image bytes.
103    pub base64: String,
104}
105
106/// A single message in the conversation.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChatMessage {
109    /// Message role: `"user"`, `"assistant"`, or `"tool"`.
110    pub role: String,
111    /// Text content (may be `None` for tool-call-only messages).
112    pub content: Option<String>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    /// Tool calls requested by the assistant.
115    pub tool_calls: Option<Vec<ToolCall>>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    /// ID of the tool call this message responds to.
118    pub tool_call_id: Option<String>,
119    /// Attached images (only used in-flight, not persisted to DB).
120    #[serde(skip_serializing_if = "Option::is_none", default)]
121    pub images: Option<Vec<ImageData>>,
122}
123
124impl ChatMessage {
125    /// Create a simple text message (convenience for the common case).
126    pub fn text(role: &str, content: &str) -> Self {
127        Self {
128            role: role.to_string(),
129            content: Some(content.to_string()),
130            tool_calls: None,
131            tool_call_id: None,
132            images: None,
133        }
134    }
135}
136
137/// Tool definition sent to the LLM.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ToolDefinition {
140    /// Tool name (e.g. `"Read"`, `"Bash"`).
141    pub name: String,
142    /// Human-readable description for the LLM.
143    pub description: String,
144    /// JSON Schema for the tool's parameters.
145    pub parameters: serde_json::Value,
146}
147
148/// A discovered model from a provider.
149#[derive(Debug, Clone)]
150pub struct ModelInfo {
151    /// Model identifier (e.g. `"claude-3-5-sonnet-20241022"`).
152    pub id: String,
153    /// Provider/organization that owns the model.
154    #[allow(dead_code)]
155    pub owned_by: Option<String>,
156}
157
158/// Model capabilities queried from the provider API.
159#[derive(Debug, Clone, Default)]
160pub struct ModelCapabilities {
161    /// Maximum context window in tokens (input + output).
162    pub context_window: Option<usize>,
163    /// Maximum output tokens the model supports.
164    pub max_output_tokens: Option<usize>,
165}
166
167/// Is this URL pointing to a local address?
168fn is_localhost_url(url: &str) -> bool {
169    let lower = url.to_lowercase();
170    lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
171}
172
173/// Build a reqwest client with proper proxy configuration.
174///
175/// - Reads HTTPS_PROXY / HTTP_PROXY from env
176/// - Supports proxy auth via URL (http://user:pass@proxy:port)
177/// - Supports separate PROXY_USER / PROXY_PASS env vars
178/// - Bypasses proxy for localhost (LM Studio)
179pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
180    let mut builder = reqwest::Client::builder();
181
182    let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
183        .or_else(|| crate::runtime_env::get("HTTP_PROXY"))
184        .or_else(|| crate::runtime_env::get("https_proxy"))
185        .or_else(|| crate::runtime_env::get("http_proxy"));
186
187    if let Some(ref url) = proxy_url
188        && !url.is_empty()
189    {
190        match reqwest::Proxy::all(url) {
191            Ok(mut proxy) => {
192                // Bypass proxy for local addresses
193                proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
194
195                // If URL doesn't contain creds, check env vars
196                if !url.contains('@') {
197                    let user = crate::runtime_env::get("PROXY_USER");
198                    let pass = crate::runtime_env::get("PROXY_PASS");
199                    if let (Some(u), Some(p)) = (user, pass) {
200                        proxy = proxy.basic_auth(&u, &p);
201                        tracing::debug!("Using proxy with basic auth (credentials redacted)");
202                    }
203                }
204
205                builder = builder.proxy(proxy);
206                tracing::debug!("Using proxy: {}", redact_url_credentials(url));
207            }
208            Err(e) => {
209                tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
210            }
211        }
212    }
213
214    // Accept self-signed certs only for localhost (LM Studio, Ollama, vLLM).
215    // The env var is still required, but it's now scoped to local addresses.
216    let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
217        .map(|v| v == "1" || v == "true")
218        .unwrap_or(false);
219    let is_local = base_url.is_some_and(is_localhost_url);
220    if wants_skip_tls && is_local {
221        tracing::info!("TLS certificate validation disabled for local provider.");
222        builder = builder.danger_accept_invalid_certs(true);
223    } else if wants_skip_tls {
224        tracing::warn!(
225            "KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
226             TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
227        );
228    }
229
230    builder.build().unwrap_or_else(|_| reqwest::Client::new())
231}
232
233/// Redact embedded credentials from a URL.
234///
235/// `http://user:pass@proxy:8080` → `http://***:***@proxy:8080`
236fn redact_url_credentials(url: &str) -> String {
237    // Pattern: scheme://user:pass@host...
238    if let Some(at_pos) = url.find('@')
239        && let Some(scheme_end) = url.find("://")
240    {
241        let prefix = &url[..scheme_end + 3]; // "http://"
242        let host_part = &url[at_pos..]; // "@proxy:8080/..."
243        return format!("{prefix}***:***{host_part}");
244    }
245    url.to_string()
246}
247
248/// A streaming chunk from the LLM.
249#[derive(Debug, Clone)]
250pub enum StreamChunk {
251    /// A text delta (partial content).
252    TextDelta(String),
253    /// A thinking/reasoning delta from native API (Anthropic extended thinking, OpenAI reasoning).
254    ThinkingDelta(String),
255    /// A single tool call whose arguments finished streaming.
256    ///
257    /// Emitted by providers that support per-block completion events (Anthropic
258    /// `content_block_stop`). Enables eager execution of read-only tools while
259    /// subsequent tool calls are still being streamed.
260    ///
261    /// Providers that don't support per-block events (OpenAI, Gemini) never
262    /// emit this — they only emit `ToolCalls` at stream end.
263    ToolCallReady(ToolCall),
264    /// All tool calls from the response (batch, emitted at stream end).
265    ///
266    /// For Anthropic, this only contains tool calls NOT already emitted via
267    /// `ToolCallReady`. For other providers, this contains all tool calls.
268    ToolCalls(Vec<ToolCall>),
269    /// Stream finished with usage info.
270    Done(TokenUsage),
271    /// The underlying HTTP connection was dropped before the stream completed.
272    ///
273    /// Distinct from `Done` (clean finish) and user-initiated cancellation
274    /// (Ctrl+C). The partial response MUST be discarded — it is incomplete
275    /// and storing it would corrupt the session history on resume.
276    NetworkError(String),
277}
278
279/// Trait for LLM provider backends.
280#[async_trait]
281pub trait LlmProvider: Send + Sync {
282    /// Send a chat completion request (non-streaming).
283    async fn chat(
284        &self,
285        messages: &[ChatMessage],
286        tools: &[ToolDefinition],
287        settings: &crate::config::ModelSettings,
288    ) -> Result<LlmResponse>;
289
290    /// Send a streaming chat completion request.
291    /// Returns a channel receiver that yields chunks as they arrive.
292    async fn chat_stream(
293        &self,
294        messages: &[ChatMessage],
295        tools: &[ToolDefinition],
296        settings: &crate::config::ModelSettings,
297    ) -> Result<tokio::sync::mpsc::Receiver<StreamChunk>>;
298
299    /// List available models from the provider.
300    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
301
302    /// Query model capabilities (context window, max output tokens) from the API.
303    ///
304    /// Returns `Ok(caps)` with whatever the provider reports. Fields are `None`
305    /// when the API doesn't expose them. Callers should fall back to the
306    /// hardcoded lookup table for any `None` fields.
307    async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
308        Ok(ModelCapabilities::default())
309    }
310
311    /// Provider display name (for UI).
312    fn provider_name(&self) -> &str;
313}
314
315// ── Provider factory ──────────────────────────────────────────
316
317use crate::config::{KodaConfig, ProviderType};
318
319/// Create an LLM provider from the given configuration.
320pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider> {
321    let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
322    match config.provider_type {
323        ProviderType::Anthropic => {
324            let key = api_key.unwrap_or_else(|| {
325                tracing::warn!("No ANTHROPIC_API_KEY set");
326                String::new()
327            });
328            Box::new(anthropic::AnthropicProvider::new(
329                key,
330                Some(&config.base_url),
331            ))
332        }
333        ProviderType::Gemini => {
334            let key = api_key.unwrap_or_else(|| {
335                tracing::warn!("No GEMINI_API_KEY set");
336                String::new()
337            });
338            Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
339        }
340        ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
341        _ => Box::new(openai_compat::OpenAiCompatProvider::new(
342            &config.base_url,
343            api_key,
344        )),
345    }
346}