llm/backends/
ollama.rs

1//! Ollama API client implementation for chat and completion functionality.
2//!
3//! This module provides integration with Ollama's local LLM server through its API.
4
5use std::pin::Pin;
6
7use crate::{
8    chat::{
9        ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
10        Tool,
11    },
12    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
13    embedding::EmbeddingProvider,
14    error::LLMError,
15    models::ModelsProvider,
16    stt::SpeechToTextProvider,
17    tts::TextToSpeechProvider,
18    FunctionCall, ToolCall,
19};
20use async_trait::async_trait;
21use base64::{self, Engine};
22use futures::Stream;
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26
27/// Client for interacting with Ollama's API.
28///
29/// Provides methods for chat and completion requests using Ollama's models.
30pub struct Ollama {
31    pub base_url: String,
32    pub api_key: Option<String>,
33    pub model: String,
34    pub max_tokens: Option<u32>,
35    pub temperature: Option<f32>,
36    pub system: Option<String>,
37    pub timeout_seconds: Option<u64>,
38    pub top_p: Option<f32>,
39    pub top_k: Option<u32>,
40    /// JSON schema for structured output
41    pub json_schema: Option<StructuredOutputFormat>,
42    /// Available tools for function calling
43    pub tools: Option<Vec<Tool>>,
44    client: Client,
45}
46
47/// Request payload for Ollama's chat API endpoint.
48#[derive(Serialize)]
49struct OllamaChatRequest<'a> {
50    model: String,
51    messages: Vec<OllamaChatMessage<'a>>,
52    stream: bool,
53    options: Option<OllamaOptions>,
54    format: Option<OllamaResponseFormat>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    tools: Option<Vec<OllamaTool>>,
57}
58
59#[derive(Serialize)]
60struct OllamaOptions {
61    top_p: Option<f32>,
62    top_k: Option<u32>,
63}
64
65/// Individual message in an Ollama chat conversation.
66#[derive(Serialize)]
67struct OllamaChatMessage<'a> {
68    role: &'a str,
69    content: &'a str,
70    #[serde(skip_serializing_if = "Option::is_none")]
71    images: Option<Vec<String>>,
72}
73
74impl<'a> From<&'a ChatMessage> for OllamaChatMessage<'a> {
75    fn from(msg: &'a ChatMessage) -> Self {
76        Self {
77            role: match msg.role {
78                ChatRole::User => "user",
79                ChatRole::Assistant => "assistant",
80            },
81            content: &msg.content,
82            images: match &msg.message_type {
83                MessageType::Image((_mime, data)) => {
84                    Some(vec![base64::engine::general_purpose::STANDARD.encode(data)])
85                }
86                _ => None,
87            },
88        }
89    }
90}
91
92/// Response from Ollama's API endpoints.
93#[derive(Deserialize, Debug)]
94struct OllamaResponse {
95    content: Option<String>,
96    response: Option<String>,
97    message: Option<OllamaChatResponseMessage>,
98}
99
100impl std::fmt::Display for OllamaResponse {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        let empty = String::new();
103        let text = self
104            .content
105            .as_ref()
106            .or(self.response.as_ref())
107            .or(self.message.as_ref().map(|m| &m.content))
108            .unwrap_or(&empty);
109
110        // Write tool calls if present
111        if let Some(message) = &self.message {
112            if let Some(tool_calls) = &message.tool_calls {
113                for tc in tool_calls {
114                    writeln!(
115                        f,
116                        "{{\"name\": \"{}\", \"arguments\": {}}}",
117                        tc.function.name,
118                        serde_json::to_string_pretty(&tc.function.arguments).unwrap_or_default()
119                    )?;
120                }
121            }
122        }
123
124        write!(f, "{}", text)
125    }
126}
127
128impl ChatResponse for OllamaResponse {
129    fn text(&self) -> Option<String> {
130        self.content
131            .as_ref()
132            .or(self.response.as_ref())
133            .or(self.message.as_ref().map(|m| &m.content))
134            .map(|s| s.to_string())
135    }
136
137    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
138        self.message.as_ref().and_then(|msg| {
139            msg.tool_calls.as_ref().map(|tcs| {
140                tcs.iter()
141                    .map(|tc| ToolCall {
142                        id: format!("call_{}", tc.function.name),
143                        call_type: "function".to_string(),
144                        function: FunctionCall {
145                            name: tc.function.name.clone(),
146                            arguments: serde_json::to_string(&tc.function.arguments)
147                                .unwrap_or_default(),
148                        },
149                    })
150                    .collect()
151            })
152        })
153    }
154}
155
156/// Message content within an Ollama chat API response.
157#[derive(Deserialize, Debug)]
158struct OllamaChatResponseMessage {
159    content: String,
160    tool_calls: Option<Vec<OllamaToolCall>>,
161}
162
163#[derive(Deserialize, Debug)]
164struct OllamaChatStreamResponse {
165    message: OllamaChatStreamMessage,
166}
167
168#[derive(Deserialize, Debug)]
169struct OllamaChatStreamMessage {
170    content: String,
171}
172
173/// Request payload for Ollama's generate API endpoint.
174#[derive(Serialize)]
175struct OllamaGenerateRequest<'a> {
176    model: String,
177    prompt: &'a str,
178    raw: bool,
179    stream: bool,
180}
181
182#[derive(Serialize)]
183struct OllamaEmbeddingRequest {
184    model: String,
185    input: Vec<String>,
186}
187
188#[derive(Deserialize, Debug)]
189struct OllamaEmbeddingResponse {
190    embeddings: Vec<Vec<f32>>,
191}
192
193#[derive(Deserialize, Debug, Serialize)]
194#[serde(untagged)]
195enum OllamaResponseType {
196    #[serde(rename = "json")]
197    Json,
198    StructuredOutput(Value),
199}
200
201#[derive(Deserialize, Debug, Serialize)]
202struct OllamaResponseFormat {
203    #[serde(flatten)]
204    format: OllamaResponseType,
205}
206
207/// Ollama's tool format
208#[derive(Serialize, Debug)]
209struct OllamaTool {
210    #[serde(rename = "type")]
211    pub tool_type: String,
212
213    pub function: OllamaFunctionTool,
214}
215
216#[derive(Serialize, Debug)]
217struct OllamaFunctionTool {
218    /// Name of the tool
219    name: String,
220    /// Description of what the tool does
221    description: String,
222    /// Parameters for the tool
223    parameters: OllamaParameters,
224}
225
226impl From<&crate::chat::Tool> for OllamaTool {
227    fn from(tool: &crate::chat::Tool) -> Self {
228        let properties_value = tool
229            .function
230            .parameters
231            .get("properties")
232            .cloned()
233            .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
234
235        let required_fields = tool
236            .function
237            .parameters
238            .get("required")
239            .and_then(|v| v.as_array())
240            .map(|arr| {
241                arr.iter()
242                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
243                    .collect::<Vec<String>>()
244            })
245            .unwrap_or_default();
246
247        OllamaTool {
248            tool_type: "function".to_owned(),
249            function: OllamaFunctionTool {
250                name: tool.function.name.clone(),
251                description: tool.function.description.clone(),
252                parameters: OllamaParameters {
253                    schema_type: "object".to_string(),
254                    properties: properties_value,
255                    required: required_fields,
256                },
257            },
258        }
259    }
260}
261
262/// Ollama's parameters schema
263#[derive(Serialize, Debug)]
264struct OllamaParameters {
265    /// The type of parameters object (usually "object")
266    #[serde(rename = "type")]
267    schema_type: String,
268    /// Map of parameter names to their properties
269    properties: Value,
270    /// List of required parameter names
271    required: Vec<String>,
272}
273
274/// Ollama's tool call response
275#[derive(Deserialize, Debug)]
276struct OllamaToolCall {
277    function: OllamaFunctionCall,
278}
279
280#[derive(Deserialize, Debug)]
281struct OllamaFunctionCall {
282    /// Name of the tool that was called
283    name: String,
284    /// Arguments provided to the tool
285    arguments: Value,
286}
287
288impl Ollama {
289    /// Creates a new Ollama client with the specified configuration.
290    ///
291    /// # Arguments
292    ///
293    /// * `base_url` - Base URL of the Ollama server
294    /// * `api_key` - Optional API key for authentication
295    /// * `model` - Model name to use (defaults to "llama3.1")
296    /// * `max_tokens` - Maximum tokens to generate
297    /// * `temperature` - Sampling temperature
298    /// * `timeout_seconds` - Request timeout in seconds
299    /// * `system` - System prompt
300    /// * `stream` - Whether to stream responses
301    /// * `json_schema` - JSON schema for structured output
302    /// * `tools` - Function tools that the model can use
303    #[allow(clippy::too_many_arguments)]
304    pub fn new(
305        base_url: impl Into<String>,
306        api_key: Option<String>,
307        model: Option<String>,
308        max_tokens: Option<u32>,
309        temperature: Option<f32>,
310        timeout_seconds: Option<u64>,
311        system: Option<String>,
312        stream: Option<bool>,
313        top_p: Option<f32>,
314        top_k: Option<u32>,
315        json_schema: Option<StructuredOutputFormat>,
316        tools: Option<Vec<Tool>>,
317    ) -> Self {
318        let mut builder = Client::builder();
319        if let Some(sec) = timeout_seconds {
320            builder = builder.timeout(std::time::Duration::from_secs(sec));
321        }
322        Self {
323            base_url: base_url.into(),
324            api_key,
325            model: model.unwrap_or("llama3.1".to_string()),
326            temperature,
327            max_tokens,
328            timeout_seconds,
329            system,
330            top_p,
331            top_k,
332            json_schema,
333            tools,
334            client: builder.build().expect("Failed to build reqwest Client"),
335        }
336    }
337
338    fn make_chat_request<'a>(
339        &'a self,
340        messages: &'a [ChatMessage],
341        tools: Option<&'a [Tool]>,
342        stream: bool,
343    ) -> OllamaChatRequest<'a> {
344        let mut chat_messages: Vec<OllamaChatMessage> =
345            messages.iter().map(OllamaChatMessage::from).collect();
346
347        if let Some(system) = &self.system {
348            chat_messages.insert(
349                0,
350                OllamaChatMessage {
351                    role: "system",
352                    content: system,
353                    images: None,
354                },
355            );
356        }
357
358        // Convert tools to Ollama format if provided
359        let ollama_tools = tools.map(|t| t.iter().map(OllamaTool::from).collect());
360
361        // Ollama doesn't require the "name" field in the schema, so we just use the schema itself
362        let format = if let Some(schema) = &self.json_schema {
363            schema.schema.as_ref().map(|schema| OllamaResponseFormat {
364                format: OllamaResponseType::StructuredOutput(schema.clone()),
365            })
366        } else {
367            None
368        };
369
370        OllamaChatRequest {
371            model: self.model.clone(),
372            messages: chat_messages,
373            stream: stream,
374            options: Some(OllamaOptions {
375                top_p: self.top_p,
376                top_k: self.top_k,
377            }),
378            format,
379            tools: ollama_tools,
380        }
381    }
382}
383
384#[async_trait]
385impl ChatProvider for Ollama {
386    async fn chat_with_tools(
387        &self,
388        messages: &[ChatMessage],
389        tools: Option<&[Tool]>,
390    ) -> Result<Box<dyn ChatResponse>, LLMError> {
391        if self.base_url.is_empty() {
392            return Err(LLMError::InvalidRequest("Missing base_url".to_string()));
393        }
394
395        let req_body = self.make_chat_request(messages, tools, false);
396
397        if log::log_enabled!(log::Level::Trace) {
398            if let Ok(json) = serde_json::to_string(&req_body) {
399                log::trace!("Ollama request payload (tools): {}", json);
400            }
401        }
402
403        let url = format!("{}/api/chat", self.base_url);
404
405        let mut request = self.client.post(&url).json(&req_body);
406
407        if let Some(timeout) = self.timeout_seconds {
408            request = request.timeout(std::time::Duration::from_secs(timeout));
409        }
410
411        let resp = request.send().await?;
412
413        log::debug!("Ollama HTTP status (tools): {}", resp.status());
414
415        let resp = resp.error_for_status()?;
416        let json_resp = resp.json::<OllamaResponse>().await?;
417
418        Ok(Box::new(json_resp))
419    }
420
421    async fn chat_stream(
422        &self,
423        messages: &[ChatMessage],
424    ) -> Result<Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError> {
425        let req_body = self.make_chat_request(messages, None, true);
426
427        let url = format!("{}/api/chat", self.base_url);
428        let mut request = self.client.post(&url).json(&req_body);
429
430        if let Some(timeout) = self.timeout_seconds {
431            request = request.timeout(std::time::Duration::from_secs(timeout));
432        }
433
434        let resp = request.send().await?;
435        log::debug!("Ollama HTTP status: {}", resp.status());
436
437        let resp = resp.error_for_status()?;
438
439        Ok(crate::chat::create_sse_stream(resp, parse_ollama_sse))
440    }
441}
442
443#[async_trait]
444impl CompletionProvider for Ollama {
445    /// Sends a completion request to Ollama's API.
446    ///
447    /// # Arguments
448    ///
449    /// * `req` - The completion request containing the prompt
450    ///
451    /// # Returns
452    ///
453    /// The completion response containing the generated text or an error
454    async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
455        if self.base_url.is_empty() {
456            return Err(LLMError::InvalidRequest("Missing base_url".to_string()));
457        }
458        let url = format!("{}/api/generate", self.base_url);
459
460        let req_body = OllamaGenerateRequest {
461            model: self.model.clone(),
462            prompt: &req.prompt,
463            raw: true,
464            stream: false,
465        };
466
467        let resp = self
468            .client
469            .post(&url)
470            .json(&req_body)
471            .send()
472            .await?
473            .error_for_status()?;
474        let json_resp: OllamaResponse = resp.json().await?;
475
476        if let Some(answer) = json_resp.response.or(json_resp.content) {
477            Ok(CompletionResponse { text: answer })
478        } else {
479            Err(LLMError::ProviderError(
480                "No answer returned by Ollama".to_string(),
481            ))
482        }
483    }
484}
485
486#[async_trait]
487impl EmbeddingProvider for Ollama {
488    async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
489        if self.base_url.is_empty() {
490            return Err(LLMError::InvalidRequest("Missing base_url".to_string()));
491        }
492        let url = format!("{}/api/embed", self.base_url);
493
494        let body = OllamaEmbeddingRequest {
495            model: self.model.clone(),
496            input: text,
497        };
498
499        let resp = self
500            .client
501            .post(&url)
502            .json(&body)
503            .send()
504            .await?
505            .error_for_status()?;
506
507        let json_resp: OllamaEmbeddingResponse = resp.json().await?;
508        Ok(json_resp.embeddings)
509    }
510}
511
512#[async_trait]
513impl SpeechToTextProvider for Ollama {
514    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
515        Err(LLMError::ProviderError(
516            "Ollama does not implement speech to text endpoint yet.".into(),
517        ))
518    }
519}
520
521#[async_trait]
522impl ModelsProvider for Ollama {}
523
524impl crate::LLMProvider for Ollama {
525    fn tools(&self) -> Option<&[Tool]> {
526        self.tools.as_deref()
527    }
528}
529
530#[async_trait]
531impl TextToSpeechProvider for Ollama {}
532
533/// Parses a Server-Sent Events (SSE) chunk from Ollama's streaming API.
534/// Ollama events differ from other providers because it uses json lines instead of the expected SSE format.
535/// # Arguments
536///
537/// * `chunk` - The raw SSE chunk text
538///
539/// # Returns
540///
541/// * `Ok(Some(String))` - Content token if found
542/// * `Ok(None)` - If chunk should be skipped (e.g., ping, done signal)
543/// * `Err(LLMError)` - If parsing fails
544fn parse_ollama_sse(chunk: &str) -> Result<Option<String>, LLMError> {
545    let mut collected_content = String::new();
546
547    for line in chunk.lines() {
548        let line = line.trim();
549
550        match serde_json::from_str::<OllamaChatStreamResponse>(line) {
551            Ok(data) => {
552                collected_content.push_str(&data.message.content);
553            }
554            Err(e) => return Err(LLMError::JsonError(e.to_string())),
555        }
556    }
557
558    if collected_content.is_empty() {
559        Ok(None)
560    } else {
561        Ok(Some(collected_content))
562    }
563}