1use crate::exceptions::AicoError;
2use crate::llm::api_models::{ChatCompletionChunk, ChatCompletionRequest};
3use crate::models::Provider;
4use reqwest::Client as HttpClient;
5use std::env;
6use std::str::FromStr;
7
8#[derive(Debug)]
9struct ModelSpec {
10 provider: Provider,
11 model_id_short: String,
12 extra_params: Option<serde_json::Value>,
13}
14
15impl FromStr for ModelSpec {
16 type Err = AicoError;
17
18 fn from_str(s: &str) -> Result<Self, Self::Err> {
19 let (base, params_part) = s.split_once('+').unwrap_or((s, ""));
20
21 let Some((provider_str, model_name)) = base.split_once('/') else {
22 return Err(AicoError::Configuration(format!(
23 "Invalid model format '{}'. Expected 'provider/model'.",
24 base
25 )));
26 };
27
28 let provider: Provider = provider_str.parse()?;
29 let mut extra_map = serde_json::Map::new();
30
31 if matches!(provider, Provider::OpenRouter) {
33 extra_map.insert("usage".into(), serde_json::json!({ "include": true }));
34 }
35
36 if !params_part.is_empty() {
38 for param in params_part.split('+') {
39 let (k, v) = param.split_once('=').unwrap_or((param, "true"));
40
41 let val = serde_json::from_str(v)
43 .unwrap_or_else(|_| serde_json::Value::String(v.to_string()));
44
45 if matches!(provider, Provider::OpenRouter) && k == "reasoning_effort" {
47 extra_map.insert("reasoning".into(), serde_json::json!({ "effort": val }));
48 } else {
49 extra_map.insert(k.to_string(), val);
50 }
51 }
52 }
53
54 Ok(Self {
55 provider,
56 model_id_short: model_name.to_string(),
57 extra_params: if extra_map.is_empty() {
58 None
59 } else {
60 Some(serde_json::Value::Object(extra_map))
61 },
62 })
63 }
64}
65
66#[derive(Debug)]
67pub struct LlmClient {
68 http: HttpClient,
69 api_key: String,
70 base_url: String,
71 pub model_id: String,
72 extra_params: Option<serde_json::Value>,
73}
74
75impl Provider {
76 fn base_url_env_var(&self) -> &'static str {
77 match self {
78 Provider::OpenAI => "OPENAI_BASE_URL",
79 Provider::OpenRouter => "OPENROUTER_BASE_URL",
80 }
81 }
82
83 fn default_base_url(&self) -> &'static str {
84 match self {
85 Provider::OpenAI => "https://api.openai.com/v1",
86 Provider::OpenRouter => "https://openrouter.ai/api/v1",
87 }
88 }
89
90 fn api_key_env_var(&self) -> &'static str {
91 match self {
92 Provider::OpenAI => "OPENAI_API_KEY",
93 Provider::OpenRouter => "OPENROUTER_API_KEY",
94 }
95 }
96}
97
98impl LlmClient {
99 pub fn new(full_model_string: &str) -> Result<Self, AicoError> {
100 Self::new_with_env(full_model_string, |k| env::var(k).ok())
101 }
102
103 pub fn new_with_env<F>(full_model_string: &str, env_get: F) -> Result<Self, AicoError>
104 where
105 F: Fn(&str) -> Option<String>,
106 {
107 let spec: ModelSpec = full_model_string.parse()?;
108
109 let api_key_var = spec.provider.api_key_env_var();
110 let api_key = env_get(api_key_var)
111 .ok_or_else(|| AicoError::Configuration(format!("{} is required.", api_key_var)))?;
112
113 let base_url = env_get(spec.provider.base_url_env_var())
114 .unwrap_or_else(|| spec.provider.default_base_url().to_string());
115
116 Ok(Self {
117 http: crate::utils::setup_http_client(),
118 api_key,
119 base_url,
120 model_id: spec.model_id_short,
121 extra_params: spec.extra_params,
122 })
123 }
124
125 pub fn base_url(&self) -> &str {
126 &self.base_url
127 }
128
129 pub fn get_extra_params(&self) -> Option<serde_json::Value> {
130 self.extra_params.clone()
131 }
132
133 pub async fn stream_chat(
136 &self,
137 req: ChatCompletionRequest,
138 ) -> Result<reqwest::Response, AicoError> {
139 let url = format!("{}/chat/completions", self.base_url);
140
141 let request_builder = self
142 .http
143 .post(&url)
144 .header("Authorization", format!("Bearer {}", self.api_key))
145 .header("Content-Type", "application/json")
146 .json(&req);
147
148 let response = request_builder
149 .send()
150 .await
151 .map_err(|e| AicoError::Provider(e.to_string()))?;
152
153 if !response.status().is_success() {
154 let status = response.status();
155 let text = response.text().await.unwrap_or_default();
156
157 let error_msg = if text.trim().is_empty() {
158 format!("API Error (Status: {}): [Empty Body]", status)
159 } else {
160 format!("API Error (Status: {}): {}", status, text)
161 };
162 return Err(AicoError::Provider(error_msg));
163 }
164
165 Ok(response)
166 }
167}
168
169pub fn parse_sse_line(line: &str) -> Option<ChatCompletionChunk> {
171 let trimmed = line.trim();
172 if !trimmed.starts_with("data: ") {
173 return None;
174 }
175 let content = &trimmed[6..];
176 if content == "[DONE]" {
177 return None;
178 }
179 serde_json::from_str(content).ok()
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 fn mock_env(key: &str) -> Option<String> {
187 match key {
188 "OPENAI_API_KEY" => Some("sk-test".to_string()),
189 "OPENROUTER_API_KEY" => Some("sk-or-test".to_string()),
190 _ => None,
191 }
192 }
193
194 #[test]
195 fn test_get_extra_params_openrouter_nesting() {
196 let client =
197 LlmClient::new_with_env("openrouter/openai/o1+reasoning_effort=medium", mock_env)
198 .unwrap();
199 let params = client.get_extra_params().unwrap();
200
201 assert_eq!(params["usage"]["include"], true);
202 assert_eq!(params["reasoning"]["effort"], "medium");
203 assert!(params.get("reasoning_effort").is_none());
204 }
205
206 #[test]
207 fn test_get_extra_params_openai_flattened() {
208 let client =
209 LlmClient::new_with_env("openai/o1+reasoning_effort=medium", mock_env).unwrap();
210 let params = client.get_extra_params().unwrap();
211
212 assert_eq!(params["reasoning_effort"], "medium");
213 assert!(params.get("usage").is_none());
214 }
215}