sparrow/provider/
ollama.rs1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5
6use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
7
8pub struct OllamaAdapter {
11 model: String,
12 base_url: String,
13 client: Client,
14 caps: ModelCaps,
15}
16
17impl OllamaAdapter {
18 pub fn new(model: &str, base_url: &str) -> Self {
19 Self {
20 model: model.to_string(),
21 base_url: base_url
22 .trim_end_matches("/v1")
23 .trim_end_matches('/')
24 .to_string(),
25 client: Client::new(),
26 caps: ModelCaps {
27 context_window: 32_768,
28 max_output: 8_000,
29 tools: true,
30 vision: false,
31 cost_input_per_mtok: 0.0,
32 cost_output_per_mtok: 0.0,
33 latency: LatencyClass::Medium,
34 },
35 }
36 }
37
38 pub fn with_caps(mut self, caps: ModelCaps) -> Self {
39 self.caps = caps;
40 self
41 }
42
43 fn build_ollama_messages(req: &BrainRequest) -> Vec<serde_json::Value> {
45 let mut messages: Vec<serde_json::Value> = Vec::new();
46
47 if let Some(sys) = &req.system {
48 messages.push(json!({"role": "system", "content": sys}));
49 }
50
51 for msg in &req.messages {
52 let role = match msg.role.as_str() {
53 "assistant" => "assistant",
54 _ => "user",
55 };
56
57 let mut content = String::new();
58 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
59
60 for block in &msg.content {
61 match block {
62 ContentBlock::Text { text } => {
63 content.push_str(text);
64 }
65 ContentBlock::ToolUse { id: _, name, input } => {
66 tool_calls.push(json!({
67 "function": {
68 "name": name,
69 "arguments": input,
70 }
71 }));
72 }
73 ContentBlock::ToolResult {
74 tool_use_id,
75 content: blocks,
76 is_error: _,
77 } => {
78 let text: String = blocks
79 .iter()
80 .filter_map(|b| match b {
81 ContentBlock::Text { text } => Some(text.as_str()),
82 _ => None,
83 })
84 .collect::<Vec<_>>()
85 .join("\n");
86 messages.push(json!({
88 "role": "tool",
89 "content": text,
90 "tool_call_id": tool_use_id,
91 }));
92 }
93 _ => {}
94 }
95 }
96
97 if !content.is_empty() || tool_calls.is_empty() {
98 let mut msg_json = json!({"role": role, "content": content});
99 if !tool_calls.is_empty() {
100 msg_json["tool_calls"] = json!(tool_calls);
101 }
102 messages.push(msg_json);
103 }
104 }
105
106 messages
107 }
108
109 fn build_ollama_tools(tools: &[super::ToolSpec]) -> Vec<serde_json::Value> {
111 if tools.is_empty() {
112 return vec![];
113 }
114 tools
115 .iter()
116 .map(|t| {
117 json!({
118 "type": "function",
119 "function": {
120 "name": t.name,
121 "description": t.description,
122 "parameters": t.input_schema,
123 }
124 })
125 })
126 .collect()
127 }
128}
129
130#[async_trait]
131impl Brain for OllamaAdapter {
132 fn id(&self) -> &str {
133 &self.model
134 }
135
136 fn caps(&self) -> ModelCaps {
137 self.caps.clone()
138 }
139
140 async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
141 let messages = Self::build_ollama_messages(&req);
142 let tools = Self::build_ollama_tools(&req.tools);
143
144 let mut body = json!({
145 "model": self.model,
146 "messages": messages,
147 "stream": true,
148 "options": {
149 "temperature": req.temperature as f64,
150 }
151 });
152
153 if req.max_tokens > 0 {
154 body["options"]["num_predict"] = json!(req.max_tokens);
155 }
156 if !tools.is_empty() {
157 body["tools"] = json!(tools);
158 }
159
160 let url = format!("{}/api/chat", self.base_url);
161
162 let response = self.client.post(&url).json(&body).send().await?;
163
164 if !response.status().is_success() {
165 let status = response.status().as_u16();
166 let body = response.text().await.unwrap_or_default();
167 return Err(anyhow::anyhow!("Ollama API error {}: {}", status, body));
168 }
169
170 let stream = response.bytes_stream();
171
172 let event_stream = stream
176 .scan(super::sse_buffer::LineBuffer::new(), |line_buf, chunk| {
177 let events: Vec<BrainEvent> = match chunk {
178 Ok(bytes) => {
179 let lines = line_buf.push(&bytes);
180 let mut parsed = Vec::new();
181 for line in lines {
182 let line = line.trim();
183 if line.is_empty() {
184 continue;
185 }
186 let event: serde_json::Value = match serde_json::from_str(line) {
187 Ok(v) => v,
188 Err(_) => continue,
189 };
190
191 if let Some(msg) = event.get("message") {
193 if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
195 if !content.is_empty() {
196 parsed.push(BrainEvent::TextDelta(content.to_string()));
197 }
198 }
199 if let Some(tc_array) =
201 msg.get("tool_calls").and_then(|v| v.as_array())
202 {
203 for tc in tc_array {
204 if let Some(func) = tc.get("function") {
205 let name = func
206 .get("name")
207 .and_then(|v| v.as_str())
208 .unwrap_or("");
209 let args = func.get("arguments");
210 let id = format!("tc_{}", name);
212 parsed.push(BrainEvent::ToolUseStart {
213 id: id.clone(),
214 name: name.to_string(),
215 });
216 if let Some(args) = args {
217 parsed.push(BrainEvent::ToolUseDelta {
218 id: id.clone(),
219 json: args.to_string(),
220 });
221 }
222 parsed.push(BrainEvent::ToolUseEnd { id });
223 }
224 }
225 }
226 }
227
228 if let (Some(prompt), Some(completion)) = (
230 event.get("prompt_eval_count").and_then(|v| v.as_u64()),
231 event.get("eval_count").and_then(|v| v.as_u64()),
232 ) {
233 parsed.push(BrainEvent::Usage(crate::event::TokenUsage {
234 input: prompt,
235 output: completion,
236 }));
237 }
238
239 if event.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
241 let reason = event
242 .get("done_reason")
243 .and_then(|v| v.as_str())
244 .unwrap_or("stop");
245 let stop = match reason {
246 "stop" => crate::event::StopReason::EndTurn,
247 "length" => crate::event::StopReason::MaxTokens,
248 "tool_calls" => crate::event::StopReason::ToolUse,
249 s => crate::event::StopReason::StopSequence(s.to_string()),
250 };
251 parsed.push(BrainEvent::Done(stop));
252 }
253 }
254 parsed
255 }
256 Err(e) => vec![BrainEvent::Error(format!("Ollama stream error: {}", e))],
257 };
258 async move { Some(stream::iter(events)) }
259 })
260 .flatten();
261
262 Ok(Box::pin(event_stream))
263 }
264}