1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{
7 build_http_client, ensure_ok, LlmError, LlmProvider, Message, ProposedToolCall, Response,
8 ResponseChunk, ToolDef, Usage,
9};
10
11#[derive(Serialize)]
12struct OllamaRequest {
13 model: String,
14 messages: Vec<OllamaMessage>,
15 stream: bool,
16 options: Option<OllamaOptions>,
17 #[serde(skip_serializing_if = "Option::is_none")]
20 tools: Option<Vec<OllamaTool>>,
21}
22
23#[derive(Serialize, Deserialize)]
24struct OllamaMessage {
25 role: String,
26 content: String,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 tool_calls: Option<Vec<OllamaToolCall>>,
29}
30
31#[derive(Serialize)]
34struct OllamaTool {
35 #[serde(rename = "type")]
36 kind: &'static str,
37 function: OllamaFunctionDef,
38}
39
40#[derive(Serialize)]
41struct OllamaFunctionDef {
42 name: String,
43 description: String,
44 parameters: serde_json::Value,
45}
46
47#[derive(Serialize, Deserialize)]
50struct OllamaToolCall {
51 function: OllamaFunctionCall,
52}
53
54#[derive(Serialize, Deserialize)]
55struct OllamaFunctionCall {
56 name: String,
57 #[serde(default)]
58 arguments: serde_json::Value,
59}
60
61#[derive(Serialize)]
62struct OllamaOptions {
63 temperature: f64,
64 #[serde(rename = "num_predict")]
65 num_predict: i32,
66}
67
68#[derive(Deserialize)]
69struct OllamaResponse {
70 message: Option<OllamaMessage>,
71 done: bool,
72 #[serde(default)]
73 prompt_eval_count: Option<u32>,
74 #[serde(default)]
75 eval_count: Option<u32>,
76}
77
78pub struct OllamaProvider {
80 client: reqwest::Client,
81 base_url: String,
82 model: String,
83 temperature: f64,
84 max_tokens: i32,
85}
86
87impl OllamaProvider {
88 pub fn new(
89 base_url: &str,
90 model: &str,
91 temperature: f64,
92 max_tokens: i32,
93 ) -> Result<Self, LlmError> {
94 let client = build_http_client(brain::timeouts::LLM_GENERATE)?;
95 Ok(Self {
96 client,
97 base_url: base_url.trim_end_matches('/').to_string(),
98 model: model.to_string(),
99 temperature,
100 max_tokens,
101 })
102 }
103
104 pub fn default_config() -> Result<Self, LlmError> {
105 Self::new("http://localhost:11434", "qwen2.5-coder:7b", 0.7, 4096)
106 }
107
108 fn convert_messages(messages: &[Message]) -> Vec<OllamaMessage> {
109 messages
110 .iter()
111 .map(|m| OllamaMessage {
112 role: m.role.as_wire_str().to_string(),
113 content: m.content.clone(),
114 tool_calls: (!m.tool_calls.is_empty())
115 .then(|| m.tool_calls.iter().map(convert_proposed_call).collect()),
116 })
117 .collect()
118 }
119
120 fn convert_tools(tools: &[ToolDef]) -> Vec<OllamaTool> {
123 tools
124 .iter()
125 .map(|t| OllamaTool {
126 kind: "function",
127 function: OllamaFunctionDef {
128 name: t.name.clone(),
129 description: t.description.clone(),
130 parameters: t.parameters.clone(),
131 },
132 })
133 .collect()
134 }
135
136 fn extract_tool_calls(message: &OllamaMessage) -> Vec<ProposedToolCall> {
140 message
141 .tool_calls
142 .iter()
143 .flatten()
144 .map(|tc| ProposedToolCall {
145 id: None,
146 name: tc.function.name.clone(),
147 arguments: tc.function.arguments.clone(),
148 })
149 .collect()
150 }
151}
152
153#[async_trait::async_trait]
154impl LlmProvider for OllamaProvider {
155 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
156 let url = format!("{}/api/chat", self.base_url);
157 let request = OllamaRequest {
158 model: self.model.clone(),
159 messages: Self::convert_messages(messages),
160 stream: false,
161 options: Some(OllamaOptions {
162 temperature: self.temperature,
163 num_predict: self.max_tokens,
164 }),
165 tools: None,
166 };
167
168 let resp = self.client.post(&url).json(&request).send().await?;
169 let resp = ensure_ok(resp).await?;
170
171 let data: OllamaResponse = resp.json().await?;
172 let usage = usage_from(&data);
173
174 Ok(Response::text(
175 data.message.map(|m| m.content).unwrap_or_default(),
176 usage,
177 ))
178 }
179
180 async fn generate_with_tools(
181 &self,
182 messages: &[Message],
183 tools: &[ToolDef],
184 ) -> Result<Response, LlmError> {
185 if tools.is_empty() {
187 return self.generate(messages).await;
188 }
189
190 let url = format!("{}/api/chat", self.base_url);
191 let request = OllamaRequest {
192 model: self.model.clone(),
193 messages: Self::convert_messages(messages),
194 stream: false,
195 options: Some(OllamaOptions {
196 temperature: self.temperature,
197 num_predict: self.max_tokens,
198 }),
199 tools: Some(Self::convert_tools(tools)),
200 };
201
202 let resp = self.client.post(&url).json(&request).send().await?;
203 let resp = ensure_ok(resp).await?;
204
205 let data: OllamaResponse = resp.json().await?;
206 let usage = usage_from(&data);
207 let (content, tool_calls) = match data.message {
208 Some(ref m) => (m.content.clone(), Self::extract_tool_calls(m)),
209 None => (String::new(), Vec::new()),
210 };
211
212 Ok(Response {
213 content,
214 usage,
215 tool_calls,
216 })
217 }
218
219 async fn generate_stream(
220 &self,
221 messages: &[Message],
222 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
223 use futures::stream::try_unfold;
224
225 let url = format!("{}/api/chat", self.base_url);
226 let request = OllamaRequest {
227 model: self.model.clone(),
228 messages: Self::convert_messages(messages),
229 stream: true,
230 options: Some(OllamaOptions {
231 temperature: self.temperature,
232 num_predict: self.max_tokens,
233 }),
234 tools: None,
235 };
236
237 let resp = self.client.post(&url).json(&request).send().await?;
238 let resp = ensure_ok(resp).await?;
239
240 let byte_stream = resp.bytes_stream();
241 let stream = try_unfold(
242 (Box::pin(byte_stream), String::new(), false),
243 |(mut byte_stream, mut buf, done)| async move {
244 use futures::TryStreamExt;
245
246 if done {
247 return Ok(None);
248 }
249
250 loop {
251 if let Some(newline_pos) = buf.find('\n') {
252 let line: String = buf[..newline_pos].to_string();
253 buf = buf[newline_pos + 1..].to_string();
254
255 let line = line.trim();
256 if line.is_empty() {
257 continue;
258 }
259
260 match serde_json::from_str::<OllamaResponse>(line) {
261 Ok(data) => {
262 let is_done = data.done;
263 let content = data.message.map(|m| m.content).unwrap_or_default();
264 let chunk = ResponseChunk { content, is_done };
265 return Ok(Some((chunk, (byte_stream, buf, is_done))));
266 }
267 Err(e) => {
268 return Err(LlmError::InvalidFormat(format!(
269 "Failed to parse streaming response: {e}"
270 )));
271 }
272 }
273 }
274
275 match byte_stream.try_next().await {
276 Ok(Some(bytes)) => {
277 buf.push_str(&String::from_utf8_lossy(&bytes));
278 }
279 Ok(None) => {
280 let remaining = buf.trim();
281 if !remaining.is_empty() {
282 if let Ok(data) = serde_json::from_str::<OllamaResponse>(remaining)
283 {
284 let content =
285 data.message.map(|m| m.content).unwrap_or_default();
286 return Ok(Some((
287 ResponseChunk {
288 content,
289 is_done: true,
290 },
291 (byte_stream, String::new(), true),
292 )));
293 }
294 }
295 return Ok(None);
296 }
297 Err(e) => return Err(LlmError::Http(e)),
298 }
299 }
300 },
301 );
302
303 Ok(Box::pin(stream))
304 }
305
306 async fn health_check(&self) -> bool {
307 let url = format!("{}/api/tags", self.base_url);
308 match self.client.get(&url).send().await {
309 Ok(resp) => resp.status().is_success(),
310 Err(_) => false,
311 }
312 }
313
314 fn name(&self) -> &str {
315 "ollama"
316 }
317
318 fn model(&self) -> &str {
319 &self.model
320 }
321
322 async fn list_models(&self) -> Result<Vec<String>, LlmError> {
323 #[derive(Deserialize)]
324 struct Tag {
325 name: String,
326 }
327 #[derive(Deserialize)]
328 struct Tags {
329 models: Vec<Tag>,
330 }
331
332 let url = format!("{}/api/tags", self.base_url);
333 let resp = self.client.get(&url).send().await?;
334 let resp = ensure_ok(resp).await?;
335 let data: Tags = resp.json().await?;
336 Ok(data.models.into_iter().map(|m| m.name).collect())
337 }
338
339 async fn fetch_context_window(&self) -> Option<usize> {
340 #[derive(Deserialize)]
342 struct ModelInfo {
343 #[serde(default)]
344 model_info: std::collections::HashMap<String, serde_json::Value>,
345 }
346
347 let from_api = (async {
348 let url = format!("{}/api/show", self.base_url);
349 let body = serde_json::json!({ "model": self.model });
350 let resp = self.client.post(&url).json(&body).send().await.ok()?;
351 let resp = ensure_ok(resp).await.ok()?;
352 let data: ModelInfo = resp.json().await.ok()?;
353
354 for key in &[
357 "llama.context_length",
358 "gptneox.context_length",
359 "llama2.context_length",
360 ] {
361 if let Some(val) = data.model_info.get(*key) {
362 if let Some(n) = val.as_u64().or_else(|| val.as_f64().map(|f| f as u64)) {
363 let n = n as usize;
364 if n >= 512 {
366 return Some(n);
367 }
368 }
369 }
370 }
371 None
372 })
373 .await;
374 if from_api.is_some() {
375 return from_api;
376 }
377
378 super::known_context_window(self.model())
380 }
381}
382
383fn convert_proposed_call(call: &ProposedToolCall) -> OllamaToolCall {
387 OllamaToolCall {
388 function: OllamaFunctionCall {
389 name: call.name.clone(),
390 arguments: call.arguments.clone(),
391 },
392 }
393}
394
395fn usage_from(data: &OllamaResponse) -> Option<Usage> {
397 let prompt = data.prompt_eval_count.unwrap_or(0);
398 let completion = data.eval_count.unwrap_or(0);
399 Some(Usage {
400 prompt_tokens: prompt,
401 completion_tokens: completion,
402 total_tokens: prompt + completion,
403 })
404}