1use serde_json::{json, Value as JsonValue};
66use std::time::Duration;
67
68const DEFAULT_OLLAMA_HOST: &str = "http://localhost:11434";
69const DEFAULT_LOCAL_MODEL: &str = "llama3";
70const DEFAULT_OPENAI_BASE: &str = "https://api.openai.com/v1";
71const DEFAULT_CLOUD_MODEL: &str = "gpt-4o-mini";
72const HTTP_TIMEOUT_SECS: u64 = 120;
73
74fn local_config() -> (String, String) {
77 let host = std::env::var("OLLAMA_HOST")
78 .unwrap_or_else(|_| DEFAULT_OLLAMA_HOST.to_string());
79 let model = std::env::var("LEX_LLM_LOCAL_MODEL")
80 .unwrap_or_else(|_| DEFAULT_LOCAL_MODEL.to_string());
81 (host, model)
82}
83
84fn cloud_config() -> Result<(String, String, String), String> {
102 let key = pick_env(&["LEX_LLM_CLOUD_API_KEY", "OPENAI_API_KEY"])
103 .ok_or_else(||
104 "agent.cloud_complete: neither LEX_LLM_CLOUD_API_KEY nor OPENAI_API_KEY env var set"
105 .to_string())?;
106 let base = pick_env(&["LEX_LLM_CLOUD_BASE_URL", "OPENAI_BASE_URL"])
107 .unwrap_or_else(|| DEFAULT_OPENAI_BASE.to_string());
108 let model = std::env::var("LEX_LLM_CLOUD_MODEL")
109 .unwrap_or_else(|_| DEFAULT_CLOUD_MODEL.to_string());
110 Ok((base, model, key))
111}
112
113fn pick_env(names: &[&str]) -> Option<String> {
115 for n in names {
116 if let Ok(v) = std::env::var(n) {
117 if !v.is_empty() { return Some(v); }
118 }
119 }
120 None
121}
122
123pub(crate) fn ollama_request_body(model: &str, prompt: &str) -> JsonValue {
127 json!({
128 "model": model,
129 "prompt": prompt,
130 "stream": false,
134 })
135}
136
137pub(crate) fn openai_request_body(model: &str, prompt: &str) -> JsonValue {
142 json!({
143 "model": model,
144 "messages": [{ "role": "user", "content": prompt }],
145 })
146}
147
148fn ollama_extract(resp: &JsonValue) -> Result<String, String> {
151 resp.get("response")
152 .and_then(|v| v.as_str())
153 .map(String::from)
154 .ok_or_else(|| format!(
155 "ollama: response missing `response` field: {}",
156 resp.to_string().chars().take(200).collect::<String>()
157 ))
158}
159
160fn openai_extract(resp: &JsonValue) -> Result<String, String> {
163 resp.pointer("/choices/0/message/content")
164 .and_then(|v| v.as_str())
165 .map(String::from)
166 .ok_or_else(|| format!(
167 "openai: response missing choices[0].message.content: {}",
168 resp.to_string().chars().take(200).collect::<String>()
169 ))
170}
171
172fn http_agent() -> ureq::Agent {
176 ureq::Agent::config_builder()
177 .timeout_global(Some(Duration::from_secs(HTTP_TIMEOUT_SECS)))
178 .http_status_as_error(false)
179 .build()
180 .new_agent()
181}
182
183fn read_body_json(mut resp: ureq::http::Response<ureq::Body>) -> Result<JsonValue, String> {
184 let bytes = resp.body_mut().read_to_vec()
185 .map_err(|e| format!("read response body: {e}"))?;
186 serde_json::from_slice(&bytes)
187 .map_err(|e| format!("parse response JSON: {e}"))
188}
189
190pub fn local_complete(prompt: &str) -> Result<String, String> {
193 let (host, model) = local_config();
194 let url = format!("{}/api/generate", host.trim_end_matches('/'));
195 let body = serde_json::to_vec(&ollama_request_body(&model, prompt))
196 .map_err(|e| format!("serialize ollama request: {e}"))?;
197 let resp = http_agent().post(&url)
198 .header("content-type", "application/json")
199 .send(&body[..])
200 .map_err(|e| format!("ollama POST {url}: {e}"))?;
201 let json = read_body_json(resp).map_err(|e| format!("ollama: {e}"))?;
202 ollama_extract(&json)
203}
204
205pub fn cloud_complete(prompt: &str) -> Result<String, String> {
208 let (base, model, key) = cloud_config()?;
209 let url = format!("{}/chat/completions", base.trim_end_matches('/'));
210 let body = serde_json::to_vec(&openai_request_body(&model, prompt))
211 .map_err(|e| format!("serialize cloud request: {e}"))?;
212 let resp = http_agent().post(&url)
213 .header("content-type", "application/json")
214 .header("Authorization", &format!("Bearer {key}"))
215 .send(&body[..])
216 .map_err(|e| format!("cloud POST {url}: {e}"))?;
217 let json = read_body_json(resp).map_err(|e| format!("cloud: {e}"))?;
218 openai_extract(&json)
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn ollama_body_is_non_streaming() {
227 let b = ollama_request_body("llama3", "hello");
228 assert_eq!(b["model"], "llama3");
229 assert_eq!(b["prompt"], "hello");
230 assert_eq!(b["stream"], false);
231 }
232
233 #[test]
234 fn openai_body_uses_user_role() {
235 let b = openai_request_body("gpt-4o-mini", "hello");
236 assert_eq!(b["model"], "gpt-4o-mini");
237 assert_eq!(b["messages"][0]["role"], "user");
238 assert_eq!(b["messages"][0]["content"], "hello");
239 }
240
241 #[test]
242 fn ollama_extract_pulls_response_field() {
243 let r = json!({"model": "llama3", "response": "hi back", "done": true});
244 assert_eq!(ollama_extract(&r).unwrap(), "hi back");
245 }
246
247 #[test]
248 fn ollama_extract_errors_on_missing_field() {
249 let r = json!({"error": "model not found"});
250 let e = ollama_extract(&r).unwrap_err();
251 assert!(e.contains("missing `response`"));
252 }
253
254 #[test]
255 fn openai_extract_pulls_choices_zero_message_content() {
256 let r = json!({
257 "id": "x",
258 "choices": [{
259 "index": 0,
260 "message": { "role": "assistant", "content": "hi back" },
261 "finish_reason": "stop"
262 }]
263 });
264 assert_eq!(openai_extract(&r).unwrap(), "hi back");
265 }
266
267 #[test]
268 fn openai_extract_errors_on_missing_path() {
269 let r = json!({"error": {"message": "invalid api key"}});
270 let e = openai_extract(&r).unwrap_err();
271 assert!(e.contains("missing"));
272 }
273
274 #[test]
275 fn cloud_config_fails_without_api_key() {
276 let prior_lex = std::env::var("LEX_LLM_CLOUD_API_KEY").ok();
281 let prior_oai = std::env::var("OPENAI_API_KEY").ok();
282 std::env::remove_var("LEX_LLM_CLOUD_API_KEY");
283 std::env::remove_var("OPENAI_API_KEY");
284 let r = cloud_config();
285 if let Some(v) = prior_lex { std::env::set_var("LEX_LLM_CLOUD_API_KEY", v); }
286 if let Some(v) = prior_oai { std::env::set_var("OPENAI_API_KEY", v); }
287 let e = r.unwrap_err();
288 assert!(e.contains("LEX_LLM_CLOUD_API_KEY"));
289 }
290
291 #[test]
292 fn cloud_config_prefers_lex_prefix_then_falls_back_to_openai() {
293 let prior_lex_key = std::env::var("LEX_LLM_CLOUD_API_KEY").ok();
294 let prior_lex_url = std::env::var("LEX_LLM_CLOUD_BASE_URL").ok();
295 let prior_oai_key = std::env::var("OPENAI_API_KEY").ok();
296 let prior_oai_url = std::env::var("OPENAI_BASE_URL").ok();
297 std::env::set_var("LEX_LLM_CLOUD_API_KEY", "k-lex");
298 std::env::set_var("OPENAI_API_KEY", "k-openai");
299 std::env::set_var("LEX_LLM_CLOUD_BASE_URL", "https://api.mistral.ai/v1");
300 std::env::remove_var("OPENAI_BASE_URL");
301 let (base, _model, key) = cloud_config().unwrap();
302 let restore = |name: &str, v: Option<String>| match v {
305 Some(s) => std::env::set_var(name, s),
306 None => std::env::remove_var(name),
307 };
308 restore("LEX_LLM_CLOUD_API_KEY", prior_lex_key);
309 restore("LEX_LLM_CLOUD_BASE_URL", prior_lex_url);
310 restore("OPENAI_API_KEY", prior_oai_key);
311 restore("OPENAI_BASE_URL", prior_oai_url);
312 assert_eq!(key, "k-lex");
313 assert_eq!(base, "https://api.mistral.ai/v1");
314 }
315
316 #[test]
317 fn local_config_uses_defaults_without_env() {
318 let prior_h = std::env::var("OLLAMA_HOST").ok();
319 let prior_m = std::env::var("LEX_LLM_LOCAL_MODEL").ok();
320 std::env::remove_var("OLLAMA_HOST");
321 std::env::remove_var("LEX_LLM_LOCAL_MODEL");
322 let (host, model) = local_config();
323 if let Some(v) = prior_h { std::env::set_var("OLLAMA_HOST", v); }
324 if let Some(v) = prior_m { std::env::set_var("LEX_LLM_LOCAL_MODEL", v); }
325 assert_eq!(host, DEFAULT_OLLAMA_HOST);
326 assert_eq!(model, DEFAULT_LOCAL_MODEL);
327 }
328}