sparrow/provider/
ollama.rs1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5
6use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
7
8pub struct OllamaAdapter {
11 model: String,
12 base_url: String,
13 client: Client,
14 caps: ModelCaps,
15}
16
17impl OllamaAdapter {
18 pub fn new(model: &str, base_url: &str) -> Self {
19 Self {
20 model: model.to_string(),
21 base_url: base_url
22 .trim_end_matches("/v1")
23 .trim_end_matches('/')
24 .to_string(),
25 client: Client::new(),
26 caps: ModelCaps {
27 context_window: 32_768,
28 max_output: 8_000,
29 tools: true,
30 vision: false,
31 cost_input_per_mtok: 0.0,
32 cost_output_per_mtok: 0.0,
33 latency: LatencyClass::Medium,
34 },
35 }
36 }
37
38 pub fn with_caps(mut self, caps: ModelCaps) -> Self {
39 self.caps = caps;
40 self
41 }
42
43 async fn live_caps(&self) -> ModelCaps {
44 let mut caps = self.caps.clone();
45 let url = format!("{}/api/show", self.base_url);
46 let Ok(response) = self
47 .client
48 .post(&url)
49 .json(&json!({ "model": self.model }))
50 .send()
51 .await
52 else {
53 return caps;
54 };
55 if !response.status().is_success() {
56 return caps;
57 }
58 let Ok(payload) = response.json::<serde_json::Value>().await else {
59 return caps;
60 };
61
62 if let Some(capabilities) = payload.get("capabilities").and_then(|v| v.as_array()) {
63 caps.tools = capabilities.iter().any(|cap| cap.as_str() == Some("tools"));
64 }
65 if let Some(ctx) = find_context_window(&payload) {
66 caps.context_window = ctx;
67 caps.max_output = (ctx / 8).clamp(4_096, 32_000);
68 }
69 caps
70 }
71
72 fn build_ollama_messages(req: &BrainRequest) -> Vec<serde_json::Value> {
74 let mut messages: Vec<serde_json::Value> = Vec::new();
75
76 if let Some(sys) = &req.system {
77 messages.push(json!({"role": "system", "content": sys}));
78 }
79
80 for msg in &req.messages {
81 let role = match msg.role.as_str() {
82 "assistant" => "assistant",
83 _ => "user",
84 };
85
86 let mut content = String::new();
87 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
88
89 for block in &msg.content {
90 match block {
91 ContentBlock::Text { text } => {
92 content.push_str(text);
93 }
94 ContentBlock::ToolUse { id: _, name, input } => {
95 tool_calls.push(json!({
96 "function": {
97 "name": name,
98 "arguments": input,
99 }
100 }));
101 }
102 ContentBlock::ToolResult {
103 tool_use_id,
104 content: blocks,
105 is_error: _,
106 } => {
107 let text: String = blocks
108 .iter()
109 .filter_map(|b| match b {
110 ContentBlock::Text { text } => Some(text.as_str()),
111 _ => None,
112 })
113 .collect::<Vec<_>>()
114 .join("\n");
115 messages.push(json!({
117 "role": "tool",
118 "content": text,
119 "tool_call_id": tool_use_id,
120 }));
121 }
122 _ => {}
123 }
124 }
125
126 if !content.is_empty() || tool_calls.is_empty() {
127 let mut msg_json = json!({"role": role, "content": content});
128 if !tool_calls.is_empty() {
129 msg_json["tool_calls"] = json!(tool_calls);
130 }
131 messages.push(msg_json);
132 }
133 }
134
135 messages
136 }
137
138 fn build_ollama_tools(tools: &[super::ToolSpec]) -> Vec<serde_json::Value> {
140 if tools.is_empty() {
141 return vec![];
142 }
143 tools
144 .iter()
145 .map(|t| {
146 json!({
147 "type": "function",
148 "function": {
149 "name": t.name,
150 "description": t.description,
151 "parameters": t.input_schema,
152 }
153 })
154 })
155 .collect()
156 }
157}
158
159#[async_trait]
160impl Brain for OllamaAdapter {
161 fn id(&self) -> &str {
162 &self.model
163 }
164
165 fn caps(&self) -> ModelCaps {
166 self.caps.clone()
167 }
168
169 async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
170 let caps = self.live_caps().await;
171 let messages = Self::build_ollama_messages(&req);
172 let tools = if caps.tools {
173 Self::build_ollama_tools(&req.tools)
174 } else {
175 Vec::new()
176 };
177
178 let mut body = json!({
179 "model": self.model,
180 "messages": messages,
181 "stream": true,
182 "options": {
183 "temperature": req.temperature as f64,
184 }
185 });
186
187 if req.max_tokens > 0 {
188 body["options"]["num_predict"] = json!(req.max_tokens);
189 }
190 if caps.context_window > 0 {
191 body["options"]["num_ctx"] = json!(caps.context_window);
192 }
193 if !tools.is_empty() {
194 body["tools"] = json!(tools);
195 }
196
197 let url = format!("{}/api/chat", self.base_url);
198
199 let response = self.client.post(&url).json(&body).send().await?;
200
201 if !response.status().is_success() {
202 let status = response.status().as_u16();
203 let body = response.text().await.unwrap_or_default();
204 return Err(anyhow::anyhow!("Ollama API error {}: {}", status, body));
205 }
206
207 let stream = response.bytes_stream();
208
209 let event_stream = stream
213 .scan(super::sse_buffer::LineBuffer::new(), |line_buf, chunk| {
214 let events: Vec<BrainEvent> = match chunk {
215 Ok(bytes) => {
216 let lines = line_buf.push(&bytes);
217 let mut parsed = Vec::new();
218 for line in lines {
219 let line = line.trim();
220 if line.is_empty() {
221 continue;
222 }
223 let event: serde_json::Value = match serde_json::from_str(line) {
224 Ok(v) => v,
225 Err(_) => continue,
226 };
227
228 if let Some(msg) = event.get("message") {
230 if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
232 if !content.is_empty() {
233 parsed.push(BrainEvent::TextDelta(content.to_string()));
234 }
235 }
236 if let Some(tc_array) =
238 msg.get("tool_calls").and_then(|v| v.as_array())
239 {
240 for tc in tc_array {
241 if let Some(func) = tc.get("function") {
242 let name = func
243 .get("name")
244 .and_then(|v| v.as_str())
245 .unwrap_or("");
246 let args = func.get("arguments");
247 let id = format!("tc_{}", name);
249 parsed.push(BrainEvent::ToolUseStart {
250 id: id.clone(),
251 name: name.to_string(),
252 });
253 if let Some(args) = args {
254 parsed.push(BrainEvent::ToolUseDelta {
255 id: id.clone(),
256 json: args.to_string(),
257 });
258 }
259 parsed.push(BrainEvent::ToolUseEnd { id });
260 }
261 }
262 }
263 }
264
265 if let (Some(prompt), Some(completion)) = (
267 event.get("prompt_eval_count").and_then(|v| v.as_u64()),
268 event.get("eval_count").and_then(|v| v.as_u64()),
269 ) {
270 parsed.push(BrainEvent::Usage(crate::event::TokenUsage {
271 input: prompt,
272 output: completion,
273 }));
274 }
275
276 if event.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
278 let reason = event
279 .get("done_reason")
280 .and_then(|v| v.as_str())
281 .unwrap_or("stop");
282 let stop = match reason {
283 "stop" => crate::event::StopReason::EndTurn,
284 "length" => crate::event::StopReason::MaxTokens,
285 "tool_calls" => crate::event::StopReason::ToolUse,
286 s => crate::event::StopReason::StopSequence(s.to_string()),
287 };
288 parsed.push(BrainEvent::Done(stop));
289 }
290 }
291 parsed
292 }
293 Err(e) => vec![BrainEvent::Error(format!("Ollama stream error: {}", e))],
294 };
295 async move { Some(stream::iter(events)) }
296 })
297 .flatten();
298
299 Ok(Box::pin(event_stream))
300 }
301}
302
303fn find_context_window(value: &serde_json::Value) -> Option<u64> {
304 fn visit(value: &serde_json::Value, best: &mut Option<u64>) {
305 match value {
306 serde_json::Value::Object(map) => {
307 for (key, child) in map {
308 let key = key.to_ascii_lowercase();
309 if (key.ends_with("context_length")
310 || key == "num_ctx"
311 || key == "context_window")
312 && child.as_u64().is_some()
313 {
314 *best = (*best).max(child.as_u64());
315 }
316 visit(child, best);
317 }
318 }
319 serde_json::Value::Array(items) => {
320 for item in items {
321 visit(item, best);
322 }
323 }
324 _ => {}
325 }
326 }
327 let mut best = None;
328 visit(value, &mut best);
329 best
330}