Skip to main content

agentic_llm/
ollama.rs

1//! Ollama adapter for local LLM models.
2
3use async_trait::async_trait;
4use futures::Stream;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::pin::Pin;
8use tracing::{debug, instrument};
9
10use crate::{
11    error::LLMError,
12    traits::{FinishReason, LLMAdapter, LLMMessage, LLMResponse, Role, StreamChunk, TokenUsage},
13};
14
15/// Ollama adapter for local models.
16pub struct OllamaAdapter {
17    client: Client,
18    base_url: String,
19    model: String,
20    temperature: f32,
21}
22
23impl OllamaAdapter {
24    /// Create a new Ollama adapter.
25    ///
26    /// # Arguments
27    ///
28    /// * `model` - Model to use (e.g., "llama3.2", "qwen2.5-coder")
29    #[must_use]
30    pub fn new(model: impl Into<String>) -> Self {
31        Self {
32            client: Client::new(),
33            base_url: "http://localhost:11434".to_string(),
34            model: model.into(),
35            temperature: 0.7,
36        }
37    }
38
39    /// Set the base URL for Ollama server.
40    #[must_use]
41    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
42        self.base_url = base_url.into();
43        self
44    }
45
46    /// Set the temperature for generation.
47    #[must_use]
48    pub const fn with_temperature(mut self, temperature: f32) -> Self {
49        self.temperature = temperature;
50        self
51    }
52}
53
54#[derive(Serialize)]
55struct OllamaChatRequest {
56    model: String,
57    messages: Vec<OllamaMessage>,
58    stream: bool,
59    options: OllamaOptions,
60}
61
62#[derive(Serialize)]
63struct OllamaMessage {
64    role: String,
65    content: String,
66}
67
68#[derive(Serialize)]
69struct OllamaOptions {
70    temperature: f32,
71}
72
73#[derive(Deserialize)]
74struct OllamaChatResponse {
75    message: OllamaResponseMessage,
76    done: bool,
77    #[serde(default)]
78    prompt_eval_count: Option<u32>,
79    #[serde(default)]
80    eval_count: Option<u32>,
81}
82
83#[derive(Deserialize)]
84struct OllamaResponseMessage {
85    content: String,
86}
87
88impl From<&LLMMessage> for OllamaMessage {
89    fn from(msg: &LLMMessage) -> Self {
90        Self {
91            role: match msg.role {
92                Role::System => "system".to_string(),
93                Role::User => "user".to_string(),
94                Role::Assistant => "assistant".to_string(),
95            },
96            content: msg.content.clone(),
97        }
98    }
99}
100
101#[async_trait]
102impl LLMAdapter for OllamaAdapter {
103    fn provider(&self) -> &'static str {
104        "ollama"
105    }
106
107    fn model(&self) -> &str {
108        &self.model
109    }
110
111    #[instrument(skip(self, messages), fields(provider = "ollama", model = %self.model))]
112    async fn generate(&self, messages: &[LLMMessage]) -> Result<LLMResponse, LLMError> {
113        debug!("Generating completion with {} messages", messages.len());
114
115        let request = OllamaChatRequest {
116            model: self.model.clone(),
117            messages: messages.iter().map(OllamaMessage::from).collect(),
118            stream: false,
119            options: OllamaOptions {
120                temperature: self.temperature,
121            },
122        };
123
124        let response = self
125            .client
126            .post(format!("{}/api/chat", self.base_url))
127            .json(&request)
128            .send()
129            .await
130            .map_err(|e| LLMError::ConnectionError(e.to_string()))?;
131
132        if !response.status().is_success() {
133            return Err(LLMError::ApiError(format!(
134                "Ollama returned status {}",
135                response.status()
136            )));
137        }
138
139        let chat_response: OllamaChatResponse = response
140            .json()
141            .await
142            .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
143
144        let prompt_tokens = chat_response.prompt_eval_count.unwrap_or(0);
145        let completion_tokens = chat_response.eval_count.unwrap_or(0);
146
147        Ok(LLMResponse {
148            content: chat_response.message.content,
149            tokens_used: TokenUsage {
150                prompt: prompt_tokens,
151                completion: completion_tokens,
152                total: prompt_tokens + completion_tokens,
153            },
154            finish_reason: FinishReason::Stop,
155            model: self.model.clone(),
156        })
157    }
158
159    fn generate_stream(
160        &self,
161        messages: &[LLMMessage],
162    ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send + '_>> {
163        let request = OllamaChatRequest {
164            model: self.model.clone(),
165            messages: messages.iter().map(OllamaMessage::from).collect(),
166            stream: true,
167            options: OllamaOptions {
168                temperature: self.temperature,
169            },
170        };
171
172        let client = self.client.clone();
173        let url = format!("{}/api/chat", self.base_url);
174
175        Box::pin(async_stream::try_stream! {
176            let response = client
177                .post(&url)
178                .json(&request)
179                .send()
180                .await
181                .map_err(|e| LLMError::ConnectionError(e.to_string()))?;
182
183            let mut stream = response.bytes_stream();
184
185            use futures::StreamExt;
186            while let Some(chunk) = stream.next().await {
187                let bytes = chunk.map_err(|e| LLMError::ConnectionError(e.to_string()))?;
188                let text = String::from_utf8_lossy(&bytes);
189
190                for line in text.lines() {
191                    if line.is_empty() {
192                        continue;
193                    }
194
195                    if let Ok(response) = serde_json::from_str::<OllamaChatResponse>(line) {
196                        yield StreamChunk {
197                            content: response.message.content,
198                            done: response.done,
199                            tokens_used: if response.done {
200                                Some(TokenUsage {
201                                    prompt: response.prompt_eval_count.unwrap_or(0),
202                                    completion: response.eval_count.unwrap_or(0),
203                                    total: response.prompt_eval_count.unwrap_or(0)
204                                        + response.eval_count.unwrap_or(0),
205                                })
206                            } else {
207                                None
208                            },
209                            finish_reason: if response.done {
210                                Some(FinishReason::Stop)
211                            } else {
212                                None
213                            },
214                        };
215                    }
216                }
217            }
218        })
219    }
220
221    async fn health_check(&self) -> Result<bool, LLMError> {
222        let response = self
223            .client
224            .get(format!("{}/api/tags", self.base_url))
225            .send()
226            .await
227            .map_err(|e| LLMError::ConnectionError(e.to_string()))?;
228
229        Ok(response.status().is_success())
230    }
231}