1use std::time::Duration;
24
25use crate::runtime::WordSuggestions;
26use crate::secrets;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
32enum Backend {
33 Anthropic,
35 Gemini,
37 OpenAiCompatible {
39 base_url: String,
41 max_completion_tokens: bool,
45 requires_key: bool,
48 },
49}
50
51fn openai_cloud(base: &str, completion_tokens: bool) -> Backend {
54 Backend::OpenAiCompatible {
55 base_url: base.to_string(),
56 max_completion_tokens: completion_tokens,
57 requires_key: true,
58 }
59}
60
61fn resolve_backend(backend: &str, base_url: Option<&str>) -> Option<Backend> {
66 match backend.trim().to_ascii_lowercase().as_str() {
67 "anthropic" => Some(Backend::Anthropic),
68 "gemini" => Some(Backend::Gemini),
69 "openai" => Some(openai_cloud("https://api.openai.com/v1", true)),
72 "openrouter" => Some(openai_cloud("https://openrouter.ai/api/v1", false)),
73 "mistral" => Some(openai_cloud("https://api.mistral.ai/v1", false)),
74 "groq" => Some(openai_cloud("https://api.groq.com/openai/v1", false)),
75 "deepseek" => Some(openai_cloud("https://api.deepseek.com/v1", false)),
76 "xai" => Some(openai_cloud("https://api.x.ai/v1", false)),
77 "openai-compatible" | "custom" => {
78 let base = base_url.map(str::trim).filter(|s| !s.is_empty())?;
79 Some(Backend::OpenAiCompatible {
80 base_url: base.trim_end_matches('/').to_string(),
81 max_completion_tokens: false,
82 requires_key: false,
83 })
84 }
85 _ => None,
86 }
87}
88
89pub fn is_backend_wired(backend: &str) -> bool {
95 resolve_backend(backend, Some("https://example.invalid")).is_some()
98}
99
100pub fn key_name(backend: &str) -> String {
105 format!("llm.{backend}")
106}
107
108const ANTHROPIC_URL: &str = "https://api.anthropic.com/v1/messages";
109const ANTHROPIC_MODELS_URL: &str = "https://api.anthropic.com/v1/models";
110const ANTHROPIC_VERSION: &str = "2023-06-01";
111const GEMINI_URL_PREFIX: &str = "https://generativelanguage.googleapis.com/v1beta/models";
112const DEFAULT_MAX_TOKENS: u32 = 1024;
113
114const SYSTEM_PROMPT: &str = "You are a spelling, typo, and minor-grammar corrector. Return ONLY the \
115 corrected version of the user's text — no preamble, no commentary, no \
116 quotation marks. Preserve the user's voice, register, and punctuation \
117 style. If the text is already fine, return it unchanged.";
118
119const WORD_SYSTEM_PROMPT: &str = "You correct ONE word at a time using sentence context. The \
120 user gives you a SENTENCE and one WORD from it to correct. Return ONLY the corrected \
121 version of that word — nothing else: no quotes, no punctuation, no commentary, no rest \
122 of the sentence. Use the rest of the sentence to disambiguate homophones \
123 (their/there/they're, its/it's, your/you're, etc.) and to pick the right fix for typos. \
124 Preserve the original casing of the word's first letter. If the word is already correct \
125 in context, return it unchanged.";
126
127const ALTERNATIVES_SYSTEM_PROMPT: &str = "You are a spelling, typo, and minor-grammar corrector. \
128 Correct the user's text and reply with ONLY a JSON object — no preamble, no commentary, no code \
129 fences — shaped exactly like: {\"corrected\": \"<the corrected text>\", \"alternatives\": \
130 [{\"word\": \"<a word you changed>\", \"options\": [\"best\", \"next\", \"...\"]}]}. Include an \
131 `alternatives` entry only for words you changed; give 3 to 5 ranked options each, best first, with \
132 the option you actually used in `corrected` listed first. Use sentence context for homophones \
133 (their/there/they're, its/it's, your/you're). Preserve the user's voice, register, casing, and \
134 punctuation. If the text is already correct, return it unchanged with an empty `alternatives` array.";
135
136#[derive(Debug, thiserror::Error)]
138pub enum LlmError {
139 #[error("no API key for the LLM provider — set one in Preferences → Providers")]
141 NoApiKey,
142 #[error("keychain: {0}")]
144 Keychain(String),
145 #[error("unsupported LLM backend: {0}")]
147 UnsupportedBackend(String),
148 #[error("LLM request failed: {0}")]
150 Request(String),
151 #[error("LLM response was unparseable: {0}")]
153 Response(String),
154}
155
156pub struct LlmProvider {
158 backend: Backend,
159 api_key: String,
160 model: String,
161}
162
163impl std::fmt::Debug for LlmProvider {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("LlmProvider")
167 .field("backend", &self.backend)
168 .field("model", &self.model)
169 .field("api_key", &"[redacted]")
170 .finish()
171 }
172}
173
174impl LlmProvider {
175 pub fn from_config(llm: &crate::LlmConfig) -> Result<Self, LlmError> {
186 let backend = resolve_backend(&llm.backend, llm.base_url.as_deref())
187 .ok_or_else(|| LlmError::UnsupportedBackend(llm.backend.clone()))?;
188 let requires_key = match &backend {
189 Backend::OpenAiCompatible { requires_key, .. } => *requires_key,
190 Backend::Anthropic | Backend::Gemini => true,
192 };
193 let api_key = secrets::get(&key_name(&llm.backend))
194 .map_err(|e| LlmError::Keychain(e.to_string()))?
195 .unwrap_or_default();
196 if requires_key && api_key.is_empty() {
197 return Err(LlmError::NoApiKey);
198 }
199 Ok(Self {
200 backend,
201 api_key,
202 model: llm.model.clone(),
203 })
204 }
205
206 pub fn rewrite(&self, text: &str) -> Result<String, LlmError> {
214 if text.trim().is_empty() {
215 return Ok(text.to_string());
216 }
217 self.request(SYSTEM_PROMPT, text.to_string())
218 }
219
220 pub fn fix_word_in_context(&self, sentence: &str, word: &str) -> Result<String, LlmError> {
231 if word.trim().is_empty() {
232 return Ok(word.to_string());
233 }
234 let content = format!("SENTENCE: {sentence}\nWORD: {word}");
235 let corrected = self.request(WORD_SYSTEM_PROMPT, content)?;
236 Ok(corrected
240 .trim()
241 .trim_matches(|c: char| c == '"' || c == '\'')
242 .to_string())
243 }
244
245 pub fn rewrite_with_alternatives(
257 &self,
258 text: &str,
259 ) -> Result<(String, Vec<WordSuggestions>), LlmError> {
260 if text.trim().is_empty() {
261 return Ok((text.to_string(), Vec::new()));
262 }
263 let reply = self.request(ALTERNATIVES_SYSTEM_PROMPT, text.to_string())?;
264 parse_alternatives(&reply)
265 }
266
267 fn request(&self, system: &str, content: String) -> Result<String, LlmError> {
272 match &self.backend {
273 Backend::Anthropic => self.request_anthropic(system, content),
274 Backend::Gemini => self.request_gemini(system, content),
275 Backend::OpenAiCompatible {
276 base_url,
277 max_completion_tokens,
278 ..
279 } => self.request_openai(base_url, *max_completion_tokens, system, content),
280 }
281 }
282
283 fn request_anthropic(&self, system: &str, content: String) -> Result<String, LlmError> {
284 let body = serde_json::json!({
285 "model": self.model,
286 "max_tokens": DEFAULT_MAX_TOKENS,
287 "system": system,
288 "messages": [{ "role": "user", "content": content }],
289 });
290 let json = agent()
291 .post(ANTHROPIC_URL)
292 .set("x-api-key", &self.api_key)
293 .set("anthropic-version", ANTHROPIC_VERSION)
294 .set("content-type", "application/json")
295 .send_json(body)
296 .map_err(|e| LlmError::Request(e.to_string()))?
297 .into_json::<serde_json::Value>()
298 .map_err(|e| LlmError::Response(e.to_string()))?;
299 parse_anthropic_reply(&json)
300 }
301
302 fn request_openai(
306 &self,
307 base: &str,
308 max_completion_tokens: bool,
309 system: &str,
310 content: String,
311 ) -> Result<String, LlmError> {
312 let token_field = if max_completion_tokens {
313 "max_completion_tokens"
314 } else {
315 "max_tokens"
316 };
317 let mut body = serde_json::json!({
318 "model": self.model,
319 "messages": [
320 { "role": "system", "content": system },
321 { "role": "user", "content": content },
322 ],
323 });
324 body[token_field] = DEFAULT_MAX_TOKENS.into();
325
326 let url = format!("{base}/chat/completions");
327 let mut req = agent().post(&url).set("content-type", "application/json");
328 if !self.api_key.is_empty() {
331 req = req.set("authorization", &format!("Bearer {}", self.api_key));
332 }
333 let json = req
334 .send_json(body)
335 .map_err(|e| LlmError::Request(e.to_string()))?
336 .into_json::<serde_json::Value>()
337 .map_err(|e| LlmError::Response(e.to_string()))?;
338 parse_openai_reply(&json)
339 }
340
341 fn request_gemini(&self, system: &str, content: String) -> Result<String, LlmError> {
342 let url = format!("{GEMINI_URL_PREFIX}/{}:generateContent", self.model);
344 let body = serde_json::json!({
345 "system_instruction": { "parts": [{ "text": system }] },
346 "contents": [{ "parts": [{ "text": content }] }],
347 "generationConfig": { "maxOutputTokens": DEFAULT_MAX_TOKENS },
348 });
349 let json = agent()
350 .post(&url)
351 .set("x-goog-api-key", &self.api_key)
352 .set("content-type", "application/json")
353 .send_json(body)
354 .map_err(|e| LlmError::Request(e.to_string()))?
355 .into_json::<serde_json::Value>()
356 .map_err(|e| LlmError::Response(e.to_string()))?;
357 parse_gemini_reply(&json)
358 }
359}
360
361pub fn validate_key(backend: &str, base_url: Option<&str>, key: &str) -> Result<(), LlmError> {
369 let resolved = resolve_backend(backend, base_url)
370 .ok_or_else(|| LlmError::UnsupportedBackend(backend.to_string()))?;
371 let key = key.trim();
372 match resolved {
373 Backend::Anthropic => validate_get(
374 ANTHROPIC_MODELS_URL,
375 &[("x-api-key", key), ("anthropic-version", ANTHROPIC_VERSION)],
376 ),
377 Backend::Gemini => validate_get(GEMINI_URL_PREFIX, &[("x-goog-api-key", key)]),
378 Backend::OpenAiCompatible {
379 base_url,
380 requires_key,
381 ..
382 } => {
383 let url = format!("{base_url}/models");
384 if key.is_empty() {
385 if requires_key {
386 return Err(LlmError::NoApiKey);
387 }
388 validate_get(&url, &[])
390 } else {
391 validate_get(&url, &[("authorization", &format!("Bearer {key}"))])
392 }
393 }
394 }
395}
396
397fn validate_get(url: &str, headers: &[(&str, &str)]) -> Result<(), LlmError> {
400 let mut req = agent().get(url);
401 for (k, v) in headers {
402 req = req.set(k, v);
403 }
404 match req.call() {
405 Ok(_) => Ok(()),
406 Err(ureq::Error::Status(code, _)) if (400..500).contains(&code) => {
407 Err(LlmError::Request(format!("key rejected (HTTP {code})")))
408 }
409 Err(e) => Err(LlmError::Request(e.to_string())),
410 }
411}
412
413fn agent() -> ureq::Agent {
414 ureq::AgentBuilder::new()
415 .timeout(Duration::from_secs(20))
416 .build()
417}
418
419fn parse_anthropic_reply(json: &serde_json::Value) -> Result<String, LlmError> {
422 let text = json["content"]
423 .as_array()
424 .and_then(|parts| {
425 parts
426 .iter()
427 .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
428 .next()
429 })
430 .ok_or_else(|| LlmError::Response("no `content[*].text` in response".into()))?;
431 Ok(text.trim_end_matches('\n').to_string())
432}
433
434fn parse_openai_reply(json: &serde_json::Value) -> Result<String, LlmError> {
437 let text = json["choices"][0]["message"]["content"]
438 .as_str()
439 .ok_or_else(|| LlmError::Response("no `choices[0].message.content` in response".into()))?;
440 Ok(text.trim_end_matches('\n').to_string())
441}
442
443fn parse_gemini_reply(json: &serde_json::Value) -> Result<String, LlmError> {
446 let text = json["candidates"][0]["content"]["parts"]
447 .as_array()
448 .and_then(|parts| {
449 parts
450 .iter()
451 .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
452 .next()
453 })
454 .ok_or_else(|| {
455 LlmError::Response("no `candidates[0].content.parts[*].text` in response".into())
456 })?;
457 Ok(text.trim_end_matches('\n').to_string())
458}
459
460fn parse_alternatives(reply: &str) -> Result<(String, Vec<WordSuggestions>), LlmError> {
465 let json = json_object_slice(reply);
466 let v: serde_json::Value = serde_json::from_str(json)
467 .map_err(|e| LlmError::Response(format!("alternatives JSON: {e}")))?;
468 let corrected = v["corrected"]
469 .as_str()
470 .ok_or_else(|| LlmError::Response("no `corrected` string in response".into()))?
471 .to_string();
472 let mut alternatives = Vec::new();
473 if let Some(arr) = v["alternatives"].as_array() {
474 for item in arr {
475 let Some(word) = item["word"].as_str() else {
476 continue;
477 };
478 let options: Vec<String> = item["options"]
479 .as_array()
480 .into_iter()
481 .flatten()
482 .filter_map(|o| o.as_str().map(str::to_string))
483 .collect();
484 if !options.is_empty() {
485 alternatives.push(WordSuggestions {
486 word: word.to_string(),
487 options,
488 });
489 }
490 }
491 }
492 Ok((corrected, alternatives))
493}
494
495fn json_object_slice(s: &str) -> &str {
498 match (s.find('{'), s.rfind('}')) {
499 (Some(a), Some(b)) if b >= a => &s[a..=b],
500 _ => s,
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::LlmConfig;
508
509 #[test]
510 fn parses_alternatives_reply() {
511 let reply = r#"{"corrected":"the quick brown fox",
512 "alternatives":[
513 {"word":"the","options":["the","then","they"]},
514 {"word":"brown","options":["brown","browne","crown"]}
515 ]}"#;
516 let (corrected, alts) = parse_alternatives(reply).unwrap();
517 assert_eq!(corrected, "the quick brown fox");
518 assert_eq!(alts.len(), 2);
519 assert_eq!(alts[0].word, "the");
520 assert_eq!(alts[0].options, vec!["the", "then", "they"]);
521 assert_eq!(alts[1].word, "brown");
522 }
523
524 #[test]
525 fn tolerates_code_fences_and_preamble() {
526 let reply = "Here you go:\n```json\n{\"corrected\":\"hi there\",\"alternatives\":[]}\n```";
527 let (corrected, alts) = parse_alternatives(reply).unwrap();
528 assert_eq!(corrected, "hi there");
529 assert!(alts.is_empty());
530 }
531
532 #[test]
533 fn non_json_reply_is_an_error() {
534 assert!(parse_alternatives("sorry, I cannot do that").is_err());
535 }
536
537 #[test]
538 fn unsupported_backend_is_rejected_cleanly() {
539 let cfg = LlmConfig {
542 backend: "made-up-vendor".into(),
543 model: "whatever".into(),
544 base_url: None,
545 };
546 match LlmProvider::from_config(&cfg) {
547 Err(LlmError::UnsupportedBackend(name)) => assert_eq!(name, "made-up-vendor"),
548 other => panic!("expected UnsupportedBackend, got {other:?}"),
549 }
550 }
551
552 #[test]
553 fn custom_endpoint_without_base_url_is_unsupported() {
554 let cfg = LlmConfig {
558 backend: "openai-compatible".into(),
559 model: "llama3.1".into(),
560 base_url: None,
561 };
562 assert!(matches!(
563 LlmProvider::from_config(&cfg),
564 Err(LlmError::UnsupportedBackend(_))
565 ));
566 }
567
568 #[test]
569 fn key_name_and_wiring_are_stable() {
570 assert_eq!(key_name("anthropic"), "llm.anthropic");
573 assert_eq!(key_name("openai"), "llm.openai");
574 for b in [
576 "anthropic",
577 "openai",
578 "gemini",
579 "openrouter",
580 "mistral",
581 "groq",
582 "deepseek",
583 "xai",
584 "openai-compatible",
585 ] {
586 assert!(is_backend_wired(b), "{b} should be wired");
587 }
588 assert!(is_backend_wired("OpenAI"));
590 assert!(!is_backend_wired("made-up-vendor"));
591 }
592
593 #[test]
594 fn resolve_backend_picks_the_right_shape_and_url() {
595 assert_eq!(
597 resolve_backend("openai", None),
598 Some(Backend::OpenAiCompatible {
599 base_url: "https://api.openai.com/v1".into(),
600 max_completion_tokens: true,
601 requires_key: true,
602 })
603 );
604 assert_eq!(
605 resolve_backend("groq", None),
606 Some(Backend::OpenAiCompatible {
607 base_url: "https://api.groq.com/openai/v1".into(),
608 max_completion_tokens: false,
609 requires_key: true,
610 })
611 );
612 assert_eq!(resolve_backend("anthropic", None), Some(Backend::Anthropic));
613 assert_eq!(resolve_backend("gemini", None), Some(Backend::Gemini));
614 assert_eq!(
617 resolve_backend("openai-compatible", Some("http://localhost:11434/v1/")),
618 Some(Backend::OpenAiCompatible {
619 base_url: "http://localhost:11434/v1".into(),
620 max_completion_tokens: false,
621 requires_key: false,
622 })
623 );
624 assert_eq!(resolve_backend("openai-compatible", Some(" ")), None);
625 assert_eq!(resolve_backend("nope", Some("http://x")), None);
626 }
627
628 #[test]
629 fn parses_each_provider_reply_shape() {
630 let anthropic = serde_json::json!({
631 "content": [{ "type": "text", "text": "fixed\n" }]
632 });
633 assert_eq!(parse_anthropic_reply(&anthropic).unwrap(), "fixed");
634
635 let openai = serde_json::json!({
636 "choices": [{ "message": { "role": "assistant", "content": "fixed\n" } }]
637 });
638 assert_eq!(parse_openai_reply(&openai).unwrap(), "fixed");
639
640 let gemini = serde_json::json!({
641 "candidates": [{ "content": { "parts": [{ "text": "fixed\n" }] } }]
642 });
643 assert_eq!(parse_gemini_reply(&gemini).unwrap(), "fixed");
644
645 assert!(parse_openai_reply(&anthropic).is_err());
647 assert!(parse_gemini_reply(&openai).is_err());
648 }
649}