1use std::fmt;
2use std::io::{IsTerminal, Write, stderr};
3use std::path::PathBuf;
4use std::sync::{
5 Arc,
6 atomic::{AtomicBool, Ordering},
7};
8use std::thread;
9use std::time::Duration;
10
11use memvid_core::types::SearchHit;
12
13#[derive(Debug, Clone)]
14pub struct ModelAnswer {
15 pub requested: String,
16 pub model: String,
17 pub answer: String,
18}
19
20#[derive(Debug, Clone)]
21pub struct ModelInference {
22 pub answer: ModelAnswer,
23 pub context_body: String,
24 pub context_fragments: Vec<ModelContextFragment>,
25}
26
27#[derive(Debug, Clone)]
28pub struct ModelContextFragment {
29 pub rank: usize,
30 pub uri: String,
31 pub title: Option<String>,
32 pub score: Option<f32>,
33 pub matches: usize,
34 pub frame_id: u64,
35 pub range: (usize, usize),
36 pub chunk_range: Option<(usize, usize)>,
37 pub text: String,
38 pub kind: ModelContextFragmentKind,
39}
40
41#[derive(Debug, Clone, Copy, Eq, PartialEq)]
42pub enum ModelContextFragmentKind {
43 Full,
44 Summary,
45}
46
47impl ModelContextFragment {
48 fn from_record(record: context::ContextRecord) -> Self {
49 let kind = match record.mode {
50 context::ContextMode::Full => ModelContextFragmentKind::Full,
51 context::ContextMode::Summary => ModelContextFragmentKind::Summary,
52 };
53 Self {
54 rank: record.rank,
55 uri: record.uri,
56 title: record.title,
57 score: record.score,
58 matches: record.matches,
59 frame_id: record.frame_id,
60 range: record.range,
61 chunk_range: record.chunk_range,
62 text: record.text,
63 kind,
64 }
65 }
66}
67
68#[derive(Debug)]
69pub enum ModelRunError {
70 UnsupportedModel(String),
71 AssetsMissing {
72 model: String,
73 missing: Vec<PathBuf>,
74 },
75 Runtime(anyhow::Error),
76}
77
78impl fmt::Display for ModelRunError {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 Self::UnsupportedModel(model) => write!(f, "unsupported model '{model}'"),
82 Self::AssetsMissing { model, missing } => {
83 let paths: Vec<_> = missing
84 .iter()
85 .map(|path| path.display().to_string())
86 .collect();
87 write!(
88 f,
89 "model '{model}' missing required assets: {}",
90 paths.join(", ")
91 )
92 }
93 Self::Runtime(err) => write!(f, "model runtime error: {err}"),
94 }
95 }
96}
97
98impl std::error::Error for ModelRunError {
99 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
100 match self {
101 Self::Runtime(err) => Some(err.root_cause()),
102 _ => None,
103 }
104 }
105}
106
107const LOCAL_CONTEXT_CHARS: usize = 32_768;
108const MAX_QUESTION_CHARS: usize = 512;
109const LOCAL_MAX_OUTPUT_TOKENS: usize = 256;
110const REMOTE_MAX_OUTPUT_TOKENS: usize = 768;
111const SYSTEM_PROMPT: &str = "You are a helpful assistant that answers questions based on the provided context.\n\nGuidelines:\n1. Read all the provided context carefully before answering.\n2. PAY CLOSE ATTENTION TO DATES: When context has timestamps, note that later dates reflect current/updated values.\n3. For temporal questions (e.g., \"How many X when started vs now?\"), find both earlier and later values. The most recent reflects the current state.\n4. For counting questions (e.g., \"How many times...\"), count occurrences in the context.\n5. Only say \"not enough information\" if the context truly contains no relevant information.\n6. Base all answers on the context - do not use external knowledge.";
112const TINYLLAMA_LABEL: &str = "tinyllama-1.1b";
113const LOCAL_PROMPT_MARGIN_CHARS: usize = 2_048;
114const REMOTE_PROMPT_MARGIN_CHARS: usize = 4_096;
115const OLLAMA_PROMPT_CHARS: usize = 110_000;
116const OPENAI_PROMPT_CHARS: usize = 240_000;
117const GEMINI_PROMPT_CHARS: usize = 320_000;
118const CLAUDE_PROMPT_CHARS: usize = 360_000;
119
120#[derive(Debug, Clone, Copy)]
121struct ModelContextBudget {
122 total_chars: usize,
123 reserved_chars: usize,
124}
125
126impl ModelContextBudget {
127 const fn new(total_chars: usize, reserved_chars: usize) -> Self {
128 Self {
129 total_chars,
130 reserved_chars,
131 }
132 }
133
134 fn context_chars(&self) -> usize {
135 self.total_chars.saturating_sub(self.reserved_chars)
136 }
137
138 fn question_limit(&self) -> usize {
139 MAX_QUESTION_CHARS
140 .min(self.reserved_chars.max(1))
141 .min(self.total_chars.max(1))
142 }
143
144 fn apply_override(self, override_context_chars: usize) -> Self {
145 let total = override_context_chars.saturating_add(self.reserved_chars);
146 Self {
147 total_chars: total.max(self.reserved_chars + 1),
148 reserved_chars: self.reserved_chars,
149 }
150 }
151
152 fn prompt_ceiling(&self) -> usize {
153 self.total_chars
154 }
155}
156
157pub struct PromptParts {
158 completion_prompt: String,
159 user_message: String,
160 max_output_tokens: usize,
161}
162
163impl PromptParts {
164 pub fn completion_prompt(&self) -> &str {
165 &self.completion_prompt
166 }
167
168 pub fn user_message(&self) -> &str {
169 &self.user_message
170 }
171
172 pub fn max_output_tokens(&self) -> usize {
173 self.max_output_tokens
174 }
175}
176
177fn build_prompt_parts(
178 question: &str,
179 context: &str,
180 budget: &ModelContextBudget,
181 max_output_tokens: usize,
182) -> PromptParts {
183 let mut context_section = context.to_string();
184 let trimmed_question = trim_to(question, budget.question_limit());
185
186 let system_section = format!("### System\n{SYSTEM_PROMPT}");
187 let question_section = format!("### Question\n{trimmed_question}");
188 let answer_stub = "### Answer\n";
189
190 let overhead = system_section.len() + 2 + question_section.len() + 2 + answer_stub.len();
191 if budget.prompt_ceiling() > overhead {
192 let max_context_len = budget
193 .prompt_ceiling()
194 .saturating_sub(overhead)
195 .min(budget.context_chars());
196 if context_section.len() > max_context_len {
197 context_section = clamp_to(&context_section, max_context_len);
198 }
199 } else {
200 context_section = String::new();
201 }
202
203 let completion_prompt =
204 format!("{system_section}\n\n{context_section}\n\n{question_section}\n\n### Answer\n");
205
206 let user_message = format!(
207 "{context_section}\n\nQuestion:\n{trimmed_question}\n\nRespond concisely using only information from the retrieval context."
208 );
209
210 PromptParts {
211 completion_prompt,
212 user_message,
213 max_output_tokens,
214 }
215}
216
217fn trim_to(text: &str, limit: usize) -> String {
218 if text.len() <= limit {
219 text.to_string()
220 } else {
221 let mut truncated = text[..limit].to_string();
222 truncated.push_str("...");
223 truncated
224 }
225}
226
227fn clamp_to(text: &str, limit: usize) -> String {
228 if text.len() <= limit {
229 text.to_string()
230 } else if limit <= 3 {
231 "...".chars().take(limit).collect()
232 } else {
233 let end = limit.saturating_sub(3);
234 let mut truncated = text[..end].to_string();
235 truncated.push_str("...");
236 truncated
237 }
238}
239
240struct ThinkingSpinner {
241 flag: Arc<AtomicBool>,
242 handle: Option<thread::JoinHandle<()>>,
243}
244
245impl ThinkingSpinner {
246 fn start() -> Self {
247 let flag = Arc::new(AtomicBool::new(true));
248 let thread_flag = flag.clone();
249
250 let is_tty = stderr().is_terminal();
254
255 let handle = thread::spawn(move || {
256 if !is_tty {
257 while thread_flag.load(Ordering::Relaxed) {
259 thread::sleep(Duration::from_millis(200));
260 }
261 return;
262 }
263
264 let frames = [
265 "Thinking ",
266 "Thinking. ",
267 "Thinking.. ",
268 "Thinking... ",
269 "Thinking .. ",
270 "Thinking . ",
271 ];
272 let mut idx = 0;
273 let mut err = stderr();
274 while thread_flag.load(Ordering::Relaxed) {
275 let frame = frames[idx % frames.len()];
276 let _ = write!(err, "\r{frame}");
277 let _ = err.flush();
278 idx = idx.wrapping_add(1);
279 thread::sleep(Duration::from_millis(200));
280 }
281 let _ = write!(err, "\r \r");
282 let _ = err.flush();
283 });
284
285 Self {
286 flag,
287 handle: Some(handle),
288 }
289 }
290
291 fn stop(&mut self) {
292 if let Some(handle) = self.handle.take() {
293 self.flag.store(false, Ordering::Relaxed);
294 let _ = handle.join();
295 }
296 }
297}
298
299impl Drop for ThinkingSpinner {
300 fn drop(&mut self) {
301 self.stop();
302 }
303}
304
305#[derive(Debug, Clone)]
306enum ModelKind {
307 TinyLlama,
308 Ollama { model: String },
309 OpenAi { model: String },
310 Gemini { model: String },
311 Claude { model: String },
312}
313
314impl ModelKind {
315 fn parse(raw: &str) -> Option<Self> {
316 let trimmed = raw.trim();
317 if trimmed.is_empty() {
318 return None;
319 }
320
321 let (provider, explicit_model) = if let Some((p, rest)) = trimmed.split_once(':') {
322 let value = rest.trim();
323 let explicit = if value.is_empty() {
324 None
325 } else {
326 Some(value.to_string())
327 };
328 (p.trim().to_ascii_lowercase(), explicit)
329 } else {
330 (trimmed.to_ascii_lowercase(), None)
331 };
332
333 match provider.as_str() {
334 "tinyllama" | "tiny-llama" | "tinyllama-1.1b" => Some(Self::TinyLlama),
335 "ollama" => Some(Self::Ollama {
336 model: explicit_model.unwrap_or_else(|| "ollama1.5".to_string()),
337 }),
338 "ollama1.5" | "ollama1-5" => Some(Self::Ollama {
339 model: "ollama1.5".to_string(),
340 }),
341 "openai" => Some(Self::OpenAi {
342 model: normalize_openai_model(explicit_model),
343 }),
344 "gemini" => Some(Self::Gemini {
345 model: normalize_gemini_model(explicit_model),
346 }),
347 "claude" | "anthropic" => Some(Self::Claude {
348 model: normalize_claude_model(explicit_model),
349 }),
350 _ => Self::infer_from_model_name_full(trimmed, &provider),
354 }
355 }
356
357 fn infer_from_model_name_full(full_name: &str, prefix: &str) -> Option<Self> {
360 let lowered = prefix.to_ascii_lowercase();
361
362 if lowered.starts_with("gemini") || lowered.starts_with("models/gemini") {
364 return Some(Self::Gemini {
365 model: full_name.to_string(),
366 });
367 }
368
369 if lowered.starts_with("gpt-")
371 || lowered.starts_with("o1-")
372 || lowered.starts_with("o3-")
373 || lowered.starts_with("chatgpt-")
374 || lowered.starts_with("text-")
375 {
376 return Some(Self::OpenAi {
377 model: full_name.to_string(),
378 });
379 }
380
381 if lowered.starts_with("claude-") {
383 return Some(Self::Claude {
384 model: full_name.to_string(),
385 });
386 }
387
388 if lowered.starts_with("llama")
391 || lowered.starts_with("mistral")
392 || lowered.starts_with("phi")
393 || lowered.starts_with("codellama")
394 || lowered.starts_with("deepseek")
395 || lowered.starts_with("qwen")
396 || lowered.starts_with("gemma")
397 {
398 return Some(Self::Ollama {
399 model: full_name.to_string(),
400 });
401 }
402
403 None
404 }
405
406 fn label(&self) -> String {
407 match self {
408 Self::TinyLlama => TINYLLAMA_LABEL.to_string(),
409 Self::Ollama { model } => format!("ollama:{model}"),
410 Self::OpenAi { model } => format!("openai:{model}"),
411 Self::Gemini { model } => format!("gemini:{model}"),
412 Self::Claude { model } => format!("claude:{model}"),
413 }
414 }
415
416 fn context_budget(&self) -> ModelContextBudget {
417 match self {
418 Self::TinyLlama => {
419 ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
420 }
421 Self::Ollama { .. } => {
422 ModelContextBudget::new(OLLAMA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
423 }
424 Self::OpenAi { .. } => {
425 ModelContextBudget::new(OPENAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
426 }
427 Self::Gemini { .. } => {
428 ModelContextBudget::new(GEMINI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
429 }
430 Self::Claude { .. } => {
431 ModelContextBudget::new(CLAUDE_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
432 }
433 }
434 }
435
436 fn max_output_tokens(&self) -> usize {
437 match self {
438 Self::TinyLlama => LOCAL_MAX_OUTPUT_TOKENS,
439 Self::Ollama { .. }
440 | Self::OpenAi { .. }
441 | Self::Gemini { .. }
442 | Self::Claude { .. } => REMOTE_MAX_OUTPUT_TOKENS,
443 }
444 }
445}
446
447fn normalize_openai_model(explicit: Option<String>) -> String {
448 match explicit {
449 Some(raw) if !raw.trim().is_empty() => raw,
450 _ => "gpt-4o-mini".to_string(),
451 }
452}
453
454fn normalize_gemini_model(explicit: Option<String>) -> String {
455 let default_model = "gemini-2.5-flash".to_string();
456 let Some(raw) = explicit else {
457 return default_model;
458 };
459
460 let lowered = raw.to_ascii_lowercase();
461 match lowered.as_str() {
462 "gemini-pro" | "gemini-1.5-pro" | "gemini-1.5-flash" | "gemini-2.0-pro-exp" => raw,
463 _ => raw,
464 }
465}
466
467fn normalize_claude_model(explicit: Option<String>) -> String {
468 let default_model = "claude-3-5-sonnet-20241022".to_string();
469 let Some(raw) = explicit else {
470 return default_model;
471 };
472
473 raw
474}
475
476pub fn run_model_inference(
477 requested_model: &str,
478 question: &str,
479 fallback_context: &str,
480 hits: &[SearchHit],
481 context_override: Option<usize>,
482 api_key: Option<&str>,
483 system_prompt_override: Option<&str>,
484) -> Result<ModelInference, ModelRunError> {
485 let Some(model_kind) = ModelKind::parse(requested_model) else {
486 return Err(ModelRunError::UnsupportedModel(requested_model.to_string()));
487 };
488
489 let mut budget = model_kind.context_budget();
490 if let Some(override_chars) = context_override {
491 budget = budget.apply_override(override_chars);
492 }
493
494 let context_plan = context::assemble_context(hits, fallback_context, &budget);
495 let prompt = build_prompt_parts(
496 question,
497 &context_plan.body,
498 &budget,
499 model_kind.max_output_tokens(),
500 );
501
502 let answer = match &model_kind {
503 ModelKind::TinyLlama => {
504 #[cfg(feature = "llama-cpp")]
505 {
506 tinyllama::run(&prompt)?
507 }
508 #[cfg(not(feature = "llama-cpp"))]
509 {
510 return Err(ModelRunError::UnsupportedModel(
511 "tinyllama (llama-cpp feature not enabled)".to_string(),
512 ));
513 }
514 }
515 ModelKind::Ollama { model } => ollama::run(model, &prompt)?,
516 ModelKind::OpenAi { model } => {
517 openai::run(model, &prompt, api_key, system_prompt_override)?
518 }
519 ModelKind::Gemini { model } => {
520 gemini::run(model, &prompt, api_key, system_prompt_override)?
521 }
522 ModelKind::Claude { model } => {
523 claude::run(model, &prompt, api_key, system_prompt_override)?
524 }
525 };
526
527 let context::ContextAggregation {
528 body: context_body,
529 records,
530 } = context_plan;
531 let context_fragments = records
532 .into_iter()
533 .map(ModelContextFragment::from_record)
534 .collect();
535
536 Ok(ModelInference {
537 answer: ModelAnswer {
538 requested: requested_model.to_string(),
539 model: model_kind.label(),
540 answer,
541 },
542 context_body,
543 context_fragments,
544 })
545}
546
547mod context {
548 use super::{ModelContextBudget, clamp_to};
549 use memvid_core::types::SearchHit;
550
551 const CONTEXT_HEADER: &str = "## Retrieval Context\n";
552 const PRIMARY_HEADER: &str = "### Primary Hit\n";
553 const SUPPORT_HEADER: &str = "### Supporting Hits\n";
554 const SUMMARY_HEADER: &str = "### Overflow Summaries\n";
555 const SUMMARY_HIGHLIGHT_CHARS: usize = 240;
556
557 #[derive(Debug, Clone)]
558 pub(super) struct ContextAggregation {
559 pub body: String,
560 pub records: Vec<ContextRecord>,
561 }
562
563 impl ContextAggregation {
564 fn from_fallback(fallback: &str, limit: usize) -> Self {
565 let body = if limit == 0 || fallback.is_empty() {
566 String::new()
567 } else if fallback.len() <= limit {
568 fallback.to_string()
569 } else {
570 clamp_to(fallback, limit)
571 };
572 Self {
573 body,
574 records: Vec::new(),
575 }
576 }
577 }
578
579 #[derive(Debug, Clone)]
580 pub(super) struct ContextRecord {
581 pub rank: usize,
582 pub uri: String,
583 pub title: Option<String>,
584 pub score: Option<f32>,
585 pub matches: usize,
586 pub frame_id: u64,
587 pub range: (usize, usize),
588 pub chunk_range: Option<(usize, usize)>,
589 pub text: String,
590 pub mode: ContextMode,
591 }
592
593 #[derive(Debug, Clone, Copy, Eq, PartialEq)]
594 pub(super) enum ContextMode {
595 Full,
596 Summary,
597 }
598
599 #[derive(Debug, Clone)]
600 pub(super) struct ContextAssemblyPlan {
601 primary: Option<ContextRecord>,
602 supporting: Vec<ContextRecord>,
603 summaries: Vec<ContextRecord>,
604 }
605
606 pub(super) fn assemble_context(
607 hits: &[SearchHit],
608 fallback: &str,
609 budget: &ModelContextBudget,
610 ) -> ContextAggregation {
611 if hits.is_empty() {
612 return ContextAggregation::from_fallback(fallback, budget.context_chars());
613 }
614
615 let plan = assemble_plan(hits, budget.context_chars());
616 let mut body = String::new();
617 let mut records = Vec::new();
618
619 body.push_str(CONTEXT_HEADER);
620 if let Some(primary) = plan.primary {
621 body.push_str(PRIMARY_HEADER);
622 body.push_str(&primary.text);
623 body.push_str("\n\n");
624 records.push(primary);
625 }
626
627 if !plan.supporting.is_empty() {
628 body.push_str(SUPPORT_HEADER);
629 for record in plan.supporting {
630 body.push_str(&record.text);
631 body.push_str("\n\n");
632 records.push(record);
633 }
634 }
635
636 if !plan.summaries.is_empty() {
637 body.push_str(SUMMARY_HEADER);
638 for record in plan.summaries {
639 body.push_str(&record.text);
640 body.push_str("\n\n");
641 records.push(record);
642 }
643 }
644
645 ContextAggregation { body, records }
646 }
647
648 fn assemble_plan(hits: &[SearchHit], mut remaining_chars: usize) -> ContextAssemblyPlan {
649 let mut records = Vec::new();
650 for hit in hits.iter().take(32) {
651 let full_record = build_record(hit, render_full(hit), ContextMode::Full);
652 let summary_record = build_record(hit, render_summary(hit), ContextMode::Summary);
653 records.push((full_record, summary_record));
654 }
655
656 let mut plan = ContextAssemblyPlan {
657 primary: None,
658 supporting: Vec::new(),
659 summaries: Vec::new(),
660 };
661
662 if let Some((primary_full, _)) = records.first() {
663 if primary_full.text.len() <= remaining_chars {
664 remaining_chars = remaining_chars.saturating_sub(primary_full.text.len());
665 plan.primary = Some(primary_full.clone());
666 }
667 }
668
669 for (idx, (full, summary)) in records.iter().enumerate() {
670 if idx == 0 {
671 continue;
672 }
673
674 if full.text.len() <= remaining_chars {
675 remaining_chars = remaining_chars.saturating_sub(full.text.len());
676 plan.supporting.push(full.clone());
677 } else if summary.text.len() <= remaining_chars {
678 remaining_chars = remaining_chars.saturating_sub(summary.text.len());
679 plan.summaries.push(summary.clone());
680 }
681 }
682
683 plan
684 }
685
686 fn render_full(hit: &SearchHit) -> String {
687 format!(
688 "Rank: {}\nURI: {}\nTitle: {}\nMatches: {}\nScore: {:.3}\nSnippet:\n{}",
689 hit.rank,
690 hit.uri,
691 hit.title
692 .clone()
693 .unwrap_or_else(|| "(untitled)".to_string()),
694 hit.matches,
695 hit.score.unwrap_or_default(),
696 hit.chunk_text
697 .clone()
698 .or_else(|| Some(hit.text.clone()))
699 .unwrap_or_default()
700 )
701 }
702
703 fn render_summary(hit: &SearchHit) -> String {
704 let snippet = hit
705 .chunk_text
706 .clone()
707 .or_else(|| Some(hit.text.clone()))
708 .unwrap_or_default();
709 let snippet = trim_highlight(&snippet, SUMMARY_HIGHLIGHT_CHARS);
710 format!(
711 "Rank: {}\nURI: {}\nHighlight: {}",
712 hit.rank, hit.uri, snippet
713 )
714 }
715
716 fn trim_highlight(text: &str, limit: usize) -> String {
717 let clean = text.replace('\n', " ");
718 clamp_to(&clean, limit)
719 }
720
721 fn build_record(hit: &SearchHit, text: String, mode: ContextMode) -> ContextRecord {
722 ContextRecord {
723 rank: hit.rank,
724 uri: hit.uri.clone(),
725 title: hit.title.clone(),
726 score: hit.score,
727 matches: hit.matches,
728 frame_id: hit.frame_id,
729 range: hit.range,
730 chunk_range: hit.chunk_range,
731 text,
732 mode,
733 }
734 }
735}
736
737#[cfg(feature = "llama-cpp")]
738mod tinyllama {
739 use super::{ModelRunError, PromptParts, TINYLLAMA_LABEL, ThinkingSpinner};
740 use anyhow::anyhow;
741 use llama_cpp::standard_sampler::StandardSampler;
742 use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
743 use tokio::runtime::Builder;
744
745 use std::path::{Path, PathBuf};
746
747 const MODEL_DIR: &str = "models/tinyllama";
748 const GGUF_HINT: &str = "*.gguf";
749
750 pub(super) fn run(prompt: &PromptParts) -> Result<String, ModelRunError> {
751 let base_dir = Path::new(MODEL_DIR);
752 let assets = RequiredAssets::new(base_dir);
753
754 if let Some(missing) = assets.missing_paths() {
755 return Err(ModelRunError::AssetsMissing {
756 model: TINYLLAMA_LABEL.to_string(),
757 missing,
758 });
759 }
760
761 let gguf_path = assets.gguf_path.clone().ok_or_else(|| {
762 ModelRunError::Runtime(anyhow!(
763 "no GGUF model file found in {}",
764 base_dir.display()
765 ))
766 })?;
767
768 unsafe {
769 std::env::set_var("GGML_LOG_LEVEL", "ERROR");
770 std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
771 }
772
773 let model =
774 LlamaModel::load_from_file(&gguf_path, LlamaParams::default()).map_err(|err| {
775 ModelRunError::Runtime(anyhow!(
776 "failed to load TinyLlama weights from {}: {err}",
777 gguf_path.display()
778 ))
779 })?;
780
781 let mut session_params = SessionParams::default();
782 if session_params.n_ctx == 0 {
783 session_params.n_ctx = 2048;
784 }
785 session_params.n_batch = session_params.n_ctx.min(512);
786 if session_params.n_ubatch == 0 {
787 session_params.n_ubatch = 512;
788 }
789 let max_tokens = session_params.n_ctx as usize;
790 let mut session = model.create_session(session_params).map_err(|err| {
791 ModelRunError::Runtime(anyhow!("failed to create TinyLlama session: {err}"))
792 })?;
793
794 let mut priming_tokens = model
795 .tokenize_bytes(prompt.completion_prompt().as_bytes(), true, true)
796 .map_err(|err| {
797 ModelRunError::Runtime(anyhow!("failed to tokenize TinyLlama prompt: {err}"))
798 })?;
799
800 let requested_tokens = prompt.max_output_tokens();
801 if max_tokens > 0 {
802 let reserved = requested_tokens + 64;
803 if priming_tokens.len() >= max_tokens.saturating_sub(reserved) {
804 let target = max_tokens.saturating_sub(reserved).max(1);
805 let tail_start = priming_tokens.len().saturating_sub(target);
806 priming_tokens = priming_tokens.split_off(tail_start);
807 }
808 }
809
810 session
811 .advance_context_with_tokens(&priming_tokens)
812 .map_err(|err| {
813 ModelRunError::Runtime(anyhow!("failed to prime TinyLlama context: {err}"))
814 })?;
815
816 let handle = session
817 .start_completing_with(StandardSampler::default(), requested_tokens)
818 .map_err(|err| ModelRunError::Runtime(anyhow!("completion failed to start: {err}")))?;
819
820 let runtime = Builder::new_current_thread()
821 .enable_all()
822 .build()
823 .map_err(|err| {
824 ModelRunError::Runtime(anyhow!("failed to build tokio runtime: {err}"))
825 })?;
826
827 let mut spinner = ThinkingSpinner::start();
828 let generated = runtime.block_on(async { handle.into_string_async().await });
829 spinner.stop();
830
831 let answer = generated.trim().to_string();
832
833 if answer.is_empty() {
834 Ok("No answer generated by TinyLlama.".to_string())
835 } else {
836 Ok(answer)
837 }
838 }
839
840 struct RequiredAssets {
841 gguf_path: Option<PathBuf>,
842 base_dir: PathBuf,
843 }
844
845 impl RequiredAssets {
846 fn new(base_dir: &Path) -> Self {
847 let gguf_path = find_first_gguf(base_dir);
848 Self {
849 gguf_path,
850 base_dir: base_dir.to_path_buf(),
851 }
852 }
853
854 fn missing_paths(&self) -> Option<Vec<PathBuf>> {
855 if self.gguf_path.is_some() {
856 None
857 } else {
858 Some(vec![self.base_dir.join(GGUF_HINT)])
859 }
860 }
861 }
862
863 fn find_first_gguf(base_dir: &Path) -> Option<PathBuf> {
864 let mut entries: Vec<PathBuf> = std::fs::read_dir(base_dir)
865 .ok()?
866 .filter_map(|entry| entry.ok().map(|e| e.path()))
867 .filter(|path| path.is_file() && path.extension().map_or(false, |ext| ext == "gguf"))
868 .collect();
869 entries.sort();
870 entries.into_iter().next()
871 }
872}
873
874mod ollama {
875 use super::{ModelRunError, PromptParts, ThinkingSpinner};
876 use anyhow::anyhow;
877 use reqwest::blocking::Client;
878 use serde::Deserialize;
879 use serde_json::json;
880
881 const ENDPOINT: &str = "http://127.0.0.1:11434/api/generate";
882
883 pub(super) fn run(model: &str, prompt: &PromptParts) -> Result<String, ModelRunError> {
884 let client = Client::builder()
885 .timeout(std::time::Duration::from_secs(60))
886 .build()
887 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
888
889 let mut spinner = ThinkingSpinner::start();
890 let response = client
891 .post(ENDPOINT)
892 .json(&json!({
893 "model": model,
894 "prompt": prompt.completion_prompt(),
895 "stream": false
896 }))
897 .send()
898 .map_err(|err| ModelRunError::Runtime(anyhow!("ollama request failed: {err}")))?
899 .error_for_status()
900 .map_err(|err| {
901 ModelRunError::Runtime(anyhow!("ollama returned error status: {err}"))
902 })?;
903
904 let body: GenerateResponse = response.json().map_err(|err| {
905 ModelRunError::Runtime(anyhow!("failed to decode ollama response: {err}"))
906 })?;
907 spinner.stop();
908
909 let text = body.response.trim().to_string();
910 if text.is_empty() {
911 Ok("No answer returned by Ollama.".to_string())
912 } else {
913 Ok(text)
914 }
915 }
916
917 #[derive(Debug, Deserialize)]
918 struct GenerateResponse {
919 #[serde(default)]
920 response: String,
921 }
922}
923
924mod openai {
925 use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
926 use anyhow::anyhow;
927 use reqwest::blocking::Client;
928 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
929 use serde::Deserialize;
930 use serde_json::json;
931
932 const CHAT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
933 const RESPONSES_ENDPOINT: &str = "https://api.openai.com/v1/responses";
934
935 pub(super) fn run(
936 model: &str,
937 prompt: &PromptParts,
938 override_key: Option<&str>,
939 system_prompt_override: Option<&str>,
940 ) -> Result<String, ModelRunError> {
941 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
942 let key = override_key
943 .map(|value| value.to_string())
944 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
945 .ok_or_else(|| {
946 ModelRunError::Runtime(anyhow!(
947 "OPENAI_API_KEY environment variable is required for OpenAI models"
948 ))
949 })?;
950
951 let mut headers = HeaderMap::new();
952 headers.insert(
953 AUTHORIZATION,
954 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
955 ModelRunError::Runtime(anyhow!("invalid OPENAI_API_KEY header value: {err}"))
956 })?,
957 );
958 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
959
960 let client = Client::builder()
961 .timeout(std::time::Duration::from_secs(60))
962 .default_headers(headers)
963 .build()
964 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
965
966 let mut spinner = ThinkingSpinner::start();
967 let text = if requires_responses_api(model) {
968 let combined_prompt = format!(
969 "System instructions:\n{}\n\nUser query:\n{}",
970 system_prompt,
971 prompt.user_message()
972 );
973 let payload = json!({
974 "model": model,
975 "input": combined_prompt,
976 "max_output_tokens": prompt.max_output_tokens() as u32,
977 "reasoning": {
978 "effort": "low"
979 }
980 });
981
982 let response = client
983 .post(RESPONSES_ENDPOINT)
984 .json(&payload)
985 .send()
986 .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
987
988 let status = response.status();
989 if !status.is_success() {
990 let body = response
991 .text()
992 .unwrap_or_else(|_| "<failed to read body>".to_string());
993 return Err(ModelRunError::Runtime(anyhow!(
994 "OpenAI returned error status {status}: {body}"
995 )));
996 }
997
998 let body: ResponsesResponse = response.json().map_err(|err| {
999 ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
1000 })?;
1001
1002 extract_responses_text(body)
1003 } else {
1004 let payload = json!({
1005 "model": model,
1006 "messages": [
1007 {"role": "system", "content": system_prompt},
1008 {"role": "user", "content": prompt.user_message()}
1009 ],
1010 "temperature": 0.2,
1011 "max_tokens": prompt.max_output_tokens() as u32
1012 });
1013
1014 let response = client
1015 .post(CHAT_ENDPOINT)
1016 .json(&payload)
1017 .send()
1018 .map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
1019
1020 let status = response.status();
1021 if !status.is_success() {
1022 let body = response
1023 .text()
1024 .unwrap_or_else(|_| "<failed to read body>".to_string());
1025 return Err(ModelRunError::Runtime(anyhow!(
1026 "OpenAI returned error status {status}: {body}"
1027 )));
1028 }
1029
1030 let body: ChatResponse = response.json().map_err(|err| {
1031 ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
1032 })?;
1033
1034 extract_chat_text(body)
1035 };
1036 spinner.stop();
1037 Ok(text)
1038 }
1039
1040 #[derive(Debug, Deserialize)]
1041 struct ChatResponse {
1042 choices: Vec<Choice>,
1043 }
1044
1045 #[derive(Debug, Deserialize)]
1046 struct Choice {
1047 message: ChatMessage,
1048 }
1049
1050 #[derive(Debug, Deserialize)]
1051 struct ChatMessage {
1052 #[serde(default)]
1053 content: Option<String>,
1054 }
1055
1056 #[derive(Debug, Deserialize)]
1057 struct ResponsesResponse {
1058 #[serde(default)]
1059 output: Vec<ResponseItem>,
1060 #[serde(default)]
1061 output_text: Vec<String>,
1062 }
1063
1064 #[derive(Debug, Deserialize)]
1065 struct ResponseItem {
1066 #[serde(default)]
1067 content: Vec<ResponseContent>,
1068 }
1069
1070 #[derive(Debug, Deserialize)]
1071 struct ResponseContent {
1072 #[serde(rename = "type")]
1073 kind: String,
1074 #[serde(default)]
1075 text: Option<String>,
1076 }
1077
1078 fn extract_chat_text(body: ChatResponse) -> String {
1079 body.choices
1080 .into_iter()
1081 .find_map(|choice| choice.message.content)
1082 .map(|value| value.trim().to_string())
1083 .unwrap_or_else(|| "No answer returned by OpenAI.".to_string())
1084 }
1085
1086 fn extract_responses_text(body: ResponsesResponse) -> String {
1087 if !body.output_text.is_empty() {
1088 let text = body
1089 .output_text
1090 .into_iter()
1091 .find(|value| !value.trim().is_empty());
1092 if let Some(text) = text {
1093 return text.trim().to_string();
1094 }
1095 }
1096 for item in body.output {
1097 for segment in item.content {
1098 match segment.kind.as_str() {
1099 "output_text" | "text" => {
1100 if let Some(text) = segment.text {
1101 let trimmed = text.trim();
1102 if !trimmed.is_empty() {
1103 return trimmed.to_string();
1104 }
1105 }
1106 }
1107 _ => {}
1108 }
1109 }
1110 }
1111 "No answer returned by OpenAI.".to_string()
1112 }
1113
1114 fn requires_responses_api(model: &str) -> bool {
1115 let lowered = model.to_ascii_lowercase();
1116 lowered.starts_with("gpt-5") || lowered.contains("gpt-4.1")
1117 }
1118}
1119
1120mod gemini {
1121 use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
1122 use anyhow::anyhow;
1123 use reqwest::blocking::Client;
1124 use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
1125 use serde::Deserialize;
1126 use serde_json::json;
1127
1128 pub(super) fn run(
1129 model: &str,
1130 prompt: &PromptParts,
1131 override_key: Option<&str>,
1132 system_prompt_override: Option<&str>,
1133 ) -> Result<String, ModelRunError> {
1134 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
1135 let key = override_key
1136 .map(|value| value.to_string())
1137 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
1138 .ok_or_else(|| {
1139 ModelRunError::Runtime(anyhow!(
1140 "GEMINI_API_KEY environment variable is required for Gemini models"
1141 ))
1142 })?;
1143
1144 let url = format!(
1145 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
1146 model
1147 );
1148
1149 let mut headers = HeaderMap::new();
1150 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
1151 headers.insert(
1152 HeaderName::from_static("x-goog-api-key"),
1153 HeaderValue::from_str(&key).map_err(|err| {
1154 ModelRunError::Runtime(anyhow!("invalid GEMINI_API_KEY header value: {err}"))
1155 })?,
1156 );
1157
1158 let client = Client::builder()
1159 .timeout(std::time::Duration::from_secs(60))
1160 .default_headers(headers)
1161 .build()
1162 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
1163
1164 let payload = json!({
1165 "contents": [{
1166 "parts": [
1167 { "text": system_prompt },
1168 { "text": prompt.user_message() }
1169 ]
1170 }],
1171 "generationConfig": {
1172 "temperature": 0.2,
1173 "maxOutputTokens": prompt.max_output_tokens() as u32,
1174 "topK": 40,
1175 "topP": 0.95
1176 }
1177 });
1178
1179 let mut spinner = ThinkingSpinner::start();
1180 let response = client
1181 .post(url)
1182 .json(&payload)
1183 .send()
1184 .map_err(|err| ModelRunError::Runtime(anyhow!("Gemini request failed: {err}")))?
1185 .error_for_status()
1186 .map_err(|err| {
1187 ModelRunError::Runtime(anyhow!("Gemini returned error status: {err}"))
1188 })?;
1189
1190 let body: GenerateResponse = response.json().map_err(|err| {
1191 ModelRunError::Runtime(anyhow!("failed to decode Gemini response: {err}"))
1192 })?;
1193 spinner.stop();
1194
1195 let text = body
1196 .candidates
1197 .into_iter()
1198 .flat_map(|candidate| candidate.content.parts)
1199 .find_map(|part| part.text)
1200 .map(|value| value.trim().to_string())
1201 .unwrap_or_else(|| "No answer returned by Gemini.".to_string());
1202
1203 Ok(text)
1204 }
1205
1206 #[derive(Debug, Deserialize)]
1207 struct GenerateResponse {
1208 candidates: Vec<Candidate>,
1209 }
1210
1211 #[derive(Debug, Deserialize)]
1212 struct Candidate {
1213 content: CandidateContent,
1214 }
1215
1216 #[derive(Debug, Deserialize)]
1217 struct CandidateContent {
1218 parts: Vec<CandidatePart>,
1219 }
1220
1221 #[derive(Debug, Deserialize)]
1222 struct CandidatePart {
1223 #[serde(default)]
1224 text: Option<String>,
1225 }
1226}
1227
1228mod claude {
1229 use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
1230 use anyhow::anyhow;
1231 use reqwest::blocking::Client;
1232 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
1233 use serde::Deserialize;
1234 use serde_json::json;
1235
1236 const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
1237 const API_VERSION: &str = "2023-06-01";
1238
1239 pub(super) fn run(
1240 model: &str,
1241 prompt: &PromptParts,
1242 override_key: Option<&str>,
1243 system_prompt_override: Option<&str>,
1244 ) -> Result<String, ModelRunError> {
1245 let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
1246 let key = override_key
1247 .map(|value| value.to_string())
1248 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
1249 .or_else(|| std::env::var("CLAUDE_API_KEY").ok())
1250 .ok_or_else(|| {
1251 ModelRunError::Runtime(anyhow!(
1252 "ANTHROPIC_API_KEY environment variable is required for Claude models"
1253 ))
1254 })?;
1255
1256 let mut headers = HeaderMap::new();
1257 headers.insert(
1258 AUTHORIZATION,
1259 HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
1260 ModelRunError::Runtime(anyhow!("invalid ANTHROPIC_API_KEY header value: {err}"))
1261 })?,
1262 );
1263 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
1264 headers.insert(
1265 HeaderName::from_static("anthropic-version"),
1266 HeaderValue::from_static(API_VERSION),
1267 );
1268
1269 let client = Client::builder()
1270 .timeout(std::time::Duration::from_secs(60))
1271 .default_headers(headers)
1272 .build()
1273 .map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
1274
1275 let payload = json!({
1276 "model": model,
1277 "max_tokens": prompt.max_output_tokens() as u32,
1278 "temperature": 0.2,
1279 "system": system_prompt,
1280 "messages": [{
1281 "role": "user",
1282 "content": [{"type": "text", "text": prompt.user_message()}]
1283 }]
1284 });
1285
1286 let mut spinner = ThinkingSpinner::start();
1287 let response = client
1288 .post(ENDPOINT)
1289 .json(&payload)
1290 .send()
1291 .map_err(|err| ModelRunError::Runtime(anyhow!("Claude request failed: {err}")))?
1292 .error_for_status()
1293 .map_err(|err| {
1294 ModelRunError::Runtime(anyhow!("Claude returned error status: {err}"))
1295 })?;
1296
1297 let body: ClaudeResponse = response.json().map_err(|err| {
1298 ModelRunError::Runtime(anyhow!("failed to decode Claude response: {err}"))
1299 })?;
1300 spinner.stop();
1301
1302 let text = body
1303 .content
1304 .into_iter()
1305 .find_map(|part| match part {
1306 ContentBlock::Text { text } if !text.trim().is_empty() => {
1307 Some(text.trim().to_string())
1308 }
1309 _ => None,
1310 })
1311 .unwrap_or_else(|| "No answer returned by Claude.".to_string());
1312
1313 Ok(text)
1314 }
1315
1316 #[derive(Debug, Deserialize)]
1317 struct ClaudeResponse {
1318 #[serde(default)]
1319 content: Vec<ContentBlock>,
1320 }
1321
1322 #[derive(Debug, Deserialize)]
1323 #[serde(tag = "type", rename_all = "lowercase")]
1324 enum ContentBlock {
1325 Text {
1326 text: String,
1327 },
1328 #[serde(other)]
1329 Other,
1330 }
1331}
1332
1333pub const ENTITY_EXTRACTION_PROMPT: &str = r#"Extract named entities from the provided text. Return a JSON object with an "entities" array.
1339
1340Each entity should have:
1341- "name": The entity name as it appears in the text
1342- "type": One of "PERSON", "ORG", "LOCATION", "DATE", "PRODUCT", "EVENT", or "OTHER"
1343- "confidence": A number between 0.0 and 1.0 indicating your confidence
1344
1345Guidelines:
13461. Only include entities you're confident about (confidence >= 0.7)
13472. Preserve the original capitalization of entity names
13483. For organizations, include full names (e.g., "S&P Global" not just "S&P")
13494. For people, include full names when available
13505. Deduplicate: if an entity appears multiple times, include it only once
1351
1352Return format:
1353{"entities": [{"name": "...", "type": "...", "confidence": 0.9}, ...]}"#;
1354
1355#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1357pub struct ExtractedEntity {
1358 pub name: String,
1359 #[serde(rename = "type")]
1360 pub entity_type: String,
1361 pub confidence: f32,
1362}
1363
1364#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1366pub struct EntityExtractionResponse {
1367 pub entities: Vec<ExtractedEntity>,
1368 pub model: String,
1369 pub text_chars: usize,
1370}
1371
1372pub fn extract_entities(
1396 model: &str,
1397 text: &str,
1398 system_prompt: Option<&str>,
1399 api_key: Option<&str>,
1400) -> Result<EntityExtractionResponse, ModelRunError> {
1401 let prompt = system_prompt.unwrap_or(ENTITY_EXTRACTION_PROMPT);
1402 let text_chars = text.len();
1403
1404 let (provider, model_name) = parse_model_spec(model);
1406
1407 let json_response = match provider.as_str() {
1408 "openai" => extract_entities_openai(&model_name, text, prompt, api_key)?,
1409 "claude" | "anthropic" => extract_entities_claude(&model_name, text, prompt, api_key)?,
1410 "gemini" | "google" => extract_entities_gemini(&model_name, text, prompt, api_key)?,
1411 _ => {
1412 return Err(ModelRunError::UnsupportedModel(format!(
1413 "Entity extraction not supported for provider '{}'. Use openai:, claude:, or gemini:",
1414 provider
1415 )));
1416 }
1417 };
1418
1419 let entities = parse_entity_response(&json_response)?;
1421
1422 Ok(EntityExtractionResponse {
1423 entities,
1424 model: model.to_string(),
1425 text_chars,
1426 })
1427}
1428
1429fn parse_model_spec(model: &str) -> (String, String) {
1430 if let Some((provider, name)) = model.split_once(':') {
1431 (provider.to_lowercase(), name.to_string())
1432 } else {
1433 ("openai".to_string(), model.to_string())
1435 }
1436}
1437
1438fn parse_entity_response(json_str: &str) -> Result<Vec<ExtractedEntity>, ModelRunError> {
1439 let trimmed = json_str.trim();
1441
1442 let clean_json = if trimmed.starts_with("```json") {
1444 trimmed
1445 .strip_prefix("```json")
1446 .and_then(|s| s.strip_suffix("```"))
1447 .unwrap_or(trimmed)
1448 .trim()
1449 } else if trimmed.starts_with("```") {
1450 trimmed
1451 .strip_prefix("```")
1452 .and_then(|s| s.strip_suffix("```"))
1453 .unwrap_or(trimmed)
1454 .trim()
1455 } else {
1456 trimmed
1457 };
1458
1459 #[derive(serde::Deserialize)]
1461 struct EntityResponse {
1462 entities: Vec<ExtractedEntity>,
1463 }
1464
1465 if let Ok(response) = serde_json::from_str::<EntityResponse>(clean_json) {
1466 return Ok(response.entities);
1467 }
1468
1469 if let Ok(entities) = serde_json::from_str::<Vec<ExtractedEntity>>(clean_json) {
1471 return Ok(entities);
1472 }
1473
1474 Err(ModelRunError::Runtime(anyhow::anyhow!(
1475 "Failed to parse entity extraction response as JSON: {}",
1476 &clean_json[..clean_json.len().min(200)]
1477 )))
1478}
1479
1480fn extract_entities_openai(
1481 model: &str,
1482 text: &str,
1483 system_prompt: &str,
1484 api_key: Option<&str>,
1485) -> Result<String, ModelRunError> {
1486 use serde_json::json;
1487
1488 let api_key = api_key
1489 .map(|s| s.to_string())
1490 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
1491 .ok_or_else(|| {
1492 ModelRunError::Runtime(anyhow::anyhow!(
1493 "OpenAI API key required. Set OPENAI_API_KEY or pass api_key parameter."
1494 ))
1495 })?;
1496
1497 let model_name = if model.is_empty() { "gpt-4o-mini" } else { model };
1498
1499 let client = reqwest::blocking::Client::new();
1500 let payload = json!({
1501 "model": model_name,
1502 "messages": [
1503 {"role": "system", "content": system_prompt},
1504 {"role": "user", "content": text}
1505 ],
1506 "response_format": {"type": "json_object"},
1507 "temperature": 0.1
1508 });
1509
1510 let response = client
1511 .post("https://api.openai.com/v1/chat/completions")
1512 .header("Authorization", format!("Bearer {}", api_key))
1513 .json(&payload)
1514 .send()
1515 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI request failed: {}", e)))?
1516 .error_for_status()
1517 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("OpenAI returned error: {}", e)))?;
1518
1519 #[derive(serde::Deserialize)]
1520 struct OpenAIResponse {
1521 choices: Vec<OpenAIChoice>,
1522 }
1523 #[derive(serde::Deserialize)]
1524 struct OpenAIChoice {
1525 message: OpenAIMessage,
1526 }
1527 #[derive(serde::Deserialize)]
1528 struct OpenAIMessage {
1529 content: String,
1530 }
1531
1532 let body: OpenAIResponse = response.json().map_err(|e| {
1533 ModelRunError::Runtime(anyhow::anyhow!("Failed to parse OpenAI response: {}", e))
1534 })?;
1535
1536 body.choices
1537 .into_iter()
1538 .next()
1539 .map(|c| c.message.content)
1540 .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No response from OpenAI")))
1541}
1542
1543fn extract_entities_claude(
1544 model: &str,
1545 text: &str,
1546 system_prompt: &str,
1547 api_key: Option<&str>,
1548) -> Result<String, ModelRunError> {
1549 use serde_json::json;
1550
1551 let api_key = api_key
1552 .map(|s| s.to_string())
1553 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
1554 .ok_or_else(|| {
1555 ModelRunError::Runtime(anyhow::anyhow!(
1556 "Anthropic API key required. Set ANTHROPIC_API_KEY or pass api_key parameter."
1557 ))
1558 })?;
1559
1560 let model_name = if model.is_empty() { "claude-3-5-sonnet-20241022" } else { model };
1561
1562 let client = reqwest::blocking::Client::new();
1563 let payload = json!({
1564 "model": model_name,
1565 "max_tokens": 4096,
1566 "system": format!("{}\n\nRespond with valid JSON only.", system_prompt),
1567 "messages": [
1568 {"role": "user", "content": text}
1569 ]
1570 });
1571
1572 let response = client
1573 .post("https://api.anthropic.com/v1/messages")
1574 .header("x-api-key", &api_key)
1575 .header("anthropic-version", "2023-06-01")
1576 .header("content-type", "application/json")
1577 .json(&payload)
1578 .send()
1579 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude request failed: {}", e)))?
1580 .error_for_status()
1581 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Claude returned error: {}", e)))?;
1582
1583 #[derive(serde::Deserialize)]
1584 struct ClaudeResponse {
1585 content: Vec<ClaudeContent>,
1586 }
1587 #[derive(serde::Deserialize)]
1588 struct ClaudeContent {
1589 text: Option<String>,
1590 }
1591
1592 let body: ClaudeResponse = response.json().map_err(|e| {
1593 ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Claude response: {}", e))
1594 })?;
1595
1596 body.content
1597 .into_iter()
1598 .find_map(|c| c.text)
1599 .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Claude")))
1600}
1601
1602fn extract_entities_gemini(
1603 model: &str,
1604 text: &str,
1605 system_prompt: &str,
1606 api_key: Option<&str>,
1607) -> Result<String, ModelRunError> {
1608 use serde_json::json;
1609
1610 let api_key = api_key
1611 .map(|s| s.to_string())
1612 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
1613 .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
1614 .ok_or_else(|| {
1615 ModelRunError::Runtime(anyhow::anyhow!(
1616 "Gemini API key required. Set GEMINI_API_KEY or pass api_key parameter."
1617 ))
1618 })?;
1619
1620 let model_name = if model.is_empty() { "gemini-2.0-flash" } else { model };
1621 let url = format!(
1622 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
1623 model_name, api_key
1624 );
1625
1626 let client = reqwest::blocking::Client::new();
1627 let payload = json!({
1628 "contents": [{
1629 "parts": [{"text": format!("{}\n\nText to analyze:\n{}", system_prompt, text)}]
1630 }],
1631 "generationConfig": {
1632 "temperature": 0.1,
1633 "responseMimeType": "application/json"
1634 }
1635 });
1636
1637 let response = client
1638 .post(&url)
1639 .json(&payload)
1640 .send()
1641 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini request failed: {}", e)))?
1642 .error_for_status()
1643 .map_err(|e| ModelRunError::Runtime(anyhow::anyhow!("Gemini returned error: {}", e)))?;
1644
1645 #[derive(serde::Deserialize)]
1646 struct GeminiResponse {
1647 candidates: Vec<GeminiCandidate>,
1648 }
1649 #[derive(serde::Deserialize)]
1650 struct GeminiCandidate {
1651 content: GeminiContent,
1652 }
1653 #[derive(serde::Deserialize)]
1654 struct GeminiContent {
1655 parts: Vec<GeminiPart>,
1656 }
1657 #[derive(serde::Deserialize)]
1658 struct GeminiPart {
1659 text: Option<String>,
1660 }
1661
1662 let body: GeminiResponse = response.json().map_err(|e| {
1663 ModelRunError::Runtime(anyhow::anyhow!("Failed to parse Gemini response: {}", e))
1664 })?;
1665
1666 body.candidates
1667 .into_iter()
1668 .next()
1669 .and_then(|c| c.content.parts.into_iter().find_map(|p| p.text))
1670 .ok_or_else(|| ModelRunError::Runtime(anyhow::anyhow!("No text response from Gemini")))
1671}
1672
1673#[cfg(test)]
1674mod tests {
1675 use super::*;
1676
1677 #[test]
1678 fn normalize_models() {
1679 assert_eq!(normalize_openai_model(None), "gpt-4o-mini");
1680 assert_eq!(normalize_gemini_model(None), "gemini-2.5-flash");
1681 assert_eq!(normalize_claude_model(None), "claude-3-5-sonnet-20241022");
1682 }
1683
1684 #[test]
1685 fn parse_entity_json() {
1686 let json = r#"{"entities": [{"name": "John", "type": "PERSON", "confidence": 0.95}]}"#;
1687 let entities = parse_entity_response(json).unwrap();
1688 assert_eq!(entities.len(), 1);
1689 assert_eq!(entities[0].name, "John");
1690 }
1691
1692 #[test]
1693 fn parse_entity_json_with_markdown() {
1694 let json = r#"```json
1695{"entities": [{"name": "Microsoft", "type": "ORG", "confidence": 0.99}]}
1696```"#;
1697 let entities = parse_entity_response(json).unwrap();
1698 assert_eq!(entities.len(), 1);
1699 assert_eq!(entities[0].name, "Microsoft");
1700 }
1701
1702 #[test]
1703 fn parse_model_spec_test() {
1704 let (provider, model) = parse_model_spec("openai:gpt-4o");
1705 assert_eq!(provider, "openai");
1706 assert_eq!(model, "gpt-4o");
1707
1708 let (provider, model) = parse_model_spec("gpt-4o-mini");
1709 assert_eq!(provider, "openai");
1710 assert_eq!(model, "gpt-4o-mini");
1711 }
1712}