use std::fmt;
use std::io::{IsTerminal, Write, stderr};
use std::path::PathBuf;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use std::thread;
use std::time::Duration;
use memvid_core::types::SearchHit;
#[derive(Debug, Clone)]
pub struct ModelAnswer {
pub requested: String,
pub model: String,
pub answer: String,
}
#[derive(Debug, Clone)]
pub struct ModelInference {
pub answer: ModelAnswer,
pub context_body: String,
pub context_fragments: Vec<ModelContextFragment>,
}
#[derive(Debug, Clone)]
pub struct ModelContextFragment {
pub rank: usize,
pub uri: String,
pub title: Option<String>,
pub score: Option<f32>,
pub matches: usize,
pub frame_id: u64,
pub range: (usize, usize),
pub chunk_range: Option<(usize, usize)>,
pub text: String,
pub kind: ModelContextFragmentKind,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ModelContextFragmentKind {
Full,
Summary,
}
impl ModelContextFragment {
fn from_record(record: context::ContextRecord) -> Self {
let kind = match record.mode {
context::ContextMode::Full => ModelContextFragmentKind::Full,
context::ContextMode::Summary => ModelContextFragmentKind::Summary,
};
Self {
rank: record.rank,
uri: record.uri,
title: record.title,
score: record.score,
matches: record.matches,
frame_id: record.frame_id,
range: record.range,
chunk_range: record.chunk_range,
text: record.text,
kind,
}
}
}
#[derive(Debug)]
pub enum ModelRunError {
UnsupportedModel(String),
AssetsMissing {
model: String,
missing: Vec<PathBuf>,
},
Runtime(anyhow::Error),
}
impl fmt::Display for ModelRunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnsupportedModel(model) => write!(f, "unsupported model '{model}'"),
Self::AssetsMissing { model, missing } => {
let paths: Vec<_> = missing
.iter()
.map(|path| path.display().to_string())
.collect();
write!(
f,
"model '{model}' missing required assets: {}",
paths.join(", ")
)
}
Self::Runtime(err) => write!(f, "model runtime error: {err}"),
}
}
}
impl std::error::Error for ModelRunError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Runtime(err) => Some(err.root_cause()),
_ => None,
}
}
}
const LOCAL_CONTEXT_CHARS: usize = 32_768;
const MAX_QUESTION_CHARS: usize = 512;
const LOCAL_MAX_OUTPUT_TOKENS: usize = 256;
const REMOTE_MAX_OUTPUT_TOKENS: usize = 768;
const SYSTEM_PROMPT: &str = "You are a helpful assistant that answers questions based on the provided context.\n\nGuidelines:\n1. Read all the provided context carefully before answering.\n2. PAY CLOSE ATTENTION TO DATES: When context has timestamps, note that later dates reflect current/updated values.\n3. For temporal questions (e.g., \"How many X when started vs now?\"), find both earlier and later values. The most recent reflects the current state.\n4. For counting questions (e.g., \"How many times...\"), count occurrences in the context.\n5. Only say \"not enough information\" if the context truly contains no relevant information.\n6. Base all answers on the context - do not use external knowledge.";
const TINYLLAMA_LABEL: &str = "tinyllama-1.1b";
const LOCAL_PROMPT_MARGIN_CHARS: usize = 2_048;
const REMOTE_PROMPT_MARGIN_CHARS: usize = 4_096;
const OLLAMA_PROMPT_CHARS: usize = 110_000;
const OPENAI_PROMPT_CHARS: usize = 240_000;
const GEMINI_PROMPT_CHARS: usize = 320_000;
const CLAUDE_PROMPT_CHARS: usize = 360_000;
#[derive(Debug, Clone, Copy)]
struct ModelContextBudget {
total_chars: usize,
reserved_chars: usize,
}
impl ModelContextBudget {
const fn new(total_chars: usize, reserved_chars: usize) -> Self {
Self {
total_chars,
reserved_chars,
}
}
fn context_chars(&self) -> usize {
self.total_chars.saturating_sub(self.reserved_chars)
}
fn question_limit(&self) -> usize {
MAX_QUESTION_CHARS
.min(self.reserved_chars.max(1))
.min(self.total_chars.max(1))
}
fn apply_override(self, override_context_chars: usize) -> Self {
let total = override_context_chars.saturating_add(self.reserved_chars);
Self {
total_chars: total.max(self.reserved_chars + 1),
reserved_chars: self.reserved_chars,
}
}
fn prompt_ceiling(&self) -> usize {
self.total_chars
}
}
pub struct PromptParts {
completion_prompt: String,
user_message: String,
max_output_tokens: usize,
}
impl PromptParts {
pub fn completion_prompt(&self) -> &str {
&self.completion_prompt
}
pub fn user_message(&self) -> &str {
&self.user_message
}
pub fn max_output_tokens(&self) -> usize {
self.max_output_tokens
}
}
fn build_prompt_parts(
question: &str,
context: &str,
budget: &ModelContextBudget,
max_output_tokens: usize,
) -> PromptParts {
let mut context_section = context.to_string();
let trimmed_question = trim_to(question, budget.question_limit());
let system_section = format!("### System\n{SYSTEM_PROMPT}");
let question_section = format!("### Question\n{trimmed_question}");
let answer_stub = "### Answer\n";
let overhead = system_section.len() + 2 + question_section.len() + 2 + answer_stub.len();
if budget.prompt_ceiling() > overhead {
let max_context_len = budget
.prompt_ceiling()
.saturating_sub(overhead)
.min(budget.context_chars());
if context_section.len() > max_context_len {
context_section = clamp_to(&context_section, max_context_len);
}
} else {
context_section = String::new();
}
let completion_prompt =
format!("{system_section}\n\n{context_section}\n\n{question_section}\n\n### Answer\n");
let user_message = format!(
"{context_section}\n\nQuestion:\n{trimmed_question}\n\nRespond concisely using only information from the retrieval context."
);
PromptParts {
completion_prompt,
user_message,
max_output_tokens,
}
}
fn trim_to(text: &str, limit: usize) -> String {
if text.len() <= limit {
text.to_string()
} else {
let mut truncated = text[..limit].to_string();
truncated.push_str("...");
truncated
}
}
fn clamp_to(text: &str, limit: usize) -> String {
if text.len() <= limit {
text.to_string()
} else if limit <= 3 {
"...".chars().take(limit).collect()
} else {
let end = limit.saturating_sub(3);
let mut truncated = text[..end].to_string();
truncated.push_str("...");
truncated
}
}
struct ThinkingSpinner {
flag: Arc<AtomicBool>,
handle: Option<thread::JoinHandle<()>>,
}
impl ThinkingSpinner {
fn start() -> Self {
let flag = Arc::new(AtomicBool::new(true));
let thread_flag = flag.clone();
// Only show spinner if stderr is a TTY (interactive terminal).
// This prevents control characters from polluting output when
// stderr is redirected or combined with stdout (e.g., `2>&1`).
let is_tty = stderr().is_terminal();
let handle = thread::spawn(move || {
if !is_tty {
// Not a TTY, don't show spinner - just wait for stop signal
while thread_flag.load(Ordering::Relaxed) {
thread::sleep(Duration::from_millis(200));
}
return;
}
let frames = [
"Thinking ",
"Thinking. ",
"Thinking.. ",
"Thinking... ",
"Thinking .. ",
"Thinking . ",
];
let mut idx = 0;
let mut err = stderr();
while thread_flag.load(Ordering::Relaxed) {
let frame = frames[idx % frames.len()];
let _ = write!(err, "\r{frame}");
let _ = err.flush();
idx = idx.wrapping_add(1);
thread::sleep(Duration::from_millis(200));
}
let _ = write!(err, "\r \r");
let _ = err.flush();
});
Self {
flag,
handle: Some(handle),
}
}
fn stop(&mut self) {
if let Some(handle) = self.handle.take() {
self.flag.store(false, Ordering::Relaxed);
let _ = handle.join();
}
}
}
impl Drop for ThinkingSpinner {
fn drop(&mut self) {
self.stop();
}
}
#[derive(Debug, Clone)]
enum ModelKind {
TinyLlama,
Ollama { model: String },
OpenAi { model: String },
Gemini { model: String },
Claude { model: String },
}
impl ModelKind {
fn parse(raw: &str) -> Option<Self> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return None;
}
let (provider, explicit_model) = if let Some((p, rest)) = trimmed.split_once(':') {
let value = rest.trim();
let explicit = if value.is_empty() {
None
} else {
Some(value.to_string())
};
(p.trim().to_ascii_lowercase(), explicit)
} else {
(trimmed.to_ascii_lowercase(), None)
};
match provider.as_str() {
"tinyllama" | "tiny-llama" | "tinyllama-1.1b" => Some(Self::TinyLlama),
"ollama" => Some(Self::Ollama {
model: explicit_model.unwrap_or_else(|| "ollama1.5".to_string()),
}),
"ollama1.5" | "ollama1-5" => Some(Self::Ollama {
model: "ollama1.5".to_string(),
}),
"openai" => Some(Self::OpenAi {
model: normalize_openai_model(explicit_model),
}),
"gemini" => Some(Self::Gemini {
model: normalize_gemini_model(explicit_model),
}),
"claude" | "anthropic" => Some(Self::Claude {
model: normalize_claude_model(explicit_model),
}),
// Auto-detect provider from model name prefix
// For Ollama models with colons in the name (e.g., qwen2.5:1.5b),
// we need to use the full original name, not just the provider prefix
_ => Self::infer_from_model_name_full(trimmed, &provider),
}
}
/// Infer the provider from a model name, using the full original name for Ollama models.
/// This handles model names with colons like "qwen2.5:1.5b" by using the full name.
fn infer_from_model_name_full(full_name: &str, prefix: &str) -> Option<Self> {
let lowered = prefix.to_ascii_lowercase();
// Gemini models: gemini-*, models/gemini-*
if lowered.starts_with("gemini") || lowered.starts_with("models/gemini") {
return Some(Self::Gemini {
model: full_name.to_string(),
});
}
// OpenAI models: gpt-*, o1-*, chatgpt-*, text-davinci-*, etc.
if lowered.starts_with("gpt-")
|| lowered.starts_with("o1-")
|| lowered.starts_with("o3-")
|| lowered.starts_with("chatgpt-")
|| lowered.starts_with("text-")
{
return Some(Self::OpenAi {
model: full_name.to_string(),
});
}
// Claude/Anthropic models: claude-*
if lowered.starts_with("claude-") {
return Some(Self::Claude {
model: full_name.to_string(),
});
}
// Ollama models: llama*, mistral*, phi*, qwen*, gemma*, etc.
// Use the full name to preserve version tags like ":1.5b"
if lowered.starts_with("llama")
|| lowered.starts_with("mistral")
|| lowered.starts_with("phi")
|| lowered.starts_with("codellama")
|| lowered.starts_with("deepseek")
|| lowered.starts_with("qwen")
|| lowered.starts_with("gemma")
{
return Some(Self::Ollama {
model: full_name.to_string(),
});
}
None
}
fn label(&self) -> String {
match self {
Self::TinyLlama => TINYLLAMA_LABEL.to_string(),
Self::Ollama { model } => format!("ollama:{model}"),
Self::OpenAi { model } => format!("openai:{model}"),
Self::Gemini { model } => format!("gemini:{model}"),
Self::Claude { model } => format!("claude:{model}"),
}
}
fn context_budget(&self) -> ModelContextBudget {
match self {
Self::TinyLlama => {
ModelContextBudget::new(LOCAL_CONTEXT_CHARS, LOCAL_PROMPT_MARGIN_CHARS)
}
Self::Ollama { .. } => {
ModelContextBudget::new(OLLAMA_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
}
Self::OpenAi { .. } => {
ModelContextBudget::new(OPENAI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
}
Self::Gemini { .. } => {
ModelContextBudget::new(GEMINI_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
}
Self::Claude { .. } => {
ModelContextBudget::new(CLAUDE_PROMPT_CHARS, REMOTE_PROMPT_MARGIN_CHARS)
}
}
}
fn max_output_tokens(&self) -> usize {
match self {
Self::TinyLlama => LOCAL_MAX_OUTPUT_TOKENS,
Self::Ollama { .. }
| Self::OpenAi { .. }
| Self::Gemini { .. }
| Self::Claude { .. } => REMOTE_MAX_OUTPUT_TOKENS,
}
}
}
fn normalize_openai_model(explicit: Option<String>) -> String {
match explicit {
Some(raw) if !raw.trim().is_empty() => raw,
_ => "gpt-4o-mini".to_string(),
}
}
fn normalize_gemini_model(explicit: Option<String>) -> String {
let default_model = "gemini-2.5-flash".to_string();
let Some(raw) = explicit else {
return default_model;
};
let lowered = raw.to_ascii_lowercase();
match lowered.as_str() {
"gemini-pro" | "gemini-1.5-pro" | "gemini-1.5-flash" | "gemini-2.0-pro-exp" => raw,
_ => raw,
}
}
fn normalize_claude_model(explicit: Option<String>) -> String {
let default_model = "claude-3-5-sonnet-20241022".to_string();
let Some(raw) = explicit else {
return default_model;
};
raw
}
pub fn run_model_inference(
requested_model: &str,
question: &str,
fallback_context: &str,
hits: &[SearchHit],
context_override: Option<usize>,
api_key: Option<&str>,
system_prompt_override: Option<&str>,
) -> Result<ModelInference, ModelRunError> {
let Some(model_kind) = ModelKind::parse(requested_model) else {
return Err(ModelRunError::UnsupportedModel(requested_model.to_string()));
};
let mut budget = model_kind.context_budget();
if let Some(override_chars) = context_override {
budget = budget.apply_override(override_chars);
}
let context_plan = context::assemble_context(hits, fallback_context, &budget);
let prompt = build_prompt_parts(
question,
&context_plan.body,
&budget,
model_kind.max_output_tokens(),
);
let answer = match &model_kind {
ModelKind::TinyLlama => {
#[cfg(feature = "llama-cpp")]
{
tinyllama::run(&prompt)?
}
#[cfg(not(feature = "llama-cpp"))]
{
return Err(ModelRunError::UnsupportedModel(
"tinyllama (llama-cpp feature not enabled)".to_string(),
));
}
}
ModelKind::Ollama { model } => ollama::run(model, &prompt)?,
ModelKind::OpenAi { model } => {
openai::run(model, &prompt, api_key, system_prompt_override)?
}
ModelKind::Gemini { model } => {
gemini::run(model, &prompt, api_key, system_prompt_override)?
}
ModelKind::Claude { model } => {
claude::run(model, &prompt, api_key, system_prompt_override)?
}
};
let context::ContextAggregation {
body: context_body,
records,
} = context_plan;
let context_fragments = records
.into_iter()
.map(ModelContextFragment::from_record)
.collect();
Ok(ModelInference {
answer: ModelAnswer {
requested: requested_model.to_string(),
model: model_kind.label(),
answer,
},
context_body,
context_fragments,
})
}
mod context {
use super::{ModelContextBudget, clamp_to};
use memvid_core::types::SearchHit;
const CONTEXT_HEADER: &str = "## Retrieval Context\n";
const PRIMARY_HEADER: &str = "### Primary Hit\n";
const SUPPORT_HEADER: &str = "### Supporting Hits\n";
const SUMMARY_HEADER: &str = "### Overflow Summaries\n";
const SUMMARY_HIGHLIGHT_CHARS: usize = 240;
#[derive(Debug, Clone)]
pub(super) struct ContextAggregation {
pub body: String,
pub records: Vec<ContextRecord>,
}
impl ContextAggregation {
fn from_fallback(fallback: &str, limit: usize) -> Self {
let body = if limit == 0 || fallback.is_empty() {
String::new()
} else if fallback.len() <= limit {
fallback.to_string()
} else {
clamp_to(fallback, limit)
};
Self {
body,
records: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub(super) struct ContextRecord {
pub rank: usize,
pub uri: String,
pub title: Option<String>,
pub score: Option<f32>,
pub matches: usize,
pub frame_id: u64,
pub range: (usize, usize),
pub chunk_range: Option<(usize, usize)>,
pub text: String,
pub mode: ContextMode,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub(super) enum ContextMode {
Full,
Summary,
}
#[derive(Debug, Clone)]
pub(super) struct ContextAssemblyPlan {
primary: Option<ContextRecord>,
supporting: Vec<ContextRecord>,
summaries: Vec<ContextRecord>,
}
pub(super) fn assemble_context(
hits: &[SearchHit],
fallback: &str,
budget: &ModelContextBudget,
) -> ContextAggregation {
if hits.is_empty() {
return ContextAggregation::from_fallback(fallback, budget.context_chars());
}
let plan = assemble_plan(hits, budget.context_chars());
let mut body = String::new();
let mut records = Vec::new();
body.push_str(CONTEXT_HEADER);
if let Some(primary) = plan.primary {
body.push_str(PRIMARY_HEADER);
body.push_str(&primary.text);
body.push_str("\n\n");
records.push(primary);
}
if !plan.supporting.is_empty() {
body.push_str(SUPPORT_HEADER);
for record in plan.supporting {
body.push_str(&record.text);
body.push_str("\n\n");
records.push(record);
}
}
if !plan.summaries.is_empty() {
body.push_str(SUMMARY_HEADER);
for record in plan.summaries {
body.push_str(&record.text);
body.push_str("\n\n");
records.push(record);
}
}
ContextAggregation { body, records }
}
fn assemble_plan(hits: &[SearchHit], mut remaining_chars: usize) -> ContextAssemblyPlan {
let mut records = Vec::new();
for hit in hits.iter().take(32) {
let full_record = build_record(hit, render_full(hit), ContextMode::Full);
let summary_record = build_record(hit, render_summary(hit), ContextMode::Summary);
records.push((full_record, summary_record));
}
let mut plan = ContextAssemblyPlan {
primary: None,
supporting: Vec::new(),
summaries: Vec::new(),
};
if let Some((primary_full, _)) = records.first() {
if primary_full.text.len() <= remaining_chars {
remaining_chars = remaining_chars.saturating_sub(primary_full.text.len());
plan.primary = Some(primary_full.clone());
}
}
for (idx, (full, summary)) in records.iter().enumerate() {
if idx == 0 {
continue;
}
if full.text.len() <= remaining_chars {
remaining_chars = remaining_chars.saturating_sub(full.text.len());
plan.supporting.push(full.clone());
} else if summary.text.len() <= remaining_chars {
remaining_chars = remaining_chars.saturating_sub(summary.text.len());
plan.summaries.push(summary.clone());
}
}
plan
}
fn render_full(hit: &SearchHit) -> String {
format!(
"Rank: {}\nURI: {}\nTitle: {}\nMatches: {}\nScore: {:.3}\nSnippet:\n{}",
hit.rank,
hit.uri,
hit.title
.clone()
.unwrap_or_else(|| "(untitled)".to_string()),
hit.matches,
hit.score.unwrap_or_default(),
hit.chunk_text
.clone()
.or_else(|| Some(hit.text.clone()))
.unwrap_or_default()
)
}
fn render_summary(hit: &SearchHit) -> String {
let snippet = hit
.chunk_text
.clone()
.or_else(|| Some(hit.text.clone()))
.unwrap_or_default();
let snippet = trim_highlight(&snippet, SUMMARY_HIGHLIGHT_CHARS);
format!(
"Rank: {}\nURI: {}\nHighlight: {}",
hit.rank, hit.uri, snippet
)
}
fn trim_highlight(text: &str, limit: usize) -> String {
let clean = text.replace('\n', " ");
clamp_to(&clean, limit)
}
fn build_record(hit: &SearchHit, text: String, mode: ContextMode) -> ContextRecord {
ContextRecord {
rank: hit.rank,
uri: hit.uri.clone(),
title: hit.title.clone(),
score: hit.score,
matches: hit.matches,
frame_id: hit.frame_id,
range: hit.range,
chunk_range: hit.chunk_range,
text,
mode,
}
}
}
#[cfg(feature = "llama-cpp")]
mod tinyllama {
use super::{ModelRunError, PromptParts, TINYLLAMA_LABEL, ThinkingSpinner};
use anyhow::anyhow;
use llama_cpp::standard_sampler::StandardSampler;
use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
use tokio::runtime::Builder;
use std::path::{Path, PathBuf};
const MODEL_DIR: &str = "models/tinyllama";
const GGUF_HINT: &str = "*.gguf";
pub(super) fn run(prompt: &PromptParts) -> Result<String, ModelRunError> {
let base_dir = Path::new(MODEL_DIR);
let assets = RequiredAssets::new(base_dir);
if let Some(missing) = assets.missing_paths() {
return Err(ModelRunError::AssetsMissing {
model: TINYLLAMA_LABEL.to_string(),
missing,
});
}
let gguf_path = assets.gguf_path.clone().ok_or_else(|| {
ModelRunError::Runtime(anyhow!(
"no GGUF model file found in {}",
base_dir.display()
))
})?;
unsafe {
std::env::set_var("GGML_LOG_LEVEL", "ERROR");
std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
}
let model =
LlamaModel::load_from_file(&gguf_path, LlamaParams::default()).map_err(|err| {
ModelRunError::Runtime(anyhow!(
"failed to load TinyLlama weights from {}: {err}",
gguf_path.display()
))
})?;
let mut session_params = SessionParams::default();
if session_params.n_ctx == 0 {
session_params.n_ctx = 2048;
}
session_params.n_batch = session_params.n_ctx.min(512);
if session_params.n_ubatch == 0 {
session_params.n_ubatch = 512;
}
let max_tokens = session_params.n_ctx as usize;
let mut session = model.create_session(session_params).map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to create TinyLlama session: {err}"))
})?;
let mut priming_tokens = model
.tokenize_bytes(prompt.completion_prompt().as_bytes(), true, true)
.map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to tokenize TinyLlama prompt: {err}"))
})?;
let requested_tokens = prompt.max_output_tokens();
if max_tokens > 0 {
let reserved = requested_tokens + 64;
if priming_tokens.len() >= max_tokens.saturating_sub(reserved) {
let target = max_tokens.saturating_sub(reserved).max(1);
let tail_start = priming_tokens.len().saturating_sub(target);
priming_tokens = priming_tokens.split_off(tail_start);
}
}
session
.advance_context_with_tokens(&priming_tokens)
.map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to prime TinyLlama context: {err}"))
})?;
let handle = session
.start_completing_with(StandardSampler::default(), requested_tokens)
.map_err(|err| ModelRunError::Runtime(anyhow!("completion failed to start: {err}")))?;
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to build tokio runtime: {err}"))
})?;
let mut spinner = ThinkingSpinner::start();
let generated = runtime.block_on(async { handle.into_string_async().await });
spinner.stop();
let answer = generated.trim().to_string();
if answer.is_empty() {
Ok("No answer generated by TinyLlama.".to_string())
} else {
Ok(answer)
}
}
struct RequiredAssets {
gguf_path: Option<PathBuf>,
base_dir: PathBuf,
}
impl RequiredAssets {
fn new(base_dir: &Path) -> Self {
let gguf_path = find_first_gguf(base_dir);
Self {
gguf_path,
base_dir: base_dir.to_path_buf(),
}
}
fn missing_paths(&self) -> Option<Vec<PathBuf>> {
if self.gguf_path.is_some() {
None
} else {
Some(vec![self.base_dir.join(GGUF_HINT)])
}
}
}
fn find_first_gguf(base_dir: &Path) -> Option<PathBuf> {
let mut entries: Vec<PathBuf> = std::fs::read_dir(base_dir)
.ok()?
.filter_map(|entry| entry.ok().map(|e| e.path()))
.filter(|path| path.is_file() && path.extension().map_or(false, |ext| ext == "gguf"))
.collect();
entries.sort();
entries.into_iter().next()
}
}
mod ollama {
use super::{ModelRunError, PromptParts, ThinkingSpinner};
use anyhow::anyhow;
use reqwest::blocking::Client;
use serde::Deserialize;
use serde_json::json;
const ENDPOINT: &str = "http://127.0.0.1:11434/api/generate";
pub(super) fn run(model: &str, prompt: &PromptParts) -> Result<String, ModelRunError> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
let mut spinner = ThinkingSpinner::start();
let response = client
.post(ENDPOINT)
.json(&json!({
"model": model,
"prompt": prompt.completion_prompt(),
"stream": false
}))
.send()
.map_err(|err| ModelRunError::Runtime(anyhow!("ollama request failed: {err}")))?
.error_for_status()
.map_err(|err| {
ModelRunError::Runtime(anyhow!("ollama returned error status: {err}"))
})?;
let body: GenerateResponse = response.json().map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to decode ollama response: {err}"))
})?;
spinner.stop();
let text = body.response.trim().to_string();
if text.is_empty() {
Ok("No answer returned by Ollama.".to_string())
} else {
Ok(text)
}
}
#[derive(Debug, Deserialize)]
struct GenerateResponse {
#[serde(default)]
response: String,
}
}
mod openai {
use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
use anyhow::anyhow;
use reqwest::blocking::Client;
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
use serde::Deserialize;
use serde_json::json;
const CHAT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
const RESPONSES_ENDPOINT: &str = "https://api.openai.com/v1/responses";
pub(super) fn run(
model: &str,
prompt: &PromptParts,
override_key: Option<&str>,
system_prompt_override: Option<&str>,
) -> Result<String, ModelRunError> {
let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
let key = override_key
.map(|value| value.to_string())
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
.ok_or_else(|| {
ModelRunError::Runtime(anyhow!(
"OPENAI_API_KEY environment variable is required for OpenAI models"
))
})?;
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
ModelRunError::Runtime(anyhow!("invalid OPENAI_API_KEY header value: {err}"))
})?,
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let client = Client::builder()
.timeout(std::time::Duration::from_secs(60))
.default_headers(headers)
.build()
.map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
let mut spinner = ThinkingSpinner::start();
let text = if requires_responses_api(model) {
let combined_prompt = format!(
"System instructions:\n{}\n\nUser query:\n{}",
system_prompt,
prompt.user_message()
);
let payload = json!({
"model": model,
"input": combined_prompt,
"max_output_tokens": prompt.max_output_tokens() as u32,
"reasoning": {
"effort": "low"
}
});
let response = client
.post(RESPONSES_ENDPOINT)
.json(&payload)
.send()
.map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
let status = response.status();
if !status.is_success() {
let body = response
.text()
.unwrap_or_else(|_| "<failed to read body>".to_string());
return Err(ModelRunError::Runtime(anyhow!(
"OpenAI returned error status {status}: {body}"
)));
}
let body: ResponsesResponse = response.json().map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
})?;
extract_responses_text(body)
} else {
let payload = json!({
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt.user_message()}
],
"temperature": 0.2,
"max_tokens": prompt.max_output_tokens() as u32
});
let response = client
.post(CHAT_ENDPOINT)
.json(&payload)
.send()
.map_err(|err| ModelRunError::Runtime(anyhow!("OpenAI request failed: {err}")))?;
let status = response.status();
if !status.is_success() {
let body = response
.text()
.unwrap_or_else(|_| "<failed to read body>".to_string());
return Err(ModelRunError::Runtime(anyhow!(
"OpenAI returned error status {status}: {body}"
)));
}
let body: ChatResponse = response.json().map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to decode OpenAI response: {err}"))
})?;
extract_chat_text(body)
};
spinner.stop();
Ok(text)
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ChatMessage,
}
#[derive(Debug, Deserialize)]
struct ChatMessage {
#[serde(default)]
content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ResponsesResponse {
#[serde(default)]
output: Vec<ResponseItem>,
#[serde(default)]
output_text: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct ResponseItem {
#[serde(default)]
content: Vec<ResponseContent>,
}
#[derive(Debug, Deserialize)]
struct ResponseContent {
#[serde(rename = "type")]
kind: String,
#[serde(default)]
text: Option<String>,
}
fn extract_chat_text(body: ChatResponse) -> String {
body.choices
.into_iter()
.find_map(|choice| choice.message.content)
.map(|value| value.trim().to_string())
.unwrap_or_else(|| "No answer returned by OpenAI.".to_string())
}
fn extract_responses_text(body: ResponsesResponse) -> String {
if !body.output_text.is_empty() {
let text = body
.output_text
.into_iter()
.find(|value| !value.trim().is_empty());
if let Some(text) = text {
return text.trim().to_string();
}
}
for item in body.output {
for segment in item.content {
match segment.kind.as_str() {
"output_text" | "text" => {
if let Some(text) = segment.text {
let trimmed = text.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
_ => {}
}
}
}
"No answer returned by OpenAI.".to_string()
}
fn requires_responses_api(model: &str) -> bool {
let lowered = model.to_ascii_lowercase();
lowered.starts_with("gpt-5") || lowered.contains("gpt-4.1")
}
}
mod gemini {
use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
use anyhow::anyhow;
use reqwest::blocking::Client;
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
use serde::Deserialize;
use serde_json::json;
pub(super) fn run(
model: &str,
prompt: &PromptParts,
override_key: Option<&str>,
system_prompt_override: Option<&str>,
) -> Result<String, ModelRunError> {
let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
let key = override_key
.map(|value| value.to_string())
.or_else(|| std::env::var("GEMINI_API_KEY").ok())
.ok_or_else(|| {
ModelRunError::Runtime(anyhow!(
"GEMINI_API_KEY environment variable is required for Gemini models"
))
})?;
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
model
);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(
HeaderName::from_static("x-goog-api-key"),
HeaderValue::from_str(&key).map_err(|err| {
ModelRunError::Runtime(anyhow!("invalid GEMINI_API_KEY header value: {err}"))
})?,
);
let client = Client::builder()
.timeout(std::time::Duration::from_secs(60))
.default_headers(headers)
.build()
.map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
let payload = json!({
"contents": [{
"parts": [
{ "text": system_prompt },
{ "text": prompt.user_message() }
]
}],
"generationConfig": {
"temperature": 0.2,
"maxOutputTokens": prompt.max_output_tokens() as u32,
"topK": 40,
"topP": 0.95
}
});
let mut spinner = ThinkingSpinner::start();
let response = client
.post(url)
.json(&payload)
.send()
.map_err(|err| ModelRunError::Runtime(anyhow!("Gemini request failed: {err}")))?
.error_for_status()
.map_err(|err| {
ModelRunError::Runtime(anyhow!("Gemini returned error status: {err}"))
})?;
let body: GenerateResponse = response.json().map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to decode Gemini response: {err}"))
})?;
spinner.stop();
let text = body
.candidates
.into_iter()
.flat_map(|candidate| candidate.content.parts)
.find_map(|part| part.text)
.map(|value| value.trim().to_string())
.unwrap_or_else(|| "No answer returned by Gemini.".to_string());
Ok(text)
}
#[derive(Debug, Deserialize)]
struct GenerateResponse {
candidates: Vec<Candidate>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: CandidateContent,
}
#[derive(Debug, Deserialize)]
struct CandidateContent {
parts: Vec<CandidatePart>,
}
#[derive(Debug, Deserialize)]
struct CandidatePart {
#[serde(default)]
text: Option<String>,
}
}
mod claude {
use super::{ModelRunError, PromptParts, SYSTEM_PROMPT, ThinkingSpinner};
use anyhow::anyhow;
use reqwest::blocking::Client;
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
use serde::Deserialize;
use serde_json::json;
const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
const API_VERSION: &str = "2023-06-01";
pub(super) fn run(
model: &str,
prompt: &PromptParts,
override_key: Option<&str>,
system_prompt_override: Option<&str>,
) -> Result<String, ModelRunError> {
let system_prompt = system_prompt_override.unwrap_or(SYSTEM_PROMPT);
let key = override_key
.map(|value| value.to_string())
.or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
.or_else(|| std::env::var("CLAUDE_API_KEY").ok())
.ok_or_else(|| {
ModelRunError::Runtime(anyhow!(
"ANTHROPIC_API_KEY environment variable is required for Claude models"
))
})?;
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {key}")).map_err(|err| {
ModelRunError::Runtime(anyhow!("invalid ANTHROPIC_API_KEY header value: {err}"))
})?,
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(
HeaderName::from_static("anthropic-version"),
HeaderValue::from_static(API_VERSION),
);
let client = Client::builder()
.timeout(std::time::Duration::from_secs(60))
.default_headers(headers)
.build()
.map_err(|err| ModelRunError::Runtime(anyhow!("failed to build HTTP client: {err}")))?;
let payload = json!({
"model": model,
"max_tokens": prompt.max_output_tokens() as u32,
"temperature": 0.2,
"system": system_prompt,
"messages": [{
"role": "user",
"content": [{"type": "text", "text": prompt.user_message()}]
}]
});
let mut spinner = ThinkingSpinner::start();
let response = client
.post(ENDPOINT)
.json(&payload)
.send()
.map_err(|err| ModelRunError::Runtime(anyhow!("Claude request failed: {err}")))?
.error_for_status()
.map_err(|err| {
ModelRunError::Runtime(anyhow!("Claude returned error status: {err}"))
})?;
let body: ClaudeResponse = response.json().map_err(|err| {
ModelRunError::Runtime(anyhow!("failed to decode Claude response: {err}"))
})?;
spinner.stop();
let text = body
.content
.into_iter()
.find_map(|part| match part {
ContentBlock::Text { text } if !text.trim().is_empty() => {
Some(text.trim().to_string())
}
_ => None,
})
.unwrap_or_else(|| "No answer returned by Claude.".to_string());
Ok(text)
}
#[derive(Debug, Deserialize)]
struct ClaudeResponse {
#[serde(default)]
content: Vec<ContentBlock>,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
enum ContentBlock {
Text {
text: String,
},
#[serde(other)]
Other,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_models() {
assert_eq!(normalize_openai_model(None), "gpt-4o-mini");
assert_eq!(normalize_gemini_model(None), "gemini-2.5-flash");
assert_eq!(normalize_claude_model(None), "claude-3-5-sonnet-20241022");
}
}