1use async_trait::async_trait;
4use futures::Stream;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::pin::Pin;
8use tracing::{debug, instrument};
9
10use crate::{
11 error::LLMError,
12 traits::{FinishReason, LLMAdapter, LLMMessage, LLMResponse, Role, StreamChunk, TokenUsage},
13};
14
15pub struct OllamaAdapter {
17 client: Client,
18 base_url: String,
19 model: String,
20 temperature: f32,
21}
22
23impl OllamaAdapter {
24 #[must_use]
30 pub fn new(model: impl Into<String>) -> Self {
31 Self {
32 client: Client::new(),
33 base_url: "http://localhost:11434".to_string(),
34 model: model.into(),
35 temperature: 0.7,
36 }
37 }
38
39 #[must_use]
41 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
42 self.base_url = base_url.into();
43 self
44 }
45
46 #[must_use]
48 pub const fn with_temperature(mut self, temperature: f32) -> Self {
49 self.temperature = temperature;
50 self
51 }
52}
53
54#[derive(Serialize)]
55struct OllamaChatRequest {
56 model: String,
57 messages: Vec<OllamaMessage>,
58 stream: bool,
59 options: OllamaOptions,
60}
61
62#[derive(Serialize)]
63struct OllamaMessage {
64 role: String,
65 content: String,
66}
67
68#[derive(Serialize)]
69struct OllamaOptions {
70 temperature: f32,
71}
72
73#[derive(Deserialize)]
74struct OllamaChatResponse {
75 message: OllamaResponseMessage,
76 done: bool,
77 #[serde(default)]
78 prompt_eval_count: Option<u32>,
79 #[serde(default)]
80 eval_count: Option<u32>,
81}
82
83#[derive(Deserialize)]
84struct OllamaResponseMessage {
85 content: String,
86}
87
88impl From<&LLMMessage> for OllamaMessage {
89 fn from(msg: &LLMMessage) -> Self {
90 Self {
91 role: match msg.role {
92 Role::System => "system".to_string(),
93 Role::User => "user".to_string(),
94 Role::Assistant => "assistant".to_string(),
95 },
96 content: msg.content.clone(),
97 }
98 }
99}
100
101#[async_trait]
102impl LLMAdapter for OllamaAdapter {
103 fn provider(&self) -> &'static str {
104 "ollama"
105 }
106
107 fn model(&self) -> &str {
108 &self.model
109 }
110
111 #[instrument(skip(self, messages), fields(provider = "ollama", model = %self.model))]
112 async fn generate(&self, messages: &[LLMMessage]) -> Result<LLMResponse, LLMError> {
113 debug!("Generating completion with {} messages", messages.len());
114
115 let request = OllamaChatRequest {
116 model: self.model.clone(),
117 messages: messages.iter().map(OllamaMessage::from).collect(),
118 stream: false,
119 options: OllamaOptions {
120 temperature: self.temperature,
121 },
122 };
123
124 let response = self
125 .client
126 .post(format!("{}/api/chat", self.base_url))
127 .json(&request)
128 .send()
129 .await
130 .map_err(|e| LLMError::ConnectionError(e.to_string()))?;
131
132 if !response.status().is_success() {
133 return Err(LLMError::ApiError(format!(
134 "Ollama returned status {}",
135 response.status()
136 )));
137 }
138
139 let chat_response: OllamaChatResponse = response
140 .json()
141 .await
142 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
143
144 let prompt_tokens = chat_response.prompt_eval_count.unwrap_or(0);
145 let completion_tokens = chat_response.eval_count.unwrap_or(0);
146
147 Ok(LLMResponse {
148 content: chat_response.message.content,
149 tokens_used: TokenUsage {
150 prompt: prompt_tokens,
151 completion: completion_tokens,
152 total: prompt_tokens + completion_tokens,
153 },
154 finish_reason: FinishReason::Stop,
155 model: self.model.clone(),
156 })
157 }
158
159 fn generate_stream(
160 &self,
161 messages: &[LLMMessage],
162 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send + '_>> {
163 let request = OllamaChatRequest {
164 model: self.model.clone(),
165 messages: messages.iter().map(OllamaMessage::from).collect(),
166 stream: true,
167 options: OllamaOptions {
168 temperature: self.temperature,
169 },
170 };
171
172 let client = self.client.clone();
173 let url = format!("{}/api/chat", self.base_url);
174
175 Box::pin(async_stream::try_stream! {
176 let response = client
177 .post(&url)
178 .json(&request)
179 .send()
180 .await
181 .map_err(|e| LLMError::ConnectionError(e.to_string()))?;
182
183 let mut stream = response.bytes_stream();
184
185 use futures::StreamExt;
186 while let Some(chunk) = stream.next().await {
187 let bytes = chunk.map_err(|e| LLMError::ConnectionError(e.to_string()))?;
188 let text = String::from_utf8_lossy(&bytes);
189
190 for line in text.lines() {
191 if line.is_empty() {
192 continue;
193 }
194
195 if let Ok(response) = serde_json::from_str::<OllamaChatResponse>(line) {
196 yield StreamChunk {
197 content: response.message.content,
198 done: response.done,
199 tokens_used: if response.done {
200 Some(TokenUsage {
201 prompt: response.prompt_eval_count.unwrap_or(0),
202 completion: response.eval_count.unwrap_or(0),
203 total: response.prompt_eval_count.unwrap_or(0)
204 + response.eval_count.unwrap_or(0),
205 })
206 } else {
207 None
208 },
209 finish_reason: if response.done {
210 Some(FinishReason::Stop)
211 } else {
212 None
213 },
214 };
215 }
216 }
217 }
218 })
219 }
220
221 async fn health_check(&self) -> Result<bool, LLMError> {
222 let response = self
223 .client
224 .get(format!("{}/api/tags", self.base_url))
225 .send()
226 .await
227 .map_err(|e| LLMError::ConnectionError(e.to_string()))?;
228
229 Ok(response.status().is_success())
230 }
231}