atomcode_core/provider/
ollama.rs1use std::pin::Pin;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use futures::Stream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10
11use crate::config::provider::ProviderConfig;
12use crate::conversation::message::{Message, MessageContent, Role};
13use crate::stream::StreamEvent;
14use crate::tool::ToolDef;
15
16use super::LlmProvider;
17
18pub struct OllamaProvider {
19 client: Client,
20 model: String,
21 base_url: String,
22}
23
24impl OllamaProvider {
25 pub fn new(config: &ProviderConfig) -> Result<Self> {
26 Ok(Self {
27 client: super::build_http_client(config.user_agent.as_deref(), config.skip_tls_verify),
28 model: config.model.clone(),
29 base_url: config
30 .base_url
31 .clone()
32 .unwrap_or_else(|| "http://localhost:11434".to_string()),
33 })
34 }
35
36 fn format_messages(messages: &[Message]) -> Vec<serde_json::Value> {
37 messages
38 .iter()
39 .filter_map(|m| {
40 match &m.content {
41 MessageContent::Text(s) => {
42 let role = match m.role {
45 Role::System => "system",
46 Role::User => "user",
47 Role::Assistant => "assistant",
48 Role::Tool => return None,
49 };
50 if s.trim().is_empty() {
51 return None;
52 }
53 Some(json!({"role": role, "content": s}))
54 }
55 MessageContent::AssistantWithToolCalls { text, tool_calls, .. } => {
56 if tool_calls.is_empty() {
57 let t = text.as_deref().unwrap_or("");
58 if t.is_empty() { return None; }
59 return Some(json!({"role": "assistant", "content": t}));
60 }
61 let mut msg = json!({
62 "role": "assistant",
63 "content": text.as_deref().unwrap_or("")
64 });
65 msg["tool_calls"] = json!(tool_calls.iter().map(|tc| {
66 json!({
67 "function": {
68 "name": tc.name,
69 "arguments": serde_json::from_str::<serde_json::Value>(&tc.arguments)
70 .unwrap_or_else(|_| json!({"input": tc.arguments})),
71 }
72 })
73 }).collect::<Vec<_>>());
74 Some(msg)
75 }
76 MessageContent::ToolResult(r) => {
77 Some(json!({
78 "role": "tool",
79 "content": r.output,
80 }))
81 }
82 MessageContent::ToolResultRef(r) => {
83 Some(json!({
84 "role": "tool",
85 "content": r.summary,
86 }))
87 }
88 MessageContent::MultiPart { text, .. } => {
89 let t = text.as_deref().unwrap_or("");
90 if t.is_empty() { return None; }
91 Some(json!({"role": "user", "content": t}))
92 }
93 }
94 })
95 .collect()
96 }
97}
98
99#[derive(Deserialize, Debug)]
101struct OllamaToolCall {
102 function: OllamaFunction,
103}
104
105#[derive(Deserialize, Debug)]
106struct OllamaFunction {
107 name: String,
108 arguments: serde_json::Value,
109}
110
111#[derive(Deserialize)]
112struct OllamaChunk {
113 message: Option<OllamaMessage>,
114 done: bool,
115 #[serde(default)]
116 prompt_eval_count: usize,
117 #[serde(default)]
118 eval_count: usize,
119}
120
121#[derive(Deserialize)]
122struct OllamaMessage {
123 #[serde(default)]
124 content: String,
125 #[serde(default)]
126 tool_calls: Option<Vec<OllamaToolCall>>,
127}
128
129#[async_trait]
130impl LlmProvider for OllamaProvider {
131 fn chat_stream(
132 &self,
133 messages: &[Message],
134 tools: Option<&[ToolDef]>,
135 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
136 let url = format!("{}/api/chat", self.base_url);
137 let mut body = json!({
138 "model": self.model,
139 "messages": Self::format_messages(messages),
140 "stream": true,
141 });
142
143 if let Some(tool_defs) = tools {
145 if !tool_defs.is_empty() {
146 body["tools"] = json!(tool_defs.iter().map(|td| json!({
147 "type": "function",
148 "function": {
149 "name": td.name,
150 "description": td.description,
151 "parameters": td.parameters,
152 }
153 })).collect::<Vec<_>>());
154 }
155 }
156
157 let request = self
158 .client
159 .post(&url)
160 .header("Content-Type", "application/json")
161 .json(&body);
162
163 let policy = crate::provider::retry::RetryPolicy::default_policy();
164
165 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
166
167 tokio::spawn(async move {
168 let response = match crate::provider::retry::send_with_retry(request, &policy).await {
169 Ok(resp) => resp,
170 Err(e) => {
171 let _ = tx.send(Ok(StreamEvent::Error(format!("Connection failed: {}", e))));
172 return;
173 }
174 };
175
176 if !response.status().is_success() {
177 let status = response.status();
178 let body = response.text().await.unwrap_or_default();
179 let msg = super::extract_error_message(&body);
180 let _ = tx.send(Ok(StreamEvent::Error(format!(
181 "Ollama error ({}): {}",
182 status, msg
183 ))));
184 return;
185 }
186
187 let mut byte_buffer: Vec<u8> = Vec::with_capacity(4096);
189 let mut buffer = String::new();
190 let mut byte_stream = response.bytes_stream();
191 let mut tool_call_counter = 0u32;
192
193 while let Some(chunk) = byte_stream.next().await {
194 match chunk {
195 Ok(bytes) => {
196 byte_buffer.extend_from_slice(&bytes);
197 }
198 Err(e) => {
199 let _ = tx.send(Ok(StreamEvent::Error(e.to_string())));
200 return;
201 }
202 }
203
204 let text = match String::from_utf8(byte_buffer.clone()) {
206 Ok(s) => {
207 byte_buffer.clear();
208 s
209 }
210 Err(e) => {
211 let valid_len = e.utf8_error().valid_up_to();
212 if valid_len == 0 {
213 continue;
214 }
215 let valid = String::from_utf8_lossy(&byte_buffer[..valid_len]).to_string();
216 byte_buffer = byte_buffer[valid_len..].to_vec();
217 valid
218 }
219 };
220
221 buffer.push_str(&text);
222
223 while let Some(pos) = buffer.find('\n') {
224 let line = buffer[..pos].trim().to_string();
225 buffer = buffer[pos + 1..].to_string();
226
227 if line.is_empty() {
228 continue;
229 }
230
231 if let Ok(chunk) = serde_json::from_str::<OllamaChunk>(&line) {
232 if let Some(ref msg) = chunk.message {
234 if let Some(ref tcs) = msg.tool_calls {
235 for tc in tcs {
236 tool_call_counter += 1;
237 let call_id = format!("call_{}", tool_call_counter);
238 let args = tc.function.arguments.to_string();
239
240 let _ = tx.send(Ok(StreamEvent::ToolCallStart {
241 id: call_id.clone(),
242 name: tc.function.name.clone(),
243 }));
244 let _ = tx.send(Ok(StreamEvent::ToolCallDelta(args.clone())));
245 let _ = tx.send(Ok(StreamEvent::ToolCallDone(
246 crate::tool::ToolCall {
247 id: call_id,
248 name: tc.function.name.clone(),
249 arguments: args,
250 }
251 )));
252 }
253 }
254 }
255
256 if chunk.done {
257 if chunk.eval_count > 0 || chunk.prompt_eval_count > 0 {
258 let _ =
259 tx.send(Ok(StreamEvent::Usage(crate::stream::TokenUsage {
260 prompt_tokens: chunk.prompt_eval_count,
261 completion_tokens: chunk.eval_count,
262 cached_tokens: 0,
263 })));
264 }
265 let _ = tx.send(Ok(StreamEvent::Done { truncated: false }));
266 return;
267 } else if let Some(msg) = chunk.message {
268 if msg.tool_calls.is_none() && !msg.content.is_empty() {
270 let _ = tx.send(Ok(StreamEvent::Delta(msg.content)));
271 }
272 }
273 }
274 }
275 }
276
277 let _ = tx.send(Ok(StreamEvent::Done { truncated: false }));
278 });
279
280 Ok(Box::pin(
281 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
282 ))
283 }
284
285 fn model_name(&self) -> &str {
286 &self.model
287 }
288}