Skip to main content

brainos_cortex/llm/
ollama.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{
7    build_http_client, ensure_ok, LlmError, LlmProvider, Message, ProposedToolCall, Response,
8    ResponseChunk, ToolDef, Usage,
9};
10
11#[derive(Serialize)]
12struct OllamaRequest {
13    model: String,
14    messages: Vec<OllamaMessage>,
15    stream: bool,
16    options: Option<OllamaOptions>,
17    /// Advertised tools. Omitted from a plain-text request so behaviour is
18    /// unchanged when no tools channel is in play.
19    #[serde(skip_serializing_if = "Option::is_none")]
20    tools: Option<Vec<OllamaTool>>,
21}
22
23#[derive(Serialize, Deserialize)]
24struct OllamaMessage {
25    role: String,
26    content: String,
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    tool_calls: Option<Vec<OllamaToolCall>>,
29}
30
31/// One advertised tool in the request (`{"type":"function", ...}` — Ollama
32/// mirrors the OpenAI function-calling shape).
33#[derive(Serialize)]
34struct OllamaTool {
35    #[serde(rename = "type")]
36    kind: &'static str,
37    function: OllamaFunctionDef,
38}
39
40#[derive(Serialize)]
41struct OllamaFunctionDef {
42    name: String,
43    description: String,
44    parameters: serde_json::Value,
45}
46
47/// A tool call in the response. Unlike OpenAI, Ollama sends
48/// `function.arguments` as a JSON *object*, not a string.
49#[derive(Serialize, Deserialize)]
50struct OllamaToolCall {
51    function: OllamaFunctionCall,
52}
53
54#[derive(Serialize, Deserialize)]
55struct OllamaFunctionCall {
56    name: String,
57    #[serde(default)]
58    arguments: serde_json::Value,
59}
60
61#[derive(Serialize)]
62struct OllamaOptions {
63    temperature: f64,
64    #[serde(rename = "num_predict")]
65    num_predict: i32,
66}
67
68#[derive(Deserialize)]
69struct OllamaResponse {
70    message: Option<OllamaMessage>,
71    done: bool,
72    #[serde(default)]
73    prompt_eval_count: Option<u32>,
74    #[serde(default)]
75    eval_count: Option<u32>,
76}
77
78/// Ollama LLM provider.
79pub struct OllamaProvider {
80    client: reqwest::Client,
81    base_url: String,
82    model: String,
83    temperature: f64,
84    max_tokens: i32,
85}
86
87impl OllamaProvider {
88    pub fn new(
89        base_url: &str,
90        model: &str,
91        temperature: f64,
92        max_tokens: i32,
93    ) -> Result<Self, LlmError> {
94        let client = build_http_client(brain::timeouts::LLM_GENERATE)?;
95        Ok(Self {
96            client,
97            base_url: base_url.trim_end_matches('/').to_string(),
98            model: model.to_string(),
99            temperature,
100            max_tokens,
101        })
102    }
103
104    pub fn default_config() -> Result<Self, LlmError> {
105        Self::new("http://localhost:11434", "qwen2.5-coder:7b", 0.7, 4096)
106    }
107
108    fn convert_messages(messages: &[Message]) -> Vec<OllamaMessage> {
109        messages
110            .iter()
111            .map(|m| OllamaMessage {
112                role: m.role.as_wire_str().to_string(),
113                content: m.content.clone(),
114                tool_calls: (!m.tool_calls.is_empty())
115                    .then(|| m.tool_calls.iter().map(convert_proposed_call).collect()),
116            })
117            .collect()
118    }
119
120    /// Translate the kernel's provider-agnostic [`ToolDef`]s into Ollama's
121    /// function-calling request shape.
122    fn convert_tools(tools: &[ToolDef]) -> Vec<OllamaTool> {
123        tools
124            .iter()
125            .map(|t| OllamaTool {
126                kind: "function",
127                function: OllamaFunctionDef {
128                    name: t.name.clone(),
129                    description: t.description.clone(),
130                    parameters: t.parameters.clone(),
131                },
132            })
133            .collect()
134    }
135
136    /// Map a response message's `tool_calls` into provider-agnostic
137    /// [`ProposedToolCall`]s. Ollama supplies no call id and sends
138    /// arguments as an object, which we pass through unchanged.
139    fn extract_tool_calls(message: &OllamaMessage) -> Vec<ProposedToolCall> {
140        message
141            .tool_calls
142            .iter()
143            .flatten()
144            .map(|tc| ProposedToolCall {
145                id: None,
146                name: tc.function.name.clone(),
147                arguments: tc.function.arguments.clone(),
148            })
149            .collect()
150    }
151}
152
153#[async_trait::async_trait]
154impl LlmProvider for OllamaProvider {
155    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
156        let url = format!("{}/api/chat", self.base_url);
157        let request = OllamaRequest {
158            model: self.model.clone(),
159            messages: Self::convert_messages(messages),
160            stream: false,
161            options: Some(OllamaOptions {
162                temperature: self.temperature,
163                num_predict: self.max_tokens,
164            }),
165            tools: None,
166        };
167
168        let resp = self.client.post(&url).json(&request).send().await?;
169        let resp = ensure_ok(resp).await?;
170
171        let data: OllamaResponse = resp.json().await?;
172        let usage = usage_from(&data);
173
174        Ok(Response::text(
175            data.message.map(|m| m.content).unwrap_or_default(),
176            usage,
177        ))
178    }
179
180    async fn generate_with_tools(
181        &self,
182        messages: &[Message],
183        tools: &[ToolDef],
184    ) -> Result<Response, LlmError> {
185        // No tools to advertise → identical to a plain generate.
186        if tools.is_empty() {
187            return self.generate(messages).await;
188        }
189
190        let url = format!("{}/api/chat", self.base_url);
191        let request = OllamaRequest {
192            model: self.model.clone(),
193            messages: Self::convert_messages(messages),
194            stream: false,
195            options: Some(OllamaOptions {
196                temperature: self.temperature,
197                num_predict: self.max_tokens,
198            }),
199            tools: Some(Self::convert_tools(tools)),
200        };
201
202        let resp = self.client.post(&url).json(&request).send().await?;
203        let resp = ensure_ok(resp).await?;
204
205        let data: OllamaResponse = resp.json().await?;
206        let usage = usage_from(&data);
207        let (content, tool_calls) = match data.message {
208            Some(ref m) => (m.content.clone(), Self::extract_tool_calls(m)),
209            None => (String::new(), Vec::new()),
210        };
211
212        Ok(Response {
213            content,
214            usage,
215            tool_calls,
216        })
217    }
218
219    async fn generate_stream(
220        &self,
221        messages: &[Message],
222    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
223        use futures::stream::try_unfold;
224
225        let url = format!("{}/api/chat", self.base_url);
226        let request = OllamaRequest {
227            model: self.model.clone(),
228            messages: Self::convert_messages(messages),
229            stream: true,
230            options: Some(OllamaOptions {
231                temperature: self.temperature,
232                num_predict: self.max_tokens,
233            }),
234            tools: None,
235        };
236
237        let resp = self.client.post(&url).json(&request).send().await?;
238        let resp = ensure_ok(resp).await?;
239
240        let byte_stream = resp.bytes_stream();
241        let stream = try_unfold(
242            (Box::pin(byte_stream), String::new(), false),
243            |(mut byte_stream, mut buf, done)| async move {
244                use futures::TryStreamExt;
245
246                if done {
247                    return Ok(None);
248                }
249
250                loop {
251                    if let Some(newline_pos) = buf.find('\n') {
252                        let line: String = buf[..newline_pos].to_string();
253                        buf = buf[newline_pos + 1..].to_string();
254
255                        let line = line.trim();
256                        if line.is_empty() {
257                            continue;
258                        }
259
260                        match serde_json::from_str::<OllamaResponse>(line) {
261                            Ok(data) => {
262                                let is_done = data.done;
263                                let content = data.message.map(|m| m.content).unwrap_or_default();
264                                let chunk = ResponseChunk { content, is_done };
265                                return Ok(Some((chunk, (byte_stream, buf, is_done))));
266                            }
267                            Err(e) => {
268                                return Err(LlmError::InvalidFormat(format!(
269                                    "Failed to parse streaming response: {e}"
270                                )));
271                            }
272                        }
273                    }
274
275                    match byte_stream.try_next().await {
276                        Ok(Some(bytes)) => {
277                            buf.push_str(&String::from_utf8_lossy(&bytes));
278                        }
279                        Ok(None) => {
280                            let remaining = buf.trim();
281                            if !remaining.is_empty() {
282                                if let Ok(data) = serde_json::from_str::<OllamaResponse>(remaining)
283                                {
284                                    let content =
285                                        data.message.map(|m| m.content).unwrap_or_default();
286                                    return Ok(Some((
287                                        ResponseChunk {
288                                            content,
289                                            is_done: true,
290                                        },
291                                        (byte_stream, String::new(), true),
292                                    )));
293                                }
294                            }
295                            return Ok(None);
296                        }
297                        Err(e) => return Err(LlmError::Http(e)),
298                    }
299                }
300            },
301        );
302
303        Ok(Box::pin(stream))
304    }
305
306    async fn health_check(&self) -> bool {
307        let url = format!("{}/api/tags", self.base_url);
308        match self.client.get(&url).send().await {
309            Ok(resp) => resp.status().is_success(),
310            Err(_) => false,
311        }
312    }
313
314    fn name(&self) -> &str {
315        "ollama"
316    }
317
318    fn model(&self) -> &str {
319        &self.model
320    }
321
322    async fn list_models(&self) -> Result<Vec<String>, LlmError> {
323        #[derive(Deserialize)]
324        struct Tag {
325            name: String,
326        }
327        #[derive(Deserialize)]
328        struct Tags {
329            models: Vec<Tag>,
330        }
331
332        let url = format!("{}/api/tags", self.base_url);
333        let resp = self.client.get(&url).send().await?;
334        let resp = ensure_ok(resp).await?;
335        let data: Tags = resp.json().await?;
336        Ok(data.models.into_iter().map(|m| m.name).collect())
337    }
338
339    async fn fetch_context_window(&self) -> Option<usize> {
340        // 1. API-based detection via /api/show (works for most Ollama models).
341        #[derive(Deserialize)]
342        struct ModelInfo {
343            #[serde(default)]
344            model_info: std::collections::HashMap<String, serde_json::Value>,
345        }
346
347        let from_api = (async {
348            let url = format!("{}/api/show", self.base_url);
349            let body = serde_json::json!({ "model": self.model });
350            let resp = self.client.post(&url).json(&body).send().await.ok()?;
351            let resp = ensure_ok(resp).await.ok()?;
352            let data: ModelInfo = resp.json().await.ok()?;
353
354            // Ollama exposes context length under various keys depending
355            // on the backend. Try known patterns.
356            for key in &[
357                "llama.context_length",
358                "gptneox.context_length",
359                "llama2.context_length",
360            ] {
361                if let Some(val) = data.model_info.get(*key) {
362                    if let Some(n) = val.as_u64().or_else(|| val.as_f64().map(|f| f as u64)) {
363                        let n = n as usize;
364                        // Sanity: reject anything below 512 (parse artifact).
365                        if n >= 512 {
366                            return Some(n);
367                        }
368                    }
369                }
370            }
371            None
372        })
373        .await;
374        if from_api.is_some() {
375            return from_api;
376        }
377
378        // 2. Model-name heuristics.
379        super::known_context_window(self.model())
380    }
381}
382
383/// Reverse of [`OllamaProvider::extract_tool_calls`]: render a kernel
384/// [`ProposedToolCall`] back into Ollama's request shape for an assistant
385/// tool-call turn. Arguments stay an object (Ollama's wire format).
386fn convert_proposed_call(call: &ProposedToolCall) -> OllamaToolCall {
387    OllamaToolCall {
388        function: OllamaFunctionCall {
389            name: call.name.clone(),
390            arguments: call.arguments.clone(),
391        },
392    }
393}
394
395/// Build the kernel's [`Usage`] from an Ollama response's eval counts.
396fn usage_from(data: &OllamaResponse) -> Option<Usage> {
397    let prompt = data.prompt_eval_count.unwrap_or(0);
398    let completion = data.eval_count.unwrap_or(0);
399    Some(Usage {
400        prompt_tokens: prompt,
401        completion_tokens: completion,
402        total_tokens: prompt + completion,
403    })
404}