Skip to main content

brainos_cortex/llm/
ollama.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{LlmError, LlmProvider, Message, Response, ResponseChunk, Role, Usage};
7
8#[derive(Serialize)]
9struct OllamaRequest {
10    model: String,
11    messages: Vec<OllamaMessage>,
12    stream: bool,
13    options: Option<OllamaOptions>,
14}
15
16#[derive(Serialize, Deserialize)]
17struct OllamaMessage {
18    role: String,
19    content: String,
20}
21
22#[derive(Serialize)]
23struct OllamaOptions {
24    temperature: f64,
25    #[serde(rename = "num_predict")]
26    num_predict: i32,
27}
28
29#[derive(Deserialize)]
30struct OllamaResponse {
31    message: Option<OllamaMessage>,
32    done: bool,
33    #[serde(default)]
34    prompt_eval_count: Option<u32>,
35    #[serde(default)]
36    eval_count: Option<u32>,
37}
38
39/// Ollama LLM provider.
40pub struct OllamaProvider {
41    client: reqwest::Client,
42    base_url: String,
43    model: String,
44    temperature: f64,
45    max_tokens: i32,
46}
47
48impl OllamaProvider {
49    pub fn new(
50        base_url: &str,
51        model: &str,
52        temperature: f64,
53        max_tokens: i32,
54    ) -> Result<Self, LlmError> {
55        let client = reqwest::Client::builder()
56            .timeout(brain_core::timeouts::LLM_GENERATE)
57            .build()
58            .map_err(|e| {
59                LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
60            })?;
61
62        Ok(Self {
63            client,
64            base_url: base_url.trim_end_matches('/').to_string(),
65            model: model.to_string(),
66            temperature,
67            max_tokens,
68        })
69    }
70
71    pub fn default_config() -> Result<Self, LlmError> {
72        Self::new("http://localhost:11434", "qwen2.5-coder:7b", 0.7, 4096)
73    }
74
75    fn convert_messages(messages: &[Message]) -> Vec<OllamaMessage> {
76        messages
77            .iter()
78            .map(|m| OllamaMessage {
79                role: match m.role {
80                    Role::System => "system".to_string(),
81                    Role::User => "user".to_string(),
82                    Role::Assistant => "assistant".to_string(),
83                },
84                content: m.content.clone(),
85            })
86            .collect()
87    }
88}
89
90#[async_trait::async_trait]
91impl LlmProvider for OllamaProvider {
92    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
93        let url = format!("{}/api/chat", self.base_url);
94        let request = OllamaRequest {
95            model: self.model.clone(),
96            messages: Self::convert_messages(messages),
97            stream: false,
98            options: Some(OllamaOptions {
99                temperature: self.temperature,
100                num_predict: self.max_tokens,
101            }),
102        };
103
104        let resp = self.client.post(&url).json(&request).send().await?;
105
106        if !resp.status().is_success() {
107            let status = resp.status();
108            let body = resp.text().await.unwrap_or_default();
109            return Err(LlmError::Api {
110                status: status.as_u16(),
111                message: body,
112            });
113        }
114
115        let data: OllamaResponse = resp.json().await?;
116
117        Ok(Response {
118            content: data.message.map(|m| m.content).unwrap_or_default(),
119            usage: Some(Usage {
120                prompt_tokens: data.prompt_eval_count.unwrap_or(0),
121                completion_tokens: data.eval_count.unwrap_or(0),
122                total_tokens: data.prompt_eval_count.unwrap_or(0) + data.eval_count.unwrap_or(0),
123            }),
124        })
125    }
126
127    async fn generate_stream(
128        &self,
129        messages: &[Message],
130    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
131        use futures::stream::try_unfold;
132
133        let url = format!("{}/api/chat", self.base_url);
134        let request = OllamaRequest {
135            model: self.model.clone(),
136            messages: Self::convert_messages(messages),
137            stream: true,
138            options: Some(OllamaOptions {
139                temperature: self.temperature,
140                num_predict: self.max_tokens,
141            }),
142        };
143
144        let resp = self.client.post(&url).json(&request).send().await?;
145
146        if !resp.status().is_success() {
147            let status = resp.status();
148            let body = resp.text().await.unwrap_or_default();
149            return Err(LlmError::Api {
150                status: status.as_u16(),
151                message: body,
152            });
153        }
154
155        let byte_stream = resp.bytes_stream();
156        let stream = try_unfold(
157            (Box::pin(byte_stream), String::new(), false),
158            |(mut byte_stream, mut buf, done)| async move {
159                use futures::TryStreamExt;
160
161                if done {
162                    return Ok(None);
163                }
164
165                loop {
166                    if let Some(newline_pos) = buf.find('\n') {
167                        let line: String = buf[..newline_pos].to_string();
168                        buf = buf[newline_pos + 1..].to_string();
169
170                        let line = line.trim();
171                        if line.is_empty() {
172                            continue;
173                        }
174
175                        match serde_json::from_str::<OllamaResponse>(line) {
176                            Ok(data) => {
177                                let is_done = data.done;
178                                let content = data.message.map(|m| m.content).unwrap_or_default();
179                                let chunk = ResponseChunk { content, is_done };
180                                return Ok(Some((chunk, (byte_stream, buf, is_done))));
181                            }
182                            Err(e) => {
183                                return Err(LlmError::InvalidFormat(format!(
184                                    "Failed to parse streaming response: {e}"
185                                )));
186                            }
187                        }
188                    }
189
190                    match byte_stream.try_next().await {
191                        Ok(Some(bytes)) => {
192                            buf.push_str(&String::from_utf8_lossy(&bytes));
193                        }
194                        Ok(None) => {
195                            let remaining = buf.trim();
196                            if !remaining.is_empty() {
197                                if let Ok(data) = serde_json::from_str::<OllamaResponse>(remaining)
198                                {
199                                    let content =
200                                        data.message.map(|m| m.content).unwrap_or_default();
201                                    return Ok(Some((
202                                        ResponseChunk {
203                                            content,
204                                            is_done: true,
205                                        },
206                                        (byte_stream, String::new(), true),
207                                    )));
208                                }
209                            }
210                            return Ok(None);
211                        }
212                        Err(e) => return Err(LlmError::Http(e)),
213                    }
214                }
215            },
216        );
217
218        Ok(Box::pin(stream))
219    }
220
221    async fn health_check(&self) -> bool {
222        let url = format!("{}/api/tags", self.base_url);
223        match self.client.get(&url).send().await {
224            Ok(resp) => resp.status().is_success(),
225            Err(_) => false,
226        }
227    }
228
229    fn name(&self) -> &str {
230        "ollama"
231    }
232
233    fn model(&self) -> &str {
234        &self.model
235    }
236
237    async fn list_models(&self) -> Result<Vec<String>, LlmError> {
238        #[derive(Deserialize)]
239        struct Tag {
240            name: String,
241        }
242        #[derive(Deserialize)]
243        struct Tags {
244            models: Vec<Tag>,
245        }
246
247        let url = format!("{}/api/tags", self.base_url);
248        let resp = self.client.get(&url).send().await?;
249        if !resp.status().is_success() {
250            let status = resp.status();
251            let body = resp.text().await.unwrap_or_default();
252            return Err(LlmError::Api {
253                status: status.as_u16(),
254                message: body,
255            });
256        }
257        let data: Tags = resp.json().await?;
258        Ok(data.models.into_iter().map(|m| m.name).collect())
259    }
260}