Skip to main content

atomcode_core/provider/
ollama.rs

1use std::pin::Pin;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use futures::Stream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10
11use crate::config::provider::ProviderConfig;
12use crate::conversation::message::{Message, MessageContent, Role};
13use crate::stream::StreamEvent;
14use crate::tool::ToolDef;
15
16use super::LlmProvider;
17
18pub struct OllamaProvider {
19    client: Client,
20    model: String,
21    base_url: String,
22}
23
24impl OllamaProvider {
25    pub fn new(config: &ProviderConfig) -> Result<Self> {
26        Ok(Self {
27            client: super::build_http_client(config.user_agent.as_deref(), config.skip_tls_verify),
28            model: config.model.clone(),
29            base_url: config
30                .base_url
31                .clone()
32                .unwrap_or_else(|| "http://localhost:11434".to_string()),
33        })
34    }
35
36    fn format_messages(messages: &[Message]) -> Vec<serde_json::Value> {
37        messages
38            .iter()
39            .filter_map(|m| {
40                match &m.content {
41                    MessageContent::Text(s) => {
42                        // Tool role with plain Text is invalid for the tool-call
43                        // protocol — tool results must use MessageContent::ToolResult.
44                        let role = match m.role {
45                            Role::System => "system",
46                            Role::User => "user",
47                            Role::Assistant => "assistant",
48                            Role::Tool => return None,
49                        };
50                        if s.trim().is_empty() {
51                            return None;
52                        }
53                        Some(json!({"role": role, "content": s}))
54                    }
55                    MessageContent::AssistantWithToolCalls { text, tool_calls, .. } => {
56                        if tool_calls.is_empty() {
57                            let t = text.as_deref().unwrap_or("");
58                            if t.is_empty() { return None; }
59                            return Some(json!({"role": "assistant", "content": t}));
60                        }
61                        let mut msg = json!({
62                            "role": "assistant",
63                            "content": text.as_deref().unwrap_or("")
64                        });
65                        msg["tool_calls"] = json!(tool_calls.iter().map(|tc| {
66                            json!({
67                                "function": {
68                                    "name": tc.name,
69                                    "arguments": serde_json::from_str::<serde_json::Value>(&tc.arguments)
70                                        .unwrap_or_else(|_| json!({"input": tc.arguments})),
71                                }
72                            })
73                        }).collect::<Vec<_>>());
74                        Some(msg)
75                    }
76                    MessageContent::ToolResult(r) => {
77                        Some(json!({
78                            "role": "tool",
79                            "content": r.output,
80                        }))
81                    }
82                    MessageContent::ToolResultRef(r) => {
83                        Some(json!({
84                            "role": "tool",
85                            "content": r.summary,
86                        }))
87                    }
88                    MessageContent::MultiPart { text, .. } => {
89                        let t = text.as_deref().unwrap_or("");
90                        if t.is_empty() { return None; }
91                        Some(json!({"role": "user", "content": t}))
92                    }
93                }
94            })
95            .collect()
96    }
97}
98
99/// A single tool call from Ollama response.
100#[derive(Deserialize, Debug)]
101struct OllamaToolCall {
102    function: OllamaFunction,
103}
104
105#[derive(Deserialize, Debug)]
106struct OllamaFunction {
107    name: String,
108    arguments: serde_json::Value,
109}
110
111#[derive(Deserialize)]
112struct OllamaChunk {
113    message: Option<OllamaMessage>,
114    done: bool,
115    #[serde(default)]
116    prompt_eval_count: usize,
117    #[serde(default)]
118    eval_count: usize,
119}
120
121#[derive(Deserialize)]
122struct OllamaMessage {
123    #[serde(default)]
124    content: String,
125    #[serde(default)]
126    tool_calls: Option<Vec<OllamaToolCall>>,
127}
128
129#[async_trait]
130impl LlmProvider for OllamaProvider {
131    fn chat_stream(
132        &self,
133        messages: &[Message],
134        tools: Option<&[ToolDef]>,
135    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
136        let url = format!("{}/api/chat", self.base_url);
137        let mut body = json!({
138            "model": self.model,
139            "messages": Self::format_messages(messages),
140            "stream": true,
141        });
142
143        // Pass tool definitions to Ollama (supported since v0.3+)
144        if let Some(tool_defs) = tools {
145            if !tool_defs.is_empty() {
146                body["tools"] = json!(tool_defs.iter().map(|td| json!({
147                    "type": "function",
148                    "function": {
149                        "name": td.name,
150                        "description": td.description,
151                        "parameters": td.parameters,
152                    }
153                })).collect::<Vec<_>>());
154            }
155        }
156
157        let request = self
158            .client
159            .post(&url)
160            .header("Content-Type", "application/json")
161            .json(&body);
162
163        let policy = crate::provider::retry::RetryPolicy::default_policy();
164
165        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
166
167        tokio::spawn(async move {
168            let response = match crate::provider::retry::send_with_retry(request, &policy).await {
169                Ok(resp) => resp,
170                Err(e) => {
171                    let _ = tx.send(Ok(StreamEvent::Error(format!("Connection failed: {}", e))));
172                    return;
173                }
174            };
175
176            if !response.status().is_success() {
177                let status = response.status();
178                let body = response.text().await.unwrap_or_default();
179                let msg = super::extract_error_message(&body);
180                let _ = tx.send(Ok(StreamEvent::Error(format!(
181                    "Ollama error ({}): {}",
182                    status, msg
183                ))));
184                return;
185            }
186
187            // Use byte buffer to properly handle UTF-8 characters that span chunk boundaries
188            let mut byte_buffer: Vec<u8> = Vec::with_capacity(4096);
189            let mut buffer = String::new();
190            let mut byte_stream = response.bytes_stream();
191            let mut tool_call_counter = 0u32;
192
193            while let Some(chunk) = byte_stream.next().await {
194                match chunk {
195                    Ok(bytes) => {
196                        byte_buffer.extend_from_slice(&bytes);
197                    }
198                    Err(e) => {
199                        let _ = tx.send(Ok(StreamEvent::Error(e.to_string())));
200                        return;
201                    }
202                }
203
204                // Convert bytes to string, keeping incomplete UTF-8 sequences for next chunk
205                let text = match String::from_utf8(byte_buffer.clone()) {
206                    Ok(s) => {
207                        byte_buffer.clear();
208                        s
209                    }
210                    Err(e) => {
211                        let valid_len = e.utf8_error().valid_up_to();
212                        if valid_len == 0 {
213                            continue;
214                        }
215                        let valid = String::from_utf8_lossy(&byte_buffer[..valid_len]).to_string();
216                        byte_buffer = byte_buffer[valid_len..].to_vec();
217                        valid
218                    }
219                };
220
221                buffer.push_str(&text);
222
223                while let Some(pos) = buffer.find('\n') {
224                    let line = buffer[..pos].trim().to_string();
225                    buffer = buffer[pos + 1..].to_string();
226
227                    if line.is_empty() {
228                        continue;
229                    }
230
231                    if let Ok(chunk) = serde_json::from_str::<OllamaChunk>(&line) {
232                        // Handle tool calls from message
233                        if let Some(ref msg) = chunk.message {
234                            if let Some(ref tcs) = msg.tool_calls {
235                                for tc in tcs {
236                                    tool_call_counter += 1;
237                                    let call_id = format!("call_{}", tool_call_counter);
238                                    let args = tc.function.arguments.to_string();
239
240                                    let _ = tx.send(Ok(StreamEvent::ToolCallStart {
241                                        id: call_id.clone(),
242                                        name: tc.function.name.clone(),
243                                    }));
244                                    let _ = tx.send(Ok(StreamEvent::ToolCallDelta(args.clone())));
245                                    let _ = tx.send(Ok(StreamEvent::ToolCallDone(
246                                        crate::tool::ToolCall {
247                                            id: call_id,
248                                            name: tc.function.name.clone(),
249                                            arguments: args,
250                                        }
251                                    )));
252                                }
253                            }
254                        }
255
256                        if chunk.done {
257                            if chunk.eval_count > 0 || chunk.prompt_eval_count > 0 {
258                                let _ =
259                                    tx.send(Ok(StreamEvent::Usage(crate::stream::TokenUsage {
260                                        prompt_tokens: chunk.prompt_eval_count,
261                                        completion_tokens: chunk.eval_count,
262                                        cached_tokens: 0,
263                                    })));
264                            }
265                            let _ = tx.send(Ok(StreamEvent::Done { truncated: false }));
266                            return;
267                        } else if let Some(msg) = chunk.message {
268                            // Only send text delta if no tool calls in this chunk
269                            if msg.tool_calls.is_none() && !msg.content.is_empty() {
270                                let _ = tx.send(Ok(StreamEvent::Delta(msg.content)));
271                            }
272                        }
273                    }
274                }
275            }
276
277            let _ = tx.send(Ok(StreamEvent::Done { truncated: false }));
278        });
279
280        Ok(Box::pin(
281            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
282        ))
283    }
284
285    fn model_name(&self) -> &str {
286        &self.model
287    }
288}