1use std::fmt;
2use std::io::{IsTerminal, Write, stderr};
3use std::path::PathBuf;
4use std::sync::{
5 Arc,
6 atomic::{AtomicBool, Ordering},
7};
8use std::thread;
9use std::time::Duration;
10
11use memvid_core::types::SearchHit;
12
13#[derive(Debug, Clone)]
14pub struct ModelAnswer {
15 pub requested: String,
16 pub model: String,
17 pub answer: String,
18}
19
20#[derive(Debug, Clone)]
21pub struct ModelInference {
22 pub answer: ModelAnswer,
23 pub context_body: String,
24 pub context_fragments: Vec<ModelContextFragment>,
25 pub usage: Option<TokenUsage>,
26 pub grounding: Option<GroundingResult>,
27 pub cached: bool,
29}
30
31#[derive(Debug, Clone, Default)]
33pub struct TokenUsage {
34 pub input_tokens: u32,
36 pub output_tokens: u32,
38 pub total_tokens: u32,
40 pub cost_usd: f64,
42}
43
44pub mod cache {
47 use std::collections::HashMap;
48 use std::sync::Mutex;
49
50 #[derive(Debug, Clone)]
52 pub struct CacheEntry {
53 pub answer: String,
54 pub model: String,
55 pub input_tokens: u32,
56 pub output_tokens: u32,
57 pub cost_usd: f64,
58 pub grounding_score: f32,
59 pub created_at: std::time::SystemTime,
60 }
61
62 pub struct AnswerCache {
65 entries: Mutex<HashMap<[u8; 32], CacheEntry>>,
66 max_size: usize,
67 hits: std::sync::atomic::AtomicU64,
68 misses: std::sync::atomic::AtomicU64,
69 }
70
71 impl AnswerCache {
72 pub fn new(max_size: usize) -> Self {
74 Self {
75 entries: Mutex::new(HashMap::new()),
76 max_size,
77 hits: std::sync::atomic::AtomicU64::new(0),
78 misses: std::sync::atomic::AtomicU64::new(0),
79 }
80 }
81
82 pub fn make_key(query: &str, context: &str, model: &str) -> [u8; 32] {
84 use std::io::Write;
85 let mut hasher = blake3::Hasher::new();
86 let _ = write!(hasher, "{}|{}|{}", model, query, context);
87 *hasher.finalize().as_bytes()
88 }
89
90 pub fn get(&self, key: &[u8; 32]) -> Option<CacheEntry> {
92 let entries = self.entries.lock().ok()?;
93 let result = entries.get(key).cloned();
94 if result.is_some() {
95 self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
96 } else {
97 self.misses
98 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
99 }
100 result
101 }
102
103 pub fn insert(&self, key: [u8; 32], entry: CacheEntry) {
105 if let Ok(mut entries) = self.entries.lock() {
106 if entries.len() >= self.max_size {
108 let oldest_key = entries
109 .iter()
110 .min_by_key(|(_, v)| v.created_at)
111 .map(|(k, _)| *k);
112 if let Some(k) = oldest_key {
113 entries.remove(&k);
114 }
115 }
116 entries.insert(key, entry);
117 }
118 }
119
120 pub fn clear(&self) {
122 if let Ok(mut entries) = self.entries.lock() {
123 entries.clear();
124 }
125 }
126
127 pub fn stats(&self) -> CacheStats {
129 let entries = self.entries.lock().map(|e| e.len()).unwrap_or(0);
130 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
131 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
132 CacheStats {
133 entries,
134 hits,
135 misses,
136 hit_rate: if hits + misses > 0 {
137 hits as f64 / (hits + misses) as f64
138 } else {
139 0.0
140 },
141 }
142 }
143
144 pub fn estimated_savings(&self) -> f64 {
146 if let Ok(entries) = self.entries.lock() {
147 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
148 let avg_cost =
149 entries.values().map(|e| e.cost_usd).sum::<f64>() / entries.len().max(1) as f64;
150 hits as f64 * avg_cost
151 } else {
152 0.0
153 }
154 }
155 }
156
157 impl Default for AnswerCache {
158 fn default() -> Self {
159 Self::new(100) }
161 }
162
163 #[derive(Debug, Clone)]
164 pub struct CacheStats {
165 pub entries: usize,
166 pub hits: u64,
167 pub misses: u64,
168 pub hit_rate: f64,
169 }
170
171 lazy_static::lazy_static! {
173 pub static ref GLOBAL_CACHE: AnswerCache = AnswerCache::new(500);
174 }
175
176 pub fn check_cache(query: &str, context: &str, model: &str) -> Option<CacheEntry> {
178 let key = AnswerCache::make_key(query, context, model);
179 GLOBAL_CACHE.get(&key)
180 }
181
182 pub fn store_in_cache(query: &str, context: &str, model: &str, entry: CacheEntry) {
184 let key = AnswerCache::make_key(query, context, model);
185 GLOBAL_CACHE.insert(key, entry);
186 }
187
188 pub fn global_stats() -> CacheStats {
190 GLOBAL_CACHE.stats()
191 }
192
193 pub fn clear_global_cache() {
195 GLOBAL_CACHE.clear();
196 }
197}
198
199#[derive(Debug, Clone, Default)]
201pub struct GroundingResult {
202 pub score: f32,
205 pub sentence_count: usize,
207 pub grounded_sentences: usize,
209 pub sentence_scores: Vec<f32>,
211 pub has_warning: bool,
213 pub warning_reason: Option<String>,
215}
216
217impl GroundingResult {
218 pub fn grade(&self) -> &'static str {
220 match self.score {
221 s if s >= 0.8 => "A",
222 s if s >= 0.6 => "B",
223 s if s >= 0.4 => "C",
224 s if s >= 0.2 => "D",
225 _ => "F",
226 }
227 }
228
229 pub fn label(&self) -> &'static str {
231 match self.score {
232 s if s >= 0.7 => "HIGH",
233 s if s >= 0.4 => "MEDIUM",
234 _ => "LOW",
235 }
236 }
237}
238
239#[derive(Debug, Clone)]
240pub struct ModelContextFragment {
241 pub rank: usize,
242 pub uri: String,
243 pub title: Option<String>,
244 pub score: Option<f32>,
245 pub matches: usize,
246 pub frame_id: u64,
247 pub range: (usize, usize),
248 pub chunk_range: Option<(usize, usize)>,
249 pub text: String,
250 pub kind: ModelContextFragmentKind,
251}
252
253#[derive(Debug, Clone, Copy, Eq, PartialEq)]
254pub enum ModelContextFragmentKind {
255 Full,
256 Summary,
257}
258
259impl ModelContextFragment {
260 fn from_record(record: context::ContextRecord) -> Self {
261 let kind = match record.mode {
262 context::ContextMode::Full => ModelContextFragmentKind::Full,
263 context::ContextMode::Summary => ModelContextFragmentKind::Summary,
264 };
265 Self {
266 rank: record.rank,
267 uri: record.uri,
268 title: record.title,
269 score: record.score,
270 matches: record.matches,
271 frame_id: record.frame_id,
272 range: record.range,
273 chunk_range: record.chunk_range,
274 text: record.text,
275 kind,
276 }
277 }
278}
279
280#[derive(Debug)]
281pub enum ModelRunError {
282 UnsupportedModel(String),
283 AssetsMissing {
284 model: String,
285 missing: Vec<PathBuf>,
286 },
287 Runtime(anyhow::Error),
288}
289
290impl fmt::Display for ModelRunError {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 match self {
293 Self::UnsupportedModel(model) => write!(f, "unsupported model '{model}'"),
294 Self::AssetsMissing { model, missing } => {
295 let paths: Vec<_> = missing
296 .iter()
297 .map(|path| path.display().to_string())
298 .collect();
299 write!(
300 f,
301 "model '{model}' missing required assets: {}",
302 paths.join(", ")
303 )
304 }
305 Self::Runtime(err) => write!(f, "model runtime error: {err}"),
306 }
307 }
308}
309
310impl std::error::Error for ModelRunError {
311 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
312 match self {
313 Self::Runtime(err) => Some(err.root_cause()),
314 _ => None,
315 }
316 }
317}
318
319const LOCAL_CONTEXT_CHARS: usize = 32_768;
320const MAX_QUESTION_CHARS: usize = 512;
321const LOCAL_MAX_OUTPUT_TOKENS: usize = 256;
322const REMOTE_MAX_OUTPUT_TOKENS: usize = 768;
323const SYSTEM_PROMPT: &str = r#"You are a precise, intelligent assistant that answers questions using ONLY the provided retrieval context.
324
325## Core Principles
3261. GROUND EVERY CLAIM in the context. If asked for a number, quote it exactly.
3272. NEVER hallucinate or use external knowledge. If unsure, say "Based on the context..."
3283. BE CONCISE but complete. One clear answer is better than verbose hedging.
329
330## CRITICAL: Correction Handling (MANDATORY)
331**STOP AND READ THIS FIRST.** Before answering ANY question:
3321. Scan ALL sources for "Correction:" in the title or "mv2://correction/" in the URI
3332. If ANY correction exists that relates to the question, USE THAT ANSWER ONLY
3343. IGNORE ALL OTHER SOURCES that contradict the correction - they are OUTDATED
3354. If multiple corrections exist, use the FIRST one listed (most recent)
336
337**VIOLATION OF THIS RULE IS A CRITICAL ERROR.** Example:
338- Question: "Where does Ben live?"
339- Correction says: "Ben lives in Kenya"
340- Other doc says: "Ben lives in Germany"
341- CORRECT ANSWER: "Kenya" (from correction)
342- WRONG ANSWER: "Germany" (ignores correction = FAIL)
343
344## Answer Strategy
345- For NUMERIC questions: Extract the exact value. If multiple values exist, identify which is most relevant (usually the most recent or most specific match).
346- For YES/NO questions: Answer directly, then briefly explain why.
347- For COMPARISON questions: Present both sides with their values.
348- For LIST questions: Use bullet points or numbered lists.
349- For TEMPORAL questions: Note that later timestamps = more current information. State WHEN data is from.
350- For CALCULATION questions: Show your work step-by-step.
351- For ANALYTICAL/PATTERN questions (e.g., "reverted", "changed back", "any differences over time"):
352 1. TRACE each attribute's value across ALL time periods in the context
353 2. Look for A→B→A patterns where a value changes then returns to its original state
354 3. Terms like "consolidated", "same as", "unified", or "aligned" often indicate returning to a prior arrangement
355 4. Compare explicit state changes: if Period 1 says "X was same as Y", Period 2 says "X different from Y", and Period 3 says "X consolidated/same as Y again", that IS a reversion
356 5. Create a timeline table if helpful to track changes
357
358## Handling Ambiguity
359- If the question is ambiguous, interpret it reasonably and state your interpretation.
360- If multiple valid answers exist, present the most likely one first, then mention alternatives.
361- If context is insufficient, say what IS known, then note what's missing.
362
363## Quality Standards
364- PREFER specific answers over vague ones ("$1,234.56" not "around a thousand")
365- CITE context when helpful ("[Source: ...]")
366- CORRECT obvious typos in your interpretation (e.g., "teh" → "the")
367- For percentages/ratios, include the actual numbers when available"#;
368const TINYLLAMA_LABEL: &str = "tinyllama-1.1b";
369const LOCAL_PROMPT_MARGIN_CHARS: usize = 2_048;
370const REMOTE_PROMPT_MARGIN_CHARS: usize = 4_096;
371const OLLAMA_PROMPT_CHARS: usize = 110_000;
372const OPENAI_PROMPT_CHARS: usize = 240_000;
373const NVIDIA_PROMPT_CHARS: usize = 240_000;
374const GEMINI_PROMPT_CHARS: usize = 320_000;
375const CLAUDE_PROMPT_CHARS: usize = 360_000;
376const XAI_PROMPT_CHARS: usize = 260_000; const GROQ_PROMPT_CHARS: usize = 260_000; const MISTRAL_PROMPT_CHARS: usize = 260_000; #[derive(Debug, Clone, Copy)]
381struct ModelContextBudget {
382 total_chars: usize,
383 reserved_chars: usize,
384}
385
386impl ModelContextBudget {
387 const fn new(total_chars: usize, reserved_chars: usize) -> Self {
388 Self {
389 total_chars,
390 reserved_chars,
391 }
392 }
393
394 fn context_chars(&self) -> usize {
395 self.total_chars.saturating_sub(self.reserved_chars)
396 }
397
398 fn question_limit(&self) -> usize {
399 MAX_QUESTION_CHARS
400 .min(self.reserved_chars.max(1))
401 .min(self.total_chars.max(1))
402 }
403
404 fn apply_override(self, override_context_chars: usize) -> Self {
405 let total = override_context_chars.saturating_add(self.reserved_chars);
406 Self {
407 total_chars: total.max(self.reserved_chars + 1),
408 reserved_chars: self.reserved_chars,
409 }
410 }
411
412 fn prompt_ceiling(&self) -> usize {
413 self.total_chars
414 }
415}
416
417pub struct PromptParts {
418 completion_prompt: String,
419 user_message: String,
420 max_output_tokens: usize,
421}
422
423impl PromptParts {
424 pub fn completion_prompt(&self) -> &str {
425 &self.completion_prompt
426 }
427
428 pub fn user_message(&self) -> &str {
429 &self.user_message
430 }
431
432 pub fn max_output_tokens(&self) -> usize {
433 self.max_output_tokens
434 }
435}
436
437fn normalize_question(question: &str) -> String {
445 let trimmed = question.trim();
446 if trimmed.is_empty() {
447 return trimmed.to_string();
448 }
449
450 let mut normalized = fix_common_typos(trimmed);
452
453 normalized = expand_abbreviations(&normalized);
455
456 normalized = ensure_question_punctuation(&normalized);
458
459 normalized
460}
461
462fn fix_common_typos(text: &str) -> String {
464 let mut result = text.to_string();
465
466 let typos: &[(&str, &str)] = &[
468 ("teh ", "the "),
470 ("hte ", "the "),
471 ("adn ", "and "),
472 ("taht ", "that "),
473 ("wiht ", "with "),
474 ("thier ", "their "),
475 ("recieve", "receive"),
476 ("occured", "occurred"),
477 ("seperate", "separate"),
478 ("waht ", "what "),
480 ("hwat ", "what "),
481 ("wehn ", "when "),
482 ("whre ", "where "),
483 ("wher ", "where "),
484 ("howm ", "how "),
485 ("hwo ", "who "),
486 ("amoutn", "amount"),
488 ("totla", "total"),
489 ("nubmer", "number"),
490 ("vlaue", "value"),
491 ("prive", "price"),
492 ("proce", "price"),
493 ("revneue", "revenue"),
494 ("reveneu", "revenue"),
495 ];
496
497 for (typo, correction) in typos {
498 let lower = result.to_lowercase();
500 if lower.contains(*typo) {
501 let start = lower.find(*typo).unwrap();
502 let end = start + typo.len();
503 result = format!("{}{}{}", &result[..start], correction, &result[end..]);
504 }
505 }
506
507 let mut prev_space = false;
509 result = result
510 .chars()
511 .filter(|c| {
512 if c.is_whitespace() {
513 if prev_space {
514 false
515 } else {
516 prev_space = true;
517 true
518 }
519 } else {
520 prev_space = false;
521 true
522 }
523 })
524 .collect();
525
526 result
527}
528
529pub fn generate_search_query(
532 question: &str,
533 model: &str,
534 api_key: &str,
535) -> Result<String, ModelRunError> {
536 let prompt = format!(
540 r#"Extract 2 key search terms from this question.
541KEEP abbreviations exactly as written (QPS, API, SDK, etc.) - don't expand them.
542Output only the main topic and one key term.
543
544Question: {}
545
546Examples:
547- "What is the QPS for memvid?" → "memvid QPS"
548- "How many queries per second?" → "QPS throughput"
549- "What's the API rate limit?" → "API rate"
550- "How much does it cost?" → "cost pricing"
551
552Output exactly 2 words, nothing else."#,
553 question
554 );
555
556 let extraction_model =
559 if model.starts_with("gpt") || model.starts_with("o1") || model.contains("openai") {
560 "gpt-4o-mini"
561 } else if model.starts_with("claude") || model.contains("anthropic") {
562 "claude-haiku-4-5"
563 } else if model.contains("llama") || model.contains("groq") || model.contains("mixtral") {
564 "llama-3.1-8b-instant" } else if model.contains("grok") || model.contains("xai") {
566 "grok-4-fast"
567 } else if model.contains("mistral") {
568 "mistral-small-latest" } else {
570 return Ok(question.to_string());
572 };
573
574 let rewritten = call_llm_for_keywords(&prompt, extraction_model, api_key)?;
576
577 let rewritten = rewritten.trim();
579 if rewritten.is_empty() || rewritten.len() > 100 {
580 Ok(question.to_string())
582 } else {
583 Ok(rewritten.to_string())
585 }
586}
587
588fn call_llm_for_keywords(
590 prompt: &str,
591 model: &str,
592 api_key: &str,
593) -> Result<String, ModelRunError> {
594 use reqwest::blocking::Client;
595 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
596
597 let client = Client::builder()
598 .timeout(std::time::Duration::from_secs(10)) .build()
600 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("HTTP client error: {e}")))?;
601
602 let (url, is_anthropic) = if model.starts_with("gpt") || model.starts_with("o1") {
604 ("https://api.openai.com/v1/chat/completions", false)
605 } else if model.starts_with("claude") {
606 ("https://api.anthropic.com/v1/messages", true)
607 } else if model.contains("llama") || model.contains("mixtral") {
608 ("https://api.groq.com/openai/v1/chat/completions", false)
609 } else if model.contains("grok") {
610 ("https://api.x.ai/v1/chat/completions", false)
611 } else if model.contains("mistral") {
612 ("https://api.mistral.ai/v1/chat/completions", false)
613 } else {
614 return Err(ModelRunError::UnsupportedModel(model.to_string()));
615 };
616
617 let response = if is_anthropic {
618 let mut headers = HeaderMap::new();
619 headers.insert(
620 reqwest::header::HeaderName::from_static("x-api-key"),
621 HeaderValue::from_str(api_key)
622 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Invalid API key: {e}")))?,
623 );
624 headers.insert(
625 reqwest::header::HeaderName::from_static("anthropic-version"),
626 HeaderValue::from_static("2023-06-01"),
627 );
628 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
629
630 client
631 .post(url)
632 .headers(headers)
633 .json(&serde_json::json!({
634 "model": model,
635 "max_tokens": 100,
636 "messages": [{"role": "user", "content": prompt}]
637 }))
638 .send()
639 } else {
640 let mut headers = HeaderMap::new();
642 headers.insert(
643 AUTHORIZATION,
644 HeaderValue::from_str(&format!("Bearer {}", api_key))
645 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Invalid API key: {e}")))?,
646 );
647 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
648
649 client
650 .post(url)
651 .headers(headers)
652 .json(&serde_json::json!({
653 "model": model,
654 "messages": [{"role": "user", "content": prompt}],
655 "max_tokens": 100,
656 "temperature": 0.0
657 }))
658 .send()
659 };
660
661 match response {
662 Ok(resp) => {
663 let json: serde_json::Value = resp
664 .json()
665 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("JSON parse error: {e}")))?;
666
667 let text = if model.starts_with("claude") {
669 json["content"][0]["text"].as_str().unwrap_or("")
670 } else {
671 json["choices"][0]["message"]["content"]
672 .as_str()
673 .unwrap_or("")
674 };
675
676 Ok(text.to_string())
677 }
678 Err(_) => {
679 Ok(String::new())
682 }
683 }
684}
685
686fn expand_abbreviations(text: &str) -> String {
688 text.to_string()
690}
691
692fn ensure_question_punctuation(text: &str) -> String {
694 let trimmed = text.trim();
695
696 if trimmed.ends_with('?') || trimmed.ends_with('.') || trimmed.ends_with('!') {
698 return trimmed.to_string();
699 }
700
701 let lower = trimmed.to_lowercase();
703 let question_starters = [
704 "how", "what", "where", "when", "why", "which", "who", "whom", "whose", "is", "are", "was",
705 "were", "will", "would", "can", "could", "should", "do", "does", "did", "have", "has",
706 "had", "may", "might", "shall", "tell me", "show me", "find", "list", "give me", "explain",
707 ];
708
709 let is_question = question_starters.iter().any(|starter| {
710 lower.starts_with(starter)
711 && (lower.len() == starter.len()
712 || !lower[starter.len()..].starts_with(|c: char| c.is_alphanumeric()))
713 });
714
715 if is_question {
716 format!("{}?", trimmed)
717 } else {
718 trimmed.to_string()
719 }
720}
721
722fn build_prompt_parts(
723 question: &str,
724 context: &str,
725 budget: &ModelContextBudget,
726 max_output_tokens: usize,
727) -> PromptParts {
728 let mut context_section = context.to_string();
729 let normalized_question = normalize_question(question);
730 let trimmed_question = trim_to(&normalized_question, budget.question_limit());
731
732 let question_type = detect_question_type(&trimmed_question);
734 let type_hint = question_type.hint();
735
736 let system_section = format!("### System\n{SYSTEM_PROMPT}");
737 let question_section = format!("### Question\n{trimmed_question}");
738 let answer_stub = "### Answer\n";
739
740 let overhead = system_section.len() + 2 + question_section.len() + 2 + answer_stub.len();
741 if budget.prompt_ceiling() > overhead {
742 let max_context_len = budget
743 .prompt_ceiling()
744 .saturating_sub(overhead)
745 .min(budget.context_chars());
746 if context_section.len() > max_context_len {
747 context_section = clamp_to(&context_section, max_context_len);
748 }
749 } else {
750 context_section = String::new();
751 }
752
753 let context_instruction = if context_section.trim().is_empty() {
755 "Note: No relevant context was found. Answer based on what you know, but clearly state this limitation."
756 } else {
757 ""
758 };
759
760 let completion_prompt =
761 format!("{system_section}\n\n{context_section}\n\n{question_section}\n\n### Answer\n");
762
763 let user_message = if context_instruction.is_empty() {
765 format!(
766 "{context_section}\n\n---\nQuestion: {trimmed_question}\n{type_hint}\nProvide a direct, accurate answer using only the context above."
767 )
768 } else {
769 format!("{context_instruction}\n\nQuestion: {trimmed_question}\n{type_hint}")
770 };
771
772 PromptParts {
773 completion_prompt,
774 user_message,
775 max_output_tokens,
776 }
777}
778
779#[derive(Debug, Clone, Copy, PartialEq, Eq)]
781enum QuestionType {
782 Numeric, YesNo, List, Comparison, Temporal, Explanation, Factual, Other,
790}
791
792impl QuestionType {
793 fn hint(&self) -> &'static str {
794 match self {
795 Self::Numeric => "(Expected: a specific number or value)",
796 Self::YesNo => "(Expected: yes/no with brief explanation)",
797 Self::List => "(Expected: a list of items)",
798 Self::Comparison => "(Expected: comparison of two or more items)",
799 Self::Temporal => "(Expected: a date, time, or duration)",
800 Self::Explanation => "(Expected: reasoning or explanation)",
801 Self::Factual => "(Expected: a factual answer)",
802 Self::Other => "",
803 }
804 }
805}
806
807fn detect_question_type(question: &str) -> QuestionType {
808 let lower = question.to_lowercase();
809
810 if lower.contains("how much")
812 || lower.contains("how many")
813 || lower.contains("what is the value")
814 || lower.contains("what's the value")
815 || lower.contains("total")
816 || lower.contains("sum")
817 || lower.contains("average")
818 || lower.contains("percentage")
819 || lower.contains("rate")
820 || lower.contains("amount")
821 || lower.contains("price")
822 || lower.contains("cost")
823 || lower.contains("revenue")
824 || lower.contains("profit")
825 {
826 return QuestionType::Numeric;
827 }
828
829 let yes_no_starters = [
831 "is ", "are ", "does ", "do ", "can ", "will ", "has ", "have ", "was ", "were ",
832 ];
833 if yes_no_starters.iter().any(|s| lower.starts_with(s)) {
834 return QuestionType::YesNo;
835 }
836
837 if lower.contains("list")
839 || lower.contains("show all")
840 || lower.contains("what are the")
841 || lower.contains("name all")
842 || lower.contains("enumerate")
843 {
844 return QuestionType::List;
845 }
846
847 if lower.contains("compare")
849 || lower.contains("difference between")
850 || lower.contains(" vs ")
851 || lower.contains("versus")
852 || lower.contains("better than")
853 || lower.contains("worse than")
854 {
855 return QuestionType::Comparison;
856 }
857
858 if lower.starts_with("when")
860 || lower.contains("what date")
861 || lower.contains("how long")
862 || lower.contains("how old")
863 || lower.contains("since when")
864 {
865 return QuestionType::Temporal;
866 }
867
868 if lower.starts_with("why")
870 || lower.starts_with("explain")
871 || lower.contains("how does")
872 || lower.contains("reason for")
873 || lower.contains("cause of")
874 {
875 return QuestionType::Explanation;
876 }
877
878 if lower.starts_with("what") || lower.starts_with("who") || lower.starts_with("where") {
880 return QuestionType::Factual;
881 }
882
883 QuestionType::Other
884}
885
886pub fn postprocess_answer(answer: &str) -> String {
888 let mut result = answer.trim().to_string();
889
890 let artifacts = [
892 "Based on the provided context,",
893 "According to the context,",
894 "From the context provided,",
895 "The context shows that",
896 "Based on the information provided,",
897 ];
898 for artifact in artifacts {
899 if result.starts_with(artifact) {
900 result = result[artifact.len()..].trim_start().to_string();
901 if let Some(first) = result.chars().next() {
903 result = first.to_uppercase().chain(result.chars().skip(1)).collect();
904 }
905 }
906 }
907
908 result = result.split_whitespace().collect::<Vec<_>>().join(" ");
910
911 if let Some(first) = result.chars().next() {
913 if first.is_lowercase() && !result.starts_with("i ") {
914 result = first.to_uppercase().chain(result.chars().skip(1)).collect();
915 }
916 }
917
918 result
919}
920
921fn trim_to(text: &str, limit: usize) -> String {
922 if text.len() <= limit {
923 text.to_string()
924 } else {
925 let mut truncated = text[..limit].to_string();
926 truncated.push_str("...");
927 truncated
928 }
929}
930
931fn clamp_to(text: &str, limit: usize) -> String {
932 if text.len() <= limit {
933 text.to_string()
934 } else if limit <= 3 {
935 "...".chars().take(limit).collect()
936 } else {
937 let mut end = limit.saturating_sub(3);
938 while end > 0 && !text.is_char_boundary(end) {
940 end -= 1;
941 }
942 if end == 0 {
943 return "...".to_string();
944 }
945 let mut truncated = text[..end].to_string();
946 truncated.push_str("...");
947 truncated
948 }
949}
950
951struct ThinkingSpinner {
952 flag: Arc<AtomicBool>,
953 handle: Option<thread::JoinHandle<()>>,
954}
955
956impl ThinkingSpinner {
957 fn start() -> Self {
958 let flag = Arc::new(AtomicBool::new(true));
959 let thread_flag = flag.clone();
960
961 let is_tty = stderr().is_terminal();
965
966 let handle = thread::spawn(move || {
967 if !is_tty {
968 while thread_flag.load(Ordering::Relaxed) {
970 thread::sleep(Duration::from_millis(200));
971 }
972 return;
973 }
974
975 let frames = [
976 "Thinking ",
977 "Thinking. ",
978 "Thinking.. ",
979 "Thinking... ",
980 "Thinking .. ",
981 "Thinking . ",
982 ];
983 let mut idx = 0;
984 let mut err = stderr();
985 while thread_flag.load(Ordering::Relaxed) {
986 let frame = frames[idx % frames.len()];
987 let _ = write!(err, "\r{frame}");
988 let _ = err.flush();
989 idx = idx.wrapping_add(1);
990 thread::sleep(Duration::from_millis(200));
991 }
992 let _ = write!(err, "\r \r");
993 let _ = err.flush();
994 });
995
996 Self {
997 flag,
998 handle: Some(handle),
999 }
1000 }
1001
1002 fn stop(&mut self) {
1003 if let Some(handle) = self.handle.take() {
1004 self.flag.store(false, Ordering::Relaxed);
1005 let _ = handle.join();
1006 }
1007 }
1008}
1009
1010impl Drop for ThinkingSpinner {
1011 fn drop(&mut self) {
1012 self.stop();
1013 }
1014}
1015
1016#[derive(Debug, Clone)]
1017enum ModelKind {
1018 TinyLlama,
1019 Ghost { pack_path: PathBuf },
1020 Ollama { model: String },
1021 OpenAi { model: String },
1022 Nvidia { model: String },
1023 Gemini { model: String },
1024 Claude { model: String },
1025 Xai { model: String },
1026 Groq { model: String },
1027 Mistral { model: String },
1028}
1029
1030impl ModelKind {
1031 fn parse(raw: &str) -> Option<Self> {
1032 let trimmed = raw.trim();
1033 if trimmed.is_empty() {
1034 return None;
1035 }
1036
1037 let (provider, explicit_model) = if let Some((p, rest)) = trimmed.split_once(':') {
1038 let value = rest.trim();
1039 let explicit = if value.is_empty() {
1040 None
1041 } else {
1042 Some(value.to_string())
1043 };
1044 (p.trim().to_ascii_lowercase(), explicit)
1045 } else {
1046 (trimmed.to_ascii_lowercase(), None)
1047 };
1048
1049 match provider.as_str() {
1050 "tinyllama" | "tiny-llama" | "tinyllama-1.1b" => Some(Self::TinyLlama),
1051 "ghost" => explicit_model.map(|value| Self::Ghost {
1052 pack_path: PathBuf::from(value),
1053 }),
1054 "ollama" => Some(Self::Ollama {
1055 model: explicit_model.unwrap_or_else(|| "ollama1.5".to_string()),
1056 }),
1057 "ollama1.5" | "ollama1-5" => Some(Self::Ollama {
1058 model: "ollama1.5".to_string(),
1059 }),
1060 "openai" => Some(Self::OpenAi {
1061 model: normalize_openai_model(explicit_model),
1062 }),
1063 "nvidia" | "nv" => Some(Self::Nvidia {
1064 model: normalize_nvidia_model(explicit_model),
1065 }),
1066 "gemini" | "google" => Some(Self::Gemini {
1067 model: normalize_gemini_model(explicit_model),
1068 }),
1069 "claude" | "anthropic" => Some(Self::Claude {
1070 model: normalize_claude_model(explicit_model),
1071 }),
1072 "xai" | "grok" => Some(Self::Xai {
1073 model: normalize_xai_model(explicit_model),
1074 }),
1075 "groq" => Some(Self::Groq {
1076 model: normalize_groq_model(explicit_model),
1077 }),
1078 "mistral" => Some(Self::Mistral {
1079 model: normalize_mistral_model(explicit_model),
1080 }),
1081 _ => Self::infer_from_model_name_full(trimmed, &provider),
1085 }
1086 }
1087
1088 fn infer_from_model_name_full(full_name: &str, prefix: &str) -> Option<Self> {
1091 let lowered = prefix.to_ascii_lowercase();
1092
1093 if lowered.starts_with("gemini") || lowered.starts_with("models/gemini") {
1095 return Some(Self::Gemini {
1096 model: full_name.to_string(),
1097 });
1098 }
1099
1100 if lowered.starts_with("gpt-")
1102 || lowered.starts_with("o1-")
1103 || lowered.starts_with("o3-")
1104 || lowered.starts_with("chatgpt-")
1105 || lowered.starts_with("text-")
1106 {
1107 return Some(Self::OpenAi {
1108 model: full_name.to_string(),
1109 });
1110 }
1111
1112 if lowered.starts_with("claude-") {
1114 return Some(Self::Claude {
1115 model: full_name.to_string(),
1116 });
1117 }
1118
1119 if lowered.starts_with("grok-") {
1121 return Some(Self::Xai {
1122 model: full_name.to_string(),
1123 });
1124 }
1125
1126 if lowered.starts_with("mistral-") {
1128 return Some(Self::Mistral {
1129 model: full_name.to_string(),
1130 });
1131 }
1132
1133 if lowered.starts_with("llama-") || lowered.starts_with("mixtral-") {
1135 return Some(Self::Groq {
1136 model: full_name.to_string(),
1137 });
1138 }
1139
1140 if lowered.starts_with("llama")
1143 || lowered.starts_with("phi")
1144 || lowered.starts_with("codellama")
1145 || lowered.starts_with("deepseek")
1146 || lowered.starts_with("qwen")
1147 || lowered.starts_with("gemma")
1148 {
1149 return Some(Self::Ollama {
1150 model: full_name.to_string(),
1151 });
1152 }
1153
1154 None
1155 }
1156
1157 fn label(&self) -> String {
1158 match self {
1159 Self::TinyLlama => TINYLLAMA_LABEL.to_string(),
1160 Self::Ghost { pack_path } => format!("ghost:{}", pack_path.display()),
1161 Self::Ollama { model } => format!("ollama:{model}"),
1162 Self::OpenAi { model } => format!("openai:{model}"),
1163 Self::Nvidia { model } => format!("nvidia:{model}"),
1164 Self::Gemini { model } => format!("gemini:{model}"),
1165 Self::Claude { model } => format!("claude:{model}"),
1166 Self::Xai { model } => format!("xai:{model}"),
1167 Self::Groq { model } => format!("groq:{model}"),
1168 Self::Mistral { model } => format!("mistral:{model}"),
1169 }
1170 }
1171
1172 fn context_budget(&self) -> ModelContextBudget {
1173 match self {
1174 Self::TinyLlama => {
1175 ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
1176 }
1177 Self::Ghost { .. } => {
1178 ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
1179 }
1180 Self::Ollama { .. } => {
1181 ModelContextBudget::new(OLLAMA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1182 }
1183 Self::OpenAi { .. } => {
1184 ModelContextBudget::new(OPENAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1185 }
1186 Self::Nvidia { .. } => {
1187 ModelContextBudget::new(NVIDIA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1188 }
1189 Self::Gemini { .. } => {
1190 ModelContextBudget::new(GEMINI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1191 }
1192 Self::Claude { .. } => {
1193 ModelContextBudget::new(CLAUDE_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1194 }
1195 Self::Xai { .. } => {
1196 ModelContextBudget::new(XAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1197 }
1198 Self::Groq { .. } => {
1199 ModelContextBudget::new(GROQ_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1200 }
1201 Self::Mistral { .. } => {
1202 ModelContextBudget::new(MISTRAL_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1203 }
1204 }
1205 }
1206
1207 fn max_output_tokens(&self) -> usize {
1208 match self {
1209 Self::TinyLlama => LOCAL_MAX_OUTPUT_TOKENS,
1210 Self::Ghost { .. } => LOCAL_MAX_OUTPUT_TOKENS,
1211 Self::Ollama { .. }
1212 | Self::OpenAi { .. }
1213 | Self::Nvidia { .. }
1214 | Self::Gemini { .. }
1215 | Self::Claude { .. }
1216 | Self::Xai { .. }
1217 | Self::Groq { .. }
1218 | Self::Mistral { .. } => REMOTE_MAX_OUTPUT_TOKENS,
1219 }
1220 }
1221}
1222
1223fn normalize_openai_model(explicit: Option<String>) -> String {
1224 match explicit {
1225 Some(raw) if !raw.trim().is_empty() => raw,
1226 _ => "gpt-4o-mini".to_string(),
1227 }
1228}
1229
1230fn normalize_nvidia_model(explicit: Option<String>) -> String {
1231 match explicit {
1232 Some(raw) if !raw.trim().is_empty() => raw,
1233 _ => std::env::var("NVIDIA_LLM_MODEL")
1234 .or_else(|_| std::env::var("NVIDIA_MODEL"))
1235 .ok()
1236 .map(|value| value.trim().to_string())
1237 .filter(|value| !value.is_empty())
1238 .unwrap_or_default(),
1239 }
1240}
1241
1242fn normalize_gemini_model(explicit: Option<String>) -> String {
1243 let default_model = "gemini-2.5-flash".to_string();
1244 let Some(raw) = explicit else {
1245 return default_model;
1246 };
1247
1248 let lowered = raw.to_ascii_lowercase();
1249 match lowered.as_str() {
1250 "gemini-pro" | "gemini-1.5-pro" | "gemini-1.5-flash" | "gemini-2.0-pro-exp" => raw,
1251 _ => raw,
1252 }
1253}
1254
1255fn normalize_claude_model(explicit: Option<String>) -> String {
1256 let default_model = "claude-sonnet-4-5".to_string();
1257 let Some(raw) = explicit else {
1258 return default_model;
1259 };
1260
1261 match raw.as_str() {
1263 "claude-3-5-sonnet-20241022" | "claude-3.5-sonnet" | "sonnet" => {
1264 "claude-sonnet-4-5".to_string()
1265 }
1266 "claude-3-haiku-20240307" | "claude-3-haiku" | "haiku" => "claude-haiku-4-5".to_string(),
1267 "claude-3-opus-20240229" | "claude-3-opus" | "opus" => "claude-opus-4".to_string(),
1268 _ => raw,
1269 }
1270}
1271
1272fn normalize_xai_model(explicit: Option<String>) -> String {
1273 let default_model = "grok-4-fast".to_string();
1274 let Some(raw) = explicit else {
1275 return default_model;
1276 };
1277
1278 match raw.to_lowercase().as_str() {
1280 "grok" | "grok-fast" => "grok-4-fast".to_string(),
1281 "grok-4" | "grok-3" | "grok-4-fast" => raw, _ => raw,
1283 }
1284}
1285
1286fn normalize_groq_model(explicit: Option<String>) -> String {
1287 let default_model = "llama-3.3-70b-versatile".to_string();
1288 let Some(raw) = explicit else {
1289 return default_model;
1290 };
1291
1292 match raw.to_lowercase().as_str() {
1294 "llama" | "llama3" | "llama-3" => "llama-3.3-70b-versatile".to_string(),
1295 "llama-70b" | "llama3-70b" => "llama-3.3-70b-versatile".to_string(),
1296 "llama-8b" | "llama3-8b" => "llama-3.1-8b-instant".to_string(),
1297 "mixtral" => "mixtral-8x7b-32768".to_string(),
1298 _ => raw,
1299 }
1300}
1301
1302fn normalize_mistral_model(explicit: Option<String>) -> String {
1303 let default_model = "mistral-large-latest".to_string();
1304 let Some(raw) = explicit else {
1305 return default_model;
1306 };
1307
1308 match raw.to_lowercase().as_str() {
1310 "mistral" | "large" | "mistral-large" => "mistral-large-latest".to_string(),
1311 "medium" | "mistral-medium" => "mistral-medium-latest".to_string(),
1312 "small" | "mistral-small" => "mistral-small-latest".to_string(),
1313 _ => raw,
1314 }
1315}
1316
1317pub fn calculate_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
1320 let (input_price, output_price) = match model.to_lowercase().as_str() {
1321 m if m.contains("gpt-4o-mini") => (0.15, 0.60),
1323 m if m.contains("gpt-4o") => (2.50, 10.00),
1324 m if m.contains("gpt-4.5") => (75.00, 150.00),
1325 m if m.contains("gpt-4.1-mini") => (0.40, 1.60),
1326 m if m.contains("gpt-4.1") => (2.00, 8.00),
1327 m if m.contains("gpt-5.2") => (1.75, 14.00),
1328 m if m.contains("gpt-5") => (1.75, 14.00),
1329 m if m.contains("gpt-4-turbo") => (10.00, 30.00),
1330 m if m.contains("gpt-4") => (30.00, 60.00),
1331 m if m.contains("gpt-3.5") => (0.50, 1.50),
1332 m if m.contains("o1") || m.contains("o3") => (15.00, 60.00),
1333
1334 m if m.contains("claude-4-opus") || m.contains("claude-opus-4") => (15.00, 75.00),
1336 m if m.contains("claude-4-sonnet") || m.contains("claude-sonnet-4") => (3.00, 15.00),
1337 m if m.contains("claude-4-haiku") || m.contains("claude-haiku-4") => (0.25, 1.25),
1338 m if m.contains("claude-3-5-sonnet") || m.contains("claude-3.5-sonnet") => (3.00, 15.00),
1339 m if m.contains("claude-3-opus") => (15.00, 75.00),
1340 m if m.contains("claude-3-sonnet") => (3.00, 15.00),
1341 m if m.contains("claude-3-haiku") => (0.25, 1.25),
1342 m if m.contains("claude") => (3.00, 15.00), m if m.contains("gemini-2.5-flash") => (0.15, 3.50),
1346 m if m.contains("gemini-2.5-pro") => (1.25, 10.00),
1347 m if m.contains("gemini-2.0") => (0.10, 0.40),
1348 m if m.contains("gemini-1.5-pro") => (1.25, 5.00),
1349 m if m.contains("gemini-1.5-flash") => (0.075, 0.30),
1350 m if m.contains("gemini") => (0.15, 3.50), m if m.contains("grok-4-fast") => (0.20, 0.50),
1354 m if m.contains("grok-4") => (3.00, 15.00),
1355 m if m.contains("grok-3") => (3.00, 15.00),
1356 m if m.contains("grok") => (3.00, 15.00),
1357
1358 m if m.contains("llama-3.3-70b") => (0.59, 0.79),
1360 m if m.contains("llama-3.1-70b") => (0.59, 0.79),
1361 m if m.contains("llama-3.1-8b") => (0.05, 0.08),
1362 m if m.contains("mixtral-8x7b") => (0.24, 0.24),
1363
1364 m if m.contains("mistral-large-3") || m.contains("mistral-large-latest") => (0.50, 1.50),
1366 m if m.contains("mistral-large") => (2.00, 6.00),
1367 m if m.contains("mistral-medium") => (0.40, 1.20),
1368 m if m.contains("mistral-small") => (0.10, 0.30),
1369 m if m.contains("mistral") => (0.50, 1.50),
1370
1371 m if m.contains("deepseek-v3") || m.contains("deepseek") => (0.27, 1.10),
1373
1374 m if m.contains("nvidia") => (1.00, 3.00),
1376
1377 m if m.contains("ollama") || m.contains("tinyllama") => (0.0, 0.0),
1379
1380 _ => (1.00, 3.00),
1382 };
1383
1384 let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price;
1385 let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price;
1386 input_cost + output_cost
1387}
1388
1389struct ProviderResult {
1391 answer: String,
1392 usage: Option<TokenUsage>,
1393}
1394
1395const RELEVANCE_THRESHOLD: f32 = 0.0;
1398
1399pub fn run_model_inference(
1400 requested_model: &str,
1401 question: &str,
1402 fallback_context: &str,
1403 hits: &[SearchHit],
1404 context_override: Option<usize>,
1405 api_key: Option<&str>,
1406 system_prompt_override: Option<&str>,
1407) -> Result<ModelInference, ModelRunError> {
1408 let Some(model_kind) = ModelKind::parse(requested_model) else {
1409 return Err(ModelRunError::UnsupportedModel(requested_model.to_string()));
1410 };
1411
1412 let top_score = hits.first().and_then(|h| h.score).unwrap_or(0.0);
1414 if hits.is_empty() || top_score < RELEVANCE_THRESHOLD {
1415 let mut topics: Vec<String> = hits
1417 .iter()
1418 .take(5)
1419 .filter_map(|h| h.title.clone())
1420 .collect();
1421 topics.dedup();
1422
1423 let suggestions = if topics.is_empty() {
1424 "Try asking about the topics in your memory file.".to_string()
1425 } else {
1426 format!(
1427 "Your memory contains information about: {}. Try asking about these topics.",
1428 topics.join(", ")
1429 )
1430 };
1431
1432 let no_match_answer = format!(
1433 "No relevant information found for your question.\n\n{}\n\nRelevance score: {:.2} (threshold: {:.2})",
1434 suggestions, top_score, RELEVANCE_THRESHOLD
1435 );
1436
1437 return Ok(ModelInference {
1438 answer: ModelAnswer {
1439 requested: requested_model.to_string(),
1440 model: "none".to_string(),
1441 answer: no_match_answer,
1442 },
1443 context_body: String::new(),
1444 context_fragments: Vec::new(),
1445 usage: Some(TokenUsage {
1446 input_tokens: 0,
1447 output_tokens: 0,
1448 total_tokens: 0,
1449 cost_usd: 0.0,
1450 }),
1451 grounding: Some(GroundingResult {
1452 score: 0.0,
1453 sentence_count: 0,
1454 grounded_sentences: 0,
1455 sentence_scores: Vec::new(),
1456 has_warning: true,
1457 warning_reason: Some(
1458 "No relevant information found - retrieval score below threshold".to_string(),
1459 ),
1460 }),
1461 cached: false,
1462 });
1463 }
1464
1465 let mut budget = model_kind.context_budget();
1466 if let Some(override_chars) = context_override {
1467 budget = budget.apply_override(override_chars);
1468 }
1469
1470 let context_plan = context::assemble_context(hits, fallback_context, &budget);
1471
1472 if let Some(cached) = cache::check_cache(question, &context_plan.body, &model_kind.label()) {
1474 let grounding = Some(GroundingResult {
1475 score: cached.grounding_score,
1476 sentence_count: 0,
1477 grounded_sentences: 0,
1478 sentence_scores: Vec::new(),
1479 has_warning: cached.grounding_score < 0.4,
1480 warning_reason: if cached.grounding_score < 0.4 {
1481 Some("Cached answer - original grounding was low".to_string())
1482 } else {
1483 None
1484 },
1485 });
1486
1487 let context_fragments = context_plan
1488 .records
1489 .into_iter()
1490 .map(ModelContextFragment::from_record)
1491 .collect();
1492
1493 return Ok(ModelInference {
1494 answer: ModelAnswer {
1495 requested: requested_model.to_string(),
1496 model: cached.model.clone(),
1497 answer: cached.answer.clone(),
1498 },
1499 context_body: context_plan.body,
1500 context_fragments,
1501 usage: Some(TokenUsage {
1502 input_tokens: cached.input_tokens,
1503 output_tokens: cached.output_tokens,
1504 total_tokens: cached.input_tokens + cached.output_tokens,
1505 cost_usd: 0.0, }),
1507 grounding,
1508 cached: true,
1509 });
1510 }
1511
1512 let prompt = build_prompt_parts(
1513 question,
1514 &context_plan.body,
1515 &budget,
1516 model_kind.max_output_tokens(),
1517 );
1518
1519 let result = match &model_kind {
1520 ModelKind::TinyLlama => {
1521 #[cfg(feature = "llama-cpp")]
1522 {
1523 ProviderResult {
1524 answer: tinyllama::run(&prompt)?,
1525 usage: None, }
1527 }
1528 #[cfg(not(feature = "llama-cpp"))]
1529 {
1530 return Err(ModelRunError::UnsupportedModel(
1531 "tinyllama (llama-cpp feature not enabled)".to_string(),
1532 ));
1533 }
1534 }
1535 ModelKind::Ghost { pack_path } => {
1536 return Err(ModelRunError::UnsupportedModel(format!(
1537 "ghost model '{}' (ghost runtime not yet available)",
1538 pack_path.display()
1539 )));
1540 }
1541 ModelKind::Ollama { model } => ProviderResult {
1542 answer: ollama::run(model, &prompt)?,
1543 usage: None, },
1545 ModelKind::OpenAi { model } => {
1546 openai::run(model, &prompt, api_key, system_prompt_override)?
1547 }
1548 ModelKind::Nvidia { model } => ProviderResult {
1549 answer: nvidia::run(model, &prompt, api_key, system_prompt_override)?,
1550 usage: None, },
1552 ModelKind::Gemini { model } => {
1553 gemini::run(model, &prompt, api_key, system_prompt_override)?
1554 }
1555 ModelKind::Claude { model } => {
1556 claude::run(model, &prompt, api_key, system_prompt_override)?
1557 }
1558 ModelKind::Xai { model } => xai::run(model, &prompt, api_key, system_prompt_override)?,
1559 ModelKind::Groq { model } => groq::run(model, &prompt, api_key, system_prompt_override)?,
1560 ModelKind::Mistral { model } => {
1561 mistral::run(model, &prompt, api_key, system_prompt_override)?
1562 }
1563 };
1564
1565 let context::ContextAggregation {
1566 body: context_body,
1567 records,
1568 } = context_plan;
1569 let context_fragments = records
1570 .into_iter()
1571 .map(ModelContextFragment::from_record)
1572 .collect();
1573
1574 let grounding = Some(verify_grounding(&result.answer, &context_body));
1576
1577 let grounding_score = grounding.as_ref().map(|g| g.score).unwrap_or(0.5);
1579 let (input_tokens, output_tokens, cost_usd) = result
1580 .usage
1581 .as_ref()
1582 .map(|u| (u.input_tokens, u.output_tokens, u.cost_usd))
1583 .unwrap_or((0, 0, 0.0));
1584
1585 cache::store_in_cache(
1586 question,
1587 &context_body,
1588 &model_kind.label(),
1589 cache::CacheEntry {
1590 answer: result.answer.clone(),
1591 model: model_kind.label(),
1592 input_tokens,
1593 output_tokens,
1594 cost_usd,
1595 grounding_score,
1596 created_at: std::time::SystemTime::now(),
1597 },
1598 );
1599
1600 let processed_answer = postprocess_answer(&result.answer);
1602
1603 Ok(ModelInference {
1604 answer: ModelAnswer {
1605 requested: requested_model.to_string(),
1606 model: model_kind.label(),
1607 answer: processed_answer,
1608 },
1609 context_body,
1610 context_fragments,
1611 usage: result.usage,
1612 grounding,
1613 cached: false,
1614 })
1615}
1616
1617pub fn verify_grounding(answer: &str, context: &str) -> GroundingResult {
1621 use std::collections::HashSet;
1622
1623 if answer.is_empty() {
1624 return GroundingResult {
1625 score: 1.0, sentence_count: 0,
1627 grounded_sentences: 0,
1628 sentence_scores: Vec::new(),
1629 has_warning: false,
1630 warning_reason: None,
1631 };
1632 }
1633
1634 if context.is_empty() {
1635 return GroundingResult {
1636 score: 0.0,
1637 sentence_count: 1,
1638 grounded_sentences: 0,
1639 sentence_scores: vec![0.0],
1640 has_warning: true,
1641 warning_reason: Some("No context provided - answer may be hallucinated".to_string()),
1642 };
1643 }
1644
1645 let context_lower = context.to_lowercase();
1647 let context_words: HashSet<&str> = context_lower
1648 .split(|c: char| !c.is_alphanumeric())
1649 .filter(|w| w.len() > 2)
1650 .collect();
1651
1652 let sentences: Vec<&str> = answer
1654 .split(|c| c == '.' || c == '!' || c == '?')
1655 .map(|s| s.trim())
1656 .filter(|s| !s.is_empty() && s.len() > 10)
1657 .collect();
1658
1659 if sentences.is_empty() {
1660 return GroundingResult {
1661 score: 0.5, sentence_count: 0,
1663 grounded_sentences: 0,
1664 sentence_scores: Vec::new(),
1665 has_warning: false,
1666 warning_reason: None,
1667 };
1668 }
1669
1670 let mut sentence_scores = Vec::with_capacity(sentences.len());
1671 let mut grounded_count = 0;
1672
1673 for sentence in &sentences {
1674 let sentence_lower = sentence.to_lowercase();
1675 let sentence_words: HashSet<&str> = sentence_lower
1676 .split(|c: char| !c.is_alphanumeric())
1677 .filter(|w| w.len() > 2)
1678 .collect();
1679
1680 if sentence_words.is_empty() {
1681 sentence_scores.push(0.5);
1682 continue;
1683 }
1684
1685 let overlap: usize = sentence_words.intersection(&context_words).count();
1687 let score = (overlap as f32) / (sentence_words.len() as f32).max(1.0);
1688
1689 let phrase_bonus = if context_lower.contains(&sentence_lower) {
1691 0.3
1692 } else {
1693 let words: Vec<&str> = sentence_lower.split_whitespace().collect();
1695 if words.len() >= 3 {
1696 let phrase = words[..3.min(words.len())].join(" ");
1697 if context_lower.contains(&phrase) {
1698 0.15
1699 } else {
1700 0.0
1701 }
1702 } else {
1703 0.0
1704 }
1705 };
1706
1707 let final_score = (score + phrase_bonus).min(1.0);
1708 sentence_scores.push(final_score);
1709
1710 if final_score >= 0.3 {
1711 grounded_count += 1;
1712 }
1713 }
1714
1715 let overall_score = if sentence_scores.is_empty() {
1716 0.5
1717 } else {
1718 sentence_scores.iter().sum::<f32>() / sentence_scores.len() as f32
1719 };
1720
1721 let (has_warning, warning_reason) = if overall_score < 0.2 {
1723 (
1724 true,
1725 Some("Answer appears to be poorly grounded in context".to_string()),
1726 )
1727 } else if overall_score < 0.4 && grounded_count < sentences.len() / 2 {
1728 (
1729 true,
1730 Some("Some statements may not be supported by context".to_string()),
1731 )
1732 } else {
1733 (false, None)
1734 };
1735
1736 GroundingResult {
1737 score: overall_score,
1738 sentence_count: sentences.len(),
1739 grounded_sentences: grounded_count,
1740 sentence_scores,
1741 has_warning,
1742 warning_reason,
1743 }
1744}
1745
1746mod context {
1747 use super::{ModelContextBudget, clamp_to};
1748 use memvid_core::types::SearchHit;
1749
1750 const CONTEXT_HEADER: &str = "## Retrieval Context\n";
1751 const PRIMARY_HEADER: &str = "### Primary Hit\n";
1752 const CORRECTION_WARNING: &str = r#"
1753╔══════════════════════════════════════════════════════════════════╗
1754║ 🔴 USER CORRECTION - THIS IS THE AUTHORITATIVE ANSWER ║
1755║ Any contradicting information below is OUTDATED and WRONG. ║
1756║ YOU MUST USE THE ANSWER FROM THIS CORRECTION. ║
1757╚══════════════════════════════════════════════════════════════════╝
1758"#;
1759 const SUPPORT_HEADER: &str = "### Supporting Hits\n";
1760 const SUMMARY_HEADER: &str = "### Overflow Summaries\n";
1761 const SUMMARY_HIGHLIGHT_CHARS: usize = 240;
1762 #[allow(dead_code)]
1764 const MICRO_SUMMARY_CHARS: usize = 80;
1765
1766 #[derive(Debug, Clone)]
1767 pub(super) struct ContextAggregation {
1768 pub body: String,
1769 pub records: Vec<ContextRecord>,
1770 }
1771
1772 impl ContextAggregation {
1773 fn from_fallback(fallback: &str, limit: usize) -> Self {
1774 let body = if limit == 0 || fallback.is_empty() {
1775 String::new()
1776 } else if fallback.len() <= limit {
1777 fallback.to_string()
1778 } else {
1779 clamp_to(fallback, limit)
1780 };
1781 Self {
1782 body,
1783 records: Vec::new(),
1784 }
1785 }
1786 }
1787
1788 #[derive(Debug, Clone)]
1789 pub(super) struct ContextRecord {
1790 pub rank: usize,
1791 pub uri: String,
1792 pub title: Option<String>,
1793 pub score: Option<f32>,
1794 pub matches: usize,
1795 pub frame_id: u64,
1796 pub range: (usize, usize),
1797 pub chunk_range: Option<(usize, usize)>,
1798 pub text: String,
1799 pub mode: ContextMode,
1800 }
1801
1802 #[derive(Debug, Clone, Copy, Eq, PartialEq)]
1803 pub(super) enum ContextMode {
1804 Full,
1805 Summary,
1806 }
1807
1808 #[derive(Debug, Clone)]
1809 pub(super) struct ContextAssemblyPlan {
1810 primary: Option<ContextRecord>,
1811 supporting: Vec<ContextRecord>,
1812 summaries: Vec<ContextRecord>,
1813 }
1814
1815 pub(super) fn assemble_context(
1816 hits: &[SearchHit],
1817 fallback: &str,
1818 budget: &ModelContextBudget,
1819 ) -> ContextAggregation {
1820 if hits.is_empty() {
1821 return ContextAggregation::from_fallback(fallback, budget.context_chars());
1822 }
1823
1824 let plan = assemble_plan(hits, budget.context_chars());
1825 let mut body = String::new();
1826 let mut records = Vec::new();
1827
1828 body.push_str(CONTEXT_HEADER);
1829 let primary_is_correction = plan
1831 .primary
1832 .as_ref()
1833 .map(|p| p.uri.contains("mv2://correction/"))
1834 .unwrap_or(false);
1835 if let Some(primary) = plan.primary {
1836 body.push_str(PRIMARY_HEADER);
1837 if primary_is_correction {
1839 body.push_str(CORRECTION_WARNING);
1840 }
1841 body.push_str(&primary.text);
1842 body.push_str("\n\n");
1843 records.push(primary);
1844 }
1845
1846 if !plan.supporting.is_empty() {
1847 body.push_str(SUPPORT_HEADER);
1848 if primary_is_correction {
1849 body.push_str("⚠️ **WARNING: The following sources may contain OUTDATED information. Use the correction above.**\n\n");
1850 }
1851 for record in plan.supporting {
1852 if primary_is_correction && record.uri.contains("mv2://correction/") {
1855 continue;
1856 }
1857 body.push_str(&record.text);
1858 body.push_str("\n\n");
1859 records.push(record);
1860 }
1861 }
1862
1863 if !plan.summaries.is_empty() {
1864 body.push_str(SUMMARY_HEADER);
1865 for record in plan.summaries {
1866 body.push_str(&record.text);
1867 body.push_str("\n\n");
1868 records.push(record);
1869 }
1870 }
1871
1872 ContextAggregation { body, records }
1873 }
1874
1875 fn assemble_plan(hits: &[SearchHit], mut remaining_chars: usize) -> ContextAssemblyPlan {
1876 let mut records = Vec::new();
1877 for hit in hits.iter().take(32) {
1878 let full_record = build_record(hit, render_full(hit), ContextMode::Full);
1879 let summary_record = build_record(hit, render_summary(hit), ContextMode::Summary);
1880 let micro_record = build_record(hit, render_micro_summary(hit), ContextMode::Summary);
1881 records.push((full_record, summary_record, micro_record));
1882 }
1883
1884 let mut plan = ContextAssemblyPlan {
1885 primary: None,
1886 supporting: Vec::new(),
1887 summaries: Vec::new(),
1888 };
1889
1890 if let Some((primary_full, primary_summary, primary_micro)) = records.first() {
1892 if primary_full.text.len() <= remaining_chars {
1893 remaining_chars = remaining_chars.saturating_sub(primary_full.text.len());
1894 plan.primary = Some(primary_full.clone());
1895 } else if primary_summary.text.len() <= remaining_chars {
1896 remaining_chars = remaining_chars.saturating_sub(primary_summary.text.len());
1898 plan.primary = Some(primary_summary.clone());
1899 } else if primary_micro.text.len() <= remaining_chars {
1900 remaining_chars = remaining_chars.saturating_sub(primary_micro.text.len());
1902 plan.primary = Some(primary_micro.clone());
1903 }
1904 }
1905
1906 for (idx, (full, summary, micro)) in records.iter().enumerate() {
1908 if idx == 0 {
1909 continue;
1910 }
1911
1912 if full.text.len() <= remaining_chars {
1913 remaining_chars = remaining_chars.saturating_sub(full.text.len());
1914 plan.supporting.push(full.clone());
1915 } else if summary.text.len() <= remaining_chars {
1916 remaining_chars = remaining_chars.saturating_sub(summary.text.len());
1917 plan.summaries.push(summary.clone());
1918 } else if micro.text.len() <= remaining_chars {
1919 remaining_chars = remaining_chars.saturating_sub(micro.text.len());
1922 plan.summaries.push(micro.clone());
1923 }
1924 }
1926
1927 plan
1928 }
1929
1930 fn render_full(hit: &SearchHit) -> String {
1931 let content = hit
1932 .chunk_text
1933 .clone()
1934 .or_else(|| Some(hit.text.clone()))
1935 .unwrap_or_default();
1936
1937 let clean_content = clean_text_for_llm(&content);
1939
1940 let title = hit.title.clone().unwrap_or_default();
1942 let source_info = if title.is_empty() {
1943 format!("[Source #{}]", hit.rank)
1944 } else {
1945 format!("[Source #{}: {}]", hit.rank, title)
1946 };
1947
1948 let relevance = match hit.score {
1950 Some(s) if s > 0.8 => "⬤ High relevance",
1951 Some(s) if s > 0.5 => "◐ Medium relevance",
1952 _ => "",
1953 };
1954
1955 if relevance.is_empty() {
1956 format!("{}\n{}", source_info, clean_content)
1957 } else {
1958 format!("{} ({})\n{}", source_info, relevance, clean_content)
1959 }
1960 }
1961
1962 fn render_summary(hit: &SearchHit) -> String {
1963 let snippet = hit
1964 .chunk_text
1965 .clone()
1966 .or_else(|| Some(hit.text.clone()))
1967 .unwrap_or_default();
1968 let snippet = trim_highlight(&snippet, SUMMARY_HIGHLIGHT_CHARS);
1969 let clean_snippet = clean_text_for_llm(&snippet);
1970 format!("[Source #{}] {}", hit.rank, clean_snippet)
1971 }
1972
1973 fn render_micro_summary(hit: &SearchHit) -> String {
1976 let title = hit.title.clone().unwrap_or_else(|| "untitled".to_string());
1977 let title_truncated = clamp_to(&title, 40);
1978 format!("[#{}: {}] ...", hit.rank, title_truncated)
1980 }
1981
1982 fn clean_text_for_llm(text: &str) -> String {
1984 let mut result = text.to_string();
1985
1986 result = result
1988 .lines()
1989 .map(|line| line.trim())
1990 .filter(|line| !line.is_empty())
1991 .collect::<Vec<_>>()
1992 .join("\n");
1993
1994 result = result
1996 .replace("\u{2018}", "'") .replace("\u{2019}", "'") .replace("\u{201C}", "\"") .replace("\u{201D}", "\"") .replace("\u{2013}", "-") .replace("\u{2014}", "-"); result = result
2005 .chars()
2006 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
2007 .collect();
2008
2009 result
2010 }
2011
2012 fn trim_highlight(text: &str, limit: usize) -> String {
2013 let clean = text.replace('\n', " ");
2014 clamp_to(&clean, limit)
2015 }
2016
2017 fn build_record(hit: &SearchHit, text: String, mode: ContextMode) -> ContextRecord {
2018 ContextRecord {
2019 rank: hit.rank,
2020 uri: hit.uri.clone(),
2021 title: hit.title.clone(),
2022 score: hit.score,
2023 matches: hit.matches,
2024 frame_id: hit.frame_id,
2025 range: hit.range,
2026 chunk_range: hit.chunk_range,
2027 text,
2028 mode,
2029 }
2030 }
2031}
2032
2033#[cfg(feature = "llama-cpp")]
2034mod tinyllama {
2035 use super::{ModelRunError, PromptParts, TINYLLAMA_LABEL, ThinkingSpinner};
2036 use anyhow::anyhow;
2037 use llama_cpp::standard_sampler::StandardSampler;
2038 use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
2039 use tokio::runtime::Builder;
2040
2041 use std::path::{Path, PathBuf};
2042
2043 const MODEL_DIR: &str = "models/tinyllama";
2044 const GGUF_HINT: &str = "*.gguf";
2045
2046 pub(super) fn run(prompt: &PromptParts) -> Result<String, ModelRunError> {
2047 let base_dir = Path::new(MODEL_DIR);
2048 let assets = RequiredAssets::new(base_dir);
2049
2050 if let Some(missing) = assets.missing_paths() {
2051 return Err(ModelRunError::AssetsMissing {
2052 model: TINYLLAMA_LABEL.to_string(),
2053 missing,
2054 });
2055 }
2056
2057 let gguf_path = assets.gguf_path.clone().ok_or_else(|| {
2058 ModelRunError::Runtime(anyhow!(
2059 "no GGUF model file found in {}",
2060 base_dir.display()
2061 ))
2062 })?;
2063
2064 unsafe {
2065 std::env::set_var("GGML_LOG_LEVEL", "ERROR");
2066 std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
2067 }
2068
2069 let model =
2070 LlamaModel::load_from_file(&gguf_path, LlamaParams::default()).map_err(|err| {
2071 ModelRunError::Runtime(anyhow!(
2072 "failed to load TinyLlama weights from {}: {err}",
2073 gguf_path.display()
2074 ))
2075 })?;
2076
2077 let mut session_params = SessionParams::default();
2078 if session_params.n_ctx == 0 {
2079 session_params.n_ctx = 2048;
2080 }
2081 session_params.n_batch = session_params.n_ctx.min(512);
2082 if session_params.n_ubatch == 0 {
2083 session_params.n_ubatch = 512;
2084 }
2085 let max_tokens = session_params.n_ctx as usize;
2086 let mut session = model.create_session(session_params).map_err(|err| {
2087 ModelRunError::Runtime(anyhow!("failed to create TinyLlama session: {err}"))
2088 })?;
2089
2090 let mut priming_tokens = model
2091 .tokenize_bytes(prompt.completion_prompt().as_bytes(), true, true)
2092 .map_err(|err| {
2093 ModelRunError::Runtime(anyhow!("failed to tokenize TinyLlama prompt: {err}"))
2094 })?;
2095
2096 let requested_tokens = prompt.max_output_tokens();
2097 if max_tokens > 0 {
2098 let reserved = requested_tokens + 64;
2099 if priming_tokens.len() >= max_tokens.saturating_sub(reserved) {
2100 let target = max_tokens.saturating_sub(reserved).max(1);
2101 let tail_start = priming_tokens.len().saturating_sub(target);
2102 priming_tokens = priming_tokens.split_off(tail_start);
2103 }
2104 }
2105
2106 session
2107 .advance_context_with_tokens(&priming_tokens)
2108 .map_err(|err| {
2109 ModelRunError::Runtime(anyhow!("failed to prime TinyLlama context: {err}"))
2110 })?;
2111
2112 let handle = session
2113 .start_completing_with(StandardSampler::default(), requested_tokens)
2114 .map_err(|err| ModelRunError::Runtime(anyhow!("completion failed to start: {err}")))?;
2115
2116 let runtime = Builder::new_current_thread()
2117 .enable_all()
2118 .build()
2119 .map_err(|err| {
2120 ModelRunError::Runtime(anyhow!("failed to build tokio runtime: {err}"))
2121 })?;
2122
2123 let mut spinner = ThinkingSpinner::start();
2124 let generated = runtime.block_on(async { handle.into_string_async().await });
2125 spinner.stop();
2126
2127 let answer = generated.trim().to_string();
2128
2129 if answer.is_empty() {
2130 Ok("No answer generated by TinyLlama.".to_string())
2131 } else {
2132 Ok(answer)
2133 }
2134 }
2135
2136 struct RequiredAssets {
2137 gguf_path: Option<PathBuf>,
2138 base_dir: PathBuf,
2139 }
2140
2141 impl RequiredAssets {
2142 fn new(base_dir: &Path) -> Self {
2143 let gguf_path = find_first_gguf(base_dir);
2144 Self {
2145 gguf_path,
2146 base_dir: base_dir.to_path_buf(),
2147 }
2148 }
2149
2150 fn missing_paths(&self) -> Option<Vec<PathBuf>> {
2151 if self.gguf_path.is_some() {
2152 None
2153 } else {
2154 Some(vec![self.base_dir.join(GGUF_HINT)])
2155 }
2156 }
2157 }
2158
2159 fn find_first_gguf(base_dir: &Path) -> Option<PathBuf> {
2160 let mut entries: Vec<PathBuf> = std::fs::read_dir(base_dir)
2161 .ok()?
2162 .filter_map(|entry| entry.ok().map(|e| e.path()))
2163 .filter(|path| path.is_file() && path.extension().map_or(false, |ext| ext == "gguf"))
2164 .collect();
2165 entries.sort();
2166 entries.into_iter().next()
2167 }
2168}
2169
2170mod ollama {
2171 use super::{ModelRunError, PromptParts, ThinkingSpinner};
2172 use anyhow::anyhow;
2173 use reqwest::blocking::Client;
2174 use serde::Deserialize;
2175 use serde_json::json;
2176
2177 const ENDPOINT: &str = "http://127.0.0.1:11434/api/generate";
2178
2179 pub(super) fn run(model: &str, prompt: &PromptParts) -> Result<String, ModelRunError> {
2180 let client = Client::builder()
2181 .timeout(std::time::Duration::from_secs(60))
2182 .build()
2183 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2184
2185 let mut spinner = ThinkingSpinner::start();
2186 let response = client
2187 .post(ENDPOINT)
2188 .json(&json!({
2189 "model": model,
2190 "prompt": prompt.completion_prompt(),
2191 "stream": false
2192 }))
2193 .send()
2194 .map_err(|err| ModelRunError::Runtime(anyhow!("ollama request failed: {err}")))?
2195 .error_for_status()
2196 .map_err(|err| {
2197 ModelRunError::Runtime(anyhow!("ollama returned error status: {err}"))
2198 })?;
2199
2200 let body: GenerateResponse = response.json().map_err(|err| {
2201 ModelRunError::Runtime(anyhow!("failed to decode ollama response: {err}"))
2202 })?;
2203 spinner.stop();
2204
2205 let text = body.response.trim().to_string();
2206 if text.is_empty() {
2207 Ok("No answer returned by Ollama.".to_string())
2208 } else {
2209 Ok(text)
2210 }
2211 }
2212
2213 #[derive(Debug, Deserialize)]
2214 struct GenerateResponse {
2215 #[serde(default)]
2216 response: String,
2217 }
2218}
2219
2220mod openai {
2221 use super::{
2222 ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2223 calculate_cost,
2224 };
2225 use anyhow::anyhow;
2226 use reqwest::blocking::Client;
2227 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2228 use serde::Deserialize;
2229 use serde_json::json;
2230
2231 const CHAT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
2232 const RESPONSES_ENDPOINT: &str = "https://api.openai.com/v1/responses";
2233
2234 pub(super) fn run(
2235 model: &str,
2236 prompt: &PromptParts,
2237 override_key: Option<&str>,
2238 system_prompt_override: Option<&str>,
2239 ) -> Result<ProviderResult, ModelRunError> {
2240 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2241 let key = override_key
2242 .map(|value| value.to_string())
2243 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
2244 .ok_or_else(|| {
2245 ModelRunError::Runtime(anyhow!(
2246 "OPENAI_API_KEY environment variable is required for OpenAI models"
2247 ))
2248 })?;
2249
2250 let mut headers = HeaderMap::new();
2251 headers.insert(
2252 AUTHORIZATION,
2253 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2254 ModelRunError::Runtime(anyhow!("invalid OPENAI_API_KEY header value: {err}"))
2255 })?,
2256 );
2257 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2258
2259 let client = Client::builder()
2260 .no_proxy()
2261 .timeout(std::time::Duration::from_secs(60))
2262 .default_headers(headers)
2263 .build()
2264 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2265
2266 let mut spinner = ThinkingSpinner::start();
2267 let (text, usage) = if requires_responses_api(model) {
2268 let combined_prompt = format!(
2269 "System instructions:\n{}\n\nUser query:\n{}",
2270 system_prompt,
2271 prompt.user_message()
2272 );
2273 let payload = json!({
2274 "model": model,
2275 "input": combined_prompt,
2276 "max_output_tokens": prompt.max_output_tokens() as u32,
2277 "reasoning": {
2278 "effort": "low"
2279 }
2280 });
2281
2282 let response = client
2283 .post(RESPONSES_ENDPOINT)
2284 .json(&payload)
2285 .send()
2286 .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
2287
2288 let status = response.status();
2289 if !status.is_success() {
2290 let body = response
2291 .text()
2292 .unwrap_or_else(|_| "<failed to read body>".to_string());
2293 return Err(ModelRunError::Runtime(anyhow!(
2294 "OpenAI returned error status {status}: {body}"
2295 )));
2296 }
2297
2298 let body: ResponsesResponse = response.json().map_err(|err| {
2299 ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
2300 })?;
2301
2302 let usage = body.usage.as_ref().map(|u| {
2303 let input = u.input_tokens.unwrap_or(0);
2304 let output = u.output_tokens.unwrap_or(0);
2305 TokenUsage {
2306 input_tokens: input,
2307 output_tokens: output,
2308 total_tokens: input + output,
2309 cost_usd: calculate_cost(model, input, output),
2310 }
2311 });
2312 (extract_responses_text(&body), usage)
2313 } else {
2314 let payload = json!({
2315 "model": model,
2316 "messages": [
2317 {"role": "system", "content": system_prompt},
2318 {"role": "user", "content": prompt.user_message()}
2319 ],
2320 "temperature": 0.2,
2321 "max_tokens": prompt.max_output_tokens() as u32
2322 });
2323
2324 let response = client
2325 .post(CHAT_ENDPOINT)
2326 .json(&payload)
2327 .send()
2328 .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
2329
2330 let status = response.status();
2331 if !status.is_success() {
2332 let body = response
2333 .text()
2334 .unwrap_or_else(|_| "<failed to read body>".to_string());
2335 return Err(ModelRunError::Runtime(anyhow!(
2336 "OpenAI returned error status {status}: {body}"
2337 )));
2338 }
2339
2340 let body: ChatResponse = response.json().map_err(|err| {
2341 ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
2342 })?;
2343
2344 let usage = body.usage.as_ref().map(|u| TokenUsage {
2345 input_tokens: u.prompt_tokens,
2346 output_tokens: u.completion_tokens,
2347 total_tokens: u.total_tokens,
2348 cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
2349 });
2350 (extract_chat_text(&body), usage)
2351 };
2352 spinner.stop();
2353 Ok(ProviderResult {
2354 answer: text,
2355 usage,
2356 })
2357 }
2358
2359 #[derive(Debug, Deserialize)]
2360 struct ChatResponse {
2361 choices: Vec<Choice>,
2362 #[serde(default)]
2363 usage: Option<ChatUsage>,
2364 }
2365
2366 #[derive(Debug, Deserialize)]
2367 struct ChatUsage {
2368 prompt_tokens: u32,
2369 completion_tokens: u32,
2370 total_tokens: u32,
2371 }
2372
2373 #[derive(Debug, Deserialize)]
2374 struct Choice {
2375 message: ChatMessage,
2376 }
2377
2378 #[derive(Debug, Deserialize)]
2379 struct ChatMessage {
2380 #[serde(default)]
2381 content: Option<String>,
2382 }
2383
2384 #[derive(Debug, Deserialize)]
2385 struct ResponsesResponse {
2386 #[serde(default)]
2387 output: Vec<ResponseItem>,
2388 #[serde(default)]
2389 output_text: Vec<String>,
2390 #[serde(default)]
2391 usage: Option<ResponsesUsage>,
2392 }
2393
2394 #[derive(Debug, Deserialize)]
2395 struct ResponsesUsage {
2396 #[serde(default)]
2397 input_tokens: Option<u32>,
2398 #[serde(default)]
2399 output_tokens: Option<u32>,
2400 }
2401
2402 #[derive(Debug, Deserialize)]
2403 struct ResponseItem {
2404 #[serde(default)]
2405 content: Vec<ResponseContent>,
2406 }
2407
2408 #[derive(Debug, Deserialize)]
2409 struct ResponseContent {
2410 #[serde(rename = "type")]
2411 kind: String,
2412 #[serde(default)]
2413 text: Option<String>,
2414 }
2415
2416 fn extract_chat_text(body: &ChatResponse) -> String {
2417 body.choices
2418 .iter()
2419 .find_map(|choice| choice.message.content.clone())
2420 .map(|value| value.trim().to_string())
2421 .unwrap_or_else(|| "No answer returned by OpenAI.".to_string())
2422 }
2423
2424 fn extract_responses_text(body: &ResponsesResponse) -> String {
2425 if !body.output_text.is_empty() {
2426 let text = body
2427 .output_text
2428 .iter()
2429 .find(|value| !value.trim().is_empty());
2430 if let Some(text) = text {
2431 return text.trim().to_string();
2432 }
2433 }
2434 for item in &body.output {
2435 for segment in &item.content {
2436 match segment.kind.as_str() {
2437 "output_text" | "text" => {
2438 if let Some(text) = &segment.text {
2439 let trimmed = text.trim();
2440 if !trimmed.is_empty() {
2441 return trimmed.to_string();
2442 }
2443 }
2444 }
2445 _ => {}
2446 }
2447 }
2448 }
2449 "No answer returned by OpenAI.".to_string()
2450 }
2451
2452 fn requires_responses_api(model: &str) -> bool {
2453 let lowered = model.to_ascii_lowercase();
2454 lowered.starts_with("gpt-5") || lowered.contains("gpt-4.1")
2455 }
2456}
2457
2458mod nvidia {
2459 use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
2460 use anyhow::anyhow;
2461 use reqwest::blocking::Client;
2462 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2463 use serde::Deserialize;
2464 use serde_json::json;
2465
2466 pub(super) fn run(
2467 model: &str,
2468 prompt: &PromptParts,
2469 override_key: Option<&str>,
2470 system_prompt_override: Option<&str>,
2471 ) -> Result<String, ModelRunError> {
2472 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2473 let key = override_key
2474 .map(|value| value.to_string())
2475 .or_else(|| std::env::var("NVIDIA_API_KEY").ok())
2476 .ok_or_else(|| {
2477 ModelRunError::Runtime(anyhow!(
2478 "NVIDIA_API_KEY environment variable is required for NVIDIA models"
2479 ))
2480 })?;
2481
2482 let model = model.trim();
2483 if model.is_empty() {
2484 return Err(ModelRunError::Runtime(anyhow!(
2485 "NVIDIA model name required. Use `nvidia:<model>` or set NVIDIA_LLM_MODEL."
2486 )));
2487 }
2488
2489 let base_url = std::env::var("NVIDIA_BASE_URL")
2490 .ok()
2491 .map(|value| value.trim().trim_end_matches('/').to_string())
2492 .filter(|value| !value.is_empty())
2493 .unwrap_or_else(|| "https://integrate.api.nvidia.com".to_string());
2494 let endpoint = format!("{base_url}/v1/chat/completions");
2495
2496 let mut headers = HeaderMap::new();
2497 headers.insert(
2498 AUTHORIZATION,
2499 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2500 ModelRunError::Runtime(anyhow!("invalid NVIDIA_API_KEY header value: {err}"))
2501 })?,
2502 );
2503 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2504
2505 let client = Client::builder()
2506 .timeout(std::time::Duration::from_secs(60))
2507 .default_headers(headers)
2508 .build()
2509 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2510
2511 let payload = json!({
2512 "model": model,
2513 "messages": [
2514 {"role": "system", "content": system_prompt},
2515 {"role": "user", "content": prompt.user_message()}
2516 ],
2517 "temperature": 0.2,
2518 "max_tokens": prompt.max_output_tokens() as u32
2519 });
2520
2521 let mut spinner = ThinkingSpinner::start();
2522 let response = client
2523 .post(endpoint)
2524 .json(&payload)
2525 .send()
2526 .map_err(|err| ModelRunError::Runtime(anyhow!("NVIDIA request failed: {err}")))?;
2527
2528 let status = response.status();
2529 if !status.is_success() {
2530 let body = response
2531 .text()
2532 .unwrap_or_else(|_| "<failed to read body>".to_string());
2533 spinner.stop();
2534 return Err(ModelRunError::Runtime(anyhow!(
2535 "NVIDIA returned error status {status}: {body}"
2536 )));
2537 }
2538
2539 let body: ChatResponse = response.json().map_err(|err| {
2540 ModelRunError::Runtime(anyhow!("failed to decode NVIDIA response: {err}"))
2541 })?;
2542 spinner.stop();
2543
2544 let text = body
2545 .choices
2546 .into_iter()
2547 .find_map(|choice| choice.message.content)
2548 .map(|value| value.trim().to_string())
2549 .unwrap_or_else(|| "No answer returned by NVIDIA.".to_string());
2550
2551 Ok(text)
2552 }
2553
2554 #[derive(Debug, Deserialize)]
2555 struct ChatResponse {
2556 choices: Vec<Choice>,
2557 }
2558
2559 #[derive(Debug, Deserialize)]
2560 struct Choice {
2561 message: ChatMessage,
2562 }
2563
2564 #[derive(Debug, Deserialize)]
2565 struct ChatMessage {
2566 #[serde(default)]
2567 content: Option<String>,
2568 }
2569}
2570
2571mod gemini {
2572 use super::{
2573 ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2574 calculate_cost,
2575 };
2576 use anyhow::anyhow;
2577 use reqwest::blocking::Client;
2578 use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
2579 use serde::Deserialize;
2580 use serde_json::json;
2581
2582 pub(super) fn run(
2583 model: &str,
2584 prompt: &PromptParts,
2585 override_key: Option<&str>,
2586 system_prompt_override: Option<&str>,
2587 ) -> Result<ProviderResult, ModelRunError> {
2588 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2589 let key = override_key
2590 .map(|value| value.to_string())
2591 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
2592 .ok_or_else(|| {
2593 ModelRunError::Runtime(anyhow!(
2594 "GEMINI_API_KEY environment variable is required for Gemini models"
2595 ))
2596 })?;
2597
2598 let url = format!(
2599 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
2600 model
2601 );
2602
2603 let mut headers = HeaderMap::new();
2604 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2605 headers.insert(
2606 HeaderName::from_static("x-goog-api-key"),
2607 HeaderValue::from_str(&key).map_err(|err| {
2608 ModelRunError::Runtime(anyhow!("invalid GEMINI_API_KEY header value: {err}"))
2609 })?,
2610 );
2611
2612 let client = Client::builder()
2613 .timeout(std::time::Duration::from_secs(60))
2614 .default_headers(headers)
2615 .build()
2616 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2617
2618 let payload = json!({
2619 "contents": [{
2620 "parts": [
2621 { "text": system_prompt },
2622 { "text": prompt.user_message() }
2623 ]
2624 }],
2625 "generationConfig": {
2626 "temperature": 0.2,
2627 "maxOutputTokens": prompt.max_output_tokens() as u32,
2628 "topK": 40,
2629 "topP": 0.95
2630 }
2631 });
2632
2633 let mut spinner = ThinkingSpinner::start();
2634 let response = client
2635 .post(url)
2636 .json(&payload)
2637 .send()
2638 .map_err(|err| ModelRunError::Runtime(anyhow!("Gemini request failed: {err}")))?
2639 .error_for_status()
2640 .map_err(|err| {
2641 ModelRunError::Runtime(anyhow!("Gemini returned error status: {err}"))
2642 })?;
2643
2644 let body: GenerateResponse = response.json().map_err(|err| {
2645 ModelRunError::Runtime(anyhow!("failed to decode Gemini response: {err}"))
2646 })?;
2647 spinner.stop();
2648
2649 let text = body
2650 .candidates
2651 .iter()
2652 .flat_map(|candidate| candidate.content.parts.iter())
2653 .find_map(|part| part.text.clone())
2654 .map(|value| value.trim().to_string())
2655 .unwrap_or_else(|| "No answer returned by Gemini.".to_string());
2656
2657 let usage = body.usage_metadata.as_ref().map(|u| {
2658 let input = u.prompt_token_count.unwrap_or(0);
2659 let output = u.candidates_token_count.unwrap_or(0);
2660 TokenUsage {
2661 input_tokens: input,
2662 output_tokens: output,
2663 total_tokens: input + output,
2664 cost_usd: calculate_cost(model, input, output),
2665 }
2666 });
2667
2668 Ok(ProviderResult {
2669 answer: text,
2670 usage,
2671 })
2672 }
2673
2674 #[derive(Debug, Deserialize)]
2675 struct GenerateResponse {
2676 candidates: Vec<Candidate>,
2677 #[serde(default, rename = "usageMetadata")]
2678 usage_metadata: Option<GeminiUsage>,
2679 }
2680
2681 #[derive(Debug, Deserialize)]
2682 struct GeminiUsage {
2683 #[serde(default, rename = "promptTokenCount")]
2684 prompt_token_count: Option<u32>,
2685 #[serde(default, rename = "candidatesTokenCount")]
2686 candidates_token_count: Option<u32>,
2687 }
2688
2689 #[derive(Debug, Deserialize)]
2690 struct Candidate {
2691 content: CandidateContent,
2692 }
2693
2694 #[derive(Debug, Deserialize)]
2695 struct CandidateContent {
2696 parts: Vec<CandidatePart>,
2697 }
2698
2699 #[derive(Debug, Deserialize)]
2700 struct CandidatePart {
2701 #[serde(default)]
2702 text: Option<String>,
2703 }
2704}
2705
2706mod claude {
2707 use super::{
2708 ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2709 calculate_cost,
2710 };
2711 use anyhow::anyhow;
2712 use reqwest::blocking::Client;
2713 use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
2714 use serde::Deserialize;
2715 use serde_json::json;
2716
2717 const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
2718 const API_VERSION: &str = "2023-06-01";
2719
2720 pub(super) fn run(
2721 model: &str,
2722 prompt: &PromptParts,
2723 override_key: Option<&str>,
2724 system_prompt_override: Option<&str>,
2725 ) -> Result<ProviderResult, ModelRunError> {
2726 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2727 let key = override_key
2728 .map(|value| value.to_string())
2729 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
2730 .or_else(|| std::env::var("CLAUDE_API_KEY").ok())
2731 .ok_or_else(|| {
2732 ModelRunError::Runtime(anyhow!(
2733 "ANTHROPIC_API_KEY environment variable is required for Claude models"
2734 ))
2735 })?;
2736
2737 let mut headers = HeaderMap::new();
2738 headers.insert(
2739 HeaderName::from_static("x-api-key"),
2740 HeaderValue::from_str(&key).map_err(|err| {
2741 ModelRunError::Runtime(anyhow!("invalid ANTHROPIC_API_KEY header value: {err}"))
2742 })?,
2743 );
2744 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2745 headers.insert(
2746 HeaderName::from_static("anthropic-version"),
2747 HeaderValue::from_static(API_VERSION),
2748 );
2749
2750 let client = Client::builder()
2751 .timeout(std::time::Duration::from_secs(60))
2752 .default_headers(headers)
2753 .build()
2754 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2755
2756 let payload = json!({
2757 "model": model,
2758 "max_tokens": prompt.max_output_tokens() as u32,
2759 "temperature": 0.2,
2760 "system": system_prompt,
2761 "messages": [{
2762 "role": "user",
2763 "content": [{"type": "text", "text": prompt.user_message()}]
2764 }]
2765 });
2766
2767 let mut spinner = ThinkingSpinner::start();
2768 let response = client
2769 .post(ENDPOINT)
2770 .json(&payload)
2771 .send()
2772 .map_err(|err| ModelRunError::Runtime(anyhow!("Claude request failed: {err}")))?
2773 .error_for_status()
2774 .map_err(|err| {
2775 ModelRunError::Runtime(anyhow!("Claude returned error status: {err}"))
2776 })?;
2777
2778 let body: ClaudeResponse = response.json().map_err(|err| {
2779 ModelRunError::Runtime(anyhow!("failed to decode Claude response: {err}"))
2780 })?;
2781 spinner.stop();
2782
2783 let text = body
2784 .content
2785 .iter()
2786 .find_map(|part| match part {
2787 ContentBlock::Text { text } if !text.trim().is_empty() => {
2788 Some(text.trim().to_string())
2789 }
2790 _ => None,
2791 })
2792 .unwrap_or_else(|| "No answer returned by Claude.".to_string());
2793
2794 let usage = body.usage.as_ref().map(|u| TokenUsage {
2795 input_tokens: u.input_tokens,
2796 output_tokens: u.output_tokens,
2797 total_tokens: u.input_tokens + u.output_tokens,
2798 cost_usd: calculate_cost(model, u.input_tokens, u.output_tokens),
2799 });
2800
2801 Ok(ProviderResult {
2802 answer: text,
2803 usage,
2804 })
2805 }
2806
2807 #[derive(Debug, Deserialize)]
2808 struct ClaudeResponse {
2809 #[serde(default)]
2810 content: Vec<ContentBlock>,
2811 #[serde(default)]
2812 usage: Option<ClaudeUsage>,
2813 }
2814
2815 #[derive(Debug, Deserialize)]
2816 struct ClaudeUsage {
2817 input_tokens: u32,
2818 output_tokens: u32,
2819 }
2820
2821 #[derive(Debug, Deserialize)]
2822 #[serde(tag = "type", rename_all = "lowercase")]
2823 enum ContentBlock {
2824 Text {
2825 text: String,
2826 },
2827 #[serde(other)]
2828 Other,
2829 }
2830}
2831
2832mod xai {
2833 use super::{
2834 ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2835 calculate_cost,
2836 };
2837 use anyhow::anyhow;
2838 use reqwest::blocking::Client;
2839 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2840 use serde::Deserialize;
2841 use serde_json::json;
2842
2843 const ENDPOINT: &str = "https://api.x.ai/v1/chat/completions";
2844
2845 pub(super) fn run(
2846 model: &str,
2847 prompt: &PromptParts,
2848 override_key: Option<&str>,
2849 system_prompt_override: Option<&str>,
2850 ) -> Result<ProviderResult, ModelRunError> {
2851 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2852 let key = override_key
2853 .map(|value| value.to_string())
2854 .or_else(|| std::env::var("XAI_API_KEY").ok())
2855 .or_else(|| std::env::var("GROK_API_KEY").ok())
2856 .ok_or_else(|| {
2857 ModelRunError::Runtime(anyhow!(
2858 "XAI_API_KEY environment variable is required for Grok models"
2859 ))
2860 })?;
2861
2862 let mut headers = HeaderMap::new();
2863 headers.insert(
2864 AUTHORIZATION,
2865 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2866 ModelRunError::Runtime(anyhow!("invalid XAI_API_KEY header value: {err}"))
2867 })?,
2868 );
2869 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2870
2871 let client = Client::builder()
2872 .timeout(std::time::Duration::from_secs(120))
2873 .default_headers(headers)
2874 .build()
2875 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2876
2877 let payload = json!({
2878 "model": model,
2879 "max_tokens": prompt.max_output_tokens() as u32,
2880 "temperature": 0.2,
2881 "messages": [
2882 {"role": "system", "content": system_prompt},
2883 {"role": "user", "content": prompt.user_message()}
2884 ]
2885 });
2886
2887 let mut spinner = ThinkingSpinner::start();
2888 let response = client
2889 .post(ENDPOINT)
2890 .json(&payload)
2891 .send()
2892 .map_err(|err| ModelRunError::Runtime(anyhow!("xAI request failed: {err}")))?
2893 .error_for_status()
2894 .map_err(|err| ModelRunError::Runtime(anyhow!("xAI returned error status: {err}")))?;
2895
2896 let body: XaiResponse = response.json().map_err(|err| {
2897 ModelRunError::Runtime(anyhow!("failed to decode xAI response: {err}"))
2898 })?;
2899 spinner.stop();
2900
2901 let text = body
2902 .choices
2903 .first()
2904 .and_then(|c| c.message.content.as_ref())
2905 .map(|s| s.trim().to_string())
2906 .unwrap_or_else(|| "No answer returned by Grok.".to_string());
2907
2908 let usage = body.usage.as_ref().map(|u| TokenUsage {
2909 input_tokens: u.prompt_tokens,
2910 output_tokens: u.completion_tokens,
2911 total_tokens: u
2912 .total_tokens
2913 .unwrap_or(u.prompt_tokens + u.completion_tokens),
2914 cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
2915 });
2916
2917 Ok(ProviderResult {
2918 answer: text,
2919 usage,
2920 })
2921 }
2922
2923 #[derive(Debug, Deserialize)]
2924 struct XaiResponse {
2925 #[serde(default)]
2926 choices: Vec<XaiChoice>,
2927 #[serde(default)]
2928 usage: Option<XaiUsage>,
2929 }
2930
2931 #[derive(Debug, Deserialize)]
2932 struct XaiChoice {
2933 message: XaiMessage,
2934 }
2935
2936 #[derive(Debug, Deserialize)]
2937 struct XaiMessage {
2938 content: Option<String>,
2939 }
2940
2941 #[derive(Debug, Deserialize)]
2942 struct XaiUsage {
2943 prompt_tokens: u32,
2944 completion_tokens: u32,
2945 total_tokens: Option<u32>,
2946 }
2947}
2948
2949mod groq {
2950 use super::{
2951 ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2952 calculate_cost,
2953 };
2954 use anyhow::anyhow;
2955 use reqwest::blocking::Client;
2956 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2957 use serde::Deserialize;
2958 use serde_json::json;
2959
2960 const ENDPOINT: &str = "https://api.groq.com/openai/v1/chat/completions";
2961
2962 pub(super) fn run(
2963 model: &str,
2964 prompt: &PromptParts,
2965 override_key: Option<&str>,
2966 system_prompt_override: Option<&str>,
2967 ) -> Result<ProviderResult, ModelRunError> {
2968 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2969 let key = override_key
2970 .map(|value| value.to_string())
2971 .or_else(|| std::env::var("GROQ_API_KEY").ok())
2972 .ok_or_else(|| {
2973 ModelRunError::Runtime(anyhow!(
2974 "GROQ_API_KEY environment variable is required for Groq models"
2975 ))
2976 })?;
2977
2978 let mut headers = HeaderMap::new();
2979 headers.insert(
2980 AUTHORIZATION,
2981 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2982 ModelRunError::Runtime(anyhow!("invalid GROQ_API_KEY header value: {err}"))
2983 })?,
2984 );
2985 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2986
2987 let client = Client::builder()
2988 .timeout(std::time::Duration::from_secs(60))
2989 .default_headers(headers)
2990 .build()
2991 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2992
2993 let payload = json!({
2994 "model": model,
2995 "max_tokens": prompt.max_output_tokens() as u32,
2996 "temperature": 0.2,
2997 "messages": [
2998 {"role": "system", "content": system_prompt},
2999 {"role": "user", "content": prompt.user_message()}
3000 ]
3001 });
3002
3003 let mut spinner = ThinkingSpinner::start();
3004 let response = client
3005 .post(ENDPOINT)
3006 .json(&payload)
3007 .send()
3008 .map_err(|err| ModelRunError::Runtime(anyhow!("Groq request failed: {err}")))?
3009 .error_for_status()
3010 .map_err(|err| ModelRunError::Runtime(anyhow!("Groq returned error status: {err}")))?;
3011
3012 let body: GroqResponse = response.json().map_err(|err| {
3013 ModelRunError::Runtime(anyhow!("failed to decode Groq response: {err}"))
3014 })?;
3015 spinner.stop();
3016
3017 let text = body
3018 .choices
3019 .first()
3020 .and_then(|c| c.message.content.as_ref())
3021 .map(|s| s.trim().to_string())
3022 .unwrap_or_else(|| "No answer returned by Groq.".to_string());
3023
3024 let usage = body.usage.as_ref().map(|u| TokenUsage {
3025 input_tokens: u.prompt_tokens,
3026 output_tokens: u.completion_tokens,
3027 total_tokens: u
3028 .total_tokens
3029 .unwrap_or(u.prompt_tokens + u.completion_tokens),
3030 cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
3031 });
3032
3033 Ok(ProviderResult {
3034 answer: text,
3035 usage,
3036 })
3037 }
3038
3039 #[derive(Debug, Deserialize)]
3040 struct GroqResponse {
3041 #[serde(default)]
3042 choices: Vec<GroqChoice>,
3043 #[serde(default)]
3044 usage: Option<GroqUsage>,
3045 }
3046
3047 #[derive(Debug, Deserialize)]
3048 struct GroqChoice {
3049 message: GroqMessage,
3050 }
3051
3052 #[derive(Debug, Deserialize)]
3053 struct GroqMessage {
3054 content: Option<String>,
3055 }
3056
3057 #[derive(Debug, Deserialize)]
3058 struct GroqUsage {
3059 prompt_tokens: u32,
3060 completion_tokens: u32,
3061 total_tokens: Option<u32>,
3062 }
3063}
3064
3065mod mistral {
3066 use super::{
3067 ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
3068 calculate_cost,
3069 };
3070 use anyhow::anyhow;
3071 use reqwest::blocking::Client;
3072 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
3073 use serde::Deserialize;
3074 use serde_json::json;
3075
3076 const ENDPOINT: &str = "https://api.mistral.ai/v1/chat/completions";
3077
3078 pub(super) fn run(
3079 model: &str,
3080 prompt: &PromptParts,
3081 override_key: Option<&str>,
3082 system_prompt_override: Option<&str>,
3083 ) -> Result<ProviderResult, ModelRunError> {
3084 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
3085 let key = override_key
3086 .map(|value| value.to_string())
3087 .or_else(|| std::env::var("MISTRAL_API_KEY").ok())
3088 .ok_or_else(|| {
3089 ModelRunError::Runtime(anyhow!(
3090 "MISTRAL_API_KEY environment variable is required for Mistral models"
3091 ))
3092 })?;
3093
3094 let mut headers = HeaderMap::new();
3095 headers.insert(
3096 AUTHORIZATION,
3097 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
3098 ModelRunError::Runtime(anyhow!("invalid MISTRAL_API_KEY header value: {err}"))
3099 })?,
3100 );
3101 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
3102
3103 let client = Client::builder()
3104 .timeout(std::time::Duration::from_secs(60))
3105 .default_headers(headers)
3106 .build()
3107 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
3108
3109 let payload = json!({
3110 "model": model,
3111 "max_tokens": prompt.max_output_tokens() as u32,
3112 "temperature": 0.2,
3113 "messages": [
3114 {"role": "system", "content": system_prompt},
3115 {"role": "user", "content": prompt.user_message()}
3116 ]
3117 });
3118
3119 let mut spinner = ThinkingSpinner::start();
3120 let response = client
3121 .post(ENDPOINT)
3122 .json(&payload)
3123 .send()
3124 .map_err(|err| ModelRunError::Runtime(anyhow!("Mistral request failed: {err}")))?
3125 .error_for_status()
3126 .map_err(|err| {
3127 ModelRunError::Runtime(anyhow!("Mistral returned error status: {err}"))
3128 })?;
3129
3130 let body: MistralResponse = response.json().map_err(|err| {
3131 ModelRunError::Runtime(anyhow!("failed to decode Mistral response: {err}"))
3132 })?;
3133 spinner.stop();
3134
3135 let text = body
3136 .choices
3137 .first()
3138 .and_then(|c| c.message.content.as_ref())
3139 .map(|s| s.trim().to_string())
3140 .unwrap_or_else(|| "No answer returned by Mistral.".to_string());
3141
3142 let usage = body.usage.as_ref().map(|u| TokenUsage {
3143 input_tokens: u.prompt_tokens,
3144 output_tokens: u.completion_tokens,
3145 total_tokens: u
3146 .total_tokens
3147 .unwrap_or(u.prompt_tokens + u.completion_tokens),
3148 cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
3149 });
3150
3151 Ok(ProviderResult {
3152 answer: text,
3153 usage,
3154 })
3155 }
3156
3157 #[derive(Debug, Deserialize)]
3158 struct MistralResponse {
3159 #[serde(default)]
3160 choices: Vec<MistralChoice>,
3161 #[serde(default)]
3162 usage: Option<MistralUsage>,
3163 }
3164
3165 #[derive(Debug, Deserialize)]
3166 struct MistralChoice {
3167 message: MistralMessage,
3168 }
3169
3170 #[derive(Debug, Deserialize)]
3171 struct MistralMessage {
3172 content: Option<String>,
3173 }
3174
3175 #[derive(Debug, Deserialize)]
3176 struct MistralUsage {
3177 prompt_tokens: u32,
3178 completion_tokens: u32,
3179 total_tokens: Option<u32>,
3180 }
3181}
3182
3183pub const ENTITY_EXTRACTION_PROMPT: &str = r#"Extract named entities from the provided text. Return a JSON object with an "entities" array.
3189
3190Each entity should have:
3191- "name": The entity name as it appears in the text
3192- "type": One of "PERSON", "ORG", "LOCATION", "DATE", "PRODUCT", "EVENT", or "OTHER"
3193- "confidence": A number between 0.0 and 1.0 indicating your confidence
3194
3195Guidelines:
31961. Only include entities you're confident about (confidence >= 0.7)
31972. Preserve the original capitalization of entity names
31983. For organizations, include full names (e.g., "S&P Global" not just "S&P")
31994. For people, include full names when available
32005. Deduplicate: if an entity appears multiple times, include it only once
3201
3202Return format:
3203{"entities": [{"name": "...", "type": "...", "confidence": 0.9}, ...]}"#;
3204
3205#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
3207pub struct ExtractedEntity {
3208 pub name: String,
3209 #[serde(rename = "type")]
3210 pub entity_type: String,
3211 pub confidence: f32,
3212}
3213
3214#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
3216pub struct EntityExtractionResponse {
3217 pub entities: Vec<ExtractedEntity>,
3218 pub model: String,
3219 pub text_chars: usize,
3220}
3221
3222pub fn extract_entities(
3246 model: &str,
3247 text: &str,
3248 system_prompt: Option<&str>,
3249 api_key: Option<&str>,
3250) -> Result<EntityExtractionResponse, ModelRunError> {
3251 let prompt = system_prompt.unwrap_or(ENTITY_EXTRACTION_PROMPT);
3252 let text_chars = text.len();
3253
3254 let (provider, model_name) = parse_model_spec(model);
3256
3257 let json_response = match provider.as_str() {
3258 "openai" => extract_entities_openai(&model_name, text, prompt, api_key)?,
3259 "claude" | "anthropic" => extract_entities_claude(&model_name, text, prompt, api_key)?,
3260 "gemini" | "google" => extract_entities_gemini(&model_name, text, prompt, api_key)?,
3261 _ => {
3262 return Err(ModelRunError::UnsupportedModel(format!(
3263 "Entity extraction not supported for provider '{}'. Use openai:, claude:, or gemini:",
3264 provider
3265 )));
3266 }
3267 };
3268
3269 let entities = parse_entity_response(&json_response)?;
3271
3272 Ok(EntityExtractionResponse {
3273 entities,
3274 model: model.to_string(),
3275 text_chars,
3276 })
3277}
3278
3279fn parse_model_spec(model: &str) -> (String, String) {
3280 if let Some((provider, name)) = model.split_once(':') {
3281 (provider.to_lowercase(), name.to_string())
3282 } else {
3283 ("openai".to_string(), model.to_string())
3285 }
3286}
3287
3288fn parse_entity_response(json_str: &str) -> Result<Vec<ExtractedEntity>, ModelRunError> {
3289 let trimmed = json_str.trim();
3291
3292 let clean_json = if trimmed.starts_with("```json") {
3294 trimmed
3295 .strip_prefix("```json")
3296 .and_then(|s| s.strip_suffix("```"))
3297 .unwrap_or(trimmed)
3298 .trim()
3299 } else if trimmed.starts_with("```") {
3300 trimmed
3301 .strip_prefix("```")
3302 .and_then(|s| s.strip_suffix("```"))
3303 .unwrap_or(trimmed)
3304 .trim()
3305 } else {
3306 trimmed
3307 };
3308
3309 #[derive(serde::Deserialize)]
3311 struct EntityResponse {
3312 entities: Vec<ExtractedEntity>,
3313 }
3314
3315 if let Ok(response) = serde_json::from_str::<EntityResponse>(clean_json) {
3316 return Ok(response.entities);
3317 }
3318
3319 if let Ok(entities) = serde_json::from_str::<Vec<ExtractedEntity>>(clean_json) {
3321 return Ok(entities);
3322 }
3323
3324 Err(ModelRunError::Runtime(anyhow::anyhow!(
3325 "Failed to parse entity extraction response as JSON: {}",
3326 &clean_json[..clean_json.len().min(200)]
3327 )))
3328}
3329
3330fn extract_entities_openai(
3331 model: &str,
3332 text: &str,
3333 system_prompt: &str,
3334 api_key: Option<&str>,
3335) -> Result<String, ModelRunError> {
3336 use serde_json::json;
3337
3338 let api_key = api_key
3339 .map(|s| s.to_string())
3340 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
3341 .ok_or_else(|| {
3342 ModelRunError::Runtime(anyhow::anyhow!(
3343 "OpenAI API key required. Set OPENAI_API_KEY or pass api_key parameter."
3344 ))
3345 })?;
3346
3347 let model_name = if model.is_empty() {
3348 "gpt-4o-mini"
3349 } else {
3350 model
3351 };
3352
3353 let client = reqwest::blocking::Client::builder()
3354 .no_proxy()
3355 .build()
3356 .map_err(|err| {
3357 ModelRunError::Runtime(anyhow::anyhow!("failed to build HTTP client: {err}"))
3358 })?;
3359 let payload = json!({
3360 "model": model_name,
3361 "messages": [
3362 {"role": "system", "content": system_prompt},
3363 {"role": "user", "content": text}
3364 ],
3365 "response_format": {"type": "json_object"},
3366 "temperature": 0.1
3367 });
3368
3369 let response = client
3370 .post("https://api.openai.com/v1/chat/completions")
3371 .header("Authorization", format!("Bearer {}", api_key))
3372 .json(&payload)
3373 .send()
3374 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI request failed: {}", e)))?
3375 .error_for_status()
3376 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI returned error: {}", e)))?;
3377
3378 #[derive(serde::Deserialize)]
3379 struct OpenAIResponse {
3380 choices: Vec<OpenAIChoice>,
3381 }
3382 #[derive(serde::Deserialize)]
3383 struct OpenAIChoice {
3384 message: OpenAIMessage,
3385 }
3386 #[derive(serde::Deserialize)]
3387 struct OpenAIMessage {
3388 content: String,
3389 }
3390
3391 let body: OpenAIResponse = response.json().map_err(|e| {
3392 ModelRunError::Runtime(anyhow::anyhow!("Failed to parse OpenAI response: {}", e))
3393 })?;
3394
3395 body.choices
3396 .into_iter()
3397 .next()
3398 .map(|c| c.message.content)
3399 .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No response from OpenAI")))
3400}
3401
3402fn extract_entities_claude(
3403 model: &str,
3404 text: &str,
3405 system_prompt: &str,
3406 api_key: Option<&str>,
3407) -> Result<String, ModelRunError> {
3408 use serde_json::json;
3409
3410 let api_key = api_key
3411 .map(|s| s.to_string())
3412 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
3413 .ok_or_else(|| {
3414 ModelRunError::Runtime(anyhow::anyhow!(
3415 "Anthropic API key required. Set ANTHROPIC_API_KEY or pass api_key parameter."
3416 ))
3417 })?;
3418
3419 let model_name = if model.is_empty() {
3420 "claude-3-5-sonnet-20241022"
3421 } else {
3422 model
3423 };
3424
3425 let client = reqwest::blocking::Client::builder()
3426 .no_proxy()
3427 .build()
3428 .map_err(|err| {
3429 ModelRunError::Runtime(anyhow::anyhow!("failed to build HTTP client: {err}"))
3430 })?;
3431 let payload = json!({
3432 "model": model_name,
3433 "max_tokens": 4096,
3434 "system": format!("{}\n\nRespond with valid JSON only.", system_prompt),
3435 "messages": [
3436 {"role": "user", "content": text}
3437 ]
3438 });
3439
3440 let response = client
3441 .post("https://api.anthropic.com/v1/messages")
3442 .header("x-api-key", &api_key)
3443 .header("anthropic-version", "2023-06-01")
3444 .header("content-type", "application/json")
3445 .json(&payload)
3446 .send()
3447 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude request failed: {}", e)))?
3448 .error_for_status()
3449 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude returned error: {}", e)))?;
3450
3451 #[derive(serde::Deserialize)]
3452 struct ClaudeResponse {
3453 content: Vec<ClaudeContent>,
3454 }
3455 #[derive(serde::Deserialize)]
3456 struct ClaudeContent {
3457 text: Option<String>,
3458 }
3459
3460 let body: ClaudeResponse = response.json().map_err(|e| {
3461 ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Claude response: {}", e))
3462 })?;
3463
3464 body.content
3465 .into_iter()
3466 .find_map(|c| c.text)
3467 .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Claude")))
3468}
3469
3470fn extract_entities_gemini(
3471 model: &str,
3472 text: &str,
3473 system_prompt: &str,
3474 api_key: Option<&str>,
3475) -> Result<String, ModelRunError> {
3476 use serde_json::json;
3477
3478 let api_key = api_key
3479 .map(|s| s.to_string())
3480 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
3481 .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
3482 .ok_or_else(|| {
3483 ModelRunError::Runtime(anyhow::anyhow!(
3484 "Gemini API key required. Set GEMINI_API_KEY or pass api_key parameter."
3485 ))
3486 })?;
3487
3488 let model_name = if model.is_empty() {
3489 "gemini-2.0-flash"
3490 } else {
3491 model
3492 };
3493 let url = format!(
3494 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
3495 model_name, api_key
3496 );
3497
3498 let client = reqwest::blocking::Client::builder()
3499 .no_proxy()
3500 .build()
3501 .map_err(|err| {
3502 ModelRunError::Runtime(anyhow::anyhow!("failed to build HTTP client: {err}"))
3503 })?;
3504 let payload = json!({
3505 "contents": [{
3506 "parts": [{"text": format!("{}\n\nText to analyze:\n{}", system_prompt, text)}]
3507 }],
3508 "generationConfig": {
3509 "temperature": 0.1,
3510 "responseMimeType": "application/json"
3511 }
3512 });
3513
3514 let response = client
3515 .post(&url)
3516 .json(&payload)
3517 .send()
3518 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini request failed: {}", e)))?
3519 .error_for_status()
3520 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini returned error: {}", e)))?;
3521
3522 #[derive(serde::Deserialize)]
3523 struct GeminiResponse {
3524 candidates: Vec<GeminiCandidate>,
3525 }
3526 #[derive(serde::Deserialize)]
3527 struct GeminiCandidate {
3528 content: GeminiContent,
3529 }
3530 #[derive(serde::Deserialize)]
3531 struct GeminiContent {
3532 parts: Vec<GeminiPart>,
3533 }
3534 #[derive(serde::Deserialize)]
3535 struct GeminiPart {
3536 text: Option<String>,
3537 }
3538
3539 let body: GeminiResponse = response.json().map_err(|e| {
3540 ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Gemini response: {}", e))
3541 })?;
3542
3543 body.candidates
3544 .into_iter()
3545 .next()
3546 .and_then(|c| c.content.parts.into_iter().find_map(|p| p.text))
3547 .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Gemini")))
3548}
3549
3550#[cfg(test)]
3551mod tests {
3552 use super::*;
3553
3554 static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
3555
3556 #[test]
3557 fn normalize_models() {
3558 assert_eq!(normalize_openai_model(None), "gpt-4o-mini");
3559 assert_eq!(
3560 normalize_nvidia_model(Some("meta/llama3-8b-instruct".to_string())),
3561 "meta/llama3-8b-instruct"
3562 );
3563 let _lock = ENV_LOCK.lock().unwrap();
3564 unsafe {
3565 std::env::remove_var("NVIDIA_LLM_MODEL");
3566 std::env::remove_var("NVIDIA_MODEL");
3567 }
3568 assert_eq!(normalize_nvidia_model(None), "");
3569 assert_eq!(normalize_gemini_model(None), "gemini-2.5-flash");
3570 assert_eq!(normalize_claude_model(None), "claude-sonnet-4-5");
3571 assert_eq!(normalize_xai_model(None), "grok-4-fast");
3572 assert_eq!(normalize_groq_model(None), "llama-3.3-70b-versatile");
3573 assert_eq!(normalize_mistral_model(None), "mistral-large-latest");
3574 }
3575
3576 #[test]
3577 fn parse_entity_json() {
3578 let json = r#"{"entities": [{"name": "John", "type": "PERSON", "confidence": 0.95}]}"#;
3579 let entities = parse_entity_response(json).unwrap();
3580 assert_eq!(entities.len(), 1);
3581 assert_eq!(entities[0].name, "John");
3582 }
3583
3584 #[test]
3585 fn parse_entity_json_with_markdown() {
3586 let json = r#"```json
3587{"entities": [{"name": "Microsoft", "type": "ORG", "confidence": 0.99}]}
3588```"#;
3589 let entities = parse_entity_response(json).unwrap();
3590 assert_eq!(entities.len(), 1);
3591 assert_eq!(entities[0].name, "Microsoft");
3592 }
3593
3594 #[test]
3595 fn parse_model_spec_test() {
3596 let (provider, model) = parse_model_spec("openai:gpt-4o");
3597 assert_eq!(provider, "openai");
3598 assert_eq!(model, "gpt-4o");
3599
3600 let (provider, model) = parse_model_spec("gpt-4o-mini");
3601 assert_eq!(provider, "openai");
3602 assert_eq!(model, "gpt-4o-mini");
3603 }
3604
3605 #[test]
3606 fn modelkind_parses_ghost_spec() {
3607 let kind = ModelKind::parse("ghost:/tmp/model.ghostpack").unwrap();
3608 match kind {
3609 ModelKind::Ghost { pack_path } => {
3610 assert_eq!(pack_path, PathBuf::from("/tmp/model.ghostpack"));
3611 }
3612 other => panic!("expected ghost, got {other:?}"),
3613 }
3614 }
3615
3616 #[test]
3617 fn run_model_inference_ghost_returns_unsupported() {
3618 let hit = SearchHit {
3619 rank: 0,
3620 frame_id: 0,
3621 uri: "mv2://test".to_string(),
3622 title: Some("Test".to_string()),
3623 range: (0, 3),
3624 text: "ctx".to_string(),
3625 matches: 1,
3626 chunk_range: None,
3627 chunk_text: None,
3628 score: Some(1.0),
3629 metadata: None,
3630 };
3631
3632 let err = run_model_inference(
3633 "ghost:/tmp/fake.ghostpack",
3634 "hello?",
3635 "",
3636 &[hit],
3637 None,
3638 None,
3639 None,
3640 )
3641 .unwrap_err();
3642 let msg = err.to_string();
3643 assert!(
3644 msg.contains("ghost") || msg.contains("Ghost"),
3645 "error should mention ghost: {msg}"
3646 );
3647 }
3648
3649 #[test]
3650 fn normalize_question_adds_question_mark() {
3651 let result = normalize_question("how much is the LP rate");
3654 assert!(result.ends_with('?'), "should end with ?");
3655
3656 assert_eq!(
3657 normalize_question("what is the total revenue"),
3658 "what is the total revenue?"
3659 );
3660 assert_eq!(
3661 normalize_question("where does John live"),
3662 "where does John live?"
3663 );
3664 assert_eq!(normalize_question("is this correct"), "is this correct?");
3665 assert_eq!(normalize_question("can you help me"), "can you help me?");
3666 }
3667
3668 #[test]
3669 fn normalize_question_preserves_existing_punctuation() {
3670 assert_eq!(normalize_question("how much is X?"), "how much is X?");
3672 assert_eq!(
3673 normalize_question("Tell me about the project."),
3674 "Tell me about the project."
3675 );
3676 assert_eq!(normalize_question("Do it now!"), "Do it now!");
3677 }
3678
3679 #[test]
3680 fn normalize_question_ignores_non_questions() {
3681 assert_eq!(
3683 normalize_question("revenue for Q1 2024"),
3684 "revenue for Q1 2024"
3685 );
3686 assert_eq!(normalize_question("total sales"), "total sales");
3687 assert_eq!(
3689 normalize_question("howitzer specifications"),
3690 "howitzer specifications"
3691 );
3692 }
3693
3694 #[test]
3695 fn normalize_question_handles_edge_cases() {
3696 assert_eq!(normalize_question(""), "");
3697 assert_eq!(normalize_question(" "), "");
3698 let result = normalize_question(" how much ");
3700 assert!(result.ends_with('?'), "should end with ?");
3701 }
3702
3703 #[test]
3704 fn fix_typos_corrects_common_errors() {
3705 assert!(fix_common_typos("teh quick brown fox").contains("the"));
3706 assert!(fix_common_typos("waht is this").contains("what"));
3707 assert!(fix_common_typos("totla revenue").contains("total"));
3708 }
3709
3710 #[test]
3711 fn expand_abbreviations_works() {
3712 let result = expand_abbreviations("what is the irr");
3714 assert!(result.contains("internal rate of return") || result.contains("irr"));
3715 }
3716
3717 #[test]
3718 fn question_type_detection() {
3719 assert_eq!(
3720 detect_question_type("how much is X?"),
3721 QuestionType::Numeric
3722 );
3723 assert_eq!(
3724 detect_question_type("is this correct?"),
3725 QuestionType::YesNo
3726 );
3727 assert_eq!(detect_question_type("list all items"), QuestionType::List);
3728 assert_eq!(
3729 detect_question_type("when was it created?"),
3730 QuestionType::Temporal
3731 );
3732 assert_eq!(
3733 detect_question_type("why did this happen?"),
3734 QuestionType::Explanation
3735 );
3736 assert_eq!(
3737 detect_question_type("what is the name?"),
3738 QuestionType::Factual
3739 );
3740 }
3741
3742 #[test]
3743 fn postprocess_removes_artifacts() {
3744 let answer = "Based on the provided context, the value is 42.";
3745 let processed = postprocess_answer(answer);
3746 assert!(!processed.starts_with("Based on"));
3747 assert!(processed.contains("42"));
3748 }
3749
3750 #[test]
3751 fn postprocess_capitalizes() {
3752 let answer = "the answer is yes";
3753 let processed = postprocess_answer(answer);
3754 assert!(processed.starts_with('T'), "should start with capital T");
3755 }
3756
3757 #[test]
3758 fn postprocess_normalizes_whitespace() {
3759 let answer = "too many spaces here";
3760 let processed = postprocess_answer(answer);
3761 assert!(!processed.contains(" "), "should not have double spaces");
3762 }
3763}