Skip to main content

jamjet_models/
ollama.rs

1//! Ollama adapter — local model inference via Ollama's HTTP API.
2//!
3//! Supports any model available via `ollama pull`: qwen3, llama3, gemma2, phi3, etc.
4//! Reads `OLLAMA_HOST` from the environment (defaults to http://localhost:11434).
5//! Uses Ollama's native /api/chat endpoint for accurate token counts.
6
7use crate::adapter::{
8    ChatMessage, ChatRole, ModelAdapter, ModelConfig, ModelError, ModelRequest, ModelResponse,
9    StructuredRequest,
10};
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use tracing::{debug, instrument};
14
15const OLLAMA_DEFAULT_HOST: &str = "http://localhost:11434";
16const DEFAULT_MODEL: &str = "llama3.2:3b";
17const DEFAULT_MAX_TOKENS: u32 = 4096;
18
19/// Ollama adapter for local model inference.
20///
21/// Connects to a running Ollama server and uses its native /api/chat endpoint.
22/// All inference is free (local GPU/CPU), making this ideal for development,
23/// testing, and cost-sensitive workloads.
24pub struct OllamaAdapter {
25    client: reqwest::Client,
26    host: String,
27    default_model: String,
28}
29
30impl OllamaAdapter {
31    pub fn new(host: impl Into<String>) -> Self {
32        Self {
33            client: reqwest::Client::new(),
34            host: host.into(),
35            default_model: DEFAULT_MODEL.into(),
36        }
37    }
38
39    /// Create adapter from `OLLAMA_HOST` env var (defaults to localhost:11434).
40    pub fn from_env() -> Result<Self, ModelError> {
41        let host = std::env::var("OLLAMA_HOST").unwrap_or_else(|_| OLLAMA_DEFAULT_HOST.to_string());
42
43        // Quick check: if Ollama is not reachable, fail fast.
44        // We skip the actual health check here to keep construction sync;
45        // errors will surface on first call_api() instead.
46        Ok(Self::new(host))
47    }
48
49    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
50        self.default_model = model.into();
51        self
52    }
53
54    async fn call_api(&self, body: Value) -> Result<Value, ModelError> {
55        let resp = self
56            .client
57            .post(format!("{}/api/chat", self.host))
58            .json(&body)
59            .send()
60            .await
61            .map_err(|e| ModelError::Network(format!("Ollama unreachable: {e}")))?;
62
63        let status = resp.status().as_u16();
64        let body_text = resp
65            .text()
66            .await
67            .map_err(|e| ModelError::Network(e.to_string()))?;
68
69        if status != 200 {
70            return Err(ModelError::Api {
71                status,
72                body: body_text,
73            });
74        }
75
76        serde_json::from_str(&body_text).map_err(|e| ModelError::Serialization(e.to_string()))
77    }
78
79    fn build_request_body(
80        &self,
81        messages: &[ChatMessage],
82        config: &ModelConfig,
83        format: Option<&str>,
84    ) -> Value {
85        let model = config.model.as_deref().unwrap_or(&self.default_model);
86        let max_tokens = config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
87
88        let mut ollama_messages: Vec<Value> = Vec::new();
89
90        // Prepend system prompt if provided in config.
91        if let Some(sys) = &config.system_prompt {
92            ollama_messages.push(json!({ "role": "system", "content": sys }));
93        }
94
95        for m in messages {
96            let role = match m.role {
97                ChatRole::System => "system",
98                ChatRole::User => "user",
99                ChatRole::Assistant => "assistant",
100                ChatRole::Tool => "tool",
101            };
102            ollama_messages.push(json!({ "role": role, "content": m.content }));
103        }
104
105        let mut body = json!({
106            "model": model,
107            "messages": ollama_messages,
108            "stream": false,
109            "options": {
110                "num_predict": max_tokens,
111            },
112        });
113
114        if let Some(temp) = config.temperature {
115            body["options"]["temperature"] = json!(temp);
116        }
117        if let Some(stops) = &config.stop_sequences {
118            body["options"]["stop"] = json!(stops);
119        }
120        if let Some(fmt) = format {
121            body["format"] = json!(fmt);
122        }
123
124        body
125    }
126
127    fn parse_response(&self, resp: Value) -> Result<ModelResponse, ModelError> {
128        let model = resp["model"]
129            .as_str()
130            .unwrap_or(&self.default_model)
131            .to_string();
132
133        let content = resp["message"]["content"]
134            .as_str()
135            .unwrap_or("")
136            .to_string();
137
138        // Ollama provides token counts in prompt_eval_count / eval_count.
139        let input_tokens = resp["prompt_eval_count"].as_u64().unwrap_or(0);
140        let output_tokens = resp["eval_count"].as_u64().unwrap_or(0);
141
142        // Ollama uses "done_reason" (not "finish_reason").
143        let finish_reason = resp["done_reason"].as_str().unwrap_or("stop").to_string();
144
145        Ok(ModelResponse {
146            content,
147            model,
148            finish_reason,
149            input_tokens,
150            output_tokens,
151            structured: None,
152        })
153    }
154}
155
156#[async_trait]
157impl ModelAdapter for OllamaAdapter {
158    fn system_name(&self) -> &'static str {
159        "ollama"
160    }
161
162    fn default_model(&self) -> &str {
163        &self.default_model
164    }
165
166    #[instrument(skip(self, request), fields(
167        gen_ai.system = "ollama",
168        gen_ai.request.model = tracing::field::Empty,
169        gen_ai.usage.input_tokens = tracing::field::Empty,
170        gen_ai.usage.output_tokens = tracing::field::Empty,
171    ))]
172    async fn chat(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
173        let model = request
174            .config
175            .model
176            .as_deref()
177            .unwrap_or(&self.default_model)
178            .to_string();
179        tracing::Span::current().record("gen_ai.request.model", model.as_str());
180
181        debug!(model = %model, host = %self.host, "Calling Ollama /api/chat");
182
183        let body = self.build_request_body(&request.messages, &request.config, None);
184        let resp_json = self.call_api(body).await?;
185        let response = self.parse_response(resp_json)?;
186
187        tracing::Span::current()
188            .record("gen_ai.usage.input_tokens", response.input_tokens)
189            .record("gen_ai.usage.output_tokens", response.output_tokens);
190
191        Ok(response)
192    }
193
194    #[instrument(skip(self, request), fields(
195        gen_ai.system = "ollama",
196        gen_ai.request.model = tracing::field::Empty,
197    ))]
198    async fn structured_output(
199        &self,
200        request: StructuredRequest,
201    ) -> Result<ModelResponse, ModelError> {
202        let model = request
203            .config
204            .model
205            .as_deref()
206            .unwrap_or(&self.default_model)
207            .to_string();
208        tracing::Span::current().record("gen_ai.request.model", model.as_str());
209
210        // Ollama supports format: "json" for JSON mode.
211        // Append the schema to the system prompt so the model knows the structure.
212        let mut config = request.config.clone();
213        let schema_str = serde_json::to_string_pretty(&request.output_schema)
214            .map_err(|e| ModelError::Serialization(e.to_string()))?;
215        let system = config.system_prompt.get_or_insert_with(String::new);
216        system.push_str(&format!(
217            "\n\nRespond ONLY with a valid JSON object matching this schema:\n{schema_str}"
218        ));
219
220        let body = self.build_request_body(&request.messages, &config, Some("json"));
221        let resp_json = self.call_api(body).await?;
222        let mut response = self.parse_response(resp_json)?;
223
224        // Parse JSON from response content.
225        let structured =
226            serde_json::from_str::<serde_json::Value>(&response.content).map_err(|e| {
227                ModelError::Serialization(format!("structured output parse error: {e}"))
228            })?;
229        response.structured = Some(structured);
230
231        Ok(response)
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_build_request_body() {
241        let adapter = OllamaAdapter::new("http://localhost:11434");
242        let messages = vec![ChatMessage::user("Hello")];
243        let config = ModelConfig {
244            model: Some("qwen3:8b".into()),
245            max_tokens: Some(100),
246            temperature: Some(0.7),
247            ..Default::default()
248        };
249        let body = adapter.build_request_body(&messages, &config, None);
250
251        assert_eq!(body["model"], "qwen3:8b");
252        assert_eq!(body["stream"], false);
253        assert_eq!(body["options"]["num_predict"], 100);
254        let temp = body["options"]["temperature"].as_f64().unwrap();
255        assert!((temp - 0.7).abs() < 0.01);
256    }
257
258    #[test]
259    fn test_parse_response() {
260        let adapter = OllamaAdapter::new("http://localhost:11434");
261        let resp = json!({
262            "model": "qwen3:8b",
263            "message": {"role": "assistant", "content": "Hello!"},
264            "done": true,
265            "done_reason": "stop",
266            "prompt_eval_count": 42,
267            "eval_count": 5,
268        });
269
270        let parsed = adapter.parse_response(resp).unwrap();
271        assert_eq!(parsed.content, "Hello!");
272        assert_eq!(parsed.model, "qwen3:8b");
273        assert_eq!(parsed.input_tokens, 42);
274        assert_eq!(parsed.output_tokens, 5);
275        assert_eq!(parsed.finish_reason, "stop");
276    }
277}