1use crate::adapter::{
8 ChatMessage, ChatRole, ModelAdapter, ModelConfig, ModelError, ModelRequest, ModelResponse,
9 StructuredRequest,
10};
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use tracing::{debug, instrument};
14
15const OLLAMA_DEFAULT_HOST: &str = "http://localhost:11434";
16const DEFAULT_MODEL: &str = "llama3.2:3b";
17const DEFAULT_MAX_TOKENS: u32 = 4096;
18
19pub struct OllamaAdapter {
25 client: reqwest::Client,
26 host: String,
27 default_model: String,
28}
29
30impl OllamaAdapter {
31 pub fn new(host: impl Into<String>) -> Self {
32 Self {
33 client: reqwest::Client::new(),
34 host: host.into(),
35 default_model: DEFAULT_MODEL.into(),
36 }
37 }
38
39 pub fn from_env() -> Result<Self, ModelError> {
41 let host = std::env::var("OLLAMA_HOST").unwrap_or_else(|_| OLLAMA_DEFAULT_HOST.to_string());
42
43 Ok(Self::new(host))
47 }
48
49 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
50 self.default_model = model.into();
51 self
52 }
53
54 async fn call_api(&self, body: Value) -> Result<Value, ModelError> {
55 let resp = self
56 .client
57 .post(format!("{}/api/chat", self.host))
58 .json(&body)
59 .send()
60 .await
61 .map_err(|e| ModelError::Network(format!("Ollama unreachable: {e}")))?;
62
63 let status = resp.status().as_u16();
64 let body_text = resp
65 .text()
66 .await
67 .map_err(|e| ModelError::Network(e.to_string()))?;
68
69 if status != 200 {
70 return Err(ModelError::Api {
71 status,
72 body: body_text,
73 });
74 }
75
76 serde_json::from_str(&body_text).map_err(|e| ModelError::Serialization(e.to_string()))
77 }
78
79 fn build_request_body(
80 &self,
81 messages: &[ChatMessage],
82 config: &ModelConfig,
83 format: Option<&str>,
84 ) -> Value {
85 let model = config.model.as_deref().unwrap_or(&self.default_model);
86 let max_tokens = config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
87
88 let mut ollama_messages: Vec<Value> = Vec::new();
89
90 if let Some(sys) = &config.system_prompt {
92 ollama_messages.push(json!({ "role": "system", "content": sys }));
93 }
94
95 for m in messages {
96 let role = match m.role {
97 ChatRole::System => "system",
98 ChatRole::User => "user",
99 ChatRole::Assistant => "assistant",
100 ChatRole::Tool => "tool",
101 };
102 ollama_messages.push(json!({ "role": role, "content": m.content }));
103 }
104
105 let mut body = json!({
106 "model": model,
107 "messages": ollama_messages,
108 "stream": false,
109 "options": {
110 "num_predict": max_tokens,
111 },
112 });
113
114 if let Some(temp) = config.temperature {
115 body["options"]["temperature"] = json!(temp);
116 }
117 if let Some(stops) = &config.stop_sequences {
118 body["options"]["stop"] = json!(stops);
119 }
120 if let Some(fmt) = format {
121 body["format"] = json!(fmt);
122 }
123
124 body
125 }
126
127 fn parse_response(&self, resp: Value) -> Result<ModelResponse, ModelError> {
128 let model = resp["model"]
129 .as_str()
130 .unwrap_or(&self.default_model)
131 .to_string();
132
133 let content = resp["message"]["content"]
134 .as_str()
135 .unwrap_or("")
136 .to_string();
137
138 let input_tokens = resp["prompt_eval_count"].as_u64().unwrap_or(0);
140 let output_tokens = resp["eval_count"].as_u64().unwrap_or(0);
141
142 let finish_reason = resp["done_reason"].as_str().unwrap_or("stop").to_string();
144
145 Ok(ModelResponse {
146 content,
147 model,
148 finish_reason,
149 input_tokens,
150 output_tokens,
151 structured: None,
152 })
153 }
154}
155
156#[async_trait]
157impl ModelAdapter for OllamaAdapter {
158 fn system_name(&self) -> &'static str {
159 "ollama"
160 }
161
162 fn default_model(&self) -> &str {
163 &self.default_model
164 }
165
166 #[instrument(skip(self, request), fields(
167 gen_ai.system = "ollama",
168 gen_ai.request.model = tracing::field::Empty,
169 gen_ai.usage.input_tokens = tracing::field::Empty,
170 gen_ai.usage.output_tokens = tracing::field::Empty,
171 ))]
172 async fn chat(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
173 let model = request
174 .config
175 .model
176 .as_deref()
177 .unwrap_or(&self.default_model)
178 .to_string();
179 tracing::Span::current().record("gen_ai.request.model", model.as_str());
180
181 debug!(model = %model, host = %self.host, "Calling Ollama /api/chat");
182
183 let body = self.build_request_body(&request.messages, &request.config, None);
184 let resp_json = self.call_api(body).await?;
185 let response = self.parse_response(resp_json)?;
186
187 tracing::Span::current()
188 .record("gen_ai.usage.input_tokens", response.input_tokens)
189 .record("gen_ai.usage.output_tokens", response.output_tokens);
190
191 Ok(response)
192 }
193
194 #[instrument(skip(self, request), fields(
195 gen_ai.system = "ollama",
196 gen_ai.request.model = tracing::field::Empty,
197 ))]
198 async fn structured_output(
199 &self,
200 request: StructuredRequest,
201 ) -> Result<ModelResponse, ModelError> {
202 let model = request
203 .config
204 .model
205 .as_deref()
206 .unwrap_or(&self.default_model)
207 .to_string();
208 tracing::Span::current().record("gen_ai.request.model", model.as_str());
209
210 let mut config = request.config.clone();
213 let schema_str = serde_json::to_string_pretty(&request.output_schema)
214 .map_err(|e| ModelError::Serialization(e.to_string()))?;
215 let system = config.system_prompt.get_or_insert_with(String::new);
216 system.push_str(&format!(
217 "\n\nRespond ONLY with a valid JSON object matching this schema:\n{schema_str}"
218 ));
219
220 let body = self.build_request_body(&request.messages, &config, Some("json"));
221 let resp_json = self.call_api(body).await?;
222 let mut response = self.parse_response(resp_json)?;
223
224 let structured =
226 serde_json::from_str::<serde_json::Value>(&response.content).map_err(|e| {
227 ModelError::Serialization(format!("structured output parse error: {e}"))
228 })?;
229 response.structured = Some(structured);
230
231 Ok(response)
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_build_request_body() {
241 let adapter = OllamaAdapter::new("http://localhost:11434");
242 let messages = vec![ChatMessage::user("Hello")];
243 let config = ModelConfig {
244 model: Some("qwen3:8b".into()),
245 max_tokens: Some(100),
246 temperature: Some(0.7),
247 ..Default::default()
248 };
249 let body = adapter.build_request_body(&messages, &config, None);
250
251 assert_eq!(body["model"], "qwen3:8b");
252 assert_eq!(body["stream"], false);
253 assert_eq!(body["options"]["num_predict"], 100);
254 let temp = body["options"]["temperature"].as_f64().unwrap();
255 assert!((temp - 0.7).abs() < 0.01);
256 }
257
258 #[test]
259 fn test_parse_response() {
260 let adapter = OllamaAdapter::new("http://localhost:11434");
261 let resp = json!({
262 "model": "qwen3:8b",
263 "message": {"role": "assistant", "content": "Hello!"},
264 "done": true,
265 "done_reason": "stop",
266 "prompt_eval_count": 42,
267 "eval_count": 5,
268 });
269
270 let parsed = adapter.parse_response(resp).unwrap();
271 assert_eq!(parsed.content, "Hello!");
272 assert_eq!(parsed.model, "qwen3:8b");
273 assert_eq!(parsed.input_tokens, 42);
274 assert_eq!(parsed.output_tokens, 5);
275 assert_eq!(parsed.finish_reason, "stop");
276 }
277}