1use crate::exceptions::AicoError;
2use crate::llm::api_models::{ChatCompletionChunk, ChatCompletionRequest};
3use reqwest::Client as HttpClient;
4use std::env;
5
6#[derive(Debug)]
7struct ModelSpec {
8 api_key_env: &'static str,
9 default_base_url: &'static str,
10 model_id_short: String,
11 extra_params: Option<serde_json::Value>,
12}
13
14impl ModelSpec {
15 fn parse(full_str: &str) -> Result<Self, AicoError> {
16 let (base_model, params_part) = full_str.split_once('+').unwrap_or((full_str, ""));
17 let (provider, model_name) = base_model.split_once('/').ok_or_else(|| {
18 AicoError::Configuration(format!(
19 "Invalid model format '{}'. Expected 'provider/model'.",
20 base_model
21 ))
22 })?;
23
24 let (api_key_env, default_base_url) = match provider {
25 "openrouter" => ("OPENROUTER_API_KEY", "https://openrouter.ai/api/v1"),
26 "openai" => ("OPENAI_API_KEY", "https://api.openai.com/v1"),
27 _ => {
28 return Err(AicoError::Configuration(format!(
29 "Unrecognized provider prefix in '{}'. Use 'openai/' or 'openrouter/'.",
30 full_str
31 )));
32 }
33 };
34
35 let mut extra_map: Option<serde_json::Map<String, serde_json::Value>> = None;
36
37 if provider == "openrouter" {
38 extra_map
39 .get_or_insert_default()
40 .insert("usage".to_string(), serde_json::json!({ "include": true }));
41 }
42
43 if !params_part.is_empty() {
44 for param in params_part.split('+') {
45 let m = extra_map.get_or_insert_default();
46 if let Some((k, v)) = param.split_once('=') {
47 let val = serde_json::from_str::<serde_json::Value>(v)
48 .unwrap_or_else(|_| serde_json::Value::String(v.to_string()));
49
50 if provider == "openrouter" && k == "reasoning_effort" {
51 m.insert(
52 "reasoning".to_string(),
53 serde_json::json!({ "effort": val }),
54 );
55 } else {
56 m.insert(k.to_string(), val);
57 }
58 } else {
59 m.insert(param.to_string(), serde_json::Value::Bool(true));
60 }
61 }
62 }
63
64 let extra_params = extra_map.map(serde_json::Value::Object);
65
66 Ok(Self {
67 api_key_env,
68 default_base_url,
69 model_id_short: model_name.to_string(),
70 extra_params,
71 })
72 }
73}
74
75#[derive(Debug)]
76pub struct LlmClient {
77 http: HttpClient,
78 api_key: String,
79 base_url: String,
80 pub model_id: String,
81 extra_params: Option<serde_json::Value>,
82}
83
84impl LlmClient {
85 pub fn new(full_model_string: &str) -> Result<Self, AicoError> {
86 let spec = ModelSpec::parse(full_model_string)?;
87
88 let api_key = env::var(spec.api_key_env)
89 .map_err(|_| AicoError::Configuration(format!("{} is required.", spec.api_key_env)))?;
90
91 let base_url =
92 env::var("OPENAI_BASE_URL").unwrap_or_else(|_| spec.default_base_url.to_string());
93
94 Ok(Self {
95 http: crate::utils::setup_http_client(),
96 api_key,
97 base_url,
98 model_id: spec.model_id_short,
99 extra_params: spec.extra_params,
100 })
101 }
102
103 pub fn get_extra_params(&self) -> Option<serde_json::Value> {
104 self.extra_params.clone()
105 }
106
107 pub async fn stream_chat(
110 &self,
111 req: ChatCompletionRequest,
112 ) -> Result<reqwest::Response, AicoError> {
113 let url = format!("{}/chat/completions", self.base_url);
114
115 let request_builder = self
116 .http
117 .post(&url)
118 .header("Authorization", format!("Bearer {}", self.api_key))
119 .header("Content-Type", "application/json")
120 .json(&req);
121
122 let response = request_builder
123 .send()
124 .await
125 .map_err(|e| AicoError::Provider(e.to_string()))?;
126
127 if !response.status().is_success() {
128 let status = response.status();
129 let text = response.text().await.unwrap_or_default();
130
131 let error_msg = if text.trim().is_empty() {
132 format!("API Error (Status: {}): [Empty Body]", status)
133 } else {
134 format!("API Error (Status: {}): {}", status, text)
135 };
136 return Err(AicoError::Provider(error_msg));
137 }
138
139 Ok(response)
140 }
141}
142
143pub fn parse_sse_line(line: &str) -> Option<ChatCompletionChunk> {
145 let trimmed = line.trim();
146 if !trimmed.starts_with("data: ") {
147 return None;
148 }
149 let content = &trimmed[6..];
150 if content == "[DONE]" {
151 return None;
152 }
153 serde_json::from_str(content).ok()
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_get_extra_params_openrouter_nesting() {
162 unsafe { std::env::set_var("OPENROUTER_API_KEY", "sk-test") };
163 let client = LlmClient::new("openrouter/openai/o1+reasoning_effort=medium").unwrap();
164 let params = client.get_extra_params().unwrap();
165
166 assert_eq!(params["usage"]["include"], true);
167 assert_eq!(params["reasoning"]["effort"], "medium");
168 assert!(params.get("reasoning_effort").is_none());
169 }
170
171 #[test]
172 fn test_get_extra_params_openai_flattened() {
173 unsafe { std::env::set_var("OPENAI_API_KEY", "sk-test") };
174 let client = LlmClient::new("openai/o1+reasoning_effort=medium").unwrap();
175 let params = client.get_extra_params().unwrap();
176
177 assert_eq!(params["reasoning_effort"], "medium");
178 assert!(params.get("usage").is_none());
179 }
180}