1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{LlmError, LlmProvider, Message, Response, ResponseChunk, Role, Usage};
7
8#[derive(Serialize)]
9struct OllamaRequest {
10 model: String,
11 messages: Vec<OllamaMessage>,
12 stream: bool,
13 options: Option<OllamaOptions>,
14}
15
16#[derive(Serialize, Deserialize)]
17struct OllamaMessage {
18 role: String,
19 content: String,
20}
21
22#[derive(Serialize)]
23struct OllamaOptions {
24 temperature: f64,
25 #[serde(rename = "num_predict")]
26 num_predict: i32,
27}
28
29#[derive(Deserialize)]
30struct OllamaResponse {
31 message: Option<OllamaMessage>,
32 done: bool,
33 #[serde(default)]
34 prompt_eval_count: Option<u32>,
35 #[serde(default)]
36 eval_count: Option<u32>,
37}
38
39pub struct OllamaProvider {
41 client: reqwest::Client,
42 base_url: String,
43 model: String,
44 temperature: f64,
45 max_tokens: i32,
46}
47
48impl OllamaProvider {
49 pub fn new(
50 base_url: &str,
51 model: &str,
52 temperature: f64,
53 max_tokens: i32,
54 ) -> Result<Self, LlmError> {
55 let client = reqwest::Client::builder()
56 .timeout(brain_core::timeouts::LLM_GENERATE)
57 .build()
58 .map_err(|e| {
59 LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
60 })?;
61
62 Ok(Self {
63 client,
64 base_url: base_url.trim_end_matches('/').to_string(),
65 model: model.to_string(),
66 temperature,
67 max_tokens,
68 })
69 }
70
71 pub fn default_config() -> Result<Self, LlmError> {
72 Self::new("http://localhost:11434", "qwen2.5-coder:7b", 0.7, 4096)
73 }
74
75 fn convert_messages(messages: &[Message]) -> Vec<OllamaMessage> {
76 messages
77 .iter()
78 .map(|m| OllamaMessage {
79 role: match m.role {
80 Role::System => "system".to_string(),
81 Role::User => "user".to_string(),
82 Role::Assistant => "assistant".to_string(),
83 },
84 content: m.content.clone(),
85 })
86 .collect()
87 }
88}
89
90#[async_trait::async_trait]
91impl LlmProvider for OllamaProvider {
92 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
93 let url = format!("{}/api/chat", self.base_url);
94 let request = OllamaRequest {
95 model: self.model.clone(),
96 messages: Self::convert_messages(messages),
97 stream: false,
98 options: Some(OllamaOptions {
99 temperature: self.temperature,
100 num_predict: self.max_tokens,
101 }),
102 };
103
104 let resp = self.client.post(&url).json(&request).send().await?;
105
106 if !resp.status().is_success() {
107 let status = resp.status();
108 let body = resp.text().await.unwrap_or_default();
109 return Err(LlmError::Api {
110 status: status.as_u16(),
111 message: body,
112 });
113 }
114
115 let data: OllamaResponse = resp.json().await?;
116
117 Ok(Response {
118 content: data.message.map(|m| m.content).unwrap_or_default(),
119 usage: Some(Usage {
120 prompt_tokens: data.prompt_eval_count.unwrap_or(0),
121 completion_tokens: data.eval_count.unwrap_or(0),
122 total_tokens: data.prompt_eval_count.unwrap_or(0) + data.eval_count.unwrap_or(0),
123 }),
124 })
125 }
126
127 async fn generate_stream(
128 &self,
129 messages: &[Message],
130 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
131 use futures::stream::try_unfold;
132
133 let url = format!("{}/api/chat", self.base_url);
134 let request = OllamaRequest {
135 model: self.model.clone(),
136 messages: Self::convert_messages(messages),
137 stream: true,
138 options: Some(OllamaOptions {
139 temperature: self.temperature,
140 num_predict: self.max_tokens,
141 }),
142 };
143
144 let resp = self.client.post(&url).json(&request).send().await?;
145
146 if !resp.status().is_success() {
147 let status = resp.status();
148 let body = resp.text().await.unwrap_or_default();
149 return Err(LlmError::Api {
150 status: status.as_u16(),
151 message: body,
152 });
153 }
154
155 let byte_stream = resp.bytes_stream();
156 let stream = try_unfold(
157 (Box::pin(byte_stream), String::new(), false),
158 |(mut byte_stream, mut buf, done)| async move {
159 use futures::TryStreamExt;
160
161 if done {
162 return Ok(None);
163 }
164
165 loop {
166 if let Some(newline_pos) = buf.find('\n') {
167 let line: String = buf[..newline_pos].to_string();
168 buf = buf[newline_pos + 1..].to_string();
169
170 let line = line.trim();
171 if line.is_empty() {
172 continue;
173 }
174
175 match serde_json::from_str::<OllamaResponse>(line) {
176 Ok(data) => {
177 let is_done = data.done;
178 let content = data.message.map(|m| m.content).unwrap_or_default();
179 let chunk = ResponseChunk { content, is_done };
180 return Ok(Some((chunk, (byte_stream, buf, is_done))));
181 }
182 Err(e) => {
183 return Err(LlmError::InvalidFormat(format!(
184 "Failed to parse streaming response: {e}"
185 )));
186 }
187 }
188 }
189
190 match byte_stream.try_next().await {
191 Ok(Some(bytes)) => {
192 buf.push_str(&String::from_utf8_lossy(&bytes));
193 }
194 Ok(None) => {
195 let remaining = buf.trim();
196 if !remaining.is_empty() {
197 if let Ok(data) = serde_json::from_str::<OllamaResponse>(remaining)
198 {
199 let content =
200 data.message.map(|m| m.content).unwrap_or_default();
201 return Ok(Some((
202 ResponseChunk {
203 content,
204 is_done: true,
205 },
206 (byte_stream, String::new(), true),
207 )));
208 }
209 }
210 return Ok(None);
211 }
212 Err(e) => return Err(LlmError::Http(e)),
213 }
214 }
215 },
216 );
217
218 Ok(Box::pin(stream))
219 }
220
221 async fn health_check(&self) -> bool {
222 let url = format!("{}/api/tags", self.base_url);
223 match self.client.get(&url).send().await {
224 Ok(resp) => resp.status().is_success(),
225 Err(_) => false,
226 }
227 }
228
229 fn name(&self) -> &str {
230 "ollama"
231 }
232
233 fn model(&self) -> &str {
234 &self.model
235 }
236
237 async fn list_models(&self) -> Result<Vec<String>, LlmError> {
238 #[derive(Deserialize)]
239 struct Tag {
240 name: String,
241 }
242 #[derive(Deserialize)]
243 struct Tags {
244 models: Vec<Tag>,
245 }
246
247 let url = format!("{}/api/tags", self.base_url);
248 let resp = self.client.get(&url).send().await?;
249 if !resp.status().is_success() {
250 let status = resp.status();
251 let body = resp.text().await.unwrap_or_default();
252 return Err(LlmError::Api {
253 status: status.as_u16(),
254 message: body,
255 });
256 }
257 let data: Tags = resp.json().await?;
258 Ok(data.models.into_iter().map(|m| m.name).collect())
259 }
260}