Skip to main content

memvid_ask_model/
lib.rs

1use std::fmt;
2use std::io::{IsTerminal, Write, stderr};
3use std::path::PathBuf;
4use std::sync::{
5    Arc,
6    atomic::{AtomicBool, Ordering},
7};
8use std::thread;
9use std::time::Duration;
10
11use memvid_core::types::SearchHit;
12
13#[derive(Debug, Clone)]
14pub struct ModelAnswer {
15    pub requested: String,
16    pub model: String,
17    pub answer: String,
18}
19
20#[derive(Debug, Clone)]
21pub struct ModelInference {
22    pub answer: ModelAnswer,
23    pub context_body: String,
24    pub context_fragments: Vec<ModelContextFragment>,
25    pub usage: Option<TokenUsage>,
26    pub grounding: Option<GroundingResult>,
27    /// True if this result came from the cache
28    pub cached: bool,
29}
30
31/// Token usage and cost information from LLM inference
32#[derive(Debug, Clone, Default)]
33pub struct TokenUsage {
34    /// Input/prompt tokens
35    pub input_tokens: u32,
36    /// Output/completion tokens
37    pub output_tokens: u32,
38    /// Total tokens (input + output)
39    pub total_tokens: u32,
40    /// Estimated cost in USD (based on model pricing)
41    pub cost_usd: f64,
42}
43
44/// Cache for LLM answers to avoid redundant API calls
45/// Uses Blake3 hash of (query + context) as the key
46pub mod cache {
47    use std::collections::HashMap;
48    use std::sync::Mutex;
49
50    /// Cached answer entry
51    #[derive(Debug, Clone)]
52    pub struct CacheEntry {
53        pub answer: String,
54        pub model: String,
55        pub input_tokens: u32,
56        pub output_tokens: u32,
57        pub cost_usd: f64,
58        pub grounding_score: f32,
59        pub created_at: std::time::SystemTime,
60    }
61
62    /// In-memory LRU cache for answers
63    /// Thread-safe with a simple mutex
64    pub struct AnswerCache {
65        entries: Mutex<HashMap<[u8; 32], CacheEntry>>,
66        max_size: usize,
67        hits: std::sync::atomic::AtomicU64,
68        misses: std::sync::atomic::AtomicU64,
69    }
70
71    impl AnswerCache {
72        /// Create a new cache with the specified maximum size
73        pub fn new(max_size: usize) -> Self {
74            Self {
75                entries: Mutex::new(HashMap::new()),
76                max_size,
77                hits: std::sync::atomic::AtomicU64::new(0),
78                misses: std::sync::atomic::AtomicU64::new(0),
79            }
80        }
81
82        /// Generate a cache key from query and context
83        pub fn make_key(query: &str, context: &str, model: &str) -> [u8; 32] {
84            use std::io::Write;
85            let mut hasher = blake3::Hasher::new();
86            let _ = write!(hasher, "{}|{}|{}", model, query, context);
87            *hasher.finalize().as_bytes()
88        }
89
90        /// Look up an entry in the cache
91        pub fn get(&self, key: &[u8; 32]) -> Option<CacheEntry> {
92            let entries = self.entries.lock().ok()?;
93            let result = entries.get(key).cloned();
94            if result.is_some() {
95                self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
96            } else {
97                self.misses
98                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
99            }
100            result
101        }
102
103        /// Insert an entry into the cache
104        pub fn insert(&self, key: [u8; 32], entry: CacheEntry) {
105            if let Ok(mut entries) = self.entries.lock() {
106                // Simple LRU: if at capacity, remove oldest entry
107                if entries.len() >= self.max_size {
108                    let oldest_key = entries
109                        .iter()
110                        .min_by_key(|(_, v)| v.created_at)
111                        .map(|(k, _)| *k);
112                    if let Some(k) = oldest_key {
113                        entries.remove(&k);
114                    }
115                }
116                entries.insert(key, entry);
117            }
118        }
119
120        /// Clear the cache
121        pub fn clear(&self) {
122            if let Ok(mut entries) = self.entries.lock() {
123                entries.clear();
124            }
125        }
126
127        /// Get cache statistics
128        pub fn stats(&self) -> CacheStats {
129            let entries = self.entries.lock().map(|e| e.len()).unwrap_or(0);
130            let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
131            let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
132            CacheStats {
133                entries,
134                hits,
135                misses,
136                hit_rate: if hits + misses > 0 {
137                    hits as f64 / (hits + misses) as f64
138                } else {
139                    0.0
140                },
141            }
142        }
143
144        /// Estimated cost savings from cache hits
145        pub fn estimated_savings(&self) -> f64 {
146            if let Ok(entries) = self.entries.lock() {
147                let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
148                let avg_cost =
149                    entries.values().map(|e| e.cost_usd).sum::<f64>() / entries.len().max(1) as f64;
150                hits as f64 * avg_cost
151            } else {
152                0.0
153            }
154        }
155    }
156
157    impl Default for AnswerCache {
158        fn default() -> Self {
159            Self::new(100) // Default to 100 entries
160        }
161    }
162
163    #[derive(Debug, Clone)]
164    pub struct CacheStats {
165        pub entries: usize,
166        pub hits: u64,
167        pub misses: u64,
168        pub hit_rate: f64,
169    }
170
171    // Global cache instance
172    lazy_static::lazy_static! {
173        pub static ref GLOBAL_CACHE: AnswerCache = AnswerCache::new(500);
174    }
175
176    /// Check cache and return cached result if available
177    pub fn check_cache(query: &str, context: &str, model: &str) -> Option<CacheEntry> {
178        let key = AnswerCache::make_key(query, context, model);
179        GLOBAL_CACHE.get(&key)
180    }
181
182    /// Store result in cache
183    pub fn store_in_cache(query: &str, context: &str, model: &str, entry: CacheEntry) {
184        let key = AnswerCache::make_key(query, context, model);
185        GLOBAL_CACHE.insert(key, entry);
186    }
187
188    /// Get global cache statistics
189    pub fn global_stats() -> CacheStats {
190        GLOBAL_CACHE.stats()
191    }
192
193    /// Clear the global cache
194    pub fn clear_global_cache() {
195        GLOBAL_CACHE.clear();
196    }
197}
198
199/// Result of grounding/hallucination verification
200#[derive(Debug, Clone, Default)]
201pub struct GroundingResult {
202    /// Overall grounding score (0.0 to 1.0)
203    /// Higher = more grounded in context, less likely to hallucinate
204    pub score: f32,
205    /// Number of sentences in the answer
206    pub sentence_count: usize,
207    /// Number of sentences with at least one grounded claim
208    pub grounded_sentences: usize,
209    /// Individual sentence scores
210    pub sentence_scores: Vec<f32>,
211    /// Warning flag: true if potential hallucination detected
212    pub has_warning: bool,
213    /// Explanation of the warning (if any)
214    pub warning_reason: Option<String>,
215}
216
217impl GroundingResult {
218    /// Returns a human-readable grade based on grounding score
219    pub fn grade(&self) -> &'static str {
220        match self.score {
221            s if s >= 0.8 => "A",
222            s if s >= 0.6 => "B",
223            s if s >= 0.4 => "C",
224            s if s >= 0.2 => "D",
225            _ => "F",
226        }
227    }
228
229    /// Returns a label like "HIGH", "MEDIUM", "LOW"
230    pub fn label(&self) -> &'static str {
231        match self.score {
232            s if s >= 0.7 => "HIGH",
233            s if s >= 0.4 => "MEDIUM",
234            _ => "LOW",
235        }
236    }
237}
238
239#[derive(Debug, Clone)]
240pub struct ModelContextFragment {
241    pub rank: usize,
242    pub uri: String,
243    pub title: Option<String>,
244    pub score: Option<f32>,
245    pub matches: usize,
246    pub frame_id: u64,
247    pub range: (usize, usize),
248    pub chunk_range: Option<(usize, usize)>,
249    pub text: String,
250    pub kind: ModelContextFragmentKind,
251}
252
253#[derive(Debug, Clone, Copy, Eq, PartialEq)]
254pub enum ModelContextFragmentKind {
255    Full,
256    Summary,
257}
258
259impl ModelContextFragment {
260    fn from_record(record: context::ContextRecord) -> Self {
261        let kind = match record.mode {
262            context::ContextMode::Full => ModelContextFragmentKind::Full,
263            context::ContextMode::Summary => ModelContextFragmentKind::Summary,
264        };
265        Self {
266            rank: record.rank,
267            uri: record.uri,
268            title: record.title,
269            score: record.score,
270            matches: record.matches,
271            frame_id: record.frame_id,
272            range: record.range,
273            chunk_range: record.chunk_range,
274            text: record.text,
275            kind,
276        }
277    }
278}
279
280#[derive(Debug)]
281pub enum ModelRunError {
282    UnsupportedModel(String),
283    AssetsMissing {
284        model: String,
285        missing: Vec<PathBuf>,
286    },
287    Runtime(anyhow::Error),
288}
289
290impl fmt::Display for ModelRunError {
291    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292        match self {
293            Self::UnsupportedModel(model) => write!(f, "unsupported model '{model}'"),
294            Self::AssetsMissing { model, missing } => {
295                let paths: Vec<_> = missing
296                    .iter()
297                    .map(|path| path.display().to_string())
298                    .collect();
299                write!(
300                    f,
301                    "model '{model}' missing required assets: {}",
302                    paths.join(", ")
303                )
304            }
305            Self::Runtime(err) => write!(f, "model runtime error: {err}"),
306        }
307    }
308}
309
310impl std::error::Error for ModelRunError {
311    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
312        match self {
313            Self::Runtime(err) => Some(err.root_cause()),
314            _ => None,
315        }
316    }
317}
318
319const LOCAL_CONTEXT_CHARS: usize = 32_768;
320const MAX_QUESTION_CHARS: usize = 512;
321const LOCAL_MAX_OUTPUT_TOKENS: usize = 256;
322const REMOTE_MAX_OUTPUT_TOKENS: usize = 768;
323const SYSTEM_PROMPT: &str = r#"You are a precise, intelligent assistant that answers questions using ONLY the provided retrieval context.
324
325## Core Principles
3261. GROUND EVERY CLAIM in the context. If asked for a number, quote it exactly.
3272. NEVER hallucinate or use external knowledge. If unsure, say "Based on the context..."
3283. BE CONCISE but complete. One clear answer is better than verbose hedging.
329
330## CRITICAL: Correction Handling (MANDATORY)
331**STOP AND READ THIS FIRST.** Before answering ANY question:
3321. Scan ALL sources for "Correction:" in the title or "mv2://correction/" in the URI
3332. If ANY correction exists that relates to the question, USE THAT ANSWER ONLY
3343. IGNORE ALL OTHER SOURCES that contradict the correction - they are OUTDATED
3354. If multiple corrections exist, use the FIRST one listed (most recent)
336
337**VIOLATION OF THIS RULE IS A CRITICAL ERROR.** Example:
338- Question: "Where does Ben live?"
339- Correction says: "Ben lives in Kenya"
340- Other doc says: "Ben lives in Germany"
341- CORRECT ANSWER: "Kenya" (from correction)
342- WRONG ANSWER: "Germany" (ignores correction = FAIL)
343
344## Answer Strategy
345- For NUMERIC questions: Extract the exact value. If multiple values exist, identify which is most relevant (usually the most recent or most specific match).
346- For YES/NO questions: Answer directly, then briefly explain why.
347- For COMPARISON questions: Present both sides with their values.
348- For LIST questions: Use bullet points or numbered lists.
349- For TEMPORAL questions: Note that later timestamps = more current information. State WHEN data is from.
350- For CALCULATION questions: Show your work step-by-step.
351- For ANALYTICAL/PATTERN questions (e.g., "reverted", "changed back", "any differences over time"):
352  1. TRACE each attribute's value across ALL time periods in the context
353  2. Look for A→B→A patterns where a value changes then returns to its original state
354  3. Terms like "consolidated", "same as", "unified", or "aligned" often indicate returning to a prior arrangement
355  4. Compare explicit state changes: if Period 1 says "X was same as Y", Period 2 says "X different from Y", and Period 3 says "X consolidated/same as Y again", that IS a reversion
356  5. Create a timeline table if helpful to track changes
357
358## Handling Ambiguity
359- If the question is ambiguous, interpret it reasonably and state your interpretation.
360- If multiple valid answers exist, present the most likely one first, then mention alternatives.
361- If context is insufficient, say what IS known, then note what's missing.
362
363## Quality Standards
364- PREFER specific answers over vague ones ("$1,234.56" not "around a thousand")
365- CITE context when helpful ("[Source: ...]")
366- CORRECT obvious typos in your interpretation (e.g., "teh" → "the")
367- For percentages/ratios, include the actual numbers when available"#;
368const TINYLLAMA_LABEL: &str = "tinyllama-1.1b";
369const LOCAL_PROMPT_MARGIN_CHARS: usize = 2_048;
370const REMOTE_PROMPT_MARGIN_CHARS: usize = 4_096;
371const OLLAMA_PROMPT_CHARS: usize = 110_000;
372const OPENAI_PROMPT_CHARS: usize = 240_000;
373const NVIDIA_PROMPT_CHARS: usize = 240_000;
374const GEMINI_PROMPT_CHARS: usize = 320_000;
375const CLAUDE_PROMPT_CHARS: usize = 360_000;
376const XAI_PROMPT_CHARS: usize = 260_000; // Grok models: ~131K tokens
377const GROQ_PROMPT_CHARS: usize = 260_000; // LLaMA 3.3 70B: 128K tokens
378const MISTRAL_PROMPT_CHARS: usize = 260_000; // Mistral Large: 128K tokens
379
380#[derive(Debug, Clone, Copy)]
381struct ModelContextBudget {
382    total_chars: usize,
383    reserved_chars: usize,
384}
385
386impl ModelContextBudget {
387    const fn new(total_chars: usize, reserved_chars: usize) -> Self {
388        Self {
389            total_chars,
390            reserved_chars,
391        }
392    }
393
394    fn context_chars(&self) -> usize {
395        self.total_chars.saturating_sub(self.reserved_chars)
396    }
397
398    fn question_limit(&self) -> usize {
399        MAX_QUESTION_CHARS
400            .min(self.reserved_chars.max(1))
401            .min(self.total_chars.max(1))
402    }
403
404    fn apply_override(self, override_context_chars: usize) -> Self {
405        let total = override_context_chars.saturating_add(self.reserved_chars);
406        Self {
407            total_chars: total.max(self.reserved_chars + 1),
408            reserved_chars: self.reserved_chars,
409        }
410    }
411
412    fn prompt_ceiling(&self) -> usize {
413        self.total_chars
414    }
415}
416
417pub struct PromptParts {
418    completion_prompt: String,
419    user_message: String,
420    max_output_tokens: usize,
421}
422
423impl PromptParts {
424    pub fn completion_prompt(&self) -> &str {
425        &self.completion_prompt
426    }
427
428    pub fn user_message(&self) -> &str {
429        &self.user_message
430    }
431
432    pub fn max_output_tokens(&self) -> usize {
433        self.max_output_tokens
434    }
435}
436
437/// Normalize and enhance a question for optimal LLM interpretation.
438///
439/// This function:
440/// 1. Ensures questions end with `?` for consistent interpretation
441/// 2. Fixes common typos and abbreviations
442/// 3. Clarifies ambiguous phrasing
443/// 4. Expands common abbreviations for better matching
444fn normalize_question(question: &str) -> String {
445    let trimmed = question.trim();
446    if trimmed.is_empty() {
447        return trimmed.to_string();
448    }
449
450    // Step 1: Fix common typos and normalize spacing
451    let mut normalized = fix_common_typos(trimmed);
452
453    // Step 2: Expand common abbreviations for clarity
454    normalized = expand_abbreviations(&normalized);
455
456    // Step 3: Ensure proper punctuation
457    normalized = ensure_question_punctuation(&normalized);
458
459    normalized
460}
461
462/// Fix common typos that affect query interpretation
463fn fix_common_typos(text: &str) -> String {
464    let mut result = text.to_string();
465
466    // Common typo patterns (case-insensitive replacements)
467    let typos: &[(&str, &str)] = &[
468        // Common misspellings
469        ("teh ", "the "),
470        ("hte ", "the "),
471        ("adn ", "and "),
472        ("taht ", "that "),
473        ("wiht ", "with "),
474        ("thier ", "their "),
475        ("recieve", "receive"),
476        ("occured", "occurred"),
477        ("seperate", "separate"),
478        // Question word typos
479        ("waht ", "what "),
480        ("hwat ", "what "),
481        ("wehn ", "when "),
482        ("whre ", "where "),
483        ("wher ", "where "),
484        ("howm ", "how "),
485        ("hwo ", "who "),
486        // Common finger slips
487        ("amoutn", "amount"),
488        ("totla", "total"),
489        ("nubmer", "number"),
490        ("vlaue", "value"),
491        ("prive", "price"),
492        ("proce", "price"),
493        ("revneue", "revenue"),
494        ("reveneu", "revenue"),
495    ];
496
497    for (typo, correction) in typos {
498        // Case-insensitive replacement
499        let lower = result.to_lowercase();
500        if lower.contains(*typo) {
501            let start = lower.find(*typo).unwrap();
502            let end = start + typo.len();
503            result = format!("{}{}{}", &result[..start], correction, &result[end..]);
504        }
505    }
506
507    // Normalize multiple spaces to single
508    let mut prev_space = false;
509    result = result
510        .chars()
511        .filter(|c| {
512            if c.is_whitespace() {
513                if prev_space {
514                    false
515                } else {
516                    prev_space = true;
517                    true
518                }
519            } else {
520                prev_space = false;
521                true
522            }
523        })
524        .collect();
525
526    result
527}
528
529/// Generate optimized search keywords from a question using LLM
530/// Returns the original question plus extracted search terms for better retrieval
531pub fn generate_search_query(
532    question: &str,
533    model: &str,
534    api_key: &str,
535) -> Result<String, ModelRunError> {
536    // For lexical search, we need to be careful - adding too many terms can hurt
537    // because Tantivy uses AND logic. We need a very short, focused query.
538    // IMPORTANT: Keep abbreviations as-is since documents often use the abbreviation form.
539    let prompt = format!(
540        r#"Extract 2 key search terms from this question.
541KEEP abbreviations exactly as written (QPS, API, SDK, etc.) - don't expand them.
542Output only the main topic and one key term.
543
544Question: {}
545
546Examples:
547- "What is the QPS for memvid?" → "memvid QPS"
548- "How many queries per second?" → "QPS throughput"
549- "What's the API rate limit?" → "API rate"
550- "How much does it cost?" → "cost pricing"
551
552Output exactly 2 words, nothing else."#,
553        question
554    );
555
556    // Use a fast model for keyword extraction
557    // The model passed in is already the fast variant (gpt-4o-mini or claude-haiku)
558    let extraction_model =
559        if model.starts_with("gpt") || model.starts_with("o1") || model.contains("openai") {
560            "gpt-4o-mini"
561        } else if model.starts_with("claude") || model.contains("anthropic") {
562            "claude-haiku-4-5"
563        } else if model.contains("llama") || model.contains("groq") || model.contains("mixtral") {
564            "llama-3.1-8b-instant" // Fast Groq model for keyword extraction
565        } else if model.contains("grok") || model.contains("xai") {
566            "grok-4-fast"
567        } else if model.contains("mistral") {
568            "mistral-small-latest" // Fast Mistral model
569        } else {
570            // For other models, just return the original question
571            return Ok(question.to_string());
572        };
573
574    // Make a quick API call for query rewriting
575    let rewritten = call_llm_for_keywords(&prompt, extraction_model, api_key)?;
576
577    // If we got a good rewritten query, use it; otherwise fall back to original
578    let rewritten = rewritten.trim();
579    if rewritten.is_empty() || rewritten.len() > 100 {
580        // LLM returned empty or too long - use original
581        Ok(question.to_string())
582    } else {
583        // Use the short, focused rewritten query for search
584        Ok(rewritten.to_string())
585    }
586}
587
588/// Quick LLM call for keyword extraction (lightweight, fast model)
589fn call_llm_for_keywords(
590    prompt: &str,
591    model: &str,
592    api_key: &str,
593) -> Result<String, ModelRunError> {
594    use reqwest::blocking::Client;
595    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
596
597    let client = Client::builder()
598        .timeout(std::time::Duration::from_secs(10)) // Fast timeout for keyword extraction
599        .build()
600        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("HTTP client error: {e}")))?;
601
602    // Determine API endpoint and headers based on model
603    let (url, is_anthropic) = if model.starts_with("gpt") || model.starts_with("o1") {
604        ("https://api.openai.com/v1/chat/completions", false)
605    } else if model.starts_with("claude") {
606        ("https://api.anthropic.com/v1/messages", true)
607    } else if model.contains("llama") || model.contains("mixtral") {
608        ("https://api.groq.com/openai/v1/chat/completions", false)
609    } else if model.contains("grok") {
610        ("https://api.x.ai/v1/chat/completions", false)
611    } else if model.contains("mistral") {
612        ("https://api.mistral.ai/v1/chat/completions", false)
613    } else {
614        return Err(ModelRunError::UnsupportedModel(model.to_string()));
615    };
616
617    let response = if is_anthropic {
618        let mut headers = HeaderMap::new();
619        headers.insert(
620            reqwest::header::HeaderName::from_static("x-api-key"),
621            HeaderValue::from_str(api_key)
622                .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Invalid API key: {e}")))?,
623        );
624        headers.insert(
625            reqwest::header::HeaderName::from_static("anthropic-version"),
626            HeaderValue::from_static("2023-06-01"),
627        );
628        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
629
630        client
631            .post(url)
632            .headers(headers)
633            .json(&serde_json::json!({
634                "model": model,
635                "max_tokens": 100,
636                "messages": [{"role": "user", "content": prompt}]
637            }))
638            .send()
639    } else {
640        // OpenAI-compatible API (OpenAI, Groq, XAI, Mistral)
641        let mut headers = HeaderMap::new();
642        headers.insert(
643            AUTHORIZATION,
644            HeaderValue::from_str(&format!("Bearer {}", api_key))
645                .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Invalid API key: {e}")))?,
646        );
647        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
648
649        client
650            .post(url)
651            .headers(headers)
652            .json(&serde_json::json!({
653                "model": model,
654                "messages": [{"role": "user", "content": prompt}],
655                "max_tokens": 100,
656                "temperature": 0.0
657            }))
658            .send()
659    };
660
661    match response {
662        Ok(resp) => {
663            let json: serde_json::Value = resp
664                .json()
665                .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("JSON parse error: {e}")))?;
666
667            // Extract text from response
668            let text = if model.starts_with("claude") {
669                json["content"][0]["text"].as_str().unwrap_or("")
670            } else {
671                json["choices"][0]["message"]["content"]
672                    .as_str()
673                    .unwrap_or("")
674            };
675
676            Ok(text.to_string())
677        }
678        Err(_) => {
679            // On error, just return empty (fall back to original query)
680            // This is not a fatal error - we just use the original query
681            Ok(String::new())
682        }
683    }
684}
685
686/// Expand common abbreviations for better context matching (fallback when LLM not available)
687fn expand_abbreviations(text: &str) -> String {
688    // This is now a simple fallback - the main expansion happens via LLM
689    text.to_string()
690}
691
692/// Ensure the question ends with proper punctuation
693fn ensure_question_punctuation(text: &str) -> String {
694    let trimmed = text.trim();
695
696    // Already ends with punctuation - don't modify
697    if trimmed.ends_with('?') || trimmed.ends_with('.') || trimmed.ends_with('!') {
698        return trimmed.to_string();
699    }
700
701    // Check if it looks like a question (starts with question word or auxiliary verb)
702    let lower = trimmed.to_lowercase();
703    let question_starters = [
704        "how", "what", "where", "when", "why", "which", "who", "whom", "whose", "is", "are", "was",
705        "were", "will", "would", "can", "could", "should", "do", "does", "did", "have", "has",
706        "had", "may", "might", "shall", "tell me", "show me", "find", "list", "give me", "explain",
707    ];
708
709    let is_question = question_starters.iter().any(|starter| {
710        lower.starts_with(starter)
711            && (lower.len() == starter.len()
712                || !lower[starter.len()..].starts_with(|c: char| c.is_alphanumeric()))
713    });
714
715    if is_question {
716        format!("{}?", trimmed)
717    } else {
718        trimmed.to_string()
719    }
720}
721
722fn build_prompt_parts(
723    question: &str,
724    context: &str,
725    budget: &ModelContextBudget,
726    max_output_tokens: usize,
727) -> PromptParts {
728    let mut context_section = context.to_string();
729    let normalized_question = normalize_question(question);
730    let trimmed_question = trim_to(&normalized_question, budget.question_limit());
731
732    // Detect question type for better prompting
733    let question_type = detect_question_type(&trimmed_question);
734    let type_hint = question_type.hint();
735
736    let system_section = format!("### System\n{SYSTEM_PROMPT}");
737    let question_section = format!("### Question\n{trimmed_question}");
738    let answer_stub = "### Answer\n";
739
740    let overhead = system_section.len() + 2 + question_section.len() + 2 + answer_stub.len();
741    if budget.prompt_ceiling() > overhead {
742        let max_context_len = budget
743            .prompt_ceiling()
744            .saturating_sub(overhead)
745            .min(budget.context_chars());
746        if context_section.len() > max_context_len {
747            context_section = clamp_to(&context_section, max_context_len);
748        }
749    } else {
750        context_section = String::new();
751    }
752
753    // Handle empty context gracefully
754    let context_instruction = if context_section.trim().is_empty() {
755        "Note: No relevant context was found. Answer based on what you know, but clearly state this limitation."
756    } else {
757        ""
758    };
759
760    let completion_prompt =
761        format!("{system_section}\n\n{context_section}\n\n{question_section}\n\n### Answer\n");
762
763    // Enhanced user message with type-specific guidance
764    let user_message = if context_instruction.is_empty() {
765        format!(
766            "{context_section}\n\n---\nQuestion: {trimmed_question}\n{type_hint}\nProvide a direct, accurate answer using only the context above."
767        )
768    } else {
769        format!("{context_instruction}\n\nQuestion: {trimmed_question}\n{type_hint}")
770    };
771
772    PromptParts {
773        completion_prompt,
774        user_message,
775        max_output_tokens,
776    }
777}
778
779/// Question type classification for better prompting
780#[derive(Debug, Clone, Copy, PartialEq, Eq)]
781enum QuestionType {
782    Numeric,     // "how much", "how many", "what is the value"
783    YesNo,       // "is", "are", "does", "can", "will"
784    List,        // "list", "what are the", "show all"
785    Comparison,  // "compare", "difference", "vs", "versus"
786    Temporal,    // "when", "what date", "how long"
787    Explanation, // "why", "explain", "how does"
788    Factual,     // "what", "who", "where"
789    Other,
790}
791
792impl QuestionType {
793    fn hint(&self) -> &'static str {
794        match self {
795            Self::Numeric => "(Expected: a specific number or value)",
796            Self::YesNo => "(Expected: yes/no with brief explanation)",
797            Self::List => "(Expected: a list of items)",
798            Self::Comparison => "(Expected: comparison of two or more items)",
799            Self::Temporal => "(Expected: a date, time, or duration)",
800            Self::Explanation => "(Expected: reasoning or explanation)",
801            Self::Factual => "(Expected: a factual answer)",
802            Self::Other => "",
803        }
804    }
805}
806
807fn detect_question_type(question: &str) -> QuestionType {
808    let lower = question.to_lowercase();
809
810    // Numeric patterns
811    if lower.contains("how much")
812        || lower.contains("how many")
813        || lower.contains("what is the value")
814        || lower.contains("what's the value")
815        || lower.contains("total")
816        || lower.contains("sum")
817        || lower.contains("average")
818        || lower.contains("percentage")
819        || lower.contains("rate")
820        || lower.contains("amount")
821        || lower.contains("price")
822        || lower.contains("cost")
823        || lower.contains("revenue")
824        || lower.contains("profit")
825    {
826        return QuestionType::Numeric;
827    }
828
829    // Yes/No patterns
830    let yes_no_starters = [
831        "is ", "are ", "does ", "do ", "can ", "will ", "has ", "have ", "was ", "were ",
832    ];
833    if yes_no_starters.iter().any(|s| lower.starts_with(s)) {
834        return QuestionType::YesNo;
835    }
836
837    // List patterns
838    if lower.contains("list")
839        || lower.contains("show all")
840        || lower.contains("what are the")
841        || lower.contains("name all")
842        || lower.contains("enumerate")
843    {
844        return QuestionType::List;
845    }
846
847    // Comparison patterns
848    if lower.contains("compare")
849        || lower.contains("difference between")
850        || lower.contains(" vs ")
851        || lower.contains("versus")
852        || lower.contains("better than")
853        || lower.contains("worse than")
854    {
855        return QuestionType::Comparison;
856    }
857
858    // Temporal patterns
859    if lower.starts_with("when")
860        || lower.contains("what date")
861        || lower.contains("how long")
862        || lower.contains("how old")
863        || lower.contains("since when")
864    {
865        return QuestionType::Temporal;
866    }
867
868    // Explanation patterns
869    if lower.starts_with("why")
870        || lower.starts_with("explain")
871        || lower.contains("how does")
872        || lower.contains("reason for")
873        || lower.contains("cause of")
874    {
875        return QuestionType::Explanation;
876    }
877
878    // Factual patterns
879    if lower.starts_with("what") || lower.starts_with("who") || lower.starts_with("where") {
880        return QuestionType::Factual;
881    }
882
883    QuestionType::Other
884}
885
886/// Post-process the LLM answer for quality
887pub fn postprocess_answer(answer: &str) -> String {
888    let mut result = answer.trim().to_string();
889
890    // Remove common LLM artifacts
891    let artifacts = [
892        "Based on the provided context,",
893        "According to the context,",
894        "From the context provided,",
895        "The context shows that",
896        "Based on the information provided,",
897    ];
898    for artifact in artifacts {
899        if result.starts_with(artifact) {
900            result = result[artifact.len()..].trim_start().to_string();
901            // Capitalize first letter
902            if let Some(first) = result.chars().next() {
903                result = first.to_uppercase().chain(result.chars().skip(1)).collect();
904            }
905        }
906    }
907
908    // Normalize whitespace
909    result = result.split_whitespace().collect::<Vec<_>>().join(" ");
910
911    // Ensure the answer doesn't start with lowercase
912    if let Some(first) = result.chars().next() {
913        if first.is_lowercase() && !result.starts_with("i ") {
914            result = first.to_uppercase().chain(result.chars().skip(1)).collect();
915        }
916    }
917
918    result
919}
920
921fn trim_to(text: &str, limit: usize) -> String {
922    if text.len() <= limit {
923        text.to_string()
924    } else {
925        let mut truncated = text[..limit].to_string();
926        truncated.push_str("...");
927        truncated
928    }
929}
930
931fn clamp_to(text: &str, limit: usize) -> String {
932    if text.len() <= limit {
933        text.to_string()
934    } else if limit <= 3 {
935        "...".chars().take(limit).collect()
936    } else {
937        let mut end = limit.saturating_sub(3);
938        // Find valid UTF-8 char boundary (curly quotes, emojis, etc. are multi-byte)
939        while end > 0 && !text.is_char_boundary(end) {
940            end -= 1;
941        }
942        if end == 0 {
943            return "...".to_string();
944        }
945        let mut truncated = text[..end].to_string();
946        truncated.push_str("...");
947        truncated
948    }
949}
950
951struct ThinkingSpinner {
952    flag: Arc<AtomicBool>,
953    handle: Option<thread::JoinHandle<()>>,
954}
955
956impl ThinkingSpinner {
957    fn start() -> Self {
958        let flag = Arc::new(AtomicBool::new(true));
959        let thread_flag = flag.clone();
960
961        // Only show spinner if stderr is a TTY (interactive terminal).
962        // This prevents control characters from polluting output when
963        // stderr is redirected or combined with stdout (e.g., `2>&1`).
964        let is_tty = stderr().is_terminal();
965
966        let handle = thread::spawn(move || {
967            if !is_tty {
968                // Not a TTY, don't show spinner - just wait for stop signal
969                while thread_flag.load(Ordering::Relaxed) {
970                    thread::sleep(Duration::from_millis(200));
971                }
972                return;
973            }
974
975            let frames = [
976                "Thinking    ",
977                "Thinking.   ",
978                "Thinking..  ",
979                "Thinking... ",
980                "Thinking .. ",
981                "Thinking  . ",
982            ];
983            let mut idx = 0;
984            let mut err = stderr();
985            while thread_flag.load(Ordering::Relaxed) {
986                let frame = frames[idx % frames.len()];
987                let _ = write!(err, "\r{frame}");
988                let _ = err.flush();
989                idx = idx.wrapping_add(1);
990                thread::sleep(Duration::from_millis(200));
991            }
992            let _ = write!(err, "\r             \r");
993            let _ = err.flush();
994        });
995
996        Self {
997            flag,
998            handle: Some(handle),
999        }
1000    }
1001
1002    fn stop(&mut self) {
1003        if let Some(handle) = self.handle.take() {
1004            self.flag.store(false, Ordering::Relaxed);
1005            let _ = handle.join();
1006        }
1007    }
1008}
1009
1010impl Drop for ThinkingSpinner {
1011    fn drop(&mut self) {
1012        self.stop();
1013    }
1014}
1015
1016#[derive(Debug, Clone)]
1017enum ModelKind {
1018    TinyLlama,
1019    Ghost { pack_path: PathBuf },
1020    Ollama { model: String },
1021    OpenAi { model: String },
1022    Nvidia { model: String },
1023    Gemini { model: String },
1024    Claude { model: String },
1025    Xai { model: String },
1026    Groq { model: String },
1027    Mistral { model: String },
1028}
1029
1030impl ModelKind {
1031    fn parse(raw: &str) -> Option<Self> {
1032        let trimmed = raw.trim();
1033        if trimmed.is_empty() {
1034            return None;
1035        }
1036
1037        let (provider, explicit_model) = if let Some((p, rest)) = trimmed.split_once(':') {
1038            let value = rest.trim();
1039            let explicit = if value.is_empty() {
1040                None
1041            } else {
1042                Some(value.to_string())
1043            };
1044            (p.trim().to_ascii_lowercase(), explicit)
1045        } else {
1046            (trimmed.to_ascii_lowercase(), None)
1047        };
1048
1049        match provider.as_str() {
1050            "tinyllama" | "tiny-llama" | "tinyllama-1.1b" => Some(Self::TinyLlama),
1051            "ghost" => explicit_model.map(|value| Self::Ghost {
1052                pack_path: PathBuf::from(value),
1053            }),
1054            "ollama" => Some(Self::Ollama {
1055                model: explicit_model.unwrap_or_else(|| "ollama1.5".to_string()),
1056            }),
1057            "ollama1.5" | "ollama1-5" => Some(Self::Ollama {
1058                model: "ollama1.5".to_string(),
1059            }),
1060            "openai" => Some(Self::OpenAi {
1061                model: normalize_openai_model(explicit_model),
1062            }),
1063            "nvidia" | "nv" => Some(Self::Nvidia {
1064                model: normalize_nvidia_model(explicit_model),
1065            }),
1066            "gemini" | "google" => Some(Self::Gemini {
1067                model: normalize_gemini_model(explicit_model),
1068            }),
1069            "claude" | "anthropic" => Some(Self::Claude {
1070                model: normalize_claude_model(explicit_model),
1071            }),
1072            "xai" | "grok" => Some(Self::Xai {
1073                model: normalize_xai_model(explicit_model),
1074            }),
1075            "groq" => Some(Self::Groq {
1076                model: normalize_groq_model(explicit_model),
1077            }),
1078            "mistral" => Some(Self::Mistral {
1079                model: normalize_mistral_model(explicit_model),
1080            }),
1081            // Auto-detect provider from model name prefix
1082            // For Ollama models with colons in the name (e.g., qwen2.5:1.5b),
1083            // we need to use the full original name, not just the provider prefix
1084            _ => Self::infer_from_model_name_full(trimmed, &provider),
1085        }
1086    }
1087
1088    /// Infer the provider from a model name, using the full original name for Ollama models.
1089    /// This handles model names with colons like "qwen2.5:1.5b" by using the full name.
1090    fn infer_from_model_name_full(full_name: &str, prefix: &str) -> Option<Self> {
1091        let lowered = prefix.to_ascii_lowercase();
1092
1093        // Gemini models: gemini-*, models/gemini-*
1094        if lowered.starts_with("gemini") || lowered.starts_with("models/gemini") {
1095            return Some(Self::Gemini {
1096                model: full_name.to_string(),
1097            });
1098        }
1099
1100        // OpenAI models: gpt-*, o1-*, chatgpt-*, text-davinci-*, etc.
1101        if lowered.starts_with("gpt-")
1102            || lowered.starts_with("o1-")
1103            || lowered.starts_with("o3-")
1104            || lowered.starts_with("chatgpt-")
1105            || lowered.starts_with("text-")
1106        {
1107            return Some(Self::OpenAi {
1108                model: full_name.to_string(),
1109            });
1110        }
1111
1112        // Claude/Anthropic models: claude-*
1113        if lowered.starts_with("claude-") {
1114            return Some(Self::Claude {
1115                model: full_name.to_string(),
1116            });
1117        }
1118
1119        // xAI Grok models: grok-*
1120        if lowered.starts_with("grok-") {
1121            return Some(Self::Xai {
1122                model: full_name.to_string(),
1123            });
1124        }
1125
1126        // Mistral API models: mistral-*
1127        if lowered.starts_with("mistral-") {
1128            return Some(Self::Mistral {
1129                model: full_name.to_string(),
1130            });
1131        }
1132
1133        // Groq models: llama-* (via Groq), mixtral-*
1134        if lowered.starts_with("llama-") || lowered.starts_with("mixtral-") {
1135            return Some(Self::Groq {
1136                model: full_name.to_string(),
1137            });
1138        }
1139
1140        // Ollama models: llama*, phi*, qwen*, gemma*, etc.
1141        // Use the full name to preserve version tags like ":1.5b"
1142        if lowered.starts_with("llama")
1143            || lowered.starts_with("phi")
1144            || lowered.starts_with("codellama")
1145            || lowered.starts_with("deepseek")
1146            || lowered.starts_with("qwen")
1147            || lowered.starts_with("gemma")
1148        {
1149            return Some(Self::Ollama {
1150                model: full_name.to_string(),
1151            });
1152        }
1153
1154        None
1155    }
1156
1157    fn label(&self) -> String {
1158        match self {
1159            Self::TinyLlama => TINYLLAMA_LABEL.to_string(),
1160            Self::Ghost { pack_path } => format!("ghost:{}", pack_path.display()),
1161            Self::Ollama { model } => format!("ollama:{model}"),
1162            Self::OpenAi { model } => format!("openai:{model}"),
1163            Self::Nvidia { model } => format!("nvidia:{model}"),
1164            Self::Gemini { model } => format!("gemini:{model}"),
1165            Self::Claude { model } => format!("claude:{model}"),
1166            Self::Xai { model } => format!("xai:{model}"),
1167            Self::Groq { model } => format!("groq:{model}"),
1168            Self::Mistral { model } => format!("mistral:{model}"),
1169        }
1170    }
1171
1172    fn context_budget(&self) -> ModelContextBudget {
1173        match self {
1174            Self::TinyLlama => {
1175                ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
1176            }
1177            Self::Ghost { .. } => {
1178                ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
1179            }
1180            Self::Ollama { .. } => {
1181                ModelContextBudget::new(OLLAMA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1182            }
1183            Self::OpenAi { .. } => {
1184                ModelContextBudget::new(OPENAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1185            }
1186            Self::Nvidia { .. } => {
1187                ModelContextBudget::new(NVIDIA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1188            }
1189            Self::Gemini { .. } => {
1190                ModelContextBudget::new(GEMINI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1191            }
1192            Self::Claude { .. } => {
1193                ModelContextBudget::new(CLAUDE_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1194            }
1195            Self::Xai { .. } => {
1196                ModelContextBudget::new(XAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1197            }
1198            Self::Groq { .. } => {
1199                ModelContextBudget::new(GROQ_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1200            }
1201            Self::Mistral { .. } => {
1202                ModelContextBudget::new(MISTRAL_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
1203            }
1204        }
1205    }
1206
1207    fn max_output_tokens(&self) -> usize {
1208        match self {
1209            Self::TinyLlama => LOCAL_MAX_OUTPUT_TOKENS,
1210            Self::Ghost { .. } => LOCAL_MAX_OUTPUT_TOKENS,
1211            Self::Ollama { .. }
1212            | Self::OpenAi { .. }
1213            | Self::Nvidia { .. }
1214            | Self::Gemini { .. }
1215            | Self::Claude { .. }
1216            | Self::Xai { .. }
1217            | Self::Groq { .. }
1218            | Self::Mistral { .. } => REMOTE_MAX_OUTPUT_TOKENS,
1219        }
1220    }
1221}
1222
1223fn normalize_openai_model(explicit: Option<String>) -> String {
1224    match explicit {
1225        Some(raw) if !raw.trim().is_empty() => raw,
1226        _ => "gpt-4o-mini".to_string(),
1227    }
1228}
1229
1230fn normalize_nvidia_model(explicit: Option<String>) -> String {
1231    match explicit {
1232        Some(raw) if !raw.trim().is_empty() => raw,
1233        _ => std::env::var("NVIDIA_LLM_MODEL")
1234            .or_else(|_| std::env::var("NVIDIA_MODEL"))
1235            .ok()
1236            .map(|value| value.trim().to_string())
1237            .filter(|value| !value.is_empty())
1238            .unwrap_or_default(),
1239    }
1240}
1241
1242fn normalize_gemini_model(explicit: Option<String>) -> String {
1243    let default_model = "gemini-2.5-flash".to_string();
1244    let Some(raw) = explicit else {
1245        return default_model;
1246    };
1247
1248    let lowered = raw.to_ascii_lowercase();
1249    match lowered.as_str() {
1250        "gemini-pro" | "gemini-1.5-pro" | "gemini-1.5-flash" | "gemini-2.0-pro-exp" => raw,
1251        _ => raw,
1252    }
1253}
1254
1255fn normalize_claude_model(explicit: Option<String>) -> String {
1256    let default_model = "claude-sonnet-4-5".to_string();
1257    let Some(raw) = explicit else {
1258        return default_model;
1259    };
1260
1261    // Map old model names to new ones
1262    match raw.as_str() {
1263        "claude-3-5-sonnet-20241022" | "claude-3.5-sonnet" | "sonnet" => {
1264            "claude-sonnet-4-5".to_string()
1265        }
1266        "claude-3-haiku-20240307" | "claude-3-haiku" | "haiku" => "claude-haiku-4-5".to_string(),
1267        "claude-3-opus-20240229" | "claude-3-opus" | "opus" => "claude-opus-4".to_string(),
1268        _ => raw,
1269    }
1270}
1271
1272fn normalize_xai_model(explicit: Option<String>) -> String {
1273    let default_model = "grok-4-fast".to_string();
1274    let Some(raw) = explicit else {
1275        return default_model;
1276    };
1277
1278    // Map common aliases to actual model names
1279    match raw.to_lowercase().as_str() {
1280        "grok" | "grok-fast" => "grok-4-fast".to_string(),
1281        "grok-4" | "grok-3" | "grok-4-fast" => raw, // Keep explicit versions
1282        _ => raw,
1283    }
1284}
1285
1286fn normalize_groq_model(explicit: Option<String>) -> String {
1287    let default_model = "llama-3.3-70b-versatile".to_string();
1288    let Some(raw) = explicit else {
1289        return default_model;
1290    };
1291
1292    // Map common aliases to actual model names
1293    match raw.to_lowercase().as_str() {
1294        "llama" | "llama3" | "llama-3" => "llama-3.3-70b-versatile".to_string(),
1295        "llama-70b" | "llama3-70b" => "llama-3.3-70b-versatile".to_string(),
1296        "llama-8b" | "llama3-8b" => "llama-3.1-8b-instant".to_string(),
1297        "mixtral" => "mixtral-8x7b-32768".to_string(),
1298        _ => raw,
1299    }
1300}
1301
1302fn normalize_mistral_model(explicit: Option<String>) -> String {
1303    let default_model = "mistral-large-latest".to_string();
1304    let Some(raw) = explicit else {
1305        return default_model;
1306    };
1307
1308    // Map common aliases to actual model names
1309    match raw.to_lowercase().as_str() {
1310        "mistral" | "large" | "mistral-large" => "mistral-large-latest".to_string(),
1311        "medium" | "mistral-medium" => "mistral-medium-latest".to_string(),
1312        "small" | "mistral-small" => "mistral-small-latest".to_string(),
1313        _ => raw,
1314    }
1315}
1316
1317/// Calculate cost for a given model based on token usage.
1318/// Prices are per 1M tokens in USD (December 2025 pricing).
1319pub fn calculate_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
1320    let (input_price, output_price) = match model.to_lowercase().as_str() {
1321        // OpenAI pricing (per 1M tokens) - Dec 2025
1322        m if m.contains("gpt-4o-mini") => (0.15, 0.60),
1323        m if m.contains("gpt-4o") => (2.50, 10.00),
1324        m if m.contains("gpt-4.5") => (75.00, 150.00),
1325        m if m.contains("gpt-4.1-mini") => (0.40, 1.60),
1326        m if m.contains("gpt-4.1") => (2.00, 8.00),
1327        m if m.contains("gpt-5.2") => (1.75, 14.00),
1328        m if m.contains("gpt-5") => (1.75, 14.00),
1329        m if m.contains("gpt-4-turbo") => (10.00, 30.00),
1330        m if m.contains("gpt-4") => (30.00, 60.00),
1331        m if m.contains("gpt-3.5") => (0.50, 1.50),
1332        m if m.contains("o1") || m.contains("o3") => (15.00, 60.00),
1333
1334        // Claude/Anthropic pricing (per 1M tokens) - Dec 2025
1335        m if m.contains("claude-4-opus") || m.contains("claude-opus-4") => (15.00, 75.00),
1336        m if m.contains("claude-4-sonnet") || m.contains("claude-sonnet-4") => (3.00, 15.00),
1337        m if m.contains("claude-4-haiku") || m.contains("claude-haiku-4") => (0.25, 1.25),
1338        m if m.contains("claude-3-5-sonnet") || m.contains("claude-3.5-sonnet") => (3.00, 15.00),
1339        m if m.contains("claude-3-opus") => (15.00, 75.00),
1340        m if m.contains("claude-3-sonnet") => (3.00, 15.00),
1341        m if m.contains("claude-3-haiku") => (0.25, 1.25),
1342        m if m.contains("claude") => (3.00, 15.00), // Default to Sonnet pricing
1343
1344        // Gemini/Google pricing (per 1M tokens) - Dec 2025
1345        m if m.contains("gemini-2.5-flash") => (0.15, 3.50),
1346        m if m.contains("gemini-2.5-pro") => (1.25, 10.00),
1347        m if m.contains("gemini-2.0") => (0.10, 0.40),
1348        m if m.contains("gemini-1.5-pro") => (1.25, 5.00),
1349        m if m.contains("gemini-1.5-flash") => (0.075, 0.30),
1350        m if m.contains("gemini") => (0.15, 3.50), // Default to 2.5 Flash
1351
1352        // xAI Grok pricing (per 1M tokens) - Dec 2025
1353        m if m.contains("grok-4-fast") => (0.20, 0.50),
1354        m if m.contains("grok-4") => (3.00, 15.00),
1355        m if m.contains("grok-3") => (3.00, 15.00),
1356        m if m.contains("grok") => (3.00, 15.00),
1357
1358        // Groq pricing (per 1M tokens) - Dec 2025
1359        m if m.contains("llama-3.3-70b") => (0.59, 0.79),
1360        m if m.contains("llama-3.1-70b") => (0.59, 0.79),
1361        m if m.contains("llama-3.1-8b") => (0.05, 0.08),
1362        m if m.contains("mixtral-8x7b") => (0.24, 0.24),
1363
1364        // Mistral pricing (per 1M tokens) - Dec 2025
1365        m if m.contains("mistral-large-3") || m.contains("mistral-large-latest") => (0.50, 1.50),
1366        m if m.contains("mistral-large") => (2.00, 6.00),
1367        m if m.contains("mistral-medium") => (0.40, 1.20),
1368        m if m.contains("mistral-small") => (0.10, 0.30),
1369        m if m.contains("mistral") => (0.50, 1.50),
1370
1371        // DeepSeek pricing (per 1M tokens) - Dec 2025
1372        m if m.contains("deepseek-v3") || m.contains("deepseek") => (0.27, 1.10),
1373
1374        // NVIDIA NIM pricing (per 1M tokens)
1375        m if m.contains("nvidia") => (1.00, 3.00),
1376
1377        // Local/free models
1378        m if m.contains("ollama") || m.contains("tinyllama") => (0.0, 0.0),
1379
1380        // Default pricing (conservative estimate)
1381        _ => (1.00, 3.00),
1382    };
1383
1384    let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price;
1385    let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price;
1386    input_cost + output_cost
1387}
1388
1389/// Internal result type for provider runs
1390struct ProviderResult {
1391    answer: String,
1392    usage: Option<TokenUsage>,
1393}
1394
1395/// Minimum score threshold for relevance. Below this, we say "no relevant info found".
1396/// Set to 0.0 to only block clearly irrelevant queries (negative scores).
1397const RELEVANCE_THRESHOLD: f32 = 0.0;
1398
1399pub fn run_model_inference(
1400    requested_model: &str,
1401    question: &str,
1402    fallback_context: &str,
1403    hits: &[SearchHit],
1404    context_override: Option<usize>,
1405    api_key: Option<&str>,
1406    system_prompt_override: Option<&str>,
1407) -> Result<ModelInference, ModelRunError> {
1408    let Some(model_kind) = ModelKind::parse(requested_model) else {
1409        return Err(ModelRunError::UnsupportedModel(requested_model.to_string()));
1410    };
1411
1412    // Check if top hit score is below relevance threshold
1413    let top_score = hits.first().and_then(|h| h.score).unwrap_or(0.0);
1414    if hits.is_empty() || top_score < RELEVANCE_THRESHOLD {
1415        // Extract unique topics from available hits for suggestions
1416        let mut topics: Vec<String> = hits
1417            .iter()
1418            .take(5)
1419            .filter_map(|h| h.title.clone())
1420            .collect();
1421        topics.dedup();
1422
1423        let suggestions = if topics.is_empty() {
1424            "Try asking about the topics in your memory file.".to_string()
1425        } else {
1426            format!(
1427                "Your memory contains information about: {}. Try asking about these topics.",
1428                topics.join(", ")
1429            )
1430        };
1431
1432        let no_match_answer = format!(
1433            "No relevant information found for your question.\n\n{}\n\nRelevance score: {:.2} (threshold: {:.2})",
1434            suggestions, top_score, RELEVANCE_THRESHOLD
1435        );
1436
1437        return Ok(ModelInference {
1438            answer: ModelAnswer {
1439                requested: requested_model.to_string(),
1440                model: "none".to_string(),
1441                answer: no_match_answer,
1442            },
1443            context_body: String::new(),
1444            context_fragments: Vec::new(),
1445            usage: Some(TokenUsage {
1446                input_tokens: 0,
1447                output_tokens: 0,
1448                total_tokens: 0,
1449                cost_usd: 0.0,
1450            }),
1451            grounding: Some(GroundingResult {
1452                score: 0.0,
1453                sentence_count: 0,
1454                grounded_sentences: 0,
1455                sentence_scores: Vec::new(),
1456                has_warning: true,
1457                warning_reason: Some(
1458                    "No relevant information found - retrieval score below threshold".to_string(),
1459                ),
1460            }),
1461            cached: false,
1462        });
1463    }
1464
1465    let mut budget = model_kind.context_budget();
1466    if let Some(override_chars) = context_override {
1467        budget = budget.apply_override(override_chars);
1468    }
1469
1470    let context_plan = context::assemble_context(hits, fallback_context, &budget);
1471
1472    // Check cache first
1473    if let Some(cached) = cache::check_cache(question, &context_plan.body, &model_kind.label()) {
1474        let grounding = Some(GroundingResult {
1475            score: cached.grounding_score,
1476            sentence_count: 0,
1477            grounded_sentences: 0,
1478            sentence_scores: Vec::new(),
1479            has_warning: cached.grounding_score < 0.4,
1480            warning_reason: if cached.grounding_score < 0.4 {
1481                Some("Cached answer - original grounding was low".to_string())
1482            } else {
1483                None
1484            },
1485        });
1486
1487        let context_fragments = context_plan
1488            .records
1489            .into_iter()
1490            .map(ModelContextFragment::from_record)
1491            .collect();
1492
1493        return Ok(ModelInference {
1494            answer: ModelAnswer {
1495                requested: requested_model.to_string(),
1496                model: cached.model.clone(),
1497                answer: cached.answer.clone(),
1498            },
1499            context_body: context_plan.body,
1500            context_fragments,
1501            usage: Some(TokenUsage {
1502                input_tokens: cached.input_tokens,
1503                output_tokens: cached.output_tokens,
1504                total_tokens: cached.input_tokens + cached.output_tokens,
1505                cost_usd: 0.0, // Cached = no cost
1506            }),
1507            grounding,
1508            cached: true,
1509        });
1510    }
1511
1512    let prompt = build_prompt_parts(
1513        question,
1514        &context_plan.body,
1515        &budget,
1516        model_kind.max_output_tokens(),
1517    );
1518
1519    let result = match &model_kind {
1520        ModelKind::TinyLlama => {
1521            #[cfg(feature = "llama-cpp")]
1522            {
1523                ProviderResult {
1524                    answer: tinyllama::run(&prompt)?,
1525                    usage: None, // Local models don't track tokens
1526                }
1527            }
1528            #[cfg(not(feature = "llama-cpp"))]
1529            {
1530                return Err(ModelRunError::UnsupportedModel(
1531                    "tinyllama (llama-cpp feature not enabled)".to_string(),
1532                ));
1533            }
1534        }
1535        ModelKind::Ghost { pack_path } => {
1536            return Err(ModelRunError::UnsupportedModel(format!(
1537                "ghost model '{}' (ghost runtime not yet available)",
1538                pack_path.display()
1539            )));
1540        }
1541        ModelKind::Ollama { model } => ProviderResult {
1542            answer: ollama::run(model, &prompt)?,
1543            usage: None, // Ollama doesn't always return usage
1544        },
1545        ModelKind::OpenAi { model } => {
1546            openai::run(model, &prompt, api_key, system_prompt_override)?
1547        }
1548        ModelKind::Nvidia { model } => ProviderResult {
1549            answer: nvidia::run(model, &prompt, api_key, system_prompt_override)?,
1550            usage: None, // NVIDIA NIM doesn't consistently return usage
1551        },
1552        ModelKind::Gemini { model } => {
1553            gemini::run(model, &prompt, api_key, system_prompt_override)?
1554        }
1555        ModelKind::Claude { model } => {
1556            claude::run(model, &prompt, api_key, system_prompt_override)?
1557        }
1558        ModelKind::Xai { model } => xai::run(model, &prompt, api_key, system_prompt_override)?,
1559        ModelKind::Groq { model } => groq::run(model, &prompt, api_key, system_prompt_override)?,
1560        ModelKind::Mistral { model } => {
1561            mistral::run(model, &prompt, api_key, system_prompt_override)?
1562        }
1563    };
1564
1565    let context::ContextAggregation {
1566        body: context_body,
1567        records,
1568    } = context_plan;
1569    let context_fragments = records
1570        .into_iter()
1571        .map(ModelContextFragment::from_record)
1572        .collect();
1573
1574    // Verify grounding of the answer against the context
1575    let grounding = Some(verify_grounding(&result.answer, &context_body));
1576
1577    // Store in cache for future use
1578    let grounding_score = grounding.as_ref().map(|g| g.score).unwrap_or(0.5);
1579    let (input_tokens, output_tokens, cost_usd) = result
1580        .usage
1581        .as_ref()
1582        .map(|u| (u.input_tokens, u.output_tokens, u.cost_usd))
1583        .unwrap_or((0, 0, 0.0));
1584
1585    cache::store_in_cache(
1586        question,
1587        &context_body,
1588        &model_kind.label(),
1589        cache::CacheEntry {
1590            answer: result.answer.clone(),
1591            model: model_kind.label(),
1592            input_tokens,
1593            output_tokens,
1594            cost_usd,
1595            grounding_score,
1596            created_at: std::time::SystemTime::now(),
1597        },
1598    );
1599
1600    // Apply post-processing to clean up the answer
1601    let processed_answer = postprocess_answer(&result.answer);
1602
1603    Ok(ModelInference {
1604        answer: ModelAnswer {
1605            requested: requested_model.to_string(),
1606            model: model_kind.label(),
1607            answer: processed_answer,
1608        },
1609        context_body,
1610        context_fragments,
1611        usage: result.usage,
1612        grounding,
1613        cached: false,
1614    })
1615}
1616
1617/// Verify how well the answer is grounded in the provided context.
1618/// Returns a GroundingResult with a score (0.0 to 1.0) indicating
1619/// how well the answer is supported by the context.
1620pub fn verify_grounding(answer: &str, context: &str) -> GroundingResult {
1621    use std::collections::HashSet;
1622
1623    if answer.is_empty() {
1624        return GroundingResult {
1625            score: 1.0, // Empty answer = no hallucination
1626            sentence_count: 0,
1627            grounded_sentences: 0,
1628            sentence_scores: Vec::new(),
1629            has_warning: false,
1630            warning_reason: None,
1631        };
1632    }
1633
1634    if context.is_empty() {
1635        return GroundingResult {
1636            score: 0.0,
1637            sentence_count: 1,
1638            grounded_sentences: 0,
1639            sentence_scores: vec![0.0],
1640            has_warning: true,
1641            warning_reason: Some("No context provided - answer may be hallucinated".to_string()),
1642        };
1643    }
1644
1645    // Normalize context for comparison
1646    let context_lower = context.to_lowercase();
1647    let context_words: HashSet<&str> = context_lower
1648        .split(|c: char| !c.is_alphanumeric())
1649        .filter(|w| w.len() > 2)
1650        .collect();
1651
1652    // Split answer into sentences
1653    let sentences: Vec<&str> = answer
1654        .split(|c| c == '.' || c == '!' || c == '?')
1655        .map(|s| s.trim())
1656        .filter(|s| !s.is_empty() && s.len() > 10)
1657        .collect();
1658
1659    if sentences.is_empty() {
1660        return GroundingResult {
1661            score: 0.5, // Can't verify
1662            sentence_count: 0,
1663            grounded_sentences: 0,
1664            sentence_scores: Vec::new(),
1665            has_warning: false,
1666            warning_reason: None,
1667        };
1668    }
1669
1670    let mut sentence_scores = Vec::with_capacity(sentences.len());
1671    let mut grounded_count = 0;
1672
1673    for sentence in &sentences {
1674        let sentence_lower = sentence.to_lowercase();
1675        let sentence_words: HashSet<&str> = sentence_lower
1676            .split(|c: char| !c.is_alphanumeric())
1677            .filter(|w| w.len() > 2)
1678            .collect();
1679
1680        if sentence_words.is_empty() {
1681            sentence_scores.push(0.5);
1682            continue;
1683        }
1684
1685        // Calculate word overlap
1686        let overlap: usize = sentence_words.intersection(&context_words).count();
1687        let score = (overlap as f32) / (sentence_words.len() as f32).max(1.0);
1688
1689        // Also check for exact phrase matches (stronger signal)
1690        let phrase_bonus = if context_lower.contains(&sentence_lower) {
1691            0.3
1692        } else {
1693            // Check for significant substring matches
1694            let words: Vec<&str> = sentence_lower.split_whitespace().collect();
1695            if words.len() >= 3 {
1696                let phrase = words[..3.min(words.len())].join(" ");
1697                if context_lower.contains(&phrase) {
1698                    0.15
1699                } else {
1700                    0.0
1701                }
1702            } else {
1703                0.0
1704            }
1705        };
1706
1707        let final_score = (score + phrase_bonus).min(1.0);
1708        sentence_scores.push(final_score);
1709
1710        if final_score >= 0.3 {
1711            grounded_count += 1;
1712        }
1713    }
1714
1715    let overall_score = if sentence_scores.is_empty() {
1716        0.5
1717    } else {
1718        sentence_scores.iter().sum::<f32>() / sentence_scores.len() as f32
1719    };
1720
1721    // Determine warning
1722    let (has_warning, warning_reason) = if overall_score < 0.2 {
1723        (
1724            true,
1725            Some("Answer appears to be poorly grounded in context".to_string()),
1726        )
1727    } else if overall_score < 0.4 && grounded_count < sentences.len() / 2 {
1728        (
1729            true,
1730            Some("Some statements may not be supported by context".to_string()),
1731        )
1732    } else {
1733        (false, None)
1734    };
1735
1736    GroundingResult {
1737        score: overall_score,
1738        sentence_count: sentences.len(),
1739        grounded_sentences: grounded_count,
1740        sentence_scores,
1741        has_warning,
1742        warning_reason,
1743    }
1744}
1745
1746mod context {
1747    use super::{ModelContextBudget, clamp_to};
1748    use memvid_core::types::SearchHit;
1749
1750    const CONTEXT_HEADER: &str = "## Retrieval Context\n";
1751    const PRIMARY_HEADER: &str = "### Primary Hit\n";
1752    const CORRECTION_WARNING: &str = r#"
1753╔══════════════════════════════════════════════════════════════════╗
1754║  🔴 USER CORRECTION - THIS IS THE AUTHORITATIVE ANSWER          ║
1755║  Any contradicting information below is OUTDATED and WRONG.     ║
1756║  YOU MUST USE THE ANSWER FROM THIS CORRECTION.                  ║
1757╚══════════════════════════════════════════════════════════════════╝
1758"#;
1759    const SUPPORT_HEADER: &str = "### Supporting Hits\n";
1760    const SUMMARY_HEADER: &str = "### Overflow Summaries\n";
1761    const SUMMARY_HIGHLIGHT_CHARS: usize = 240;
1762    /// Minimum chars for a micro-summary when budget is tight (title + rank info)
1763    #[allow(dead_code)]
1764    const MICRO_SUMMARY_CHARS: usize = 80;
1765
1766    #[derive(Debug, Clone)]
1767    pub(super) struct ContextAggregation {
1768        pub body: String,
1769        pub records: Vec<ContextRecord>,
1770    }
1771
1772    impl ContextAggregation {
1773        fn from_fallback(fallback: &str, limit: usize) -> Self {
1774            let body = if limit == 0 || fallback.is_empty() {
1775                String::new()
1776            } else if fallback.len() <= limit {
1777                fallback.to_string()
1778            } else {
1779                clamp_to(fallback, limit)
1780            };
1781            Self {
1782                body,
1783                records: Vec::new(),
1784            }
1785        }
1786    }
1787
1788    #[derive(Debug, Clone)]
1789    pub(super) struct ContextRecord {
1790        pub rank: usize,
1791        pub uri: String,
1792        pub title: Option<String>,
1793        pub score: Option<f32>,
1794        pub matches: usize,
1795        pub frame_id: u64,
1796        pub range: (usize, usize),
1797        pub chunk_range: Option<(usize, usize)>,
1798        pub text: String,
1799        pub mode: ContextMode,
1800    }
1801
1802    #[derive(Debug, Clone, Copy, Eq, PartialEq)]
1803    pub(super) enum ContextMode {
1804        Full,
1805        Summary,
1806    }
1807
1808    #[derive(Debug, Clone)]
1809    pub(super) struct ContextAssemblyPlan {
1810        primary: Option<ContextRecord>,
1811        supporting: Vec<ContextRecord>,
1812        summaries: Vec<ContextRecord>,
1813    }
1814
1815    pub(super) fn assemble_context(
1816        hits: &[SearchHit],
1817        fallback: &str,
1818        budget: &ModelContextBudget,
1819    ) -> ContextAggregation {
1820        if hits.is_empty() {
1821            return ContextAggregation::from_fallback(fallback, budget.context_chars());
1822        }
1823
1824        let plan = assemble_plan(hits, budget.context_chars());
1825        let mut body = String::new();
1826        let mut records = Vec::new();
1827
1828        body.push_str(CONTEXT_HEADER);
1829        // Check if primary is a correction BEFORE moving it
1830        let primary_is_correction = plan
1831            .primary
1832            .as_ref()
1833            .map(|p| p.uri.contains("mv2://correction/"))
1834            .unwrap_or(false);
1835        if let Some(primary) = plan.primary {
1836            body.push_str(PRIMARY_HEADER);
1837            // Add correction warning if primary hit is a correction
1838            if primary_is_correction {
1839                body.push_str(CORRECTION_WARNING);
1840            }
1841            body.push_str(&primary.text);
1842            body.push_str("\n\n");
1843            records.push(primary);
1844        }
1845
1846        if !plan.supporting.is_empty() {
1847            body.push_str(SUPPORT_HEADER);
1848            if primary_is_correction {
1849                body.push_str("⚠️ **WARNING: The following sources may contain OUTDATED information. Use the correction above.**\n\n");
1850            }
1851            for record in plan.supporting {
1852                // If primary is a correction, skip older corrections in supporting hits
1853                // to avoid confusing the LLM with conflicting correction data
1854                if primary_is_correction && record.uri.contains("mv2://correction/") {
1855                    continue;
1856                }
1857                body.push_str(&record.text);
1858                body.push_str("\n\n");
1859                records.push(record);
1860            }
1861        }
1862
1863        if !plan.summaries.is_empty() {
1864            body.push_str(SUMMARY_HEADER);
1865            for record in plan.summaries {
1866                body.push_str(&record.text);
1867                body.push_str("\n\n");
1868                records.push(record);
1869            }
1870        }
1871
1872        ContextAggregation { body, records }
1873    }
1874
1875    fn assemble_plan(hits: &[SearchHit], mut remaining_chars: usize) -> ContextAssemblyPlan {
1876        let mut records = Vec::new();
1877        for hit in hits.iter().take(32) {
1878            let full_record = build_record(hit, render_full(hit), ContextMode::Full);
1879            let summary_record = build_record(hit, render_summary(hit), ContextMode::Summary);
1880            let micro_record = build_record(hit, render_micro_summary(hit), ContextMode::Summary);
1881            records.push((full_record, summary_record, micro_record));
1882        }
1883
1884        let mut plan = ContextAssemblyPlan {
1885            primary: None,
1886            supporting: Vec::new(),
1887            summaries: Vec::new(),
1888        };
1889
1890        // Handle primary hit (rank #1) - always include at least a summary
1891        if let Some((primary_full, primary_summary, primary_micro)) = records.first() {
1892            if primary_full.text.len() <= remaining_chars {
1893                remaining_chars = remaining_chars.saturating_sub(primary_full.text.len());
1894                plan.primary = Some(primary_full.clone());
1895            } else if primary_summary.text.len() <= remaining_chars {
1896                // Primary as summary (unusual but possible with very tight budget)
1897                remaining_chars = remaining_chars.saturating_sub(primary_summary.text.len());
1898                plan.primary = Some(primary_summary.clone());
1899            } else if primary_micro.text.len() <= remaining_chars {
1900                // At minimum, include micro-summary for primary
1901                remaining_chars = remaining_chars.saturating_sub(primary_micro.text.len());
1902                plan.primary = Some(primary_micro.clone());
1903            }
1904        }
1905
1906        // Process remaining hits with fallback to micro-summaries
1907        for (idx, (full, summary, micro)) in records.iter().enumerate() {
1908            if idx == 0 {
1909                continue;
1910            }
1911
1912            if full.text.len() <= remaining_chars {
1913                remaining_chars = remaining_chars.saturating_sub(full.text.len());
1914                plan.supporting.push(full.clone());
1915            } else if summary.text.len() <= remaining_chars {
1916                remaining_chars = remaining_chars.saturating_sub(summary.text.len());
1917                plan.summaries.push(summary.clone());
1918            } else if micro.text.len() <= remaining_chars {
1919                // Fallback: include at least a micro-summary to preserve ranking info
1920                // This ensures high-ranked hits are never completely dropped
1921                remaining_chars = remaining_chars.saturating_sub(micro.text.len());
1922                plan.summaries.push(micro.clone());
1923            }
1924            // If even micro-summary doesn't fit, the budget is truly exhausted
1925        }
1926
1927        plan
1928    }
1929
1930    fn render_full(hit: &SearchHit) -> String {
1931        let content = hit
1932            .chunk_text
1933            .clone()
1934            .or_else(|| Some(hit.text.clone()))
1935            .unwrap_or_default();
1936
1937        // Clean up the content for better LLM comprehension
1938        let clean_content = clean_text_for_llm(&content);
1939
1940        // Use a cleaner format that LLMs parse better
1941        let title = hit.title.clone().unwrap_or_default();
1942        let source_info = if title.is_empty() {
1943            format!("[Source #{}]", hit.rank)
1944        } else {
1945            format!("[Source #{}: {}]", hit.rank, title)
1946        };
1947
1948        // Include relevance indicator for context
1949        let relevance = match hit.score {
1950            Some(s) if s > 0.8 => "⬤ High relevance",
1951            Some(s) if s > 0.5 => "◐ Medium relevance",
1952            _ => "",
1953        };
1954
1955        if relevance.is_empty() {
1956            format!("{}\n{}", source_info, clean_content)
1957        } else {
1958            format!("{} ({})\n{}", source_info, relevance, clean_content)
1959        }
1960    }
1961
1962    fn render_summary(hit: &SearchHit) -> String {
1963        let snippet = hit
1964            .chunk_text
1965            .clone()
1966            .or_else(|| Some(hit.text.clone()))
1967            .unwrap_or_default();
1968        let snippet = trim_highlight(&snippet, SUMMARY_HIGHLIGHT_CHARS);
1969        let clean_snippet = clean_text_for_llm(&snippet);
1970        format!("[Source #{}] {}", hit.rank, clean_snippet)
1971    }
1972
1973    /// Create a minimal summary when budget is very tight.
1974    /// Always fits within MICRO_SUMMARY_CHARS to ensure important hits are never dropped.
1975    fn render_micro_summary(hit: &SearchHit) -> String {
1976        let title = hit.title.clone().unwrap_or_else(|| "untitled".to_string());
1977        let title_truncated = clamp_to(&title, 40);
1978        // Format: "[#2: document.pdf] ..." - ~50-60 chars max
1979        format!("[#{}: {}] ...", hit.rank, title_truncated)
1980    }
1981
1982    /// Clean text to improve LLM comprehension
1983    fn clean_text_for_llm(text: &str) -> String {
1984        let mut result = text.to_string();
1985
1986        // Remove excessive whitespace while preserving paragraph structure
1987        result = result
1988            .lines()
1989            .map(|line| line.trim())
1990            .filter(|line| !line.is_empty())
1991            .collect::<Vec<_>>()
1992            .join("\n");
1993
1994        // Normalize unicode quotes and dashes (using string patterns for multi-byte chars)
1995        result = result
1996            .replace("\u{2018}", "'") // Left single quote '
1997            .replace("\u{2019}", "'") // Right single quote '
1998            .replace("\u{201C}", "\"") // Left double quote "
1999            .replace("\u{201D}", "\"") // Right double quote "
2000            .replace("\u{2013}", "-") // En dash –
2001            .replace("\u{2014}", "-"); // Em dash —
2002
2003        // Remove null bytes and other control characters
2004        result = result
2005            .chars()
2006            .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
2007            .collect();
2008
2009        result
2010    }
2011
2012    fn trim_highlight(text: &str, limit: usize) -> String {
2013        let clean = text.replace('\n', " ");
2014        clamp_to(&clean, limit)
2015    }
2016
2017    fn build_record(hit: &SearchHit, text: String, mode: ContextMode) -> ContextRecord {
2018        ContextRecord {
2019            rank: hit.rank,
2020            uri: hit.uri.clone(),
2021            title: hit.title.clone(),
2022            score: hit.score,
2023            matches: hit.matches,
2024            frame_id: hit.frame_id,
2025            range: hit.range,
2026            chunk_range: hit.chunk_range,
2027            text,
2028            mode,
2029        }
2030    }
2031}
2032
2033#[cfg(feature = "llama-cpp")]
2034mod tinyllama {
2035    use super::{ModelRunError, PromptParts, TINYLLAMA_LABEL, ThinkingSpinner};
2036    use anyhow::anyhow;
2037    use llama_cpp::standard_sampler::StandardSampler;
2038    use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
2039    use tokio::runtime::Builder;
2040
2041    use std::path::{Path, PathBuf};
2042
2043    const MODEL_DIR: &str = "models/tinyllama";
2044    const GGUF_HINT: &str = "*.gguf";
2045
2046    pub(super) fn run(prompt: &PromptParts) -> Result<String, ModelRunError> {
2047        let base_dir = Path::new(MODEL_DIR);
2048        let assets = RequiredAssets::new(base_dir);
2049
2050        if let Some(missing) = assets.missing_paths() {
2051            return Err(ModelRunError::AssetsMissing {
2052                model: TINYLLAMA_LABEL.to_string(),
2053                missing,
2054            });
2055        }
2056
2057        let gguf_path = assets.gguf_path.clone().ok_or_else(|| {
2058            ModelRunError::Runtime(anyhow!(
2059                "no GGUF model file found in {}",
2060                base_dir.display()
2061            ))
2062        })?;
2063
2064        unsafe {
2065            std::env::set_var("GGML_LOG_LEVEL", "ERROR");
2066            std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
2067        }
2068
2069        let model =
2070            LlamaModel::load_from_file(&gguf_path, LlamaParams::default()).map_err(|err| {
2071                ModelRunError::Runtime(anyhow!(
2072                    "failed to load TinyLlama weights from {}: {err}",
2073                    gguf_path.display()
2074                ))
2075            })?;
2076
2077        let mut session_params = SessionParams::default();
2078        if session_params.n_ctx == 0 {
2079            session_params.n_ctx = 2048;
2080        }
2081        session_params.n_batch = session_params.n_ctx.min(512);
2082        if session_params.n_ubatch == 0 {
2083            session_params.n_ubatch = 512;
2084        }
2085        let max_tokens = session_params.n_ctx as usize;
2086        let mut session = model.create_session(session_params).map_err(|err| {
2087            ModelRunError::Runtime(anyhow!("failed to create TinyLlama session: {err}"))
2088        })?;
2089
2090        let mut priming_tokens = model
2091            .tokenize_bytes(prompt.completion_prompt().as_bytes(), true, true)
2092            .map_err(|err| {
2093                ModelRunError::Runtime(anyhow!("failed to tokenize TinyLlama prompt: {err}"))
2094            })?;
2095
2096        let requested_tokens = prompt.max_output_tokens();
2097        if max_tokens > 0 {
2098            let reserved = requested_tokens + 64;
2099            if priming_tokens.len() >= max_tokens.saturating_sub(reserved) {
2100                let target = max_tokens.saturating_sub(reserved).max(1);
2101                let tail_start = priming_tokens.len().saturating_sub(target);
2102                priming_tokens = priming_tokens.split_off(tail_start);
2103            }
2104        }
2105
2106        session
2107            .advance_context_with_tokens(&priming_tokens)
2108            .map_err(|err| {
2109                ModelRunError::Runtime(anyhow!("failed to prime TinyLlama context: {err}"))
2110            })?;
2111
2112        let handle = session
2113            .start_completing_with(StandardSampler::default(), requested_tokens)
2114            .map_err(|err| ModelRunError::Runtime(anyhow!("completion failed to start: {err}")))?;
2115
2116        let runtime = Builder::new_current_thread()
2117            .enable_all()
2118            .build()
2119            .map_err(|err| {
2120                ModelRunError::Runtime(anyhow!("failed to build tokio runtime: {err}"))
2121            })?;
2122
2123        let mut spinner = ThinkingSpinner::start();
2124        let generated = runtime.block_on(async { handle.into_string_async().await });
2125        spinner.stop();
2126
2127        let answer = generated.trim().to_string();
2128
2129        if answer.is_empty() {
2130            Ok("No answer generated by TinyLlama.".to_string())
2131        } else {
2132            Ok(answer)
2133        }
2134    }
2135
2136    struct RequiredAssets {
2137        gguf_path: Option<PathBuf>,
2138        base_dir: PathBuf,
2139    }
2140
2141    impl RequiredAssets {
2142        fn new(base_dir: &Path) -> Self {
2143            let gguf_path = find_first_gguf(base_dir);
2144            Self {
2145                gguf_path,
2146                base_dir: base_dir.to_path_buf(),
2147            }
2148        }
2149
2150        fn missing_paths(&self) -> Option<Vec<PathBuf>> {
2151            if self.gguf_path.is_some() {
2152                None
2153            } else {
2154                Some(vec![self.base_dir.join(GGUF_HINT)])
2155            }
2156        }
2157    }
2158
2159    fn find_first_gguf(base_dir: &Path) -> Option<PathBuf> {
2160        let mut entries: Vec<PathBuf> = std::fs::read_dir(base_dir)
2161            .ok()?
2162            .filter_map(|entry| entry.ok().map(|e| e.path()))
2163            .filter(|path| path.is_file() && path.extension().map_or(false, |ext| ext == "gguf"))
2164            .collect();
2165        entries.sort();
2166        entries.into_iter().next()
2167    }
2168}
2169
2170mod ollama {
2171    use super::{ModelRunError, PromptParts, ThinkingSpinner};
2172    use anyhow::anyhow;
2173    use reqwest::blocking::Client;
2174    use serde::Deserialize;
2175    use serde_json::json;
2176
2177    const ENDPOINT: &str = "http://127.0.0.1:11434/api/generate";
2178
2179    pub(super) fn run(model: &str, prompt: &PromptParts) -> Result<String, ModelRunError> {
2180        let client = Client::builder()
2181            .timeout(std::time::Duration::from_secs(60))
2182            .build()
2183            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2184
2185        let mut spinner = ThinkingSpinner::start();
2186        let response = client
2187            .post(ENDPOINT)
2188            .json(&json!({
2189                "model": model,
2190                "prompt": prompt.completion_prompt(),
2191                "stream": false
2192            }))
2193            .send()
2194            .map_err(|err| ModelRunError::Runtime(anyhow!("ollama request failed: {err}")))?
2195            .error_for_status()
2196            .map_err(|err| {
2197                ModelRunError::Runtime(anyhow!("ollama returned error status: {err}"))
2198            })?;
2199
2200        let body: GenerateResponse = response.json().map_err(|err| {
2201            ModelRunError::Runtime(anyhow!("failed to decode ollama response: {err}"))
2202        })?;
2203        spinner.stop();
2204
2205        let text = body.response.trim().to_string();
2206        if text.is_empty() {
2207            Ok("No answer returned by Ollama.".to_string())
2208        } else {
2209            Ok(text)
2210        }
2211    }
2212
2213    #[derive(Debug, Deserialize)]
2214    struct GenerateResponse {
2215        #[serde(default)]
2216        response: String,
2217    }
2218}
2219
2220mod openai {
2221    use super::{
2222        ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2223        calculate_cost,
2224    };
2225    use anyhow::anyhow;
2226    use reqwest::blocking::Client;
2227    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2228    use serde::Deserialize;
2229    use serde_json::json;
2230
2231    const CHAT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
2232    const RESPONSES_ENDPOINT: &str = "https://api.openai.com/v1/responses";
2233
2234    pub(super) fn run(
2235        model: &str,
2236        prompt: &PromptParts,
2237        override_key: Option<&str>,
2238        system_prompt_override: Option<&str>,
2239    ) -> Result<ProviderResult, ModelRunError> {
2240        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2241        let key = override_key
2242            .map(|value| value.to_string())
2243            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
2244            .ok_or_else(|| {
2245                ModelRunError::Runtime(anyhow!(
2246                    "OPENAI_API_KEY environment variable is required for OpenAI models"
2247                ))
2248            })?;
2249
2250        let mut headers = HeaderMap::new();
2251        headers.insert(
2252            AUTHORIZATION,
2253            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2254                ModelRunError::Runtime(anyhow!("invalid OPENAI_API_KEY header value: {err}"))
2255            })?,
2256        );
2257        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2258
2259        let client = Client::builder()
2260            .no_proxy()
2261            .timeout(std::time::Duration::from_secs(60))
2262            .default_headers(headers)
2263            .build()
2264            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2265
2266        let mut spinner = ThinkingSpinner::start();
2267        let (text, usage) = if requires_responses_api(model) {
2268            let combined_prompt = format!(
2269                "System instructions:\n{}\n\nUser query:\n{}",
2270                system_prompt,
2271                prompt.user_message()
2272            );
2273            let payload = json!({
2274                "model": model,
2275                "input": combined_prompt,
2276                "max_output_tokens": prompt.max_output_tokens() as u32,
2277                "reasoning": {
2278                    "effort": "low"
2279                }
2280            });
2281
2282            let response = client
2283                .post(RESPONSES_ENDPOINT)
2284                .json(&payload)
2285                .send()
2286                .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
2287
2288            let status = response.status();
2289            if !status.is_success() {
2290                let body = response
2291                    .text()
2292                    .unwrap_or_else(|_| "<failed to read body>".to_string());
2293                return Err(ModelRunError::Runtime(anyhow!(
2294                    "OpenAI returned error status {status}: {body}"
2295                )));
2296            }
2297
2298            let body: ResponsesResponse = response.json().map_err(|err| {
2299                ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
2300            })?;
2301
2302            let usage = body.usage.as_ref().map(|u| {
2303                let input = u.input_tokens.unwrap_or(0);
2304                let output = u.output_tokens.unwrap_or(0);
2305                TokenUsage {
2306                    input_tokens: input,
2307                    output_tokens: output,
2308                    total_tokens: input + output,
2309                    cost_usd: calculate_cost(model, input, output),
2310                }
2311            });
2312            (extract_responses_text(&body), usage)
2313        } else {
2314            let payload = json!({
2315                "model": model,
2316                "messages": [
2317                    {"role": "system", "content": system_prompt},
2318                    {"role": "user", "content": prompt.user_message()}
2319                ],
2320                "temperature": 0.2,
2321                "max_tokens": prompt.max_output_tokens() as u32
2322            });
2323
2324            let response = client
2325                .post(CHAT_ENDPOINT)
2326                .json(&payload)
2327                .send()
2328                .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
2329
2330            let status = response.status();
2331            if !status.is_success() {
2332                let body = response
2333                    .text()
2334                    .unwrap_or_else(|_| "<failed to read body>".to_string());
2335                return Err(ModelRunError::Runtime(anyhow!(
2336                    "OpenAI returned error status {status}: {body}"
2337                )));
2338            }
2339
2340            let body: ChatResponse = response.json().map_err(|err| {
2341                ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
2342            })?;
2343
2344            let usage = body.usage.as_ref().map(|u| TokenUsage {
2345                input_tokens: u.prompt_tokens,
2346                output_tokens: u.completion_tokens,
2347                total_tokens: u.total_tokens,
2348                cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
2349            });
2350            (extract_chat_text(&body), usage)
2351        };
2352        spinner.stop();
2353        Ok(ProviderResult {
2354            answer: text,
2355            usage,
2356        })
2357    }
2358
2359    #[derive(Debug, Deserialize)]
2360    struct ChatResponse {
2361        choices: Vec<Choice>,
2362        #[serde(default)]
2363        usage: Option<ChatUsage>,
2364    }
2365
2366    #[derive(Debug, Deserialize)]
2367    struct ChatUsage {
2368        prompt_tokens: u32,
2369        completion_tokens: u32,
2370        total_tokens: u32,
2371    }
2372
2373    #[derive(Debug, Deserialize)]
2374    struct Choice {
2375        message: ChatMessage,
2376    }
2377
2378    #[derive(Debug, Deserialize)]
2379    struct ChatMessage {
2380        #[serde(default)]
2381        content: Option<String>,
2382    }
2383
2384    #[derive(Debug, Deserialize)]
2385    struct ResponsesResponse {
2386        #[serde(default)]
2387        output: Vec<ResponseItem>,
2388        #[serde(default)]
2389        output_text: Vec<String>,
2390        #[serde(default)]
2391        usage: Option<ResponsesUsage>,
2392    }
2393
2394    #[derive(Debug, Deserialize)]
2395    struct ResponsesUsage {
2396        #[serde(default)]
2397        input_tokens: Option<u32>,
2398        #[serde(default)]
2399        output_tokens: Option<u32>,
2400    }
2401
2402    #[derive(Debug, Deserialize)]
2403    struct ResponseItem {
2404        #[serde(default)]
2405        content: Vec<ResponseContent>,
2406    }
2407
2408    #[derive(Debug, Deserialize)]
2409    struct ResponseContent {
2410        #[serde(rename = "type")]
2411        kind: String,
2412        #[serde(default)]
2413        text: Option<String>,
2414    }
2415
2416    fn extract_chat_text(body: &ChatResponse) -> String {
2417        body.choices
2418            .iter()
2419            .find_map(|choice| choice.message.content.clone())
2420            .map(|value| value.trim().to_string())
2421            .unwrap_or_else(|| "No answer returned by OpenAI.".to_string())
2422    }
2423
2424    fn extract_responses_text(body: &ResponsesResponse) -> String {
2425        if !body.output_text.is_empty() {
2426            let text = body
2427                .output_text
2428                .iter()
2429                .find(|value| !value.trim().is_empty());
2430            if let Some(text) = text {
2431                return text.trim().to_string();
2432            }
2433        }
2434        for item in &body.output {
2435            for segment in &item.content {
2436                match segment.kind.as_str() {
2437                    "output_text" | "text" => {
2438                        if let Some(text) = &segment.text {
2439                            let trimmed = text.trim();
2440                            if !trimmed.is_empty() {
2441                                return trimmed.to_string();
2442                            }
2443                        }
2444                    }
2445                    _ => {}
2446                }
2447            }
2448        }
2449        "No answer returned by OpenAI.".to_string()
2450    }
2451
2452    fn requires_responses_api(model: &str) -> bool {
2453        let lowered = model.to_ascii_lowercase();
2454        lowered.starts_with("gpt-5") || lowered.contains("gpt-4.1")
2455    }
2456}
2457
2458mod nvidia {
2459    use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
2460    use anyhow::anyhow;
2461    use reqwest::blocking::Client;
2462    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2463    use serde::Deserialize;
2464    use serde_json::json;
2465
2466    pub(super) fn run(
2467        model: &str,
2468        prompt: &PromptParts,
2469        override_key: Option<&str>,
2470        system_prompt_override: Option<&str>,
2471    ) -> Result<String, ModelRunError> {
2472        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2473        let key = override_key
2474            .map(|value| value.to_string())
2475            .or_else(|| std::env::var("NVIDIA_API_KEY").ok())
2476            .ok_or_else(|| {
2477                ModelRunError::Runtime(anyhow!(
2478                    "NVIDIA_API_KEY environment variable is required for NVIDIA models"
2479                ))
2480            })?;
2481
2482        let model = model.trim();
2483        if model.is_empty() {
2484            return Err(ModelRunError::Runtime(anyhow!(
2485                "NVIDIA model name required. Use `nvidia:<model>` or set NVIDIA_LLM_MODEL."
2486            )));
2487        }
2488
2489        let base_url = std::env::var("NVIDIA_BASE_URL")
2490            .ok()
2491            .map(|value| value.trim().trim_end_matches('/').to_string())
2492            .filter(|value| !value.is_empty())
2493            .unwrap_or_else(|| "https://integrate.api.nvidia.com".to_string());
2494        let endpoint = format!("{base_url}/v1/chat/completions");
2495
2496        let mut headers = HeaderMap::new();
2497        headers.insert(
2498            AUTHORIZATION,
2499            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2500                ModelRunError::Runtime(anyhow!("invalid NVIDIA_API_KEY header value: {err}"))
2501            })?,
2502        );
2503        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2504
2505        let client = Client::builder()
2506            .timeout(std::time::Duration::from_secs(60))
2507            .default_headers(headers)
2508            .build()
2509            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2510
2511        let payload = json!({
2512            "model": model,
2513            "messages": [
2514                {"role": "system", "content": system_prompt},
2515                {"role": "user", "content": prompt.user_message()}
2516            ],
2517            "temperature": 0.2,
2518            "max_tokens": prompt.max_output_tokens() as u32
2519        });
2520
2521        let mut spinner = ThinkingSpinner::start();
2522        let response = client
2523            .post(endpoint)
2524            .json(&payload)
2525            .send()
2526            .map_err(|err| ModelRunError::Runtime(anyhow!("NVIDIA request failed: {err}")))?;
2527
2528        let status = response.status();
2529        if !status.is_success() {
2530            let body = response
2531                .text()
2532                .unwrap_or_else(|_| "<failed to read body>".to_string());
2533            spinner.stop();
2534            return Err(ModelRunError::Runtime(anyhow!(
2535                "NVIDIA returned error status {status}: {body}"
2536            )));
2537        }
2538
2539        let body: ChatResponse = response.json().map_err(|err| {
2540            ModelRunError::Runtime(anyhow!("failed to decode NVIDIA response: {err}"))
2541        })?;
2542        spinner.stop();
2543
2544        let text = body
2545            .choices
2546            .into_iter()
2547            .find_map(|choice| choice.message.content)
2548            .map(|value| value.trim().to_string())
2549            .unwrap_or_else(|| "No answer returned by NVIDIA.".to_string());
2550
2551        Ok(text)
2552    }
2553
2554    #[derive(Debug, Deserialize)]
2555    struct ChatResponse {
2556        choices: Vec<Choice>,
2557    }
2558
2559    #[derive(Debug, Deserialize)]
2560    struct Choice {
2561        message: ChatMessage,
2562    }
2563
2564    #[derive(Debug, Deserialize)]
2565    struct ChatMessage {
2566        #[serde(default)]
2567        content: Option<String>,
2568    }
2569}
2570
2571mod gemini {
2572    use super::{
2573        ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2574        calculate_cost,
2575    };
2576    use anyhow::anyhow;
2577    use reqwest::blocking::Client;
2578    use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
2579    use serde::Deserialize;
2580    use serde_json::json;
2581
2582    pub(super) fn run(
2583        model: &str,
2584        prompt: &PromptParts,
2585        override_key: Option<&str>,
2586        system_prompt_override: Option<&str>,
2587    ) -> Result<ProviderResult, ModelRunError> {
2588        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2589        let key = override_key
2590            .map(|value| value.to_string())
2591            .or_else(|| std::env::var("GEMINI_API_KEY").ok())
2592            .ok_or_else(|| {
2593                ModelRunError::Runtime(anyhow!(
2594                    "GEMINI_API_KEY environment variable is required for Gemini models"
2595                ))
2596            })?;
2597
2598        let url = format!(
2599            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
2600            model
2601        );
2602
2603        let mut headers = HeaderMap::new();
2604        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2605        headers.insert(
2606            HeaderName::from_static("x-goog-api-key"),
2607            HeaderValue::from_str(&key).map_err(|err| {
2608                ModelRunError::Runtime(anyhow!("invalid GEMINI_API_KEY header value: {err}"))
2609            })?,
2610        );
2611
2612        let client = Client::builder()
2613            .timeout(std::time::Duration::from_secs(60))
2614            .default_headers(headers)
2615            .build()
2616            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2617
2618        let payload = json!({
2619            "contents": [{
2620                "parts": [
2621                    { "text": system_prompt },
2622                    { "text": prompt.user_message() }
2623                ]
2624            }],
2625            "generationConfig": {
2626                "temperature": 0.2,
2627                "maxOutputTokens": prompt.max_output_tokens() as u32,
2628                "topK": 40,
2629                "topP": 0.95
2630            }
2631        });
2632
2633        let mut spinner = ThinkingSpinner::start();
2634        let response = client
2635            .post(url)
2636            .json(&payload)
2637            .send()
2638            .map_err(|err| ModelRunError::Runtime(anyhow!("Gemini request failed: {err}")))?
2639            .error_for_status()
2640            .map_err(|err| {
2641                ModelRunError::Runtime(anyhow!("Gemini returned error status: {err}"))
2642            })?;
2643
2644        let body: GenerateResponse = response.json().map_err(|err| {
2645            ModelRunError::Runtime(anyhow!("failed to decode Gemini response: {err}"))
2646        })?;
2647        spinner.stop();
2648
2649        let text = body
2650            .candidates
2651            .iter()
2652            .flat_map(|candidate| candidate.content.parts.iter())
2653            .find_map(|part| part.text.clone())
2654            .map(|value| value.trim().to_string())
2655            .unwrap_or_else(|| "No answer returned by Gemini.".to_string());
2656
2657        let usage = body.usage_metadata.as_ref().map(|u| {
2658            let input = u.prompt_token_count.unwrap_or(0);
2659            let output = u.candidates_token_count.unwrap_or(0);
2660            TokenUsage {
2661                input_tokens: input,
2662                output_tokens: output,
2663                total_tokens: input + output,
2664                cost_usd: calculate_cost(model, input, output),
2665            }
2666        });
2667
2668        Ok(ProviderResult {
2669            answer: text,
2670            usage,
2671        })
2672    }
2673
2674    #[derive(Debug, Deserialize)]
2675    struct GenerateResponse {
2676        candidates: Vec<Candidate>,
2677        #[serde(default, rename = "usageMetadata")]
2678        usage_metadata: Option<GeminiUsage>,
2679    }
2680
2681    #[derive(Debug, Deserialize)]
2682    struct GeminiUsage {
2683        #[serde(default, rename = "promptTokenCount")]
2684        prompt_token_count: Option<u32>,
2685        #[serde(default, rename = "candidatesTokenCount")]
2686        candidates_token_count: Option<u32>,
2687    }
2688
2689    #[derive(Debug, Deserialize)]
2690    struct Candidate {
2691        content: CandidateContent,
2692    }
2693
2694    #[derive(Debug, Deserialize)]
2695    struct CandidateContent {
2696        parts: Vec<CandidatePart>,
2697    }
2698
2699    #[derive(Debug, Deserialize)]
2700    struct CandidatePart {
2701        #[serde(default)]
2702        text: Option<String>,
2703    }
2704}
2705
2706mod claude {
2707    use super::{
2708        ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2709        calculate_cost,
2710    };
2711    use anyhow::anyhow;
2712    use reqwest::blocking::Client;
2713    use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
2714    use serde::Deserialize;
2715    use serde_json::json;
2716
2717    const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
2718    const API_VERSION: &str = "2023-06-01";
2719
2720    pub(super) fn run(
2721        model: &str,
2722        prompt: &PromptParts,
2723        override_key: Option<&str>,
2724        system_prompt_override: Option<&str>,
2725    ) -> Result<ProviderResult, ModelRunError> {
2726        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2727        let key = override_key
2728            .map(|value| value.to_string())
2729            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
2730            .or_else(|| std::env::var("CLAUDE_API_KEY").ok())
2731            .ok_or_else(|| {
2732                ModelRunError::Runtime(anyhow!(
2733                    "ANTHROPIC_API_KEY environment variable is required for Claude models"
2734                ))
2735            })?;
2736
2737        let mut headers = HeaderMap::new();
2738        headers.insert(
2739            HeaderName::from_static("x-api-key"),
2740            HeaderValue::from_str(&key).map_err(|err| {
2741                ModelRunError::Runtime(anyhow!("invalid ANTHROPIC_API_KEY header value: {err}"))
2742            })?,
2743        );
2744        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2745        headers.insert(
2746            HeaderName::from_static("anthropic-version"),
2747            HeaderValue::from_static(API_VERSION),
2748        );
2749
2750        let client = Client::builder()
2751            .timeout(std::time::Duration::from_secs(60))
2752            .default_headers(headers)
2753            .build()
2754            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2755
2756        let payload = json!({
2757            "model": model,
2758            "max_tokens": prompt.max_output_tokens() as u32,
2759            "temperature": 0.2,
2760            "system": system_prompt,
2761            "messages": [{
2762                "role": "user",
2763                "content": [{"type": "text", "text": prompt.user_message()}]
2764            }]
2765        });
2766
2767        let mut spinner = ThinkingSpinner::start();
2768        let response = client
2769            .post(ENDPOINT)
2770            .json(&payload)
2771            .send()
2772            .map_err(|err| ModelRunError::Runtime(anyhow!("Claude request failed: {err}")))?
2773            .error_for_status()
2774            .map_err(|err| {
2775                ModelRunError::Runtime(anyhow!("Claude returned error status: {err}"))
2776            })?;
2777
2778        let body: ClaudeResponse = response.json().map_err(|err| {
2779            ModelRunError::Runtime(anyhow!("failed to decode Claude response: {err}"))
2780        })?;
2781        spinner.stop();
2782
2783        let text = body
2784            .content
2785            .iter()
2786            .find_map(|part| match part {
2787                ContentBlock::Text { text } if !text.trim().is_empty() => {
2788                    Some(text.trim().to_string())
2789                }
2790                _ => None,
2791            })
2792            .unwrap_or_else(|| "No answer returned by Claude.".to_string());
2793
2794        let usage = body.usage.as_ref().map(|u| TokenUsage {
2795            input_tokens: u.input_tokens,
2796            output_tokens: u.output_tokens,
2797            total_tokens: u.input_tokens + u.output_tokens,
2798            cost_usd: calculate_cost(model, u.input_tokens, u.output_tokens),
2799        });
2800
2801        Ok(ProviderResult {
2802            answer: text,
2803            usage,
2804        })
2805    }
2806
2807    #[derive(Debug, Deserialize)]
2808    struct ClaudeResponse {
2809        #[serde(default)]
2810        content: Vec<ContentBlock>,
2811        #[serde(default)]
2812        usage: Option<ClaudeUsage>,
2813    }
2814
2815    #[derive(Debug, Deserialize)]
2816    struct ClaudeUsage {
2817        input_tokens: u32,
2818        output_tokens: u32,
2819    }
2820
2821    #[derive(Debug, Deserialize)]
2822    #[serde(tag = "type", rename_all = "lowercase")]
2823    enum ContentBlock {
2824        Text {
2825            text: String,
2826        },
2827        #[serde(other)]
2828        Other,
2829    }
2830}
2831
2832mod xai {
2833    use super::{
2834        ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2835        calculate_cost,
2836    };
2837    use anyhow::anyhow;
2838    use reqwest::blocking::Client;
2839    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2840    use serde::Deserialize;
2841    use serde_json::json;
2842
2843    const ENDPOINT: &str = "https://api.x.ai/v1/chat/completions";
2844
2845    pub(super) fn run(
2846        model: &str,
2847        prompt: &PromptParts,
2848        override_key: Option<&str>,
2849        system_prompt_override: Option<&str>,
2850    ) -> Result<ProviderResult, ModelRunError> {
2851        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2852        let key = override_key
2853            .map(|value| value.to_string())
2854            .or_else(|| std::env::var("XAI_API_KEY").ok())
2855            .or_else(|| std::env::var("GROK_API_KEY").ok())
2856            .ok_or_else(|| {
2857                ModelRunError::Runtime(anyhow!(
2858                    "XAI_API_KEY environment variable is required for Grok models"
2859                ))
2860            })?;
2861
2862        let mut headers = HeaderMap::new();
2863        headers.insert(
2864            AUTHORIZATION,
2865            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2866                ModelRunError::Runtime(anyhow!("invalid XAI_API_KEY header value: {err}"))
2867            })?,
2868        );
2869        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2870
2871        let client = Client::builder()
2872            .timeout(std::time::Duration::from_secs(120))
2873            .default_headers(headers)
2874            .build()
2875            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2876
2877        let payload = json!({
2878            "model": model,
2879            "max_tokens": prompt.max_output_tokens() as u32,
2880            "temperature": 0.2,
2881            "messages": [
2882                {"role": "system", "content": system_prompt},
2883                {"role": "user", "content": prompt.user_message()}
2884            ]
2885        });
2886
2887        let mut spinner = ThinkingSpinner::start();
2888        let response = client
2889            .post(ENDPOINT)
2890            .json(&payload)
2891            .send()
2892            .map_err(|err| ModelRunError::Runtime(anyhow!("xAI request failed: {err}")))?
2893            .error_for_status()
2894            .map_err(|err| ModelRunError::Runtime(anyhow!("xAI returned error status: {err}")))?;
2895
2896        let body: XaiResponse = response.json().map_err(|err| {
2897            ModelRunError::Runtime(anyhow!("failed to decode xAI response: {err}"))
2898        })?;
2899        spinner.stop();
2900
2901        let text = body
2902            .choices
2903            .first()
2904            .and_then(|c| c.message.content.as_ref())
2905            .map(|s| s.trim().to_string())
2906            .unwrap_or_else(|| "No answer returned by Grok.".to_string());
2907
2908        let usage = body.usage.as_ref().map(|u| TokenUsage {
2909            input_tokens: u.prompt_tokens,
2910            output_tokens: u.completion_tokens,
2911            total_tokens: u
2912                .total_tokens
2913                .unwrap_or(u.prompt_tokens + u.completion_tokens),
2914            cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
2915        });
2916
2917        Ok(ProviderResult {
2918            answer: text,
2919            usage,
2920        })
2921    }
2922
2923    #[derive(Debug, Deserialize)]
2924    struct XaiResponse {
2925        #[serde(default)]
2926        choices: Vec<XaiChoice>,
2927        #[serde(default)]
2928        usage: Option<XaiUsage>,
2929    }
2930
2931    #[derive(Debug, Deserialize)]
2932    struct XaiChoice {
2933        message: XaiMessage,
2934    }
2935
2936    #[derive(Debug, Deserialize)]
2937    struct XaiMessage {
2938        content: Option<String>,
2939    }
2940
2941    #[derive(Debug, Deserialize)]
2942    struct XaiUsage {
2943        prompt_tokens: u32,
2944        completion_tokens: u32,
2945        total_tokens: Option<u32>,
2946    }
2947}
2948
2949mod groq {
2950    use super::{
2951        ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
2952        calculate_cost,
2953    };
2954    use anyhow::anyhow;
2955    use reqwest::blocking::Client;
2956    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
2957    use serde::Deserialize;
2958    use serde_json::json;
2959
2960    const ENDPOINT: &str = "https://api.groq.com/openai/v1/chat/completions";
2961
2962    pub(super) fn run(
2963        model: &str,
2964        prompt: &PromptParts,
2965        override_key: Option<&str>,
2966        system_prompt_override: Option<&str>,
2967    ) -> Result<ProviderResult, ModelRunError> {
2968        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
2969        let key = override_key
2970            .map(|value| value.to_string())
2971            .or_else(|| std::env::var("GROQ_API_KEY").ok())
2972            .ok_or_else(|| {
2973                ModelRunError::Runtime(anyhow!(
2974                    "GROQ_API_KEY environment variable is required for Groq models"
2975                ))
2976            })?;
2977
2978        let mut headers = HeaderMap::new();
2979        headers.insert(
2980            AUTHORIZATION,
2981            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
2982                ModelRunError::Runtime(anyhow!("invalid GROQ_API_KEY header value: {err}"))
2983            })?,
2984        );
2985        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
2986
2987        let client = Client::builder()
2988            .timeout(std::time::Duration::from_secs(60))
2989            .default_headers(headers)
2990            .build()
2991            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
2992
2993        let payload = json!({
2994            "model": model,
2995            "max_tokens": prompt.max_output_tokens() as u32,
2996            "temperature": 0.2,
2997            "messages": [
2998                {"role": "system", "content": system_prompt},
2999                {"role": "user", "content": prompt.user_message()}
3000            ]
3001        });
3002
3003        let mut spinner = ThinkingSpinner::start();
3004        let response = client
3005            .post(ENDPOINT)
3006            .json(&payload)
3007            .send()
3008            .map_err(|err| ModelRunError::Runtime(anyhow!("Groq request failed: {err}")))?
3009            .error_for_status()
3010            .map_err(|err| ModelRunError::Runtime(anyhow!("Groq returned error status: {err}")))?;
3011
3012        let body: GroqResponse = response.json().map_err(|err| {
3013            ModelRunError::Runtime(anyhow!("failed to decode Groq response: {err}"))
3014        })?;
3015        spinner.stop();
3016
3017        let text = body
3018            .choices
3019            .first()
3020            .and_then(|c| c.message.content.as_ref())
3021            .map(|s| s.trim().to_string())
3022            .unwrap_or_else(|| "No answer returned by Groq.".to_string());
3023
3024        let usage = body.usage.as_ref().map(|u| TokenUsage {
3025            input_tokens: u.prompt_tokens,
3026            output_tokens: u.completion_tokens,
3027            total_tokens: u
3028                .total_tokens
3029                .unwrap_or(u.prompt_tokens + u.completion_tokens),
3030            cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
3031        });
3032
3033        Ok(ProviderResult {
3034            answer: text,
3035            usage,
3036        })
3037    }
3038
3039    #[derive(Debug, Deserialize)]
3040    struct GroqResponse {
3041        #[serde(default)]
3042        choices: Vec<GroqChoice>,
3043        #[serde(default)]
3044        usage: Option<GroqUsage>,
3045    }
3046
3047    #[derive(Debug, Deserialize)]
3048    struct GroqChoice {
3049        message: GroqMessage,
3050    }
3051
3052    #[derive(Debug, Deserialize)]
3053    struct GroqMessage {
3054        content: Option<String>,
3055    }
3056
3057    #[derive(Debug, Deserialize)]
3058    struct GroqUsage {
3059        prompt_tokens: u32,
3060        completion_tokens: u32,
3061        total_tokens: Option<u32>,
3062    }
3063}
3064
3065mod mistral {
3066    use super::{
3067        ModelRunError, PromptParts, ProviderResult, SYSTEM_PROMPT, ThinkingSpinner, TokenUsage,
3068        calculate_cost,
3069    };
3070    use anyhow::anyhow;
3071    use reqwest::blocking::Client;
3072    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
3073    use serde::Deserialize;
3074    use serde_json::json;
3075
3076    const ENDPOINT: &str = "https://api.mistral.ai/v1/chat/completions";
3077
3078    pub(super) fn run(
3079        model: &str,
3080        prompt: &PromptParts,
3081        override_key: Option<&str>,
3082        system_prompt_override: Option<&str>,
3083    ) -> Result<ProviderResult, ModelRunError> {
3084        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
3085        let key = override_key
3086            .map(|value| value.to_string())
3087            .or_else(|| std::env::var("MISTRAL_API_KEY").ok())
3088            .ok_or_else(|| {
3089                ModelRunError::Runtime(anyhow!(
3090                    "MISTRAL_API_KEY environment variable is required for Mistral models"
3091                ))
3092            })?;
3093
3094        let mut headers = HeaderMap::new();
3095        headers.insert(
3096            AUTHORIZATION,
3097            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
3098                ModelRunError::Runtime(anyhow!("invalid MISTRAL_API_KEY header value: {err}"))
3099            })?,
3100        );
3101        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
3102
3103        let client = Client::builder()
3104            .timeout(std::time::Duration::from_secs(60))
3105            .default_headers(headers)
3106            .build()
3107            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
3108
3109        let payload = json!({
3110            "model": model,
3111            "max_tokens": prompt.max_output_tokens() as u32,
3112            "temperature": 0.2,
3113            "messages": [
3114                {"role": "system", "content": system_prompt},
3115                {"role": "user", "content": prompt.user_message()}
3116            ]
3117        });
3118
3119        let mut spinner = ThinkingSpinner::start();
3120        let response = client
3121            .post(ENDPOINT)
3122            .json(&payload)
3123            .send()
3124            .map_err(|err| ModelRunError::Runtime(anyhow!("Mistral request failed: {err}")))?
3125            .error_for_status()
3126            .map_err(|err| {
3127                ModelRunError::Runtime(anyhow!("Mistral returned error status: {err}"))
3128            })?;
3129
3130        let body: MistralResponse = response.json().map_err(|err| {
3131            ModelRunError::Runtime(anyhow!("failed to decode Mistral response: {err}"))
3132        })?;
3133        spinner.stop();
3134
3135        let text = body
3136            .choices
3137            .first()
3138            .and_then(|c| c.message.content.as_ref())
3139            .map(|s| s.trim().to_string())
3140            .unwrap_or_else(|| "No answer returned by Mistral.".to_string());
3141
3142        let usage = body.usage.as_ref().map(|u| TokenUsage {
3143            input_tokens: u.prompt_tokens,
3144            output_tokens: u.completion_tokens,
3145            total_tokens: u
3146                .total_tokens
3147                .unwrap_or(u.prompt_tokens + u.completion_tokens),
3148            cost_usd: calculate_cost(model, u.prompt_tokens, u.completion_tokens),
3149        });
3150
3151        Ok(ProviderResult {
3152            answer: text,
3153            usage,
3154        })
3155    }
3156
3157    #[derive(Debug, Deserialize)]
3158    struct MistralResponse {
3159        #[serde(default)]
3160        choices: Vec<MistralChoice>,
3161        #[serde(default)]
3162        usage: Option<MistralUsage>,
3163    }
3164
3165    #[derive(Debug, Deserialize)]
3166    struct MistralChoice {
3167        message: MistralMessage,
3168    }
3169
3170    #[derive(Debug, Deserialize)]
3171    struct MistralMessage {
3172        content: Option<String>,
3173    }
3174
3175    #[derive(Debug, Deserialize)]
3176    struct MistralUsage {
3177        prompt_tokens: u32,
3178        completion_tokens: u32,
3179        total_tokens: Option<u32>,
3180    }
3181}
3182
3183// ============================================================================
3184// Entity Extraction API
3185// ============================================================================
3186
3187/// Default system prompt for entity extraction
3188pub const ENTITY_EXTRACTION_PROMPT: &str = r#"Extract named entities from the provided text. Return a JSON object with an "entities" array.
3189
3190Each entity should have:
3191- "name": The entity name as it appears in the text
3192- "type": One of "PERSON", "ORG", "LOCATION", "DATE", "PRODUCT", "EVENT", or "OTHER"
3193- "confidence": A number between 0.0 and 1.0 indicating your confidence
3194
3195Guidelines:
31961. Only include entities you're confident about (confidence >= 0.7)
31972. Preserve the original capitalization of entity names
31983. For organizations, include full names (e.g., "S&P Global" not just "S&P")
31994. For people, include full names when available
32005. Deduplicate: if an entity appears multiple times, include it only once
3201
3202Return format:
3203{"entities": [{"name": "...", "type": "...", "confidence": 0.9}, ...]}"#;
3204
3205/// Extracted entity from text
3206#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
3207pub struct ExtractedEntity {
3208    pub name: String,
3209    #[serde(rename = "type")]
3210    pub entity_type: String,
3211    pub confidence: f32,
3212}
3213
3214/// Response from entity extraction
3215#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
3216pub struct EntityExtractionResponse {
3217    pub entities: Vec<ExtractedEntity>,
3218    pub model: String,
3219    pub text_chars: usize,
3220}
3221
3222/// Extract entities from text using an LLM
3223///
3224/// # Arguments
3225/// * `model` - Model identifier (e.g., "openai:gpt-4o-mini", "claude:claude-3-5-sonnet")
3226/// * `text` - The text to extract entities from
3227/// * `system_prompt` - Optional custom system prompt (uses default if None)
3228/// * `api_key` - Optional API key (uses environment variable if None)
3229///
3230/// # Returns
3231/// An `EntityExtractionResponse` with the extracted entities
3232///
3233/// # Example
3234/// ```ignore
3235/// let response = extract_entities(
3236///     "openai:gpt-4o-mini",
3237///     "John Smith met with Microsoft CEO Satya Nadella in Seattle.",
3238///     None,  // use default prompt
3239///     None,  // use OPENAI_API_KEY env var
3240/// )?;
3241/// for entity in response.entities {
3242///     println!("{}: {} ({:.0}%)", entity.name, entity.entity_type, entity.confidence * 100.0);
3243/// }
3244/// ```
3245pub fn extract_entities(
3246    model: &str,
3247    text: &str,
3248    system_prompt: Option<&str>,
3249    api_key: Option<&str>,
3250) -> Result<EntityExtractionResponse, ModelRunError> {
3251    let prompt = system_prompt.unwrap_or(ENTITY_EXTRACTION_PROMPT);
3252    let text_chars = text.len();
3253
3254    // Determine model provider and make API call
3255    let (provider, model_name) = parse_model_spec(model);
3256
3257    let json_response = match provider.as_str() {
3258        "openai" => extract_entities_openai(&model_name, text, prompt, api_key)?,
3259        "claude" | "anthropic" => extract_entities_claude(&model_name, text, prompt, api_key)?,
3260        "gemini" | "google" => extract_entities_gemini(&model_name, text, prompt, api_key)?,
3261        _ => {
3262            return Err(ModelRunError::UnsupportedModel(format!(
3263                "Entity extraction not supported for provider '{}'. Use openai:, claude:, or gemini:",
3264                provider
3265            )));
3266        }
3267    };
3268
3269    // Parse the JSON response
3270    let entities = parse_entity_response(&json_response)?;
3271
3272    Ok(EntityExtractionResponse {
3273        entities,
3274        model: model.to_string(),
3275        text_chars,
3276    })
3277}
3278
3279fn parse_model_spec(model: &str) -> (String, String) {
3280    if let Some((provider, name)) = model.split_once(':') {
3281        (provider.to_lowercase(), name.to_string())
3282    } else {
3283        // Default to OpenAI if no provider specified
3284        ("openai".to_string(), model.to_string())
3285    }
3286}
3287
3288fn parse_entity_response(json_str: &str) -> Result<Vec<ExtractedEntity>, ModelRunError> {
3289    // Try to parse the response, handling various formats
3290    let trimmed = json_str.trim();
3291
3292    // Handle markdown code blocks
3293    let clean_json = if trimmed.starts_with("```json") {
3294        trimmed
3295            .strip_prefix("```json")
3296            .and_then(|s| s.strip_suffix("```"))
3297            .unwrap_or(trimmed)
3298            .trim()
3299    } else if trimmed.starts_with("```") {
3300        trimmed
3301            .strip_prefix("```")
3302            .and_then(|s| s.strip_suffix("```"))
3303            .unwrap_or(trimmed)
3304            .trim()
3305    } else {
3306        trimmed
3307    };
3308
3309    // Try parsing as {"entities": [...]}
3310    #[derive(serde::Deserialize)]
3311    struct EntityResponse {
3312        entities: Vec<ExtractedEntity>,
3313    }
3314
3315    if let Ok(response) = serde_json::from_str::<EntityResponse>(clean_json) {
3316        return Ok(response.entities);
3317    }
3318
3319    // Try parsing as a direct array [...]
3320    if let Ok(entities) = serde_json::from_str::<Vec<ExtractedEntity>>(clean_json) {
3321        return Ok(entities);
3322    }
3323
3324    Err(ModelRunError::Runtime(anyhow::anyhow!(
3325        "Failed to parse entity extraction response as JSON: {}",
3326        &clean_json[..clean_json.len().min(200)]
3327    )))
3328}
3329
3330fn extract_entities_openai(
3331    model: &str,
3332    text: &str,
3333    system_prompt: &str,
3334    api_key: Option<&str>,
3335) -> Result<String, ModelRunError> {
3336    use serde_json::json;
3337
3338    let api_key = api_key
3339        .map(|s| s.to_string())
3340        .or_else(|| std::env::var("OPENAI_API_KEY").ok())
3341        .ok_or_else(|| {
3342            ModelRunError::Runtime(anyhow::anyhow!(
3343                "OpenAI API key required. Set OPENAI_API_KEY or pass api_key parameter."
3344            ))
3345        })?;
3346
3347    let model_name = if model.is_empty() {
3348        "gpt-4o-mini"
3349    } else {
3350        model
3351    };
3352
3353    let client = reqwest::blocking::Client::builder()
3354        .no_proxy()
3355        .build()
3356        .map_err(|err| {
3357            ModelRunError::Runtime(anyhow::anyhow!("failed to build HTTP client: {err}"))
3358        })?;
3359    let payload = json!({
3360        "model": model_name,
3361        "messages": [
3362            {"role": "system", "content": system_prompt},
3363            {"role": "user", "content": text}
3364        ],
3365        "response_format": {"type": "json_object"},
3366        "temperature": 0.1
3367    });
3368
3369    let response = client
3370        .post("https://api.openai.com/v1/chat/completions")
3371        .header("Authorization", format!("Bearer {}", api_key))
3372        .json(&payload)
3373        .send()
3374        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI request failed: {}", e)))?
3375        .error_for_status()
3376        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI returned error: {}", e)))?;
3377
3378    #[derive(serde::Deserialize)]
3379    struct OpenAIResponse {
3380        choices: Vec<OpenAIChoice>,
3381    }
3382    #[derive(serde::Deserialize)]
3383    struct OpenAIChoice {
3384        message: OpenAIMessage,
3385    }
3386    #[derive(serde::Deserialize)]
3387    struct OpenAIMessage {
3388        content: String,
3389    }
3390
3391    let body: OpenAIResponse = response.json().map_err(|e| {
3392        ModelRunError::Runtime(anyhow::anyhow!("Failed to parse OpenAI response: {}", e))
3393    })?;
3394
3395    body.choices
3396        .into_iter()
3397        .next()
3398        .map(|c| c.message.content)
3399        .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No response from OpenAI")))
3400}
3401
3402fn extract_entities_claude(
3403    model: &str,
3404    text: &str,
3405    system_prompt: &str,
3406    api_key: Option<&str>,
3407) -> Result<String, ModelRunError> {
3408    use serde_json::json;
3409
3410    let api_key = api_key
3411        .map(|s| s.to_string())
3412        .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
3413        .ok_or_else(|| {
3414            ModelRunError::Runtime(anyhow::anyhow!(
3415                "Anthropic API key required. Set ANTHROPIC_API_KEY or pass api_key parameter."
3416            ))
3417        })?;
3418
3419    let model_name = if model.is_empty() {
3420        "claude-3-5-sonnet-20241022"
3421    } else {
3422        model
3423    };
3424
3425    let client = reqwest::blocking::Client::builder()
3426        .no_proxy()
3427        .build()
3428        .map_err(|err| {
3429            ModelRunError::Runtime(anyhow::anyhow!("failed to build HTTP client: {err}"))
3430        })?;
3431    let payload = json!({
3432        "model": model_name,
3433        "max_tokens": 4096,
3434        "system": format!("{}\n\nRespond with valid JSON only.", system_prompt),
3435        "messages": [
3436            {"role": "user", "content": text}
3437        ]
3438    });
3439
3440    let response = client
3441        .post("https://api.anthropic.com/v1/messages")
3442        .header("x-api-key", &api_key)
3443        .header("anthropic-version", "2023-06-01")
3444        .header("content-type", "application/json")
3445        .json(&payload)
3446        .send()
3447        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude request failed: {}", e)))?
3448        .error_for_status()
3449        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude returned error: {}", e)))?;
3450
3451    #[derive(serde::Deserialize)]
3452    struct ClaudeResponse {
3453        content: Vec<ClaudeContent>,
3454    }
3455    #[derive(serde::Deserialize)]
3456    struct ClaudeContent {
3457        text: Option<String>,
3458    }
3459
3460    let body: ClaudeResponse = response.json().map_err(|e| {
3461        ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Claude response: {}", e))
3462    })?;
3463
3464    body.content
3465        .into_iter()
3466        .find_map(|c| c.text)
3467        .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Claude")))
3468}
3469
3470fn extract_entities_gemini(
3471    model: &str,
3472    text: &str,
3473    system_prompt: &str,
3474    api_key: Option<&str>,
3475) -> Result<String, ModelRunError> {
3476    use serde_json::json;
3477
3478    let api_key = api_key
3479        .map(|s| s.to_string())
3480        .or_else(|| std::env::var("GEMINI_API_KEY").ok())
3481        .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
3482        .ok_or_else(|| {
3483            ModelRunError::Runtime(anyhow::anyhow!(
3484                "Gemini API key required. Set GEMINI_API_KEY or pass api_key parameter."
3485            ))
3486        })?;
3487
3488    let model_name = if model.is_empty() {
3489        "gemini-2.0-flash"
3490    } else {
3491        model
3492    };
3493    let url = format!(
3494        "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
3495        model_name, api_key
3496    );
3497
3498    let client = reqwest::blocking::Client::builder()
3499        .no_proxy()
3500        .build()
3501        .map_err(|err| {
3502            ModelRunError::Runtime(anyhow::anyhow!("failed to build HTTP client: {err}"))
3503        })?;
3504    let payload = json!({
3505        "contents": [{
3506            "parts": [{"text": format!("{}\n\nText to analyze:\n{}", system_prompt, text)}]
3507        }],
3508        "generationConfig": {
3509            "temperature": 0.1,
3510            "responseMimeType": "application/json"
3511        }
3512    });
3513
3514    let response = client
3515        .post(&url)
3516        .json(&payload)
3517        .send()
3518        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini request failed: {}", e)))?
3519        .error_for_status()
3520        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini returned error: {}", e)))?;
3521
3522    #[derive(serde::Deserialize)]
3523    struct GeminiResponse {
3524        candidates: Vec<GeminiCandidate>,
3525    }
3526    #[derive(serde::Deserialize)]
3527    struct GeminiCandidate {
3528        content: GeminiContent,
3529    }
3530    #[derive(serde::Deserialize)]
3531    struct GeminiContent {
3532        parts: Vec<GeminiPart>,
3533    }
3534    #[derive(serde::Deserialize)]
3535    struct GeminiPart {
3536        text: Option<String>,
3537    }
3538
3539    let body: GeminiResponse = response.json().map_err(|e| {
3540        ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Gemini response: {}", e))
3541    })?;
3542
3543    body.candidates
3544        .into_iter()
3545        .next()
3546        .and_then(|c| c.content.parts.into_iter().find_map(|p| p.text))
3547        .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Gemini")))
3548}
3549
3550#[cfg(test)]
3551mod tests {
3552    use super::*;
3553
3554    static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
3555
3556    #[test]
3557    fn normalize_models() {
3558        assert_eq!(normalize_openai_model(None), "gpt-4o-mini");
3559        assert_eq!(
3560            normalize_nvidia_model(Some("meta/llama3-8b-instruct".to_string())),
3561            "meta/llama3-8b-instruct"
3562        );
3563        let _lock = ENV_LOCK.lock().unwrap();
3564        unsafe {
3565            std::env::remove_var("NVIDIA_LLM_MODEL");
3566            std::env::remove_var("NVIDIA_MODEL");
3567        }
3568        assert_eq!(normalize_nvidia_model(None), "");
3569        assert_eq!(normalize_gemini_model(None), "gemini-2.5-flash");
3570        assert_eq!(normalize_claude_model(None), "claude-sonnet-4-5");
3571        assert_eq!(normalize_xai_model(None), "grok-4-fast");
3572        assert_eq!(normalize_groq_model(None), "llama-3.3-70b-versatile");
3573        assert_eq!(normalize_mistral_model(None), "mistral-large-latest");
3574    }
3575
3576    #[test]
3577    fn parse_entity_json() {
3578        let json = r#"{"entities": [{"name": "John", "type": "PERSON", "confidence": 0.95}]}"#;
3579        let entities = parse_entity_response(json).unwrap();
3580        assert_eq!(entities.len(), 1);
3581        assert_eq!(entities[0].name, "John");
3582    }
3583
3584    #[test]
3585    fn parse_entity_json_with_markdown() {
3586        let json = r#"```json
3587{"entities": [{"name": "Microsoft", "type": "ORG", "confidence": 0.99}]}
3588```"#;
3589        let entities = parse_entity_response(json).unwrap();
3590        assert_eq!(entities.len(), 1);
3591        assert_eq!(entities[0].name, "Microsoft");
3592    }
3593
3594    #[test]
3595    fn parse_model_spec_test() {
3596        let (provider, model) = parse_model_spec("openai:gpt-4o");
3597        assert_eq!(provider, "openai");
3598        assert_eq!(model, "gpt-4o");
3599
3600        let (provider, model) = parse_model_spec("gpt-4o-mini");
3601        assert_eq!(provider, "openai");
3602        assert_eq!(model, "gpt-4o-mini");
3603    }
3604
3605    #[test]
3606    fn modelkind_parses_ghost_spec() {
3607        let kind = ModelKind::parse("ghost:/tmp/model.ghostpack").unwrap();
3608        match kind {
3609            ModelKind::Ghost { pack_path } => {
3610                assert_eq!(pack_path, PathBuf::from("/tmp/model.ghostpack"));
3611            }
3612            other => panic!("expected ghost, got {other:?}"),
3613        }
3614    }
3615
3616    #[test]
3617    fn run_model_inference_ghost_returns_unsupported() {
3618        let hit = SearchHit {
3619            rank: 0,
3620            frame_id: 0,
3621            uri: "mv2://test".to_string(),
3622            title: Some("Test".to_string()),
3623            range: (0, 3),
3624            text: "ctx".to_string(),
3625            matches: 1,
3626            chunk_range: None,
3627            chunk_text: None,
3628            score: Some(1.0),
3629            metadata: None,
3630        };
3631
3632        let err = run_model_inference(
3633            "ghost:/tmp/fake.ghostpack",
3634            "hello?",
3635            "",
3636            &[hit],
3637            None,
3638            None,
3639            None,
3640        )
3641        .unwrap_err();
3642        let msg = err.to_string();
3643        assert!(
3644            msg.contains("ghost") || msg.contains("Ghost"),
3645            "error should mention ghost: {msg}"
3646        );
3647    }
3648
3649    #[test]
3650    fn normalize_question_adds_question_mark() {
3651        // Should add ? to questions without punctuation
3652        // Note: abbreviation expansion may also occur (IRR -> IRR (internal rate of return))
3653        let result = normalize_question("how much is the LP rate");
3654        assert!(result.ends_with('?'), "should end with ?");
3655
3656        assert_eq!(
3657            normalize_question("what is the total revenue"),
3658            "what is the total revenue?"
3659        );
3660        assert_eq!(
3661            normalize_question("where does John live"),
3662            "where does John live?"
3663        );
3664        assert_eq!(normalize_question("is this correct"), "is this correct?");
3665        assert_eq!(normalize_question("can you help me"), "can you help me?");
3666    }
3667
3668    #[test]
3669    fn normalize_question_preserves_existing_punctuation() {
3670        // Should NOT modify queries that already have punctuation
3671        assert_eq!(normalize_question("how much is X?"), "how much is X?");
3672        assert_eq!(
3673            normalize_question("Tell me about the project."),
3674            "Tell me about the project."
3675        );
3676        assert_eq!(normalize_question("Do it now!"), "Do it now!");
3677    }
3678
3679    #[test]
3680    fn normalize_question_ignores_non_questions() {
3681        // Should NOT add ? to non-question statements
3682        assert_eq!(
3683            normalize_question("revenue for Q1 2024"),
3684            "revenue for Q1 2024"
3685        );
3686        assert_eq!(normalize_question("total sales"), "total sales");
3687        // Should not match partial words
3688        assert_eq!(
3689            normalize_question("howitzer specifications"),
3690            "howitzer specifications"
3691        );
3692    }
3693
3694    #[test]
3695    fn normalize_question_handles_edge_cases() {
3696        assert_eq!(normalize_question(""), "");
3697        assert_eq!(normalize_question("  "), "");
3698        // Note: typo correction and expansion happen, so result may differ
3699        let result = normalize_question("  how much  ");
3700        assert!(result.ends_with('?'), "should end with ?");
3701    }
3702
3703    #[test]
3704    fn fix_typos_corrects_common_errors() {
3705        assert!(fix_common_typos("teh quick brown fox").contains("the"));
3706        assert!(fix_common_typos("waht is this").contains("what"));
3707        assert!(fix_common_typos("totla revenue").contains("total"));
3708    }
3709
3710    #[test]
3711    fn expand_abbreviations_works() {
3712        // Test that abbreviations are expanded
3713        let result = expand_abbreviations("what is the irr");
3714        assert!(result.contains("internal rate of return") || result.contains("irr"));
3715    }
3716
3717    #[test]
3718    fn question_type_detection() {
3719        assert_eq!(
3720            detect_question_type("how much is X?"),
3721            QuestionType::Numeric
3722        );
3723        assert_eq!(
3724            detect_question_type("is this correct?"),
3725            QuestionType::YesNo
3726        );
3727        assert_eq!(detect_question_type("list all items"), QuestionType::List);
3728        assert_eq!(
3729            detect_question_type("when was it created?"),
3730            QuestionType::Temporal
3731        );
3732        assert_eq!(
3733            detect_question_type("why did this happen?"),
3734            QuestionType::Explanation
3735        );
3736        assert_eq!(
3737            detect_question_type("what is the name?"),
3738            QuestionType::Factual
3739        );
3740    }
3741
3742    #[test]
3743    fn postprocess_removes_artifacts() {
3744        let answer = "Based on the provided context, the value is 42.";
3745        let processed = postprocess_answer(answer);
3746        assert!(!processed.starts_with("Based on"));
3747        assert!(processed.contains("42"));
3748    }
3749
3750    #[test]
3751    fn postprocess_capitalizes() {
3752        let answer = "the answer is yes";
3753        let processed = postprocess_answer(answer);
3754        assert!(processed.starts_with('T'), "should start with capital T");
3755    }
3756
3757    #[test]
3758    fn postprocess_normalizes_whitespace() {
3759        let answer = "too    many     spaces    here";
3760        let processed = postprocess_answer(answer);
3761        assert!(!processed.contains("  "), "should not have double spaces");
3762    }
3763}