1use std::time::Duration;
24
25use crate::runtime::WordSuggestions;
26use crate::secrets;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
32enum Backend {
33 Anthropic,
35 Gemini,
37 OpenAiCompatible {
39 base_url: String,
41 max_completion_tokens: bool,
45 requires_key: bool,
48 },
49}
50
51fn openai_cloud(base: &str, completion_tokens: bool) -> Backend {
54 Backend::OpenAiCompatible {
55 base_url: base.to_string(),
56 max_completion_tokens: completion_tokens,
57 requires_key: true,
58 }
59}
60
61fn resolve_backend(backend: &str, base_url: Option<&str>) -> Option<Backend> {
66 match backend.trim().to_ascii_lowercase().as_str() {
67 "anthropic" => Some(Backend::Anthropic),
68 "gemini" => Some(Backend::Gemini),
69 "openai" => Some(openai_cloud("https://api.openai.com/v1", true)),
72 "openrouter" => Some(openai_cloud("https://openrouter.ai/api/v1", false)),
73 "mistral" => Some(openai_cloud("https://api.mistral.ai/v1", false)),
74 "groq" => Some(openai_cloud("https://api.groq.com/openai/v1", false)),
75 "deepseek" => Some(openai_cloud("https://api.deepseek.com/v1", false)),
76 "xai" => Some(openai_cloud("https://api.x.ai/v1", false)),
77 "openai-compatible" | "custom" => {
78 let base = base_url.map(str::trim).filter(|s| !s.is_empty())?;
79 Some(Backend::OpenAiCompatible {
80 base_url: base.trim_end_matches('/').to_string(),
81 max_completion_tokens: false,
82 requires_key: false,
83 })
84 }
85 _ => None,
86 }
87}
88
89pub fn is_backend_wired(backend: &str) -> bool {
95 resolve_backend(backend, Some("https://example.invalid")).is_some()
98}
99
100pub fn key_name(backend: &str) -> String {
105 format!("llm.{backend}")
106}
107
108const ANTHROPIC_URL: &str = "https://api.anthropic.com/v1/messages";
109const ANTHROPIC_VERSION: &str = "2023-06-01";
110const GEMINI_URL_PREFIX: &str = "https://generativelanguage.googleapis.com/v1beta/models";
111const DEFAULT_MAX_TOKENS: u32 = 1024;
112
113const SYSTEM_PROMPT: &str = "You are a spelling, typo, and minor-grammar corrector. Return ONLY the \
114 corrected version of the user's text — no preamble, no commentary, no \
115 quotation marks. Preserve the user's voice, register, and punctuation \
116 style. If the text is already fine, return it unchanged.";
117
118const WORD_SYSTEM_PROMPT: &str = "You correct ONE word at a time using sentence context. The \
119 user gives you a SENTENCE and one WORD from it to correct. Return ONLY the corrected \
120 version of that word — nothing else: no quotes, no punctuation, no commentary, no rest \
121 of the sentence. Use the rest of the sentence to disambiguate homophones \
122 (their/there/they're, its/it's, your/you're, etc.) and to pick the right fix for typos. \
123 Preserve the original casing of the word's first letter. If the word is already correct \
124 in context, return it unchanged.";
125
126const ALTERNATIVES_SYSTEM_PROMPT: &str = "You are a spelling, typo, and minor-grammar corrector. \
127 Correct the user's text and reply with ONLY a JSON object — no preamble, no commentary, no code \
128 fences — shaped exactly like: {\"corrected\": \"<the corrected text>\", \"alternatives\": \
129 [{\"word\": \"<a word you changed>\", \"options\": [\"best\", \"next\", \"...\"]}]}. Include an \
130 `alternatives` entry only for words you changed; give 3 to 5 ranked options each, best first, with \
131 the option you actually used in `corrected` listed first. Use sentence context for homophones \
132 (their/there/they're, its/it's, your/you're). Preserve the user's voice, register, casing, and \
133 punctuation. If the text is already correct, return it unchanged with an empty `alternatives` array.";
134
135#[derive(Debug, thiserror::Error)]
137pub enum LlmError {
138 #[error("no API key for the LLM provider — set one in Preferences → Providers")]
140 NoApiKey,
141 #[error("keychain: {0}")]
143 Keychain(String),
144 #[error("unsupported LLM backend: {0}")]
146 UnsupportedBackend(String),
147 #[error("LLM request failed: {0}")]
149 Request(String),
150 #[error("LLM response was unparseable: {0}")]
152 Response(String),
153}
154
155pub struct LlmProvider {
157 backend: Backend,
158 api_key: String,
159 model: String,
160}
161
162impl std::fmt::Debug for LlmProvider {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("LlmProvider")
166 .field("backend", &self.backend)
167 .field("model", &self.model)
168 .field("api_key", &"[redacted]")
169 .finish()
170 }
171}
172
173impl LlmProvider {
174 pub fn from_config(llm: &crate::LlmConfig) -> Result<Self, LlmError> {
185 let backend = resolve_backend(&llm.backend, llm.base_url.as_deref())
186 .ok_or_else(|| LlmError::UnsupportedBackend(llm.backend.clone()))?;
187 let requires_key = match &backend {
188 Backend::OpenAiCompatible { requires_key, .. } => *requires_key,
189 Backend::Anthropic | Backend::Gemini => true,
191 };
192 let api_key = secrets::get(&key_name(&llm.backend))
193 .map_err(|e| LlmError::Keychain(e.to_string()))?
194 .unwrap_or_default();
195 if requires_key && api_key.is_empty() {
196 return Err(LlmError::NoApiKey);
197 }
198 Ok(Self {
199 backend,
200 api_key,
201 model: llm.model.clone(),
202 })
203 }
204
205 pub fn rewrite(&self, text: &str) -> Result<String, LlmError> {
213 if text.trim().is_empty() {
214 return Ok(text.to_string());
215 }
216 self.request(SYSTEM_PROMPT, text.to_string())
217 }
218
219 pub fn fix_word_in_context(&self, sentence: &str, word: &str) -> Result<String, LlmError> {
230 if word.trim().is_empty() {
231 return Ok(word.to_string());
232 }
233 let content = format!("SENTENCE: {sentence}\nWORD: {word}");
234 let corrected = self.request(WORD_SYSTEM_PROMPT, content)?;
235 Ok(corrected
239 .trim()
240 .trim_matches(|c: char| c == '"' || c == '\'')
241 .to_string())
242 }
243
244 pub fn rewrite_with_alternatives(
256 &self,
257 text: &str,
258 ) -> Result<(String, Vec<WordSuggestions>), LlmError> {
259 if text.trim().is_empty() {
260 return Ok((text.to_string(), Vec::new()));
261 }
262 let reply = self.request(ALTERNATIVES_SYSTEM_PROMPT, text.to_string())?;
263 parse_alternatives(&reply)
264 }
265
266 fn request(&self, system: &str, content: String) -> Result<String, LlmError> {
271 match &self.backend {
272 Backend::Anthropic => self.request_anthropic(system, content),
273 Backend::Gemini => self.request_gemini(system, content),
274 Backend::OpenAiCompatible {
275 base_url,
276 max_completion_tokens,
277 ..
278 } => self.request_openai(base_url, *max_completion_tokens, system, content),
279 }
280 }
281
282 fn request_anthropic(&self, system: &str, content: String) -> Result<String, LlmError> {
283 let body = serde_json::json!({
284 "model": self.model,
285 "max_tokens": DEFAULT_MAX_TOKENS,
286 "system": system,
287 "messages": [{ "role": "user", "content": content }],
288 });
289 let json = agent()
290 .post(ANTHROPIC_URL)
291 .set("x-api-key", &self.api_key)
292 .set("anthropic-version", ANTHROPIC_VERSION)
293 .set("content-type", "application/json")
294 .send_json(body)
295 .map_err(|e| LlmError::Request(e.to_string()))?
296 .into_json::<serde_json::Value>()
297 .map_err(|e| LlmError::Response(e.to_string()))?;
298 parse_anthropic_reply(&json)
299 }
300
301 fn request_openai(
305 &self,
306 base: &str,
307 max_completion_tokens: bool,
308 system: &str,
309 content: String,
310 ) -> Result<String, LlmError> {
311 let token_field = if max_completion_tokens {
312 "max_completion_tokens"
313 } else {
314 "max_tokens"
315 };
316 let mut body = serde_json::json!({
317 "model": self.model,
318 "messages": [
319 { "role": "system", "content": system },
320 { "role": "user", "content": content },
321 ],
322 });
323 body[token_field] = DEFAULT_MAX_TOKENS.into();
324
325 let url = format!("{base}/chat/completions");
326 let mut req = agent().post(&url).set("content-type", "application/json");
327 if !self.api_key.is_empty() {
330 req = req.set("authorization", &format!("Bearer {}", self.api_key));
331 }
332 let json = req
333 .send_json(body)
334 .map_err(|e| LlmError::Request(e.to_string()))?
335 .into_json::<serde_json::Value>()
336 .map_err(|e| LlmError::Response(e.to_string()))?;
337 parse_openai_reply(&json)
338 }
339
340 fn request_gemini(&self, system: &str, content: String) -> Result<String, LlmError> {
341 let url = format!("{GEMINI_URL_PREFIX}/{}:generateContent", self.model);
343 let body = serde_json::json!({
344 "system_instruction": { "parts": [{ "text": system }] },
345 "contents": [{ "parts": [{ "text": content }] }],
346 "generationConfig": { "maxOutputTokens": DEFAULT_MAX_TOKENS },
347 });
348 let json = agent()
349 .post(&url)
350 .set("x-goog-api-key", &self.api_key)
351 .set("content-type", "application/json")
352 .send_json(body)
353 .map_err(|e| LlmError::Request(e.to_string()))?
354 .into_json::<serde_json::Value>()
355 .map_err(|e| LlmError::Response(e.to_string()))?;
356 parse_gemini_reply(&json)
357 }
358}
359
360fn agent() -> ureq::Agent {
363 ureq::AgentBuilder::new()
364 .timeout(Duration::from_secs(20))
365 .build()
366}
367
368fn parse_anthropic_reply(json: &serde_json::Value) -> Result<String, LlmError> {
371 let text = json["content"]
372 .as_array()
373 .and_then(|parts| {
374 parts
375 .iter()
376 .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
377 .next()
378 })
379 .ok_or_else(|| LlmError::Response("no `content[*].text` in response".into()))?;
380 Ok(text.trim_end_matches('\n').to_string())
381}
382
383fn parse_openai_reply(json: &serde_json::Value) -> Result<String, LlmError> {
386 let text = json["choices"][0]["message"]["content"]
387 .as_str()
388 .ok_or_else(|| LlmError::Response("no `choices[0].message.content` in response".into()))?;
389 Ok(text.trim_end_matches('\n').to_string())
390}
391
392fn parse_gemini_reply(json: &serde_json::Value) -> Result<String, LlmError> {
395 let text = json["candidates"][0]["content"]["parts"]
396 .as_array()
397 .and_then(|parts| {
398 parts
399 .iter()
400 .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
401 .next()
402 })
403 .ok_or_else(|| {
404 LlmError::Response("no `candidates[0].content.parts[*].text` in response".into())
405 })?;
406 Ok(text.trim_end_matches('\n').to_string())
407}
408
409fn parse_alternatives(reply: &str) -> Result<(String, Vec<WordSuggestions>), LlmError> {
414 let json = json_object_slice(reply);
415 let v: serde_json::Value = serde_json::from_str(json)
416 .map_err(|e| LlmError::Response(format!("alternatives JSON: {e}")))?;
417 let corrected = v["corrected"]
418 .as_str()
419 .ok_or_else(|| LlmError::Response("no `corrected` string in response".into()))?
420 .to_string();
421 let mut alternatives = Vec::new();
422 if let Some(arr) = v["alternatives"].as_array() {
423 for item in arr {
424 let Some(word) = item["word"].as_str() else {
425 continue;
426 };
427 let options: Vec<String> = item["options"]
428 .as_array()
429 .into_iter()
430 .flatten()
431 .filter_map(|o| o.as_str().map(str::to_string))
432 .collect();
433 if !options.is_empty() {
434 alternatives.push(WordSuggestions {
435 word: word.to_string(),
436 options,
437 });
438 }
439 }
440 }
441 Ok((corrected, alternatives))
442}
443
444fn json_object_slice(s: &str) -> &str {
447 match (s.find('{'), s.rfind('}')) {
448 (Some(a), Some(b)) if b >= a => &s[a..=b],
449 _ => s,
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::LlmConfig;
457
458 #[test]
459 fn parses_alternatives_reply() {
460 let reply = r#"{"corrected":"the quick brown fox",
461 "alternatives":[
462 {"word":"the","options":["the","then","they"]},
463 {"word":"brown","options":["brown","browne","crown"]}
464 ]}"#;
465 let (corrected, alts) = parse_alternatives(reply).unwrap();
466 assert_eq!(corrected, "the quick brown fox");
467 assert_eq!(alts.len(), 2);
468 assert_eq!(alts[0].word, "the");
469 assert_eq!(alts[0].options, vec!["the", "then", "they"]);
470 assert_eq!(alts[1].word, "brown");
471 }
472
473 #[test]
474 fn tolerates_code_fences_and_preamble() {
475 let reply = "Here you go:\n```json\n{\"corrected\":\"hi there\",\"alternatives\":[]}\n```";
476 let (corrected, alts) = parse_alternatives(reply).unwrap();
477 assert_eq!(corrected, "hi there");
478 assert!(alts.is_empty());
479 }
480
481 #[test]
482 fn non_json_reply_is_an_error() {
483 assert!(parse_alternatives("sorry, I cannot do that").is_err());
484 }
485
486 #[test]
487 fn unsupported_backend_is_rejected_cleanly() {
488 let cfg = LlmConfig {
491 backend: "made-up-vendor".into(),
492 model: "whatever".into(),
493 base_url: None,
494 };
495 match LlmProvider::from_config(&cfg) {
496 Err(LlmError::UnsupportedBackend(name)) => assert_eq!(name, "made-up-vendor"),
497 other => panic!("expected UnsupportedBackend, got {other:?}"),
498 }
499 }
500
501 #[test]
502 fn custom_endpoint_without_base_url_is_unsupported() {
503 let cfg = LlmConfig {
507 backend: "openai-compatible".into(),
508 model: "llama3.1".into(),
509 base_url: None,
510 };
511 assert!(matches!(
512 LlmProvider::from_config(&cfg),
513 Err(LlmError::UnsupportedBackend(_))
514 ));
515 }
516
517 #[test]
518 fn key_name_and_wiring_are_stable() {
519 assert_eq!(key_name("anthropic"), "llm.anthropic");
522 assert_eq!(key_name("openai"), "llm.openai");
523 for b in [
525 "anthropic",
526 "openai",
527 "gemini",
528 "openrouter",
529 "mistral",
530 "groq",
531 "deepseek",
532 "xai",
533 "openai-compatible",
534 ] {
535 assert!(is_backend_wired(b), "{b} should be wired");
536 }
537 assert!(is_backend_wired("OpenAI"));
539 assert!(!is_backend_wired("made-up-vendor"));
540 }
541
542 #[test]
543 fn resolve_backend_picks_the_right_shape_and_url() {
544 assert_eq!(
546 resolve_backend("openai", None),
547 Some(Backend::OpenAiCompatible {
548 base_url: "https://api.openai.com/v1".into(),
549 max_completion_tokens: true,
550 requires_key: true,
551 })
552 );
553 assert_eq!(
554 resolve_backend("groq", None),
555 Some(Backend::OpenAiCompatible {
556 base_url: "https://api.groq.com/openai/v1".into(),
557 max_completion_tokens: false,
558 requires_key: true,
559 })
560 );
561 assert_eq!(resolve_backend("anthropic", None), Some(Backend::Anthropic));
562 assert_eq!(resolve_backend("gemini", None), Some(Backend::Gemini));
563 assert_eq!(
566 resolve_backend("openai-compatible", Some("http://localhost:11434/v1/")),
567 Some(Backend::OpenAiCompatible {
568 base_url: "http://localhost:11434/v1".into(),
569 max_completion_tokens: false,
570 requires_key: false,
571 })
572 );
573 assert_eq!(resolve_backend("openai-compatible", Some(" ")), None);
574 assert_eq!(resolve_backend("nope", Some("http://x")), None);
575 }
576
577 #[test]
578 fn parses_each_provider_reply_shape() {
579 let anthropic = serde_json::json!({
580 "content": [{ "type": "text", "text": "fixed\n" }]
581 });
582 assert_eq!(parse_anthropic_reply(&anthropic).unwrap(), "fixed");
583
584 let openai = serde_json::json!({
585 "choices": [{ "message": { "role": "assistant", "content": "fixed\n" } }]
586 });
587 assert_eq!(parse_openai_reply(&openai).unwrap(), "fixed");
588
589 let gemini = serde_json::json!({
590 "candidates": [{ "content": { "parts": [{ "text": "fixed\n" }] } }]
591 });
592 assert_eq!(parse_gemini_reply(&gemini).unwrap(), "fixed");
593
594 assert!(parse_openai_reply(&anthropic).is_err());
596 assert!(parse_gemini_reply(&openai).is_err());
597 }
598}