use anyhow::{anyhow, Result};
use super::{Backend, GenAnswer, GenRequest, Generator};
pub struct OllamaGenerator {
id: String,
model: String,
host: String,
}
impl OllamaGenerator {
pub fn new(model: &str, host: Option<String>) -> Self {
let host = host
.or_else(|| std::env::var("OLLAMA_HOST").ok())
.unwrap_or_else(|| "http://localhost:11434".to_string());
let model = if model.is_empty() { "llama3".to_string() } else { model.to_string() };
Self {
id: format!("ollama:{model}"),
model,
host: host.trim_end_matches('/').to_string(),
}
}
}
impl Backend for OllamaGenerator {
fn id(&self) -> &str {
&self.id
}
fn available(&self) -> bool {
let url = format!("{}/api/tags", self.host);
ureq::get(&url).timeout(std::time::Duration::from_millis(500)).call().is_ok()
}
}
impl Generator for OllamaGenerator {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
let url = format!("{}/api/generate", self.host);
let mut body = serde_json::json!({
"model": self.model,
"prompt": req.prompt,
"stream": false,
"options": {
"temperature": req.temperature,
"num_predict": req.max_tokens,
},
});
if let Some(sys) = &req.system {
body["system"] = serde_json::Value::String(sys.clone());
}
if !req.stop.is_empty() {
body["options"]["stop"] = serde_json::json!(req.stop);
}
let body_str = serde_json::to_string(&body)?;
let started = std::time::Instant::now();
let resp = ureq::post(&url)
.set("Content-Type", "application/json")
.send_string(&body_str)
.map_err(|e| anyhow!("ollama POST {url} failed: {e}"))?;
let txt = resp.into_string().map_err(|e| anyhow!("ollama response not readable: {e}"))?;
let v: serde_json::Value =
serde_json::from_str(&txt).map_err(|e| anyhow!("ollama response not JSON: {e}"))?;
let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
let text = v.get("response").and_then(|x| x.as_str()).unwrap_or("").to_string();
let tokens_in = v.get("prompt_eval_count").and_then(|x| x.as_i64()).unwrap_or(0);
let tokens_out = v.get("eval_count").and_then(|x| x.as_i64()).unwrap_or(0);
let tokens_per_s = match v.get("eval_duration").and_then(|x| x.as_i64()) {
Some(ns) if ns > 0 => tokens_out as f64 / (ns as f64 / 1e9),
_ if latency_ms > 0.0 => tokens_out as f64 / (latency_ms / 1000.0),
_ => 0.0,
};
Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn id_and_default_host() {
let gen = OllamaGenerator::new("qwen2", Some("http://example:11434/".into()));
assert_eq!(gen.id(), "ollama:qwen2");
assert_eq!(gen.host, "http://example:11434", "trailing slash trimmed");
}
#[test]
fn empty_model_defaults() {
let gen = OllamaGenerator::new("", Some("http://h:1".into()));
assert_eq!(gen.id(), "ollama:llama3");
}
#[test]
fn available_is_false_against_a_dead_host() {
let gen = OllamaGenerator::new("m", Some("http://127.0.0.1:1".into()));
assert!(!gen.available());
}
}