1use serde_json::{json, Value as JsonValue};
66use std::time::Duration;
67
68const DEFAULT_OLLAMA_HOST: &str = "http://localhost:11434";
69const DEFAULT_LOCAL_MODEL: &str = "llama3";
70const DEFAULT_OPENAI_BASE: &str = "https://api.openai.com/v1";
71const DEFAULT_CLOUD_MODEL: &str = "gpt-4o-mini";
72const HTTP_TIMEOUT_SECS: u64 = 120;
73
74fn local_config() -> (String, String) {
77 let host = std::env::var("OLLAMA_HOST")
78 .unwrap_or_else(|_| DEFAULT_OLLAMA_HOST.to_string());
79 let model = std::env::var("LEX_LLM_LOCAL_MODEL")
80 .unwrap_or_else(|_| DEFAULT_LOCAL_MODEL.to_string());
81 (host, model)
82}
83
84fn cloud_config() -> Result<(String, String, String), String> {
102 let key = pick_env(&["LEX_LLM_CLOUD_API_KEY", "OPENAI_API_KEY"])
103 .ok_or_else(||
104 "agent.cloud_complete: neither LEX_LLM_CLOUD_API_KEY nor OPENAI_API_KEY env var set"
105 .to_string())?;
106 let base = pick_env(&["LEX_LLM_CLOUD_BASE_URL", "OPENAI_BASE_URL"])
107 .unwrap_or_else(|| DEFAULT_OPENAI_BASE.to_string());
108 let model = std::env::var("LEX_LLM_CLOUD_MODEL")
109 .unwrap_or_else(|_| DEFAULT_CLOUD_MODEL.to_string());
110 Ok((base, model, key))
111}
112
113fn pick_env(names: &[&str]) -> Option<String> {
115 for n in names {
116 if let Ok(v) = std::env::var(n) {
117 if !v.is_empty() { return Some(v); }
118 }
119 }
120 None
121}
122
123pub(crate) fn ollama_request_body(model: &str, prompt: &str) -> JsonValue {
127 json!({
128 "model": model,
129 "prompt": prompt,
130 "stream": false,
134 })
135}
136
137pub(crate) fn openai_request_body(model: &str, prompt: &str) -> JsonValue {
142 json!({
143 "model": model,
144 "messages": [{ "role": "user", "content": prompt }],
145 })
146}
147
148fn ollama_extract(resp: &JsonValue) -> Result<String, String> {
151 resp.get("response")
152 .and_then(|v| v.as_str())
153 .map(String::from)
154 .ok_or_else(|| format!(
155 "ollama: response missing `response` field: {}",
156 resp.to_string().chars().take(200).collect::<String>()
157 ))
158}
159
160fn openai_extract(resp: &JsonValue) -> Result<String, String> {
163 resp.pointer("/choices/0/message/content")
164 .and_then(|v| v.as_str())
165 .map(String::from)
166 .ok_or_else(|| format!(
167 "openai: response missing choices[0].message.content: {}",
168 resp.to_string().chars().take(200).collect::<String>()
169 ))
170}
171
172fn http_agent() -> ureq::Agent {
176 ureq::Agent::config_builder()
177 .timeout_global(Some(Duration::from_secs(HTTP_TIMEOUT_SECS)))
178 .http_status_as_error(false)
179 .build()
180 .new_agent()
181}
182
183fn read_body_json(mut resp: ureq::http::Response<ureq::Body>) -> Result<JsonValue, String> {
184 let bytes = resp.body_mut().read_to_vec()
185 .map_err(|e| format!("read response body: {e}"))?;
186 serde_json::from_slice(&bytes)
187 .map_err(|e| format!("parse response JSON: {e}"))
188}
189
190pub fn local_complete(prompt: &str) -> Result<String, String> {
193 let (host, model) = local_config();
194 let url = format!("{}/api/generate", host.trim_end_matches('/'));
195 let body = serde_json::to_vec(&ollama_request_body(&model, prompt))
196 .map_err(|e| format!("serialize ollama request: {e}"))?;
197 let resp = http_agent().post(&url)
198 .header("content-type", "application/json")
199 .send(&body[..])
200 .map_err(|e| format!("ollama POST {url}: {e}"))?;
201 let json = read_body_json(resp).map_err(|e| format!("ollama: {e}"))?;
202 ollama_extract(&json)
203}
204
205pub fn cloud_complete(prompt: &str) -> Result<String, String> {
208 let (base, model, key) = cloud_config()?;
209 let url = format!("{}/chat/completions", base.trim_end_matches('/'));
210 let body = serde_json::to_vec(&openai_request_body(&model, prompt))
211 .map_err(|e| format!("serialize cloud request: {e}"))?;
212 let resp = http_agent().post(&url)
213 .header("content-type", "application/json")
214 .header("Authorization", &format!("Bearer {key}"))
215 .send(&body[..])
216 .map_err(|e| format!("cloud POST {url}: {e}"))?;
217 let json = read_body_json(resp).map_err(|e| format!("cloud: {e}"))?;
218 openai_extract(&json)
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn env_lock() -> &'static std::sync::Mutex<()> {
234 static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
235 LOCK.get_or_init(|| std::sync::Mutex::new(()))
236 }
237
238 #[test]
239 fn ollama_body_is_non_streaming() {
240 let b = ollama_request_body("llama3", "hello");
241 assert_eq!(b["model"], "llama3");
242 assert_eq!(b["prompt"], "hello");
243 assert_eq!(b["stream"], false);
244 }
245
246 #[test]
247 fn openai_body_uses_user_role() {
248 let b = openai_request_body("gpt-4o-mini", "hello");
249 assert_eq!(b["model"], "gpt-4o-mini");
250 assert_eq!(b["messages"][0]["role"], "user");
251 assert_eq!(b["messages"][0]["content"], "hello");
252 }
253
254 #[test]
255 fn ollama_extract_pulls_response_field() {
256 let r = json!({"model": "llama3", "response": "hi back", "done": true});
257 assert_eq!(ollama_extract(&r).unwrap(), "hi back");
258 }
259
260 #[test]
261 fn ollama_extract_errors_on_missing_field() {
262 let r = json!({"error": "model not found"});
263 let e = ollama_extract(&r).unwrap_err();
264 assert!(e.contains("missing `response`"));
265 }
266
267 #[test]
268 fn openai_extract_pulls_choices_zero_message_content() {
269 let r = json!({
270 "id": "x",
271 "choices": [{
272 "index": 0,
273 "message": { "role": "assistant", "content": "hi back" },
274 "finish_reason": "stop"
275 }]
276 });
277 assert_eq!(openai_extract(&r).unwrap(), "hi back");
278 }
279
280 #[test]
281 fn openai_extract_errors_on_missing_path() {
282 let r = json!({"error": {"message": "invalid api key"}});
283 let e = openai_extract(&r).unwrap_err();
284 assert!(e.contains("missing"));
285 }
286
287 #[test]
288 fn cloud_config_fails_without_api_key() {
289 let _guard = env_lock().lock().unwrap_or_else(|e| e.into_inner());
295 let prior_lex = std::env::var("LEX_LLM_CLOUD_API_KEY").ok();
296 let prior_oai = std::env::var("OPENAI_API_KEY").ok();
297 std::env::remove_var("LEX_LLM_CLOUD_API_KEY");
298 std::env::remove_var("OPENAI_API_KEY");
299 let r = cloud_config();
300 if let Some(v) = prior_lex { std::env::set_var("LEX_LLM_CLOUD_API_KEY", v); }
301 if let Some(v) = prior_oai { std::env::set_var("OPENAI_API_KEY", v); }
302 let e = r.unwrap_err();
303 assert!(e.contains("LEX_LLM_CLOUD_API_KEY"));
304 }
305
306 #[test]
307 fn cloud_config_prefers_lex_prefix_then_falls_back_to_openai() {
308 let _guard = env_lock().lock().unwrap_or_else(|e| e.into_inner());
309 let prior_lex_key = std::env::var("LEX_LLM_CLOUD_API_KEY").ok();
310 let prior_lex_url = std::env::var("LEX_LLM_CLOUD_BASE_URL").ok();
311 let prior_oai_key = std::env::var("OPENAI_API_KEY").ok();
312 let prior_oai_url = std::env::var("OPENAI_BASE_URL").ok();
313 std::env::set_var("LEX_LLM_CLOUD_API_KEY", "k-lex");
314 std::env::set_var("OPENAI_API_KEY", "k-openai");
315 std::env::set_var("LEX_LLM_CLOUD_BASE_URL", "https://api.mistral.ai/v1");
316 std::env::remove_var("OPENAI_BASE_URL");
317 let (base, _model, key) = cloud_config().unwrap();
318 let restore = |name: &str, v: Option<String>| match v {
321 Some(s) => std::env::set_var(name, s),
322 None => std::env::remove_var(name),
323 };
324 restore("LEX_LLM_CLOUD_API_KEY", prior_lex_key);
325 restore("LEX_LLM_CLOUD_BASE_URL", prior_lex_url);
326 restore("OPENAI_API_KEY", prior_oai_key);
327 restore("OPENAI_BASE_URL", prior_oai_url);
328 assert_eq!(key, "k-lex");
329 assert_eq!(base, "https://api.mistral.ai/v1");
330 }
331
332 #[test]
333 fn local_config_uses_defaults_without_env() {
334 let _guard = env_lock().lock().unwrap_or_else(|e| e.into_inner());
335 let prior_h = std::env::var("OLLAMA_HOST").ok();
336 let prior_m = std::env::var("LEX_LLM_LOCAL_MODEL").ok();
337 std::env::remove_var("OLLAMA_HOST");
338 std::env::remove_var("LEX_LLM_LOCAL_MODEL");
339 let (host, model) = local_config();
340 if let Some(v) = prior_h { std::env::set_var("OLLAMA_HOST", v); }
341 if let Some(v) = prior_m { std::env::set_var("LEX_LLM_LOCAL_MODEL", v); }
342 assert_eq!(host, DEFAULT_OLLAMA_HOST);
343 assert_eq!(model, DEFAULT_LOCAL_MODEL);
344 }
345}