aico/llm/
client.rs

1use crate::exceptions::AicoError;
2use crate::llm::api_models::{ChatCompletionChunk, ChatCompletionRequest};
3use reqwest::Client as HttpClient;
4use std::env;
5
6#[derive(Debug)]
7struct ModelSpec {
8    api_key_env: &'static str,
9    default_base_url: &'static str,
10    model_id_short: String,
11    extra_params: Option<serde_json::Value>,
12}
13
14impl ModelSpec {
15    fn parse(full_str: &str) -> Result<Self, AicoError> {
16        let (base_model, params_part) = full_str.split_once('+').unwrap_or((full_str, ""));
17        let (provider, model_name) = base_model.split_once('/').ok_or_else(|| {
18            AicoError::Configuration(format!(
19                "Invalid model format '{}'. Expected 'provider/model'.",
20                base_model
21            ))
22        })?;
23
24        let (api_key_env, default_base_url) = match provider {
25            "openrouter" => ("OPENROUTER_API_KEY", "https://openrouter.ai/api/v1"),
26            "openai" => ("OPENAI_API_KEY", "https://api.openai.com/v1"),
27            _ => {
28                return Err(AicoError::Configuration(format!(
29                    "Unrecognized provider prefix in '{}'. Use 'openai/' or 'openrouter/'.",
30                    full_str
31                )));
32            }
33        };
34
35        let mut extra_map: Option<serde_json::Map<String, serde_json::Value>> = None;
36
37        if provider == "openrouter" {
38            extra_map
39                .get_or_insert_default()
40                .insert("usage".to_string(), serde_json::json!({ "include": true }));
41        }
42
43        if !params_part.is_empty() {
44            for param in params_part.split('+') {
45                let m = extra_map.get_or_insert_default();
46                if let Some((k, v)) = param.split_once('=') {
47                    let val = serde_json::from_str::<serde_json::Value>(v)
48                        .unwrap_or_else(|_| serde_json::Value::String(v.to_string()));
49
50                    if provider == "openrouter" && k == "reasoning_effort" {
51                        m.insert(
52                            "reasoning".to_string(),
53                            serde_json::json!({ "effort": val }),
54                        );
55                    } else {
56                        m.insert(k.to_string(), val);
57                    }
58                } else {
59                    m.insert(param.to_string(), serde_json::Value::Bool(true));
60                }
61            }
62        }
63
64        let extra_params = extra_map.map(serde_json::Value::Object);
65
66        Ok(Self {
67            api_key_env,
68            default_base_url,
69            model_id_short: model_name.to_string(),
70            extra_params,
71        })
72    }
73}
74
75#[derive(Debug)]
76pub struct LlmClient {
77    http: HttpClient,
78    api_key: String,
79    base_url: String,
80    pub model_id: String,
81    extra_params: Option<serde_json::Value>,
82}
83
84impl LlmClient {
85    pub fn new(full_model_string: &str) -> Result<Self, AicoError> {
86        let spec = ModelSpec::parse(full_model_string)?;
87
88        let api_key = env::var(spec.api_key_env)
89            .map_err(|_| AicoError::Configuration(format!("{} is required.", spec.api_key_env)))?;
90
91        let base_url =
92            env::var("OPENAI_BASE_URL").unwrap_or_else(|_| spec.default_base_url.to_string());
93
94        Ok(Self {
95            http: crate::utils::setup_http_client(),
96            api_key,
97            base_url,
98            model_id: spec.model_id_short,
99            extra_params: spec.extra_params,
100        })
101    }
102
103    pub fn get_extra_params(&self) -> Option<serde_json::Value> {
104        self.extra_params.clone()
105    }
106
107    /// Sends a streaming request and returns a channel or iterator of chunks.
108    /// For simplicity with 'minimal deps', we return the response and let the caller iterate.
109    pub async fn stream_chat(
110        &self,
111        req: ChatCompletionRequest,
112    ) -> Result<reqwest::Response, AicoError> {
113        let url = format!("{}/chat/completions", self.base_url);
114
115        let request_builder = self
116            .http
117            .post(&url)
118            .header("Authorization", format!("Bearer {}", self.api_key))
119            .header("Content-Type", "application/json")
120            .json(&req);
121
122        let response = request_builder
123            .send()
124            .await
125            .map_err(|e| AicoError::Provider(e.to_string()))?;
126
127        if !response.status().is_success() {
128            let status = response.status();
129            let text = response.text().await.unwrap_or_default();
130
131            let error_msg = if text.trim().is_empty() {
132                format!("API Error (Status: {}): [Empty Body]", status)
133            } else {
134                format!("API Error (Status: {}): {}", status, text)
135            };
136            return Err(AicoError::Provider(error_msg));
137        }
138
139        Ok(response)
140    }
141}
142
143/// Helper to parse an SSE line: "data: {json}"
144pub fn parse_sse_line(line: &str) -> Option<ChatCompletionChunk> {
145    let trimmed = line.trim();
146    if !trimmed.starts_with("data: ") {
147        return None;
148    }
149    let content = &trimmed[6..];
150    if content == "[DONE]" {
151        return None;
152    }
153    serde_json::from_str(content).ok()
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_get_extra_params_openrouter_nesting() {
162        unsafe { std::env::set_var("OPENROUTER_API_KEY", "sk-test") };
163        let client = LlmClient::new("openrouter/openai/o1+reasoning_effort=medium").unwrap();
164        let params = client.get_extra_params().unwrap();
165
166        assert_eq!(params["usage"]["include"], true);
167        assert_eq!(params["reasoning"]["effort"], "medium");
168        assert!(params.get("reasoning_effort").is_none());
169    }
170
171    #[test]
172    fn test_get_extra_params_openai_flattened() {
173        unsafe { std::env::set_var("OPENAI_API_KEY", "sk-test") };
174        let client = LlmClient::new("openai/o1+reasoning_effort=medium").unwrap();
175        let params = client.get_extra_params().unwrap();
176
177        assert_eq!(params["reasoning_effort"], "medium");
178        assert!(params.get("usage").is_none());
179    }
180}