Skip to main content

koda_core/providers/
mod.rs

1//! LLM provider abstraction layer.
2//!
3//! Defines a common trait for all providers and re-exports the default.
4
5/// Anthropic Claude API provider.
6pub mod anthropic;
7/// Google Gemini API provider.
8pub mod gemini;
9/// OpenAI-compatible provider (LM Studio, Ollama, vLLM, OpenRouter, etc.).
10pub mod openai_compat;
11/// Streaming XML tag filter for think/reasoning tags.
12pub mod stream_tag_filter;
13/// Deprecated: use `stream_tag_filter` instead.
14pub mod think_tag_filter;
15
16/// Mock provider for deterministic testing.
17pub mod mock;
18
19use anyhow::Result;
20use async_trait::async_trait;
21use serde::{Deserialize, Serialize};
22
23/// A tool call requested by the LLM.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolCall {
26    /// Provider-assigned call ID (echoed back in tool results).
27    pub id: String,
28    /// Name of the tool to invoke.
29    pub function_name: String,
30    /// Raw JSON string of tool arguments.
31    pub arguments: String,
32    /// Gemini-specific: thought signature that must be echoed back in history.
33    #[serde(skip_serializing_if = "Option::is_none", default)]
34    pub thought_signature: Option<String>,
35}
36
37/// Token usage from an LLM response.
38#[derive(Debug, Clone, Default)]
39pub struct TokenUsage {
40    /// Input tokens sent to the model.
41    pub prompt_tokens: i64,
42    /// Output tokens generated by the model.
43    pub completion_tokens: i64,
44    /// Tokens read from provider cache (e.g. Anthropic prompt caching, Gemini cached content).
45    pub cache_read_tokens: i64,
46    /// Tokens written to provider cache on this request.
47    pub cache_creation_tokens: i64,
48    /// Tokens used for reasoning/thinking (e.g. OpenAI reasoning_tokens, Anthropic thinking).
49    pub thinking_tokens: i64,
50    /// Why the model stopped: "end_turn", "max_tokens", "stop_sequence", etc.
51    /// Empty string means unknown (provider didn't report it).
52    pub stop_reason: String,
53}
54
55/// The LLM's response: either text, tool calls, or both.
56#[derive(Debug, Clone)]
57pub struct LlmResponse {
58    /// Text content of the response (may be `None` if tool-calls only).
59    pub content: Option<String>,
60    /// Tool calls requested by the model.
61    pub tool_calls: Vec<ToolCall>,
62    /// Token usage statistics.
63    pub usage: TokenUsage,
64}
65
66/// Base64-encoded image data for multi-modal messages.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ImageData {
69    /// MIME type (e.g. "image/png", "image/jpeg").
70    pub media_type: String,
71    /// Base64-encoded image bytes.
72    pub base64: String,
73}
74
75/// A single message in the conversation.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ChatMessage {
78    /// Message role: `"user"`, `"assistant"`, or `"tool"`.
79    pub role: String,
80    /// Text content (may be `None` for tool-call-only messages).
81    pub content: Option<String>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    /// Tool calls requested by the assistant.
84    pub tool_calls: Option<Vec<ToolCall>>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    /// ID of the tool call this message responds to.
87    pub tool_call_id: Option<String>,
88    /// Attached images (only used in-flight, not persisted to DB).
89    #[serde(skip_serializing_if = "Option::is_none", default)]
90    pub images: Option<Vec<ImageData>>,
91}
92
93impl ChatMessage {
94    /// Create a simple text message (convenience for the common case).
95    pub fn text(role: &str, content: &str) -> Self {
96        Self {
97            role: role.to_string(),
98            content: Some(content.to_string()),
99            tool_calls: None,
100            tool_call_id: None,
101            images: None,
102        }
103    }
104}
105
106/// Tool definition sent to the LLM.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ToolDefinition {
109    /// Tool name (e.g. `"Read"`, `"Bash"`).
110    pub name: String,
111    /// Human-readable description for the LLM.
112    pub description: String,
113    /// JSON Schema for the tool's parameters.
114    pub parameters: serde_json::Value,
115}
116
117/// A discovered model from a provider.
118#[derive(Debug, Clone)]
119pub struct ModelInfo {
120    /// Model identifier (e.g. `"claude-3-5-sonnet-20241022"`).
121    pub id: String,
122    /// Provider/organization that owns the model.
123    #[allow(dead_code)]
124    pub owned_by: Option<String>,
125}
126
127/// Model capabilities queried from the provider API.
128#[derive(Debug, Clone, Default)]
129pub struct ModelCapabilities {
130    /// Maximum context window in tokens (input + output).
131    pub context_window: Option<usize>,
132    /// Maximum output tokens the model supports.
133    pub max_output_tokens: Option<usize>,
134}
135
136/// Is this URL pointing to a local address?
137fn is_localhost_url(url: &str) -> bool {
138    let lower = url.to_lowercase();
139    lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
140}
141
142/// Build a reqwest client with proper proxy configuration.
143///
144/// - Reads HTTPS_PROXY / HTTP_PROXY from env
145/// - Supports proxy auth via URL (http://user:pass@proxy:port)
146/// - Supports separate PROXY_USER / PROXY_PASS env vars
147/// - Bypasses proxy for localhost (LM Studio)
148pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
149    let mut builder = reqwest::Client::builder();
150
151    let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
152        .or_else(|| crate::runtime_env::get("HTTP_PROXY"))
153        .or_else(|| crate::runtime_env::get("https_proxy"))
154        .or_else(|| crate::runtime_env::get("http_proxy"));
155
156    if let Some(ref url) = proxy_url
157        && !url.is_empty()
158    {
159        match reqwest::Proxy::all(url) {
160            Ok(mut proxy) => {
161                // Bypass proxy for local addresses
162                proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
163
164                // If URL doesn't contain creds, check env vars
165                if !url.contains('@') {
166                    let user = crate::runtime_env::get("PROXY_USER");
167                    let pass = crate::runtime_env::get("PROXY_PASS");
168                    if let (Some(u), Some(p)) = (user, pass) {
169                        proxy = proxy.basic_auth(&u, &p);
170                        tracing::debug!("Using proxy with basic auth (credentials redacted)");
171                    }
172                }
173
174                builder = builder.proxy(proxy);
175                tracing::debug!("Using proxy: {}", redact_url_credentials(url));
176            }
177            Err(e) => {
178                tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
179            }
180        }
181    }
182
183    // Accept self-signed certs only for localhost (LM Studio, Ollama, vLLM).
184    // The env var is still required, but it's now scoped to local addresses.
185    let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
186        .map(|v| v == "1" || v == "true")
187        .unwrap_or(false);
188    let is_local = base_url.is_some_and(is_localhost_url);
189    if wants_skip_tls && is_local {
190        tracing::info!("TLS certificate validation disabled for local provider.");
191        builder = builder.danger_accept_invalid_certs(true);
192    } else if wants_skip_tls {
193        tracing::warn!(
194            "KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
195             TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
196        );
197    }
198
199    builder.build().unwrap_or_else(|_| reqwest::Client::new())
200}
201
202/// Redact embedded credentials from a URL.
203///
204/// `http://user:pass@proxy:8080` → `http://***:***@proxy:8080`
205fn redact_url_credentials(url: &str) -> String {
206    // Pattern: scheme://user:pass@host...
207    if let Some(at_pos) = url.find('@')
208        && let Some(scheme_end) = url.find("://")
209    {
210        let prefix = &url[..scheme_end + 3]; // "http://"
211        let host_part = &url[at_pos..]; // "@proxy:8080/..."
212        return format!("{prefix}***:***{host_part}");
213    }
214    url.to_string()
215}
216
217/// A streaming chunk from the LLM.
218#[derive(Debug, Clone)]
219pub enum StreamChunk {
220    /// A text delta (partial content).
221    TextDelta(String),
222    /// A thinking/reasoning delta from native API (Anthropic extended thinking, OpenAI reasoning).
223    ThinkingDelta(String),
224    /// A tool call was returned (streaming ends, need full response).
225    ToolCalls(Vec<ToolCall>),
226    /// Stream finished with usage info.
227    Done(TokenUsage),
228}
229
230/// Trait for LLM provider backends.
231#[async_trait]
232pub trait LlmProvider: Send + Sync {
233    /// Send a chat completion request (non-streaming).
234    async fn chat(
235        &self,
236        messages: &[ChatMessage],
237        tools: &[ToolDefinition],
238        settings: &crate::config::ModelSettings,
239    ) -> Result<LlmResponse>;
240
241    /// Send a streaming chat completion request.
242    /// Returns a channel receiver that yields chunks as they arrive.
243    async fn chat_stream(
244        &self,
245        messages: &[ChatMessage],
246        tools: &[ToolDefinition],
247        settings: &crate::config::ModelSettings,
248    ) -> Result<tokio::sync::mpsc::Receiver<StreamChunk>>;
249
250    /// List available models from the provider.
251    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
252
253    /// Query model capabilities (context window, max output tokens) from the API.
254    ///
255    /// Returns `Ok(caps)` with whatever the provider reports. Fields are `None`
256    /// when the API doesn't expose them. Callers should fall back to the
257    /// hardcoded lookup table for any `None` fields.
258    async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
259        Ok(ModelCapabilities::default())
260    }
261
262    /// Provider display name (for UI).
263    fn provider_name(&self) -> &str;
264}
265
266// ── Provider factory ──────────────────────────────────────────
267
268use crate::config::{KodaConfig, ProviderType};
269
270/// Create an LLM provider from the given configuration.
271pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider> {
272    let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
273    match config.provider_type {
274        ProviderType::Anthropic => {
275            let key = api_key.unwrap_or_else(|| {
276                tracing::warn!("No ANTHROPIC_API_KEY set");
277                String::new()
278            });
279            Box::new(anthropic::AnthropicProvider::new(
280                key,
281                Some(&config.base_url),
282            ))
283        }
284        ProviderType::Gemini => {
285            let key = api_key.unwrap_or_else(|| {
286                tracing::warn!("No GEMINI_API_KEY set");
287                String::new()
288            });
289            Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
290        }
291        ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
292        _ => Box::new(openai_compat::OpenAiCompatProvider::new(
293            &config.base_url,
294            api_key,
295        )),
296    }
297}