pub mod anthropic;
pub mod gemini;
pub mod openai_compat;
pub mod stream_collector;
pub mod stream_tag_filter;
#[cfg(any(test, feature = "test-support"))]
pub mod mock;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub function_name: String,
pub arguments: String,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub thought_signature: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct TokenUsage {
pub prompt_tokens: i64,
pub completion_tokens: i64,
pub cache_read_tokens: i64,
pub cache_creation_tokens: i64,
pub thinking_tokens: i64,
pub stop_reason: String,
}
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub usage: TokenUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
pub media_type: String,
pub base64: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub images: Option<Vec<ImageData>>,
}
impl ChatMessage {
pub fn text(role: &str, content: &str) -> Self {
Self {
role: role.to_string(),
content: Some(content.to_string()),
tool_calls: None,
tool_call_id: None,
images: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub id: String,
#[allow(dead_code)]
pub owned_by: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ModelCapabilities {
pub context_window: Option<usize>,
pub max_output_tokens: Option<usize>,
}
fn is_localhost_url(url: &str) -> bool {
let lower = url.to_lowercase();
lower.contains("://localhost") || lower.contains("://127.0.0.1") || lower.contains("://[::1]")
}
pub fn build_http_client(base_url: Option<&str>) -> reqwest::Client {
let mut builder = reqwest::Client::builder();
let proxy_url = crate::runtime_env::get("HTTPS_PROXY")
.or_else(|| crate::runtime_env::get("HTTP_PROXY"))
.or_else(|| crate::runtime_env::get("https_proxy"))
.or_else(|| crate::runtime_env::get("http_proxy"));
if let Some(ref url) = proxy_url
&& !url.is_empty()
{
match reqwest::Proxy::all(url) {
Ok(mut proxy) => {
proxy = proxy.no_proxy(reqwest::NoProxy::from_string("localhost,127.0.0.1,::1"));
if !url.contains('@') {
let user = crate::runtime_env::get("PROXY_USER");
let pass = crate::runtime_env::get("PROXY_PASS");
if let (Some(u), Some(p)) = (user, pass) {
proxy = proxy.basic_auth(&u, &p);
tracing::debug!("Using proxy with basic auth (credentials redacted)");
}
}
builder = builder.proxy(proxy);
tracing::debug!("Using proxy: {}", redact_url_credentials(url));
}
Err(e) => {
tracing::warn!("Invalid proxy URL '{}': {e}", redact_url_credentials(url));
}
}
}
let wants_skip_tls = crate::runtime_env::get("KODA_ACCEPT_INVALID_CERTS")
.map(|v| v == "1" || v == "true")
.unwrap_or(false);
let is_local = base_url.is_some_and(is_localhost_url);
if wants_skip_tls && is_local {
tracing::info!("TLS certificate validation disabled for local provider.");
builder = builder.danger_accept_invalid_certs(true);
} else if wants_skip_tls {
tracing::warn!(
"KODA_ACCEPT_INVALID_CERTS is set but provider URL is not localhost — ignoring. \
TLS bypass is only allowed for local providers (localhost/127.0.0.1)."
);
}
builder.build().unwrap_or_else(|_| reqwest::Client::new())
}
fn redact_url_credentials(url: &str) -> String {
if let Some(at_pos) = url.find('@')
&& let Some(scheme_end) = url.find("://")
{
let prefix = &url[..scheme_end + 3]; let host_part = &url[at_pos..]; return format!("{prefix}***:***{host_part}");
}
url.to_string()
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
TextDelta(String),
ThinkingDelta(String),
ToolCallReady(ToolCall),
ToolCalls(Vec<ToolCall>),
Done(TokenUsage),
NetworkError(String),
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
async fn chat(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
settings: &crate::config::ModelSettings,
) -> Result<LlmResponse>;
async fn chat_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
settings: &crate::config::ModelSettings,
) -> Result<stream_collector::SseCollector>;
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
async fn model_capabilities(&self, _model: &str) -> Result<ModelCapabilities> {
Ok(ModelCapabilities::default())
}
fn provider_name(&self) -> &str;
}
use crate::config::{KodaConfig, ProviderType};
pub fn create_provider(config: &KodaConfig) -> Box<dyn LlmProvider> {
let api_key = crate::runtime_env::get(config.provider_type.env_key_name());
match config.provider_type {
ProviderType::Anthropic => {
let key = api_key.unwrap_or_else(|| {
tracing::warn!("No ANTHROPIC_API_KEY set");
String::new()
});
Box::new(anthropic::AnthropicProvider::new(
key,
Some(&config.base_url),
))
}
ProviderType::Gemini => {
let key = api_key.unwrap_or_else(|| {
tracing::warn!("No GEMINI_API_KEY set");
String::new()
});
Box::new(gemini::GeminiProvider::new(key, Some(&config.base_url)))
}
#[cfg(any(test, feature = "test-support"))]
ProviderType::Mock => Box::new(mock::MockProvider::from_env()),
_ => Box::new(openai_compat::OpenAiCompatProvider::new(
&config.base_url,
api_key,
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_localhost_url_localhost() {
assert!(is_localhost_url("http://localhost:1234/v1"));
assert!(is_localhost_url("HTTP://LOCALHOST:11434/api"));
}
#[test]
fn test_is_localhost_url_127() {
assert!(is_localhost_url("http://127.0.0.1:8000/v1"));
}
#[test]
fn test_is_localhost_url_ipv6() {
assert!(is_localhost_url("http://[::1]:1234/v1"));
}
#[test]
fn test_is_localhost_url_remote() {
assert!(!is_localhost_url("https://api.openai.com/v1"));
assert!(!is_localhost_url("https://api.anthropic.com/v1"));
}
#[test]
fn test_redact_with_credentials() {
let result = redact_url_credentials("http://user:secret@proxy.corp.com:8080");
assert!(
!result.contains("secret"),
"credentials should be redacted: {result}"
);
assert!(
result.contains("***:***"),
"should have redacted placeholder: {result}"
);
assert!(
result.contains("proxy.corp.com"),
"host should be preserved: {result}"
);
}
#[test]
fn test_redact_without_credentials() {
let url = "https://proxy.corp.com:8080";
assert_eq!(redact_url_credentials(url), url);
}
#[test]
fn test_redact_empty_url() {
assert_eq!(redact_url_credentials(""), "");
}
#[test]
fn test_chat_message_text_builder() {
let msg = ChatMessage::text("user", "hello world");
assert_eq!(msg.role, "user");
assert_eq!(msg.content.as_deref(), Some("hello world"));
assert!(msg.tool_calls.is_none());
assert!(msg.tool_call_id.is_none());
assert!(msg.images.is_none());
}
#[test]
fn test_chat_message_text_assistant() {
let msg = ChatMessage::text("assistant", "I can help with that.");
assert_eq!(msg.role, "assistant");
assert_eq!(msg.content.as_deref(), Some("I can help with that."));
}
#[test]
fn test_token_usage_default() {
let usage = TokenUsage::default();
assert_eq!(usage.prompt_tokens, 0);
assert_eq!(usage.completion_tokens, 0);
assert!(
usage.stop_reason.is_empty(),
"default stop_reason should be empty"
);
}
}