memvid_ask_model/
lib.rs

1use std::fmt;
2use std::io::{IsTerminal, Write, stderr};
3use std::path::PathBuf;
4use std::sync::{
5    Arc,
6    atomic::{AtomicBool, Ordering},
7};
8use std::thread;
9use std::time::Duration;
10
11use memvid_core::types::SearchHit;
12
13#[derive(Debug, Clone)]
14pub struct ModelAnswer {
15    pub requested: String,
16    pub model: String,
17    pub answer: String,
18}
19
20#[derive(Debug, Clone)]
21pub struct ModelInference {
22    pub answer: ModelAnswer,
23    pub context_body: String,
24    pub context_fragments: Vec<ModelContextFragment>,
25}
26
27#[derive(Debug, Clone)]
28pub struct ModelContextFragment {
29    pub rank: usize,
30    pub uri: String,
31    pub title: Option<String>,
32    pub score: Option<f32>,
33    pub matches: usize,
34    pub frame_id: u64,
35    pub range: (usize, usize),
36    pub chunk_range: Option<(usize, usize)>,
37    pub text: String,
38    pub kind: ModelContextFragmentKind,
39}
40
41#[derive(Debug, Clone, Copy, Eq, PartialEq)]
42pub enum ModelContextFragmentKind {
43    Full,
44    Summary,
45}
46
47impl ModelContextFragment {
48    fn from_record(record: context::ContextRecord) -> Self {
49        let kind = match record.mode {
50            context::ContextMode::Full => ModelContextFragmentKind::Full,
51            context::ContextMode::Summary => ModelContextFragmentKind::Summary,
52        };
53        Self {
54            rank: record.rank,
55            uri: record.uri,
56            title: record.title,
57            score: record.score,
58            matches: record.matches,
59            frame_id: record.frame_id,
60            range: record.range,
61            chunk_range: record.chunk_range,
62            text: record.text,
63            kind,
64        }
65    }
66}
67
68#[derive(Debug)]
69pub enum ModelRunError {
70    UnsupportedModel(String),
71    AssetsMissing {
72        model: String,
73        missing: Vec<PathBuf>,
74    },
75    Runtime(anyhow::Error),
76}
77
78impl fmt::Display for ModelRunError {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Self::UnsupportedModel(model) => write!(f, "unsupported model '{model}'"),
82            Self::AssetsMissing { model, missing } => {
83                let paths: Vec<_> = missing
84                    .iter()
85                    .map(|path| path.display().to_string())
86                    .collect();
87                write!(
88                    f,
89                    "model '{model}' missing required assets: {}",
90                    paths.join(", ")
91                )
92            }
93            Self::Runtime(err) => write!(f, "model runtime error: {err}"),
94        }
95    }
96}
97
98impl std::error::Error for ModelRunError {
99    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
100        match self {
101            Self::Runtime(err) => Some(err.root_cause()),
102            _ => None,
103        }
104    }
105}
106
107const LOCAL_CONTEXT_CHARS: usize = 32_768;
108const MAX_QUESTION_CHARS: usize = 512;
109const LOCAL_MAX_OUTPUT_TOKENS: usize = 256;
110const REMOTE_MAX_OUTPUT_TOKENS: usize = 768;
111const SYSTEM_PROMPT: &str = "You are a helpful assistant that answers questions based on the provided context.\n\nGuidelines:\n1. Read all the provided context carefully before answering.\n2. PAY CLOSE ATTENTION TO DATES: When context has timestamps, note that later dates reflect current/updated values.\n3. For temporal questions (e.g., \"How many X when started vs now?\"), find both earlier and later values. The most recent reflects the current state.\n4. For counting questions (e.g., \"How many times...\"), count occurrences in the context.\n5. Only say \"not enough information\" if the context truly contains no relevant information.\n6. Base all answers on the context - do not use external knowledge.";
112const TINYLLAMA_LABEL: &str = "tinyllama-1.1b";
113const LOCAL_PROMPT_MARGIN_CHARS: usize = 2_048;
114const REMOTE_PROMPT_MARGIN_CHARS: usize = 4_096;
115const OLLAMA_PROMPT_CHARS: usize = 110_000;
116const OPENAI_PROMPT_CHARS: usize = 240_000;
117const GEMINI_PROMPT_CHARS: usize = 320_000;
118const CLAUDE_PROMPT_CHARS: usize = 360_000;
119
120#[derive(Debug, Clone, Copy)]
121struct ModelContextBudget {
122    total_chars: usize,
123    reserved_chars: usize,
124}
125
126impl ModelContextBudget {
127    const fn new(total_chars: usize, reserved_chars: usize) -> Self {
128        Self {
129            total_chars,
130            reserved_chars,
131        }
132    }
133
134    fn context_chars(&self) -> usize {
135        self.total_chars.saturating_sub(self.reserved_chars)
136    }
137
138    fn question_limit(&self) -> usize {
139        MAX_QUESTION_CHARS
140            .min(self.reserved_chars.max(1))
141            .min(self.total_chars.max(1))
142    }
143
144    fn apply_override(self, override_context_chars: usize) -> Self {
145        let total = override_context_chars.saturating_add(self.reserved_chars);
146        Self {
147            total_chars: total.max(self.reserved_chars + 1),
148            reserved_chars: self.reserved_chars,
149        }
150    }
151
152    fn prompt_ceiling(&self) -> usize {
153        self.total_chars
154    }
155}
156
157pub struct PromptParts {
158    completion_prompt: String,
159    user_message: String,
160    max_output_tokens: usize,
161}
162
163impl PromptParts {
164    pub fn completion_prompt(&self) -> &str {
165        &self.completion_prompt
166    }
167
168    pub fn user_message(&self) -> &str {
169        &self.user_message
170    }
171
172    pub fn max_output_tokens(&self) -> usize {
173        self.max_output_tokens
174    }
175}
176
177fn build_prompt_parts(
178    question: &str,
179    context: &str,
180    budget: &ModelContextBudget,
181    max_output_tokens: usize,
182) -> PromptParts {
183    let mut context_section = context.to_string();
184    let trimmed_question = trim_to(question, budget.question_limit());
185
186    let system_section = format!("### System\n{SYSTEM_PROMPT}");
187    let question_section = format!("### Question\n{trimmed_question}");
188    let answer_stub = "### Answer\n";
189
190    let overhead = system_section.len() + 2 + question_section.len() + 2 + answer_stub.len();
191    if budget.prompt_ceiling() > overhead {
192        let max_context_len = budget
193            .prompt_ceiling()
194            .saturating_sub(overhead)
195            .min(budget.context_chars());
196        if context_section.len() > max_context_len {
197            context_section = clamp_to(&context_section, max_context_len);
198        }
199    } else {
200        context_section = String::new();
201    }
202
203    let completion_prompt =
204        format!("{system_section}\n\n{context_section}\n\n{question_section}\n\n### Answer\n");
205
206    let user_message = format!(
207        "{context_section}\n\nQuestion:\n{trimmed_question}\n\nRespond concisely using only information from the retrieval context."
208    );
209
210    PromptParts {
211        completion_prompt,
212        user_message,
213        max_output_tokens,
214    }
215}
216
217fn trim_to(text: &str, limit: usize) -> String {
218    if text.len() <= limit {
219        text.to_string()
220    } else {
221        let mut truncated = text[..limit].to_string();
222        truncated.push_str("...");
223        truncated
224    }
225}
226
227fn clamp_to(text: &str, limit: usize) -> String {
228    if text.len() <= limit {
229        text.to_string()
230    } else if limit <= 3 {
231        "...".chars().take(limit).collect()
232    } else {
233        let end = limit.saturating_sub(3);
234        let mut truncated = text[..end].to_string();
235        truncated.push_str("...");
236        truncated
237    }
238}
239
240struct ThinkingSpinner {
241    flag: Arc<AtomicBool>,
242    handle: Option<thread::JoinHandle<()>>,
243}
244
245impl ThinkingSpinner {
246    fn start() -> Self {
247        let flag = Arc::new(AtomicBool::new(true));
248        let thread_flag = flag.clone();
249
250        // Only show spinner if stderr is a TTY (interactive terminal).
251        // This prevents control characters from polluting output when
252        // stderr is redirected or combined with stdout (e.g., `2>&1`).
253        let is_tty = stderr().is_terminal();
254
255        let handle = thread::spawn(move || {
256            if !is_tty {
257                // Not a TTY, don't show spinner - just wait for stop signal
258                while thread_flag.load(Ordering::Relaxed) {
259                    thread::sleep(Duration::from_millis(200));
260                }
261                return;
262            }
263
264            let frames = [
265                "Thinking    ",
266                "Thinking.   ",
267                "Thinking..  ",
268                "Thinking... ",
269                "Thinking .. ",
270                "Thinking  . ",
271            ];
272            let mut idx = 0;
273            let mut err = stderr();
274            while thread_flag.load(Ordering::Relaxed) {
275                let frame = frames[idx % frames.len()];
276                let _ = write!(err, "\r{frame}");
277                let _ = err.flush();
278                idx = idx.wrapping_add(1);
279                thread::sleep(Duration::from_millis(200));
280            }
281            let _ = write!(err, "\r             \r");
282            let _ = err.flush();
283        });
284
285        Self {
286            flag,
287            handle: Some(handle),
288        }
289    }
290
291    fn stop(&mut self) {
292        if let Some(handle) = self.handle.take() {
293            self.flag.store(false, Ordering::Relaxed);
294            let _ = handle.join();
295        }
296    }
297}
298
299impl Drop for ThinkingSpinner {
300    fn drop(&mut self) {
301        self.stop();
302    }
303}
304
305#[derive(Debug, Clone)]
306enum ModelKind {
307    TinyLlama,
308    Ollama { model: String },
309    OpenAi { model: String },
310    Gemini { model: String },
311    Claude { model: String },
312}
313
314impl ModelKind {
315    fn parse(raw: &str) -> Option<Self> {
316        let trimmed = raw.trim();
317        if trimmed.is_empty() {
318            return None;
319        }
320
321        let (provider, explicit_model) = if let Some((p, rest)) = trimmed.split_once(':') {
322            let value = rest.trim();
323            let explicit = if value.is_empty() {
324                None
325            } else {
326                Some(value.to_string())
327            };
328            (p.trim().to_ascii_lowercase(), explicit)
329        } else {
330            (trimmed.to_ascii_lowercase(), None)
331        };
332
333        match provider.as_str() {
334            "tinyllama" | "tiny-llama" | "tinyllama-1.1b" => Some(Self::TinyLlama),
335            "ollama" => Some(Self::Ollama {
336                model: explicit_model.unwrap_or_else(|| "ollama1.5".to_string()),
337            }),
338            "ollama1.5" | "ollama1-5" => Some(Self::Ollama {
339                model: "ollama1.5".to_string(),
340            }),
341            "openai" => Some(Self::OpenAi {
342                model: normalize_openai_model(explicit_model),
343            }),
344            "gemini" => Some(Self::Gemini {
345                model: normalize_gemini_model(explicit_model),
346            }),
347            "claude" | "anthropic" => Some(Self::Claude {
348                model: normalize_claude_model(explicit_model),
349            }),
350            // Auto-detect provider from model name prefix
351            // For Ollama models with colons in the name (e.g., qwen2.5:1.5b),
352            // we need to use the full original name, not just the provider prefix
353            _ => Self::infer_from_model_name_full(trimmed, &provider),
354        }
355    }
356
357    /// Infer the provider from a model name, using the full original name for Ollama models.
358    /// This handles model names with colons like "qwen2.5:1.5b" by using the full name.
359    fn infer_from_model_name_full(full_name: &str, prefix: &str) -> Option<Self> {
360        let lowered = prefix.to_ascii_lowercase();
361
362        // Gemini models: gemini-*, models/gemini-*
363        if lowered.starts_with("gemini") || lowered.starts_with("models/gemini") {
364            return Some(Self::Gemini {
365                model: full_name.to_string(),
366            });
367        }
368
369        // OpenAI models: gpt-*, o1-*, chatgpt-*, text-davinci-*, etc.
370        if lowered.starts_with("gpt-")
371            || lowered.starts_with("o1-")
372            || lowered.starts_with("o3-")
373            || lowered.starts_with("chatgpt-")
374            || lowered.starts_with("text-")
375        {
376            return Some(Self::OpenAi {
377                model: full_name.to_string(),
378            });
379        }
380
381        // Claude/Anthropic models: claude-*
382        if lowered.starts_with("claude-") {
383            return Some(Self::Claude {
384                model: full_name.to_string(),
385            });
386        }
387
388        // Ollama models: llama*, mistral*, phi*, qwen*, gemma*, etc.
389        // Use the full name to preserve version tags like ":1.5b"
390        if lowered.starts_with("llama")
391            || lowered.starts_with("mistral")
392            || lowered.starts_with("phi")
393            || lowered.starts_with("codellama")
394            || lowered.starts_with("deepseek")
395            || lowered.starts_with("qwen")
396            || lowered.starts_with("gemma")
397        {
398            return Some(Self::Ollama {
399                model: full_name.to_string(),
400            });
401        }
402
403        None
404    }
405
406    fn label(&self) -> String {
407        match self {
408            Self::TinyLlama => TINYLLAMA_LABEL.to_string(),
409            Self::Ollama { model } => format!("ollama:{model}"),
410            Self::OpenAi { model } => format!("openai:{model}"),
411            Self::Gemini { model } => format!("gemini:{model}"),
412            Self::Claude { model } => format!("claude:{model}"),
413        }
414    }
415
416    fn context_budget(&self) -> ModelContextBudget {
417        match self {
418            Self::TinyLlama => {
419                ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
420            }
421            Self::Ollama { .. } => {
422                ModelContextBudget::new(OLLAMA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
423            }
424            Self::OpenAi { .. } => {
425                ModelContextBudget::new(OPENAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
426            }
427            Self::Gemini { .. } => {
428                ModelContextBudget::new(GEMINI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
429            }
430            Self::Claude { .. } => {
431                ModelContextBudget::new(CLAUDE_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
432            }
433        }
434    }
435
436    fn max_output_tokens(&self) -> usize {
437        match self {
438            Self::TinyLlama => LOCAL_MAX_OUTPUT_TOKENS,
439            Self::Ollama { .. }
440            | Self::OpenAi { .. }
441            | Self::Gemini { .. }
442            | Self::Claude { .. } => REMOTE_MAX_OUTPUT_TOKENS,
443        }
444    }
445}
446
447fn normalize_openai_model(explicit: Option<String>) -> String {
448    match explicit {
449        Some(raw) if !raw.trim().is_empty() => raw,
450        _ => "gpt-4o-mini".to_string(),
451    }
452}
453
454fn normalize_gemini_model(explicit: Option<String>) -> String {
455    let default_model = "gemini-2.5-flash".to_string();
456    let Some(raw) = explicit else {
457        return default_model;
458    };
459
460    let lowered = raw.to_ascii_lowercase();
461    match lowered.as_str() {
462        "gemini-pro" | "gemini-1.5-pro" | "gemini-1.5-flash" | "gemini-2.0-pro-exp" => raw,
463        _ => raw,
464    }
465}
466
467fn normalize_claude_model(explicit: Option<String>) -> String {
468    let default_model = "claude-3-5-sonnet-20241022".to_string();
469    let Some(raw) = explicit else {
470        return default_model;
471    };
472
473    raw
474}
475
476pub fn run_model_inference(
477    requested_model: &str,
478    question: &str,
479    fallback_context: &str,
480    hits: &[SearchHit],
481    context_override: Option<usize>,
482    api_key: Option<&str>,
483    system_prompt_override: Option<&str>,
484) -> Result<ModelInference, ModelRunError> {
485    let Some(model_kind) = ModelKind::parse(requested_model) else {
486        return Err(ModelRunError::UnsupportedModel(requested_model.to_string()));
487    };
488
489    let mut budget = model_kind.context_budget();
490    if let Some(override_chars) = context_override {
491        budget = budget.apply_override(override_chars);
492    }
493
494    let context_plan = context::assemble_context(hits, fallback_context, &budget);
495    let prompt = build_prompt_parts(
496        question,
497        &context_plan.body,
498        &budget,
499        model_kind.max_output_tokens(),
500    );
501
502    let answer = match &model_kind {
503        ModelKind::TinyLlama => {
504            #[cfg(feature = "llama-cpp")]
505            {
506                tinyllama::run(&prompt)?
507            }
508            #[cfg(not(feature = "llama-cpp"))]
509            {
510                return Err(ModelRunError::UnsupportedModel(
511                    "tinyllama (llama-cpp feature not enabled)".to_string(),
512                ));
513            }
514        }
515        ModelKind::Ollama { model } => ollama::run(model, &prompt)?,
516        ModelKind::OpenAi { model } => {
517            openai::run(model, &prompt, api_key, system_prompt_override)?
518        }
519        ModelKind::Gemini { model } => {
520            gemini::run(model, &prompt, api_key, system_prompt_override)?
521        }
522        ModelKind::Claude { model } => {
523            claude::run(model, &prompt, api_key, system_prompt_override)?
524        }
525    };
526
527    let context::ContextAggregation {
528        body: context_body,
529        records,
530    } = context_plan;
531    let context_fragments = records
532        .into_iter()
533        .map(ModelContextFragment::from_record)
534        .collect();
535
536    Ok(ModelInference {
537        answer: ModelAnswer {
538            requested: requested_model.to_string(),
539            model: model_kind.label(),
540            answer,
541        },
542        context_body,
543        context_fragments,
544    })
545}
546
547mod context {
548    use super::{ModelContextBudget, clamp_to};
549    use memvid_core::types::SearchHit;
550
551    const CONTEXT_HEADER: &str = "## Retrieval Context\n";
552    const PRIMARY_HEADER: &str = "### Primary Hit\n";
553    const SUPPORT_HEADER: &str = "### Supporting Hits\n";
554    const SUMMARY_HEADER: &str = "### Overflow Summaries\n";
555    const SUMMARY_HIGHLIGHT_CHARS: usize = 240;
556
557    #[derive(Debug, Clone)]
558    pub(super) struct ContextAggregation {
559        pub body: String,
560        pub records: Vec<ContextRecord>,
561    }
562
563    impl ContextAggregation {
564        fn from_fallback(fallback: &str, limit: usize) -> Self {
565            let body = if limit == 0 || fallback.is_empty() {
566                String::new()
567            } else if fallback.len() <= limit {
568                fallback.to_string()
569            } else {
570                clamp_to(fallback, limit)
571            };
572            Self {
573                body,
574                records: Vec::new(),
575            }
576        }
577    }
578
579    #[derive(Debug, Clone)]
580    pub(super) struct ContextRecord {
581        pub rank: usize,
582        pub uri: String,
583        pub title: Option<String>,
584        pub score: Option<f32>,
585        pub matches: usize,
586        pub frame_id: u64,
587        pub range: (usize, usize),
588        pub chunk_range: Option<(usize, usize)>,
589        pub text: String,
590        pub mode: ContextMode,
591    }
592
593    #[derive(Debug, Clone, Copy, Eq, PartialEq)]
594    pub(super) enum ContextMode {
595        Full,
596        Summary,
597    }
598
599    #[derive(Debug, Clone)]
600    pub(super) struct ContextAssemblyPlan {
601        primary: Option<ContextRecord>,
602        supporting: Vec<ContextRecord>,
603        summaries: Vec<ContextRecord>,
604    }
605
606    pub(super) fn assemble_context(
607        hits: &[SearchHit],
608        fallback: &str,
609        budget: &ModelContextBudget,
610    ) -> ContextAggregation {
611        if hits.is_empty() {
612            return ContextAggregation::from_fallback(fallback, budget.context_chars());
613        }
614
615        let plan = assemble_plan(hits, budget.context_chars());
616        let mut body = String::new();
617        let mut records = Vec::new();
618
619        body.push_str(CONTEXT_HEADER);
620        if let Some(primary) = plan.primary {
621            body.push_str(PRIMARY_HEADER);
622            body.push_str(&primary.text);
623            body.push_str("\n\n");
624            records.push(primary);
625        }
626
627        if !plan.supporting.is_empty() {
628            body.push_str(SUPPORT_HEADER);
629            for record in plan.supporting {
630                body.push_str(&record.text);
631                body.push_str("\n\n");
632                records.push(record);
633            }
634        }
635
636        if !plan.summaries.is_empty() {
637            body.push_str(SUMMARY_HEADER);
638            for record in plan.summaries {
639                body.push_str(&record.text);
640                body.push_str("\n\n");
641                records.push(record);
642            }
643        }
644
645        ContextAggregation { body, records }
646    }
647
648    fn assemble_plan(hits: &[SearchHit], mut remaining_chars: usize) -> ContextAssemblyPlan {
649        let mut records = Vec::new();
650        for hit in hits.iter().take(32) {
651            let full_record = build_record(hit, render_full(hit), ContextMode::Full);
652            let summary_record = build_record(hit, render_summary(hit), ContextMode::Summary);
653            records.push((full_record, summary_record));
654        }
655
656        let mut plan = ContextAssemblyPlan {
657            primary: None,
658            supporting: Vec::new(),
659            summaries: Vec::new(),
660        };
661
662        if let Some((primary_full, _)) = records.first() {
663            if primary_full.text.len() <= remaining_chars {
664                remaining_chars = remaining_chars.saturating_sub(primary_full.text.len());
665                plan.primary = Some(primary_full.clone());
666            }
667        }
668
669        for (idx, (full, summary)) in records.iter().enumerate() {
670            if idx == 0 {
671                continue;
672            }
673
674            if full.text.len() <= remaining_chars {
675                remaining_chars = remaining_chars.saturating_sub(full.text.len());
676                plan.supporting.push(full.clone());
677            } else if summary.text.len() <= remaining_chars {
678                remaining_chars = remaining_chars.saturating_sub(summary.text.len());
679                plan.summaries.push(summary.clone());
680            }
681        }
682
683        plan
684    }
685
686    fn render_full(hit: &SearchHit) -> String {
687        format!(
688            "Rank: {}\nURI: {}\nTitle: {}\nMatches: {}\nScore: {:.3}\nSnippet:\n{}",
689            hit.rank,
690            hit.uri,
691            hit.title
692                .clone()
693                .unwrap_or_else(|| "(untitled)".to_string()),
694            hit.matches,
695            hit.score.unwrap_or_default(),
696            hit.chunk_text
697                .clone()
698                .or_else(|| Some(hit.text.clone()))
699                .unwrap_or_default()
700        )
701    }
702
703    fn render_summary(hit: &SearchHit) -> String {
704        let snippet = hit
705            .chunk_text
706            .clone()
707            .or_else(|| Some(hit.text.clone()))
708            .unwrap_or_default();
709        let snippet = trim_highlight(&snippet, SUMMARY_HIGHLIGHT_CHARS);
710        format!(
711            "Rank: {}\nURI: {}\nHighlight: {}",
712            hit.rank, hit.uri, snippet
713        )
714    }
715
716    fn trim_highlight(text: &str, limit: usize) -> String {
717        let clean = text.replace('\n', " ");
718        clamp_to(&clean, limit)
719    }
720
721    fn build_record(hit: &SearchHit, text: String, mode: ContextMode) -> ContextRecord {
722        ContextRecord {
723            rank: hit.rank,
724            uri: hit.uri.clone(),
725            title: hit.title.clone(),
726            score: hit.score,
727            matches: hit.matches,
728            frame_id: hit.frame_id,
729            range: hit.range,
730            chunk_range: hit.chunk_range,
731            text,
732            mode,
733        }
734    }
735}
736
737#[cfg(feature = "llama-cpp")]
738mod tinyllama {
739    use super::{ModelRunError, PromptParts, TINYLLAMA_LABEL, ThinkingSpinner};
740    use anyhow::anyhow;
741    use llama_cpp::standard_sampler::StandardSampler;
742    use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
743    use tokio::runtime::Builder;
744
745    use std::path::{Path, PathBuf};
746
747    const MODEL_DIR: &str = "models/tinyllama";
748    const GGUF_HINT: &str = "*.gguf";
749
750    pub(super) fn run(prompt: &PromptParts) -> Result<String, ModelRunError> {
751        let base_dir = Path::new(MODEL_DIR);
752        let assets = RequiredAssets::new(base_dir);
753
754        if let Some(missing) = assets.missing_paths() {
755            return Err(ModelRunError::AssetsMissing {
756                model: TINYLLAMA_LABEL.to_string(),
757                missing,
758            });
759        }
760
761        let gguf_path = assets.gguf_path.clone().ok_or_else(|| {
762            ModelRunError::Runtime(anyhow!(
763                "no GGUF model file found in {}",
764                base_dir.display()
765            ))
766        })?;
767
768        unsafe {
769            std::env::set_var("GGML_LOG_LEVEL", "ERROR");
770            std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
771        }
772
773        let model =
774            LlamaModel::load_from_file(&gguf_path, LlamaParams::default()).map_err(|err| {
775                ModelRunError::Runtime(anyhow!(
776                    "failed to load TinyLlama weights from {}: {err}",
777                    gguf_path.display()
778                ))
779            })?;
780
781        let mut session_params = SessionParams::default();
782        if session_params.n_ctx == 0 {
783            session_params.n_ctx = 2048;
784        }
785        session_params.n_batch = session_params.n_ctx.min(512);
786        if session_params.n_ubatch == 0 {
787            session_params.n_ubatch = 512;
788        }
789        let max_tokens = session_params.n_ctx as usize;
790        let mut session = model.create_session(session_params).map_err(|err| {
791            ModelRunError::Runtime(anyhow!("failed to create TinyLlama session: {err}"))
792        })?;
793
794        let mut priming_tokens = model
795            .tokenize_bytes(prompt.completion_prompt().as_bytes(), true, true)
796            .map_err(|err| {
797                ModelRunError::Runtime(anyhow!("failed to tokenize TinyLlama prompt: {err}"))
798            })?;
799
800        let requested_tokens = prompt.max_output_tokens();
801        if max_tokens > 0 {
802            let reserved = requested_tokens + 64;
803            if priming_tokens.len() >= max_tokens.saturating_sub(reserved) {
804                let target = max_tokens.saturating_sub(reserved).max(1);
805                let tail_start = priming_tokens.len().saturating_sub(target);
806                priming_tokens = priming_tokens.split_off(tail_start);
807            }
808        }
809
810        session
811            .advance_context_with_tokens(&priming_tokens)
812            .map_err(|err| {
813                ModelRunError::Runtime(anyhow!("failed to prime TinyLlama context: {err}"))
814            })?;
815
816        let handle = session
817            .start_completing_with(StandardSampler::default(), requested_tokens)
818            .map_err(|err| ModelRunError::Runtime(anyhow!("completion failed to start: {err}")))?;
819
820        let runtime = Builder::new_current_thread()
821            .enable_all()
822            .build()
823            .map_err(|err| {
824                ModelRunError::Runtime(anyhow!("failed to build tokio runtime: {err}"))
825            })?;
826
827        let mut spinner = ThinkingSpinner::start();
828        let generated = runtime.block_on(async { handle.into_string_async().await });
829        spinner.stop();
830
831        let answer = generated.trim().to_string();
832
833        if answer.is_empty() {
834            Ok("No answer generated by TinyLlama.".to_string())
835        } else {
836            Ok(answer)
837        }
838    }
839
840    struct RequiredAssets {
841        gguf_path: Option<PathBuf>,
842        base_dir: PathBuf,
843    }
844
845    impl RequiredAssets {
846        fn new(base_dir: &Path) -> Self {
847            let gguf_path = find_first_gguf(base_dir);
848            Self {
849                gguf_path,
850                base_dir: base_dir.to_path_buf(),
851            }
852        }
853
854        fn missing_paths(&self) -> Option<Vec<PathBuf>> {
855            if self.gguf_path.is_some() {
856                None
857            } else {
858                Some(vec![self.base_dir.join(GGUF_HINT)])
859            }
860        }
861    }
862
863    fn find_first_gguf(base_dir: &Path) -> Option<PathBuf> {
864        let mut entries: Vec<PathBuf> = std::fs::read_dir(base_dir)
865            .ok()?
866            .filter_map(|entry| entry.ok().map(|e| e.path()))
867            .filter(|path| path.is_file() && path.extension().map_or(false, |ext| ext == "gguf"))
868            .collect();
869        entries.sort();
870        entries.into_iter().next()
871    }
872}
873
874mod ollama {
875    use super::{ModelRunError, PromptParts, ThinkingSpinner};
876    use anyhow::anyhow;
877    use reqwest::blocking::Client;
878    use serde::Deserialize;
879    use serde_json::json;
880
881    const ENDPOINT: &str = "http://127.0.0.1:11434/api/generate";
882
883    pub(super) fn run(model: &str, prompt: &PromptParts) -> Result<String, ModelRunError> {
884        let client = Client::builder()
885            .timeout(std::time::Duration::from_secs(60))
886            .build()
887            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
888
889        let mut spinner = ThinkingSpinner::start();
890        let response = client
891            .post(ENDPOINT)
892            .json(&json!({
893                "model": model,
894                "prompt": prompt.completion_prompt(),
895                "stream": false
896            }))
897            .send()
898            .map_err(|err| ModelRunError::Runtime(anyhow!("ollama request failed: {err}")))?
899            .error_for_status()
900            .map_err(|err| {
901                ModelRunError::Runtime(anyhow!("ollama returned error status: {err}"))
902            })?;
903
904        let body: GenerateResponse = response.json().map_err(|err| {
905            ModelRunError::Runtime(anyhow!("failed to decode ollama response: {err}"))
906        })?;
907        spinner.stop();
908
909        let text = body.response.trim().to_string();
910        if text.is_empty() {
911            Ok("No answer returned by Ollama.".to_string())
912        } else {
913            Ok(text)
914        }
915    }
916
917    #[derive(Debug, Deserialize)]
918    struct GenerateResponse {
919        #[serde(default)]
920        response: String,
921    }
922}
923
924mod openai {
925    use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
926    use anyhow::anyhow;
927    use reqwest::blocking::Client;
928    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
929    use serde::Deserialize;
930    use serde_json::json;
931
932    const CHAT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
933    const RESPONSES_ENDPOINT: &str = "https://api.openai.com/v1/responses";
934
935    pub(super) fn run(
936        model: &str,
937        prompt: &PromptParts,
938        override_key: Option<&str>,
939        system_prompt_override: Option<&str>,
940    ) -> Result<String, ModelRunError> {
941        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
942        let key = override_key
943            .map(|value| value.to_string())
944            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
945            .ok_or_else(|| {
946                ModelRunError::Runtime(anyhow!(
947                    "OPENAI_API_KEY environment variable is required for OpenAI models"
948                ))
949            })?;
950
951        let mut headers = HeaderMap::new();
952        headers.insert(
953            AUTHORIZATION,
954            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
955                ModelRunError::Runtime(anyhow!("invalid OPENAI_API_KEY header value: {err}"))
956            })?,
957        );
958        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
959
960        let client = Client::builder()
961            .timeout(std::time::Duration::from_secs(60))
962            .default_headers(headers)
963            .build()
964            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
965
966        let mut spinner = ThinkingSpinner::start();
967        let text = if requires_responses_api(model) {
968            let combined_prompt = format!(
969                "System instructions:\n{}\n\nUser query:\n{}",
970                system_prompt,
971                prompt.user_message()
972            );
973            let payload = json!({
974                "model": model,
975                "input": combined_prompt,
976                "max_output_tokens": prompt.max_output_tokens() as u32,
977                "reasoning": {
978                    "effort": "low"
979                }
980            });
981
982            let response = client
983                .post(RESPONSES_ENDPOINT)
984                .json(&payload)
985                .send()
986                .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
987
988            let status = response.status();
989            if !status.is_success() {
990                let body = response
991                    .text()
992                    .unwrap_or_else(|_| "<failed to read body>".to_string());
993                return Err(ModelRunError::Runtime(anyhow!(
994                    "OpenAI returned error status {status}: {body}"
995                )));
996            }
997
998            let body: ResponsesResponse = response.json().map_err(|err| {
999                ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
1000            })?;
1001
1002            extract_responses_text(body)
1003        } else {
1004            let payload = json!({
1005                "model": model,
1006                "messages": [
1007                    {"role": "system", "content": system_prompt},
1008                    {"role": "user", "content": prompt.user_message()}
1009                ],
1010                "temperature": 0.2,
1011                "max_tokens": prompt.max_output_tokens() as u32
1012            });
1013
1014            let response = client
1015                .post(CHAT_ENDPOINT)
1016                .json(&payload)
1017                .send()
1018                .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
1019
1020            let status = response.status();
1021            if !status.is_success() {
1022                let body = response
1023                    .text()
1024                    .unwrap_or_else(|_| "<failed to read body>".to_string());
1025                return Err(ModelRunError::Runtime(anyhow!(
1026                    "OpenAI returned error status {status}: {body}"
1027                )));
1028            }
1029
1030            let body: ChatResponse = response.json().map_err(|err| {
1031                ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
1032            })?;
1033
1034            extract_chat_text(body)
1035        };
1036        spinner.stop();
1037        Ok(text)
1038    }
1039
1040    #[derive(Debug, Deserialize)]
1041    struct ChatResponse {
1042        choices: Vec<Choice>,
1043    }
1044
1045    #[derive(Debug, Deserialize)]
1046    struct Choice {
1047        message: ChatMessage,
1048    }
1049
1050    #[derive(Debug, Deserialize)]
1051    struct ChatMessage {
1052        #[serde(default)]
1053        content: Option<String>,
1054    }
1055
1056    #[derive(Debug, Deserialize)]
1057    struct ResponsesResponse {
1058        #[serde(default)]
1059        output: Vec<ResponseItem>,
1060        #[serde(default)]
1061        output_text: Vec<String>,
1062    }
1063
1064    #[derive(Debug, Deserialize)]
1065    struct ResponseItem {
1066        #[serde(default)]
1067        content: Vec<ResponseContent>,
1068    }
1069
1070    #[derive(Debug, Deserialize)]
1071    struct ResponseContent {
1072        #[serde(rename = "type")]
1073        kind: String,
1074        #[serde(default)]
1075        text: Option<String>,
1076    }
1077
1078    fn extract_chat_text(body: ChatResponse) -> String {
1079        body.choices
1080            .into_iter()
1081            .find_map(|choice| choice.message.content)
1082            .map(|value| value.trim().to_string())
1083            .unwrap_or_else(|| "No answer returned by OpenAI.".to_string())
1084    }
1085
1086    fn extract_responses_text(body: ResponsesResponse) -> String {
1087        if !body.output_text.is_empty() {
1088            let text = body
1089                .output_text
1090                .into_iter()
1091                .find(|value| !value.trim().is_empty());
1092            if let Some(text) = text {
1093                return text.trim().to_string();
1094            }
1095        }
1096        for item in body.output {
1097            for segment in item.content {
1098                match segment.kind.as_str() {
1099                    "output_text" | "text" => {
1100                        if let Some(text) = segment.text {
1101                            let trimmed = text.trim();
1102                            if !trimmed.is_empty() {
1103                                return trimmed.to_string();
1104                            }
1105                        }
1106                    }
1107                    _ => {}
1108                }
1109            }
1110        }
1111        "No answer returned by OpenAI.".to_string()
1112    }
1113
1114    fn requires_responses_api(model: &str) -> bool {
1115        let lowered = model.to_ascii_lowercase();
1116        lowered.starts_with("gpt-5") || lowered.contains("gpt-4.1")
1117    }
1118}
1119
1120mod gemini {
1121    use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
1122    use anyhow::anyhow;
1123    use reqwest::blocking::Client;
1124    use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
1125    use serde::Deserialize;
1126    use serde_json::json;
1127
1128    pub(super) fn run(
1129        model: &str,
1130        prompt: &PromptParts,
1131        override_key: Option<&str>,
1132        system_prompt_override: Option<&str>,
1133    ) -> Result<String, ModelRunError> {
1134        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
1135        let key = override_key
1136            .map(|value| value.to_string())
1137            .or_else(|| std::env::var("GEMINI_API_KEY").ok())
1138            .ok_or_else(|| {
1139                ModelRunError::Runtime(anyhow!(
1140                    "GEMINI_API_KEY environment variable is required for Gemini models"
1141                ))
1142            })?;
1143
1144        let url = format!(
1145            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
1146            model
1147        );
1148
1149        let mut headers = HeaderMap::new();
1150        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
1151        headers.insert(
1152            HeaderName::from_static("x-goog-api-key"),
1153            HeaderValue::from_str(&key).map_err(|err| {
1154                ModelRunError::Runtime(anyhow!("invalid GEMINI_API_KEY header value: {err}"))
1155            })?,
1156        );
1157
1158        let client = Client::builder()
1159            .timeout(std::time::Duration::from_secs(60))
1160            .default_headers(headers)
1161            .build()
1162            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
1163
1164        let payload = json!({
1165            "contents": [{
1166                "parts": [
1167                    { "text": system_prompt },
1168                    { "text": prompt.user_message() }
1169                ]
1170            }],
1171            "generationConfig": {
1172                "temperature": 0.2,
1173                "maxOutputTokens": prompt.max_output_tokens() as u32,
1174                "topK": 40,
1175                "topP": 0.95
1176            }
1177        });
1178
1179        let mut spinner = ThinkingSpinner::start();
1180        let response = client
1181            .post(url)
1182            .json(&payload)
1183            .send()
1184            .map_err(|err| ModelRunError::Runtime(anyhow!("Gemini request failed: {err}")))?
1185            .error_for_status()
1186            .map_err(|err| {
1187                ModelRunError::Runtime(anyhow!("Gemini returned error status: {err}"))
1188            })?;
1189
1190        let body: GenerateResponse = response.json().map_err(|err| {
1191            ModelRunError::Runtime(anyhow!("failed to decode Gemini response: {err}"))
1192        })?;
1193        spinner.stop();
1194
1195        let text = body
1196            .candidates
1197            .into_iter()
1198            .flat_map(|candidate| candidate.content.parts)
1199            .find_map(|part| part.text)
1200            .map(|value| value.trim().to_string())
1201            .unwrap_or_else(|| "No answer returned by Gemini.".to_string());
1202
1203        Ok(text)
1204    }
1205
1206    #[derive(Debug, Deserialize)]
1207    struct GenerateResponse {
1208        candidates: Vec<Candidate>,
1209    }
1210
1211    #[derive(Debug, Deserialize)]
1212    struct Candidate {
1213        content: CandidateContent,
1214    }
1215
1216    #[derive(Debug, Deserialize)]
1217    struct CandidateContent {
1218        parts: Vec<CandidatePart>,
1219    }
1220
1221    #[derive(Debug, Deserialize)]
1222    struct CandidatePart {
1223        #[serde(default)]
1224        text: Option<String>,
1225    }
1226}
1227
1228mod claude {
1229    use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
1230    use anyhow::anyhow;
1231    use reqwest::blocking::Client;
1232    use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
1233    use serde::Deserialize;
1234    use serde_json::json;
1235
1236    const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
1237    const API_VERSION: &str = "2023-06-01";
1238
1239    pub(super) fn run(
1240        model: &str,
1241        prompt: &PromptParts,
1242        override_key: Option<&str>,
1243        system_prompt_override: Option<&str>,
1244    ) -> Result<String, ModelRunError> {
1245        let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
1246        let key = override_key
1247            .map(|value| value.to_string())
1248            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
1249            .or_else(|| std::env::var("CLAUDE_API_KEY").ok())
1250            .ok_or_else(|| {
1251                ModelRunError::Runtime(anyhow!(
1252                    "ANTHROPIC_API_KEY environment variable is required for Claude models"
1253                ))
1254            })?;
1255
1256        let mut headers = HeaderMap::new();
1257        headers.insert(
1258            AUTHORIZATION,
1259            HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
1260                ModelRunError::Runtime(anyhow!("invalid ANTHROPIC_API_KEY header value: {err}"))
1261            })?,
1262        );
1263        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
1264        headers.insert(
1265            HeaderName::from_static("anthropic-version"),
1266            HeaderValue::from_static(API_VERSION),
1267        );
1268
1269        let client = Client::builder()
1270            .timeout(std::time::Duration::from_secs(60))
1271            .default_headers(headers)
1272            .build()
1273            .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
1274
1275        let payload = json!({
1276            "model": model,
1277            "max_tokens": prompt.max_output_tokens() as u32,
1278            "temperature": 0.2,
1279            "system": system_prompt,
1280            "messages": [{
1281                "role": "user",
1282                "content": [{"type": "text", "text": prompt.user_message()}]
1283            }]
1284        });
1285
1286        let mut spinner = ThinkingSpinner::start();
1287        let response = client
1288            .post(ENDPOINT)
1289            .json(&payload)
1290            .send()
1291            .map_err(|err| ModelRunError::Runtime(anyhow!("Claude request failed: {err}")))?
1292            .error_for_status()
1293            .map_err(|err| {
1294                ModelRunError::Runtime(anyhow!("Claude returned error status: {err}"))
1295            })?;
1296
1297        let body: ClaudeResponse = response.json().map_err(|err| {
1298            ModelRunError::Runtime(anyhow!("failed to decode Claude response: {err}"))
1299        })?;
1300        spinner.stop();
1301
1302        let text = body
1303            .content
1304            .into_iter()
1305            .find_map(|part| match part {
1306                ContentBlock::Text { text } if !text.trim().is_empty() => {
1307                    Some(text.trim().to_string())
1308                }
1309                _ => None,
1310            })
1311            .unwrap_or_else(|| "No answer returned by Claude.".to_string());
1312
1313        Ok(text)
1314    }
1315
1316    #[derive(Debug, Deserialize)]
1317    struct ClaudeResponse {
1318        #[serde(default)]
1319        content: Vec<ContentBlock>,
1320    }
1321
1322    #[derive(Debug, Deserialize)]
1323    #[serde(tag = "type", rename_all = "lowercase")]
1324    enum ContentBlock {
1325        Text {
1326            text: String,
1327        },
1328        #[serde(other)]
1329        Other,
1330    }
1331}
1332
1333// ============================================================================
1334// Entity Extraction API
1335// ============================================================================
1336
1337/// Default system prompt for entity extraction
1338pub const ENTITY_EXTRACTION_PROMPT: &str = r#"Extract named entities from the provided text. Return a JSON object with an "entities" array.
1339
1340Each entity should have:
1341- "name": The entity name as it appears in the text
1342- "type": One of "PERSON", "ORG", "LOCATION", "DATE", "PRODUCT", "EVENT", or "OTHER"
1343- "confidence": A number between 0.0 and 1.0 indicating your confidence
1344
1345Guidelines:
13461. Only include entities you're confident about (confidence >= 0.7)
13472. Preserve the original capitalization of entity names
13483. For organizations, include full names (e.g., "S&P Global" not just "S&P")
13494. For people, include full names when available
13505. Deduplicate: if an entity appears multiple times, include it only once
1351
1352Return format:
1353{"entities": [{"name": "...", "type": "...", "confidence": 0.9}, ...]}"#;
1354
1355/// Extracted entity from text
1356#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1357pub struct ExtractedEntity {
1358    pub name: String,
1359    #[serde(rename = "type")]
1360    pub entity_type: String,
1361    pub confidence: f32,
1362}
1363
1364/// Response from entity extraction
1365#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1366pub struct EntityExtractionResponse {
1367    pub entities: Vec<ExtractedEntity>,
1368    pub model: String,
1369    pub text_chars: usize,
1370}
1371
1372/// Extract entities from text using an LLM
1373///
1374/// # Arguments
1375/// * `model` - Model identifier (e.g., "openai:gpt-4o-mini", "claude:claude-3-5-sonnet")
1376/// * `text` - The text to extract entities from
1377/// * `system_prompt` - Optional custom system prompt (uses default if None)
1378/// * `api_key` - Optional API key (uses environment variable if None)
1379///
1380/// # Returns
1381/// An `EntityExtractionResponse` with the extracted entities
1382///
1383/// # Example
1384/// ```ignore
1385/// let response = extract_entities(
1386///     "openai:gpt-4o-mini",
1387///     "John Smith met with Microsoft CEO Satya Nadella in Seattle.",
1388///     None,  // use default prompt
1389///     None,  // use OPENAI_API_KEY env var
1390/// )?;
1391/// for entity in response.entities {
1392///     println!("{}: {} ({:.0}%)", entity.name, entity.entity_type, entity.confidence * 100.0);
1393/// }
1394/// ```
1395pub fn extract_entities(
1396    model: &str,
1397    text: &str,
1398    system_prompt: Option<&str>,
1399    api_key: Option<&str>,
1400) -> Result<EntityExtractionResponse, ModelRunError> {
1401    let prompt = system_prompt.unwrap_or(ENTITY_EXTRACTION_PROMPT);
1402    let text_chars = text.len();
1403
1404    // Determine model provider and make API call
1405    let (provider, model_name) = parse_model_spec(model);
1406
1407    let json_response = match provider.as_str() {
1408        "openai" => extract_entities_openai(&model_name, text, prompt, api_key)?,
1409        "claude" | "anthropic" => extract_entities_claude(&model_name, text, prompt, api_key)?,
1410        "gemini" | "google" => extract_entities_gemini(&model_name, text, prompt, api_key)?,
1411        _ => {
1412            return Err(ModelRunError::UnsupportedModel(format!(
1413                "Entity extraction not supported for provider '{}'. Use openai:, claude:, or gemini:",
1414                provider
1415            )));
1416        }
1417    };
1418
1419    // Parse the JSON response
1420    let entities = parse_entity_response(&json_response)?;
1421
1422    Ok(EntityExtractionResponse {
1423        entities,
1424        model: model.to_string(),
1425        text_chars,
1426    })
1427}
1428
1429fn parse_model_spec(model: &str) -> (String, String) {
1430    if let Some((provider, name)) = model.split_once(':') {
1431        (provider.to_lowercase(), name.to_string())
1432    } else {
1433        // Default to OpenAI if no provider specified
1434        ("openai".to_string(), model.to_string())
1435    }
1436}
1437
1438fn parse_entity_response(json_str: &str) -> Result<Vec<ExtractedEntity>, ModelRunError> {
1439    // Try to parse the response, handling various formats
1440    let trimmed = json_str.trim();
1441
1442    // Handle markdown code blocks
1443    let clean_json = if trimmed.starts_with("```json") {
1444        trimmed
1445            .strip_prefix("```json")
1446            .and_then(|s| s.strip_suffix("```"))
1447            .unwrap_or(trimmed)
1448            .trim()
1449    } else if trimmed.starts_with("```") {
1450        trimmed
1451            .strip_prefix("```")
1452            .and_then(|s| s.strip_suffix("```"))
1453            .unwrap_or(trimmed)
1454            .trim()
1455    } else {
1456        trimmed
1457    };
1458
1459    // Try parsing as {"entities": [...]}
1460    #[derive(serde::Deserialize)]
1461    struct EntityResponse {
1462        entities: Vec<ExtractedEntity>,
1463    }
1464
1465    if let Ok(response) = serde_json::from_str::<EntityResponse>(clean_json) {
1466        return Ok(response.entities);
1467    }
1468
1469    // Try parsing as a direct array [...]
1470    if let Ok(entities) = serde_json::from_str::<Vec<ExtractedEntity>>(clean_json) {
1471        return Ok(entities);
1472    }
1473
1474    Err(ModelRunError::Runtime(anyhow::anyhow!(
1475        "Failed to parse entity extraction response as JSON: {}",
1476        &clean_json[..clean_json.len().min(200)]
1477    )))
1478}
1479
1480fn extract_entities_openai(
1481    model: &str,
1482    text: &str,
1483    system_prompt: &str,
1484    api_key: Option<&str>,
1485) -> Result<String, ModelRunError> {
1486    use serde_json::json;
1487
1488    let api_key = api_key
1489        .map(|s| s.to_string())
1490        .or_else(|| std::env::var("OPENAI_API_KEY").ok())
1491        .ok_or_else(|| {
1492            ModelRunError::Runtime(anyhow::anyhow!(
1493                "OpenAI API key required. Set OPENAI_API_KEY or pass api_key parameter."
1494            ))
1495        })?;
1496
1497    let model_name = if model.is_empty() { "gpt-4o-mini" } else { model };
1498
1499    let client = reqwest::blocking::Client::new();
1500    let payload = json!({
1501        "model": model_name,
1502        "messages": [
1503            {"role": "system", "content": system_prompt},
1504            {"role": "user", "content": text}
1505        ],
1506        "response_format": {"type": "json_object"},
1507        "temperature": 0.1
1508    });
1509
1510    let response = client
1511        .post("https://api.openai.com/v1/chat/completions")
1512        .header("Authorization", format!("Bearer {}", api_key))
1513        .json(&payload)
1514        .send()
1515        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI request failed: {}", e)))?
1516        .error_for_status()
1517        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI returned error: {}", e)))?;
1518
1519    #[derive(serde::Deserialize)]
1520    struct OpenAIResponse {
1521        choices: Vec<OpenAIChoice>,
1522    }
1523    #[derive(serde::Deserialize)]
1524    struct OpenAIChoice {
1525        message: OpenAIMessage,
1526    }
1527    #[derive(serde::Deserialize)]
1528    struct OpenAIMessage {
1529        content: String,
1530    }
1531
1532    let body: OpenAIResponse = response.json().map_err(|e| {
1533        ModelRunError::Runtime(anyhow::anyhow!("Failed to parse OpenAI response: {}", e))
1534    })?;
1535
1536    body.choices
1537        .into_iter()
1538        .next()
1539        .map(|c| c.message.content)
1540        .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No response from OpenAI")))
1541}
1542
1543fn extract_entities_claude(
1544    model: &str,
1545    text: &str,
1546    system_prompt: &str,
1547    api_key: Option<&str>,
1548) -> Result<String, ModelRunError> {
1549    use serde_json::json;
1550
1551    let api_key = api_key
1552        .map(|s| s.to_string())
1553        .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
1554        .ok_or_else(|| {
1555            ModelRunError::Runtime(anyhow::anyhow!(
1556                "Anthropic API key required. Set ANTHROPIC_API_KEY or pass api_key parameter."
1557            ))
1558        })?;
1559
1560    let model_name = if model.is_empty() { "claude-3-5-sonnet-20241022" } else { model };
1561
1562    let client = reqwest::blocking::Client::new();
1563    let payload = json!({
1564        "model": model_name,
1565        "max_tokens": 4096,
1566        "system": format!("{}\n\nRespond with valid JSON only.", system_prompt),
1567        "messages": [
1568            {"role": "user", "content": text}
1569        ]
1570    });
1571
1572    let response = client
1573        .post("https://api.anthropic.com/v1/messages")
1574        .header("x-api-key", &api_key)
1575        .header("anthropic-version", "2023-06-01")
1576        .header("content-type", "application/json")
1577        .json(&payload)
1578        .send()
1579        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude request failed: {}", e)))?
1580        .error_for_status()
1581        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude returned error: {}", e)))?;
1582
1583    #[derive(serde::Deserialize)]
1584    struct ClaudeResponse {
1585        content: Vec<ClaudeContent>,
1586    }
1587    #[derive(serde::Deserialize)]
1588    struct ClaudeContent {
1589        text: Option<String>,
1590    }
1591
1592    let body: ClaudeResponse = response.json().map_err(|e| {
1593        ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Claude response: {}", e))
1594    })?;
1595
1596    body.content
1597        .into_iter()
1598        .find_map(|c| c.text)
1599        .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Claude")))
1600}
1601
1602fn extract_entities_gemini(
1603    model: &str,
1604    text: &str,
1605    system_prompt: &str,
1606    api_key: Option<&str>,
1607) -> Result<String, ModelRunError> {
1608    use serde_json::json;
1609
1610    let api_key = api_key
1611        .map(|s| s.to_string())
1612        .or_else(|| std::env::var("GEMINI_API_KEY").ok())
1613        .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
1614        .ok_or_else(|| {
1615            ModelRunError::Runtime(anyhow::anyhow!(
1616                "Gemini API key required. Set GEMINI_API_KEY or pass api_key parameter."
1617            ))
1618        })?;
1619
1620    let model_name = if model.is_empty() { "gemini-2.0-flash" } else { model };
1621    let url = format!(
1622        "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
1623        model_name, api_key
1624    );
1625
1626    let client = reqwest::blocking::Client::new();
1627    let payload = json!({
1628        "contents": [{
1629            "parts": [{"text": format!("{}\n\nText to analyze:\n{}", system_prompt, text)}]
1630        }],
1631        "generationConfig": {
1632            "temperature": 0.1,
1633            "responseMimeType": "application/json"
1634        }
1635    });
1636
1637    let response = client
1638        .post(&url)
1639        .json(&payload)
1640        .send()
1641        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini request failed: {}", e)))?
1642        .error_for_status()
1643        .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini returned error: {}", e)))?;
1644
1645    #[derive(serde::Deserialize)]
1646    struct GeminiResponse {
1647        candidates: Vec<GeminiCandidate>,
1648    }
1649    #[derive(serde::Deserialize)]
1650    struct GeminiCandidate {
1651        content: GeminiContent,
1652    }
1653    #[derive(serde::Deserialize)]
1654    struct GeminiContent {
1655        parts: Vec<GeminiPart>,
1656    }
1657    #[derive(serde::Deserialize)]
1658    struct GeminiPart {
1659        text: Option<String>,
1660    }
1661
1662    let body: GeminiResponse = response.json().map_err(|e| {
1663        ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Gemini response: {}", e))
1664    })?;
1665
1666    body.candidates
1667        .into_iter()
1668        .next()
1669        .and_then(|c| c.content.parts.into_iter().find_map(|p| p.text))
1670        .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Gemini")))
1671}
1672
1673#[cfg(test)]
1674mod tests {
1675    use super::*;
1676
1677    #[test]
1678    fn normalize_models() {
1679        assert_eq!(normalize_openai_model(None), "gpt-4o-mini");
1680        assert_eq!(normalize_gemini_model(None), "gemini-2.5-flash");
1681        assert_eq!(normalize_claude_model(None), "claude-3-5-sonnet-20241022");
1682    }
1683
1684    #[test]
1685    fn parse_entity_json() {
1686        let json = r#"{"entities": [{"name": "John", "type": "PERSON", "confidence": 0.95}]}"#;
1687        let entities = parse_entity_response(json).unwrap();
1688        assert_eq!(entities.len(), 1);
1689        assert_eq!(entities[0].name, "John");
1690    }
1691
1692    #[test]
1693    fn parse_entity_json_with_markdown() {
1694        let json = r#"```json
1695{"entities": [{"name": "Microsoft", "type": "ORG", "confidence": 0.99}]}
1696```"#;
1697        let entities = parse_entity_response(json).unwrap();
1698        assert_eq!(entities.len(), 1);
1699        assert_eq!(entities[0].name, "Microsoft");
1700    }
1701
1702    #[test]
1703    fn parse_model_spec_test() {
1704        let (provider, model) = parse_model_spec("openai:gpt-4o");
1705        assert_eq!(provider, "openai");
1706        assert_eq!(model, "gpt-4o");
1707
1708        let (provider, model) = parse_model_spec("gpt-4o-mini");
1709        assert_eq!(provider, "openai");
1710        assert_eq!(model, "gpt-4o-mini");
1711    }
1712}