use std::time::Duration;
use crate::runtime::WordSuggestions;
use crate::secrets;
#[derive(Debug, Clone, PartialEq, Eq)]
enum Backend {
Anthropic,
Gemini,
OpenAiCompatible {
base_url: String,
max_completion_tokens: bool,
requires_key: bool,
},
}
fn openai_cloud(base: &str, completion_tokens: bool) -> Backend {
Backend::OpenAiCompatible {
base_url: base.to_string(),
max_completion_tokens: completion_tokens,
requires_key: true,
}
}
fn resolve_backend(backend: &str, base_url: Option<&str>) -> Option<Backend> {
match backend.trim().to_ascii_lowercase().as_str() {
"anthropic" => Some(Backend::Anthropic),
"gemini" => Some(Backend::Gemini),
"openai" => Some(openai_cloud("https://api.openai.com/v1", true)),
"openrouter" => Some(openai_cloud("https://openrouter.ai/api/v1", false)),
"mistral" => Some(openai_cloud("https://api.mistral.ai/v1", false)),
"groq" => Some(openai_cloud("https://api.groq.com/openai/v1", false)),
"deepseek" => Some(openai_cloud("https://api.deepseek.com/v1", false)),
"xai" => Some(openai_cloud("https://api.x.ai/v1", false)),
"openai-compatible" | "custom" => {
let base = base_url.map(str::trim).filter(|s| !s.is_empty())?;
Some(Backend::OpenAiCompatible {
base_url: base.trim_end_matches('/').to_string(),
max_completion_tokens: false,
requires_key: false,
})
}
_ => None,
}
}
pub fn is_backend_wired(backend: &str) -> bool {
resolve_backend(backend, Some("https://example.invalid")).is_some()
}
pub fn key_name(backend: &str) -> String {
format!("llm.{backend}")
}
const ANTHROPIC_URL: &str = "https://api.anthropic.com/v1/messages";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const GEMINI_URL_PREFIX: &str = "https://generativelanguage.googleapis.com/v1beta/models";
const DEFAULT_MAX_TOKENS: u32 = 1024;
const SYSTEM_PROMPT: &str = "You are a spelling, typo, and minor-grammar corrector. Return ONLY the \
corrected version of the user's text — no preamble, no commentary, no \
quotation marks. Preserve the user's voice, register, and punctuation \
style. If the text is already fine, return it unchanged.";
const WORD_SYSTEM_PROMPT: &str = "You correct ONE word at a time using sentence context. The \
user gives you a SENTENCE and one WORD from it to correct. Return ONLY the corrected \
version of that word — nothing else: no quotes, no punctuation, no commentary, no rest \
of the sentence. Use the rest of the sentence to disambiguate homophones \
(their/there/they're, its/it's, your/you're, etc.) and to pick the right fix for typos. \
Preserve the original casing of the word's first letter. If the word is already correct \
in context, return it unchanged.";
const ALTERNATIVES_SYSTEM_PROMPT: &str = "You are a spelling, typo, and minor-grammar corrector. \
Correct the user's text and reply with ONLY a JSON object — no preamble, no commentary, no code \
fences — shaped exactly like: {\"corrected\": \"<the corrected text>\", \"alternatives\": \
[{\"word\": \"<a word you changed>\", \"options\": [\"best\", \"next\", \"...\"]}]}. Include an \
`alternatives` entry only for words you changed; give 3 to 5 ranked options each, best first, with \
the option you actually used in `corrected` listed first. Use sentence context for homophones \
(their/there/they're, its/it's, your/you're). Preserve the user's voice, register, casing, and \
punctuation. If the text is already correct, return it unchanged with an empty `alternatives` array.";
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("no API key for the LLM provider — set one in Preferences → Providers")]
NoApiKey,
#[error("keychain: {0}")]
Keychain(String),
#[error("unsupported LLM backend: {0}")]
UnsupportedBackend(String),
#[error("LLM request failed: {0}")]
Request(String),
#[error("LLM response was unparseable: {0}")]
Response(String),
}
pub struct LlmProvider {
backend: Backend,
api_key: String,
model: String,
}
impl std::fmt::Debug for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmProvider")
.field("backend", &self.backend)
.field("model", &self.model)
.field("api_key", &"[redacted]")
.finish()
}
}
impl LlmProvider {
pub fn from_config(llm: &crate::LlmConfig) -> Result<Self, LlmError> {
let backend = resolve_backend(&llm.backend, llm.base_url.as_deref())
.ok_or_else(|| LlmError::UnsupportedBackend(llm.backend.clone()))?;
let requires_key = match &backend {
Backend::OpenAiCompatible { requires_key, .. } => *requires_key,
Backend::Anthropic | Backend::Gemini => true,
};
let api_key = secrets::get(&key_name(&llm.backend))
.map_err(|e| LlmError::Keychain(e.to_string()))?
.unwrap_or_default();
if requires_key && api_key.is_empty() {
return Err(LlmError::NoApiKey);
}
Ok(Self {
backend,
api_key,
model: llm.model.clone(),
})
}
pub fn rewrite(&self, text: &str) -> Result<String, LlmError> {
if text.trim().is_empty() {
return Ok(text.to_string());
}
self.request(SYSTEM_PROMPT, text.to_string())
}
pub fn fix_word_in_context(&self, sentence: &str, word: &str) -> Result<String, LlmError> {
if word.trim().is_empty() {
return Ok(word.to_string());
}
let content = format!("SENTENCE: {sentence}\nWORD: {word}");
let corrected = self.request(WORD_SYSTEM_PROMPT, content)?;
Ok(corrected
.trim()
.trim_matches(|c: char| c == '"' || c == '\'')
.to_string())
}
pub fn rewrite_with_alternatives(
&self,
text: &str,
) -> Result<(String, Vec<WordSuggestions>), LlmError> {
if text.trim().is_empty() {
return Ok((text.to_string(), Vec::new()));
}
let reply = self.request(ALTERNATIVES_SYSTEM_PROMPT, text.to_string())?;
parse_alternatives(&reply)
}
fn request(&self, system: &str, content: String) -> Result<String, LlmError> {
match &self.backend {
Backend::Anthropic => self.request_anthropic(system, content),
Backend::Gemini => self.request_gemini(system, content),
Backend::OpenAiCompatible {
base_url,
max_completion_tokens,
..
} => self.request_openai(base_url, *max_completion_tokens, system, content),
}
}
fn request_anthropic(&self, system: &str, content: String) -> Result<String, LlmError> {
let body = serde_json::json!({
"model": self.model,
"max_tokens": DEFAULT_MAX_TOKENS,
"system": system,
"messages": [{ "role": "user", "content": content }],
});
let json = agent()
.post(ANTHROPIC_URL)
.set("x-api-key", &self.api_key)
.set("anthropic-version", ANTHROPIC_VERSION)
.set("content-type", "application/json")
.send_json(body)
.map_err(|e| LlmError::Request(e.to_string()))?
.into_json::<serde_json::Value>()
.map_err(|e| LlmError::Response(e.to_string()))?;
parse_anthropic_reply(&json)
}
fn request_openai(
&self,
base: &str,
max_completion_tokens: bool,
system: &str,
content: String,
) -> Result<String, LlmError> {
let token_field = if max_completion_tokens {
"max_completion_tokens"
} else {
"max_tokens"
};
let mut body = serde_json::json!({
"model": self.model,
"messages": [
{ "role": "system", "content": system },
{ "role": "user", "content": content },
],
});
body[token_field] = DEFAULT_MAX_TOKENS.into();
let url = format!("{base}/chat/completions");
let mut req = agent().post(&url).set("content-type", "application/json");
if !self.api_key.is_empty() {
req = req.set("authorization", &format!("Bearer {}", self.api_key));
}
let json = req
.send_json(body)
.map_err(|e| LlmError::Request(e.to_string()))?
.into_json::<serde_json::Value>()
.map_err(|e| LlmError::Response(e.to_string()))?;
parse_openai_reply(&json)
}
fn request_gemini(&self, system: &str, content: String) -> Result<String, LlmError> {
let url = format!("{GEMINI_URL_PREFIX}/{}:generateContent", self.model);
let body = serde_json::json!({
"system_instruction": { "parts": [{ "text": system }] },
"contents": [{ "parts": [{ "text": content }] }],
"generationConfig": { "maxOutputTokens": DEFAULT_MAX_TOKENS },
});
let json = agent()
.post(&url)
.set("x-goog-api-key", &self.api_key)
.set("content-type", "application/json")
.send_json(body)
.map_err(|e| LlmError::Request(e.to_string()))?
.into_json::<serde_json::Value>()
.map_err(|e| LlmError::Response(e.to_string()))?;
parse_gemini_reply(&json)
}
}
fn agent() -> ureq::Agent {
ureq::AgentBuilder::new()
.timeout(Duration::from_secs(20))
.build()
}
fn parse_anthropic_reply(json: &serde_json::Value) -> Result<String, LlmError> {
let text = json["content"]
.as_array()
.and_then(|parts| {
parts
.iter()
.filter_map(|p| p.get("text").and_then(|t| t.as_str()))
.next()
})
.ok_or_else(|| LlmError::Response("no `content[*].text` in response".into()))?;
Ok(text.trim_end_matches('\n').to_string())
}
fn parse_openai_reply(json: &serde_json::Value) -> Result<String, LlmError> {
let text = json["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| LlmError::Response("no `choices[0].message.content` in response".into()))?;
Ok(text.trim_end_matches('\n').to_string())
}
fn parse_gemini_reply(json: &serde_json::Value) -> Result<String, LlmError> {
let text = json["candidates"][0]["content"]["parts"]
.as_array()
.and_then(|parts| {
parts
.iter()
.filter_map(|p| p.get("text").and_then(|t| t.as_str()))
.next()
})
.ok_or_else(|| {
LlmError::Response("no `candidates[0].content.parts[*].text` in response".into())
})?;
Ok(text.trim_end_matches('\n').to_string())
}
fn parse_alternatives(reply: &str) -> Result<(String, Vec<WordSuggestions>), LlmError> {
let json = json_object_slice(reply);
let v: serde_json::Value = serde_json::from_str(json)
.map_err(|e| LlmError::Response(format!("alternatives JSON: {e}")))?;
let corrected = v["corrected"]
.as_str()
.ok_or_else(|| LlmError::Response("no `corrected` string in response".into()))?
.to_string();
let mut alternatives = Vec::new();
if let Some(arr) = v["alternatives"].as_array() {
for item in arr {
let Some(word) = item["word"].as_str() else {
continue;
};
let options: Vec<String> = item["options"]
.as_array()
.into_iter()
.flatten()
.filter_map(|o| o.as_str().map(str::to_string))
.collect();
if !options.is_empty() {
alternatives.push(WordSuggestions {
word: word.to_string(),
options,
});
}
}
}
Ok((corrected, alternatives))
}
fn json_object_slice(s: &str) -> &str {
match (s.find('{'), s.rfind('}')) {
(Some(a), Some(b)) if b >= a => &s[a..=b],
_ => s,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::LlmConfig;
#[test]
fn parses_alternatives_reply() {
let reply = r#"{"corrected":"the quick brown fox",
"alternatives":[
{"word":"the","options":["the","then","they"]},
{"word":"brown","options":["brown","browne","crown"]}
]}"#;
let (corrected, alts) = parse_alternatives(reply).unwrap();
assert_eq!(corrected, "the quick brown fox");
assert_eq!(alts.len(), 2);
assert_eq!(alts[0].word, "the");
assert_eq!(alts[0].options, vec!["the", "then", "they"]);
assert_eq!(alts[1].word, "brown");
}
#[test]
fn tolerates_code_fences_and_preamble() {
let reply = "Here you go:\n```json\n{\"corrected\":\"hi there\",\"alternatives\":[]}\n```";
let (corrected, alts) = parse_alternatives(reply).unwrap();
assert_eq!(corrected, "hi there");
assert!(alts.is_empty());
}
#[test]
fn non_json_reply_is_an_error() {
assert!(parse_alternatives("sorry, I cannot do that").is_err());
}
#[test]
fn unsupported_backend_is_rejected_cleanly() {
let cfg = LlmConfig {
backend: "made-up-vendor".into(),
model: "whatever".into(),
base_url: None,
};
match LlmProvider::from_config(&cfg) {
Err(LlmError::UnsupportedBackend(name)) => assert_eq!(name, "made-up-vendor"),
other => panic!("expected UnsupportedBackend, got {other:?}"),
}
}
#[test]
fn custom_endpoint_without_base_url_is_unsupported() {
let cfg = LlmConfig {
backend: "openai-compatible".into(),
model: "llama3.1".into(),
base_url: None,
};
assert!(matches!(
LlmProvider::from_config(&cfg),
Err(LlmError::UnsupportedBackend(_))
));
}
#[test]
fn key_name_and_wiring_are_stable() {
assert_eq!(key_name("anthropic"), "llm.anthropic");
assert_eq!(key_name("openai"), "llm.openai");
for b in [
"anthropic",
"openai",
"gemini",
"openrouter",
"mistral",
"groq",
"deepseek",
"xai",
"openai-compatible",
] {
assert!(is_backend_wired(b), "{b} should be wired");
}
assert!(is_backend_wired("OpenAI"));
assert!(!is_backend_wired("made-up-vendor"));
}
#[test]
fn resolve_backend_picks_the_right_shape_and_url() {
assert_eq!(
resolve_backend("openai", None),
Some(Backend::OpenAiCompatible {
base_url: "https://api.openai.com/v1".into(),
max_completion_tokens: true,
requires_key: true,
})
);
assert_eq!(
resolve_backend("groq", None),
Some(Backend::OpenAiCompatible {
base_url: "https://api.groq.com/openai/v1".into(),
max_completion_tokens: false,
requires_key: true,
})
);
assert_eq!(resolve_backend("anthropic", None), Some(Backend::Anthropic));
assert_eq!(resolve_backend("gemini", None), Some(Backend::Gemini));
assert_eq!(
resolve_backend("openai-compatible", Some("http://localhost:11434/v1/")),
Some(Backend::OpenAiCompatible {
base_url: "http://localhost:11434/v1".into(),
max_completion_tokens: false,
requires_key: false,
})
);
assert_eq!(resolve_backend("openai-compatible", Some(" ")), None);
assert_eq!(resolve_backend("nope", Some("http://x")), None);
}
#[test]
fn parses_each_provider_reply_shape() {
let anthropic = serde_json::json!({
"content": [{ "type": "text", "text": "fixed\n" }]
});
assert_eq!(parse_anthropic_reply(&anthropic).unwrap(), "fixed");
let openai = serde_json::json!({
"choices": [{ "message": { "role": "assistant", "content": "fixed\n" } }]
});
assert_eq!(parse_openai_reply(&openai).unwrap(), "fixed");
let gemini = serde_json::json!({
"candidates": [{ "content": { "parts": [{ "text": "fixed\n" }] } }]
});
assert_eq!(parse_gemini_reply(&gemini).unwrap(), "fixed");
assert!(parse_openai_reply(&anthropic).is_err());
assert!(parse_gemini_reply(&openai).is_err());
}
}