Skip to main content

aico/llm/
client.rs

1use crate::exceptions::AicoError;
2use crate::llm::api_models::{ChatCompletionChunk, ChatCompletionRequest};
3use crate::models::Provider;
4use reqwest::Client as HttpClient;
5use std::env;
6use std::str::FromStr;
7
8#[derive(Debug)]
9struct ModelSpec {
10    provider: Provider,
11    model_id_short: String,
12    extra_params: Option<serde_json::Value>,
13}
14
15impl FromStr for ModelSpec {
16    type Err = AicoError;
17
18    fn from_str(s: &str) -> Result<Self, Self::Err> {
19        let (base, params_part) = s.split_once('+').unwrap_or((s, ""));
20
21        let Some((provider_str, model_name)) = base.split_once('/') else {
22            return Err(AicoError::Configuration(format!(
23                "Invalid model format '{}'. Expected 'provider/model'.",
24                base
25            )));
26        };
27
28        let provider: Provider = provider_str.parse()?;
29        let mut extra_map = serde_json::Map::new();
30
31        // Defaults
32        if matches!(provider, Provider::OpenRouter) {
33            extra_map.insert("usage".into(), serde_json::json!({ "include": true }));
34        }
35
36        // Parse params
37        if !params_part.is_empty() {
38            for param in params_part.split('+') {
39                let (k, v) = param.split_once('=').unwrap_or((param, "true"));
40
41                // Attempt to parse as JSON (numbers/bools), fallback to string
42                let val = serde_json::from_str(v)
43                    .unwrap_or_else(|_| serde_json::Value::String(v.to_string()));
44
45                // Logic specific to provider can remain here or be extracted
46                if matches!(provider, Provider::OpenRouter) && k == "reasoning_effort" {
47                    extra_map.insert("reasoning".into(), serde_json::json!({ "effort": val }));
48                } else {
49                    extra_map.insert(k.to_string(), val);
50                }
51            }
52        }
53
54        Ok(Self {
55            provider,
56            model_id_short: model_name.to_string(),
57            extra_params: if extra_map.is_empty() {
58                None
59            } else {
60                Some(serde_json::Value::Object(extra_map))
61            },
62        })
63    }
64}
65
66#[derive(Debug)]
67pub struct LlmClient {
68    http: HttpClient,
69    api_key: String,
70    base_url: String,
71    pub model_id: String,
72    extra_params: Option<serde_json::Value>,
73}
74
75impl Provider {
76    fn base_url_env_var(&self) -> &'static str {
77        match self {
78            Provider::OpenAI => "OPENAI_BASE_URL",
79            Provider::OpenRouter => "OPENROUTER_BASE_URL",
80        }
81    }
82
83    fn default_base_url(&self) -> &'static str {
84        match self {
85            Provider::OpenAI => "https://api.openai.com/v1",
86            Provider::OpenRouter => "https://openrouter.ai/api/v1",
87        }
88    }
89
90    fn api_key_env_var(&self) -> &'static str {
91        match self {
92            Provider::OpenAI => "OPENAI_API_KEY",
93            Provider::OpenRouter => "OPENROUTER_API_KEY",
94        }
95    }
96}
97
98impl LlmClient {
99    pub fn new(full_model_string: &str) -> Result<Self, AicoError> {
100        Self::new_with_env(full_model_string, |k| env::var(k).ok())
101    }
102
103    pub fn new_with_env<F>(full_model_string: &str, env_get: F) -> Result<Self, AicoError>
104    where
105        F: Fn(&str) -> Option<String>,
106    {
107        let spec: ModelSpec = full_model_string.parse()?;
108
109        let api_key_var = spec.provider.api_key_env_var();
110        let api_key = env_get(api_key_var)
111            .ok_or_else(|| AicoError::Configuration(format!("{} is required.", api_key_var)))?;
112
113        let base_url = env_get(spec.provider.base_url_env_var())
114            .unwrap_or_else(|| spec.provider.default_base_url().to_string());
115
116        Ok(Self {
117            http: crate::utils::setup_http_client(),
118            api_key,
119            base_url,
120            model_id: spec.model_id_short,
121            extra_params: spec.extra_params,
122        })
123    }
124
125    pub fn base_url(&self) -> &str {
126        &self.base_url
127    }
128
129    pub fn get_extra_params(&self) -> Option<serde_json::Value> {
130        self.extra_params.clone()
131    }
132
133    /// Sends a streaming request and returns a channel or iterator of chunks.
134    /// For simplicity with 'minimal deps', we return the response and let the caller iterate.
135    pub async fn stream_chat(
136        &self,
137        req: ChatCompletionRequest,
138    ) -> Result<reqwest::Response, AicoError> {
139        let url = format!("{}/chat/completions", self.base_url);
140
141        let request_builder = self
142            .http
143            .post(&url)
144            .header("Authorization", format!("Bearer {}", self.api_key))
145            .header("Content-Type", "application/json")
146            .json(&req);
147
148        let response = request_builder
149            .send()
150            .await
151            .map_err(|e| AicoError::Provider(e.to_string()))?;
152
153        if !response.status().is_success() {
154            let status = response.status();
155            let text = response.text().await.unwrap_or_default();
156
157            let error_msg = if text.trim().is_empty() {
158                format!("API Error (Status: {}): [Empty Body]", status)
159            } else {
160                format!("API Error (Status: {}): {}", status, text)
161            };
162            return Err(AicoError::Provider(error_msg));
163        }
164
165        Ok(response)
166    }
167}
168
169/// Helper to parse an SSE line: "data: {json}"
170pub fn parse_sse_line(line: &str) -> Option<ChatCompletionChunk> {
171    let trimmed = line.trim();
172    if !trimmed.starts_with("data: ") {
173        return None;
174    }
175    let content = &trimmed[6..];
176    if content == "[DONE]" {
177        return None;
178    }
179    serde_json::from_str(content).ok()
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    fn mock_env(key: &str) -> Option<String> {
187        match key {
188            "OPENAI_API_KEY" => Some("sk-test".to_string()),
189            "OPENROUTER_API_KEY" => Some("sk-or-test".to_string()),
190            _ => None,
191        }
192    }
193
194    #[test]
195    fn test_get_extra_params_openrouter_nesting() {
196        let client =
197            LlmClient::new_with_env("openrouter/openai/o1+reasoning_effort=medium", mock_env)
198                .unwrap();
199        let params = client.get_extra_params().unwrap();
200
201        assert_eq!(params["usage"]["include"], true);
202        assert_eq!(params["reasoning"]["effort"], "medium");
203        assert!(params.get("reasoning_effort").is_none());
204    }
205
206    #[test]
207    fn test_get_extra_params_openai_flattened() {
208        let client =
209            LlmClient::new_with_env("openai/o1+reasoning_effort=medium", mock_env).unwrap();
210        let params = client.get_extra_params().unwrap();
211
212        assert_eq!(params["reasoning_effort"], "medium");
213        assert!(params.get("usage").is_none());
214    }
215}