Skip to main content

synaptic_models/
ollama.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapseError,
7    TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use crate::backend::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct OllamaConfig {
14    pub model: String,
15    pub base_url: String,
16    pub top_p: Option<f64>,
17    pub stop: Option<Vec<String>>,
18    pub seed: Option<u64>,
19}
20
21impl OllamaConfig {
22    pub fn new(model: impl Into<String>) -> Self {
23        Self {
24            model: model.into(),
25            base_url: "http://localhost:11434".to_string(),
26            top_p: None,
27            stop: None,
28            seed: None,
29        }
30    }
31
32    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
33        self.base_url = url.into();
34        self
35    }
36
37    pub fn with_top_p(mut self, top_p: f64) -> Self {
38        self.top_p = Some(top_p);
39        self
40    }
41
42    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
43        self.stop = Some(stop);
44        self
45    }
46
47    pub fn with_seed(mut self, seed: u64) -> Self {
48        self.seed = Some(seed);
49        self
50    }
51}
52
53pub struct OllamaChatModel {
54    config: OllamaConfig,
55    backend: Arc<dyn ProviderBackend>,
56}
57
58impl OllamaChatModel {
59    pub fn new(config: OllamaConfig, backend: Arc<dyn ProviderBackend>) -> Self {
60        Self { config, backend }
61    }
62
63    fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
64        let messages: Vec<Value> = request.messages.iter().map(message_to_ollama).collect();
65
66        let mut body = json!({
67            "model": self.config.model,
68            "messages": messages,
69            "stream": stream,
70        });
71
72        if !request.tools.is_empty() {
73            body["tools"] = json!(request
74                .tools
75                .iter()
76                .map(tool_def_to_ollama)
77                .collect::<Vec<_>>());
78        }
79        if let Some(ref choice) = request.tool_choice {
80            body["tool_choice"] = match choice {
81                ToolChoice::Auto => json!("auto"),
82                ToolChoice::Required => json!("required"),
83                ToolChoice::None => json!("none"),
84                ToolChoice::Specific(name) => json!({
85                    "type": "function",
86                    "function": {"name": name}
87                }),
88            };
89        }
90
91        {
92            let mut options = json!({});
93            let mut has_options = false;
94            if let Some(top_p) = self.config.top_p {
95                options["top_p"] = json!(top_p);
96                has_options = true;
97            }
98            if let Some(ref stop) = self.config.stop {
99                options["stop"] = json!(stop);
100                has_options = true;
101            }
102            if let Some(seed) = self.config.seed {
103                options["seed"] = json!(seed);
104                has_options = true;
105            }
106            if has_options {
107                body["options"] = options;
108            }
109        }
110
111        ProviderRequest {
112            url: format!("{}/api/chat", self.config.base_url),
113            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
114            body,
115        }
116    }
117}
118
119fn message_to_ollama(msg: &Message) -> Value {
120    match msg {
121        Message::System { content, .. } => json!({
122            "role": "system",
123            "content": content,
124        }),
125        Message::Human { content, .. } => json!({
126            "role": "user",
127            "content": content,
128        }),
129        Message::AI {
130            content,
131            tool_calls,
132            ..
133        } => {
134            let mut obj = json!({
135                "role": "assistant",
136                "content": content,
137            });
138            if !tool_calls.is_empty() {
139                obj["tool_calls"] = json!(tool_calls
140                    .iter()
141                    .map(|tc| json!({
142                        "function": {
143                            "name": tc.name,
144                            "arguments": tc.arguments,
145                        }
146                    }))
147                    .collect::<Vec<_>>());
148            }
149            obj
150        }
151        Message::Tool {
152            content,
153            tool_call_id: _,
154            ..
155        } => json!({
156            "role": "tool",
157            "content": content,
158        }),
159        Message::Chat {
160            custom_role,
161            content,
162            ..
163        } => json!({
164            "role": custom_role,
165            "content": content,
166        }),
167        Message::Remove { .. } => json!(null), // Remove messages are skipped
168    }
169}
170
171fn tool_def_to_ollama(def: &ToolDefinition) -> Value {
172    json!({
173        "type": "function",
174        "function": {
175            "name": def.name,
176            "description": def.description,
177            "parameters": def.parameters,
178        }
179    })
180}
181
182fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapseError> {
183    check_error_status(resp)?;
184
185    let message_val = &resp.body["message"];
186    let content = message_val["content"].as_str().unwrap_or("").to_string();
187    let tool_calls = parse_tool_calls(message_val);
188
189    let usage = parse_usage(&resp.body);
190
191    let message = if tool_calls.is_empty() {
192        Message::ai(content)
193    } else {
194        Message::ai_with_tool_calls(content, tool_calls)
195    };
196
197    Ok(ChatResponse { message, usage })
198}
199
200fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapseError> {
201    if resp.status >= 400 {
202        let msg = resp.body["error"]
203            .as_str()
204            .unwrap_or("unknown Ollama error")
205            .to_string();
206        return Err(SynapseError::Model(format!(
207            "Ollama API error ({}): {}",
208            resp.status, msg
209        )));
210    }
211    Ok(())
212}
213
214fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
215    message["tool_calls"]
216        .as_array()
217        .map(|arr| {
218            arr.iter()
219                .enumerate()
220                .filter_map(|(i, tc)| {
221                    let name = tc["function"]["name"].as_str()?.to_string();
222                    let arguments = tc["function"]["arguments"].clone();
223                    Some(ToolCall {
224                        id: format!("ollama-{i}"),
225                        name,
226                        arguments,
227                    })
228                })
229                .collect()
230        })
231        .unwrap_or_default()
232}
233
234fn parse_usage(body: &Value) -> Option<TokenUsage> {
235    let prompt = body["prompt_eval_count"].as_u64();
236    let completion = body["eval_count"].as_u64();
237    match (prompt, completion) {
238        (Some(p), Some(c)) => Some(TokenUsage {
239            input_tokens: p as u32,
240            output_tokens: c as u32,
241            total_tokens: (p + c) as u32,
242            input_details: None,
243            output_details: None,
244        }),
245        _ => None,
246    }
247}
248
249fn parse_ndjson_chunk(line: &str) -> Option<AIMessageChunk> {
250    let v: Value = serde_json::from_str(line).ok()?;
251
252    // Ollama streaming: each line has {"message":{"role":"assistant","content":"..."}, "done":false}
253    let content = v["message"]["content"].as_str().unwrap_or("").to_string();
254    let tool_calls = parse_tool_calls(&v["message"]);
255    let done = v["done"].as_bool().unwrap_or(false);
256
257    let usage = if done { parse_usage(&v) } else { None };
258
259    Some(AIMessageChunk {
260        content,
261        tool_calls,
262        usage,
263        ..Default::default()
264    })
265}
266
267#[async_trait]
268impl ChatModel for OllamaChatModel {
269    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError> {
270        let provider_req = self.build_request(&request, false);
271        let resp = self.backend.send(provider_req).await?;
272        parse_response(&resp)
273    }
274
275    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
276        Box::pin(async_stream::stream! {
277            let provider_req = self.build_request(&request, true);
278            let byte_stream = self.backend.send_stream(provider_req).await;
279
280            let byte_stream = match byte_stream {
281                Ok(s) => s,
282                Err(e) => {
283                    yield Err(e);
284                    return;
285                }
286            };
287
288            use futures::StreamExt;
289
290            // NDJSON: accumulate bytes and split on newlines
291            let mut buffer = String::new();
292            let mut byte_stream = std::pin::pin!(byte_stream);
293
294            while let Some(result) = byte_stream.next().await {
295                match result {
296                    Ok(bytes) => {
297                        buffer.push_str(&String::from_utf8_lossy(&bytes));
298                        while let Some(pos) = buffer.find('\n') {
299                            let line = buffer[..pos].trim().to_string();
300                            buffer = buffer[pos + 1..].to_string();
301                            if line.is_empty() {
302                                continue;
303                            }
304                            if let Some(chunk) = parse_ndjson_chunk(&line) {
305                                yield Ok(chunk);
306                            }
307                        }
308                    }
309                    Err(e) => {
310                        yield Err(e);
311                        break;
312                    }
313                }
314            }
315
316            // Process remaining buffer
317            let remaining = buffer.trim().to_string();
318            if !remaining.is_empty() {
319                if let Some(chunk) = parse_ndjson_chunk(&remaining) {
320                    yield Ok(chunk);
321                }
322            }
323        })
324    }
325}