1use std::pin::Pin;
6
7use crate::{
8 chat::{
9 ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
10 Tool,
11 },
12 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
13 embedding::EmbeddingProvider,
14 error::LLMError,
15 models::ModelsProvider,
16 stt::SpeechToTextProvider,
17 tts::TextToSpeechProvider,
18 FunctionCall, ToolCall,
19};
20use async_trait::async_trait;
21use base64::{self, Engine};
22use futures::Stream;
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26
27pub struct Ollama {
31 pub base_url: String,
32 pub api_key: Option<String>,
33 pub model: String,
34 pub max_tokens: Option<u32>,
35 pub temperature: Option<f32>,
36 pub system: Option<String>,
37 pub timeout_seconds: Option<u64>,
38 pub top_p: Option<f32>,
39 pub top_k: Option<u32>,
40 pub json_schema: Option<StructuredOutputFormat>,
42 pub tools: Option<Vec<Tool>>,
44 client: Client,
45}
46
47#[derive(Serialize)]
49struct OllamaChatRequest<'a> {
50 model: String,
51 messages: Vec<OllamaChatMessage<'a>>,
52 stream: bool,
53 options: Option<OllamaOptions>,
54 format: Option<OllamaResponseFormat>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 tools: Option<Vec<OllamaTool>>,
57}
58
59#[derive(Serialize)]
60struct OllamaOptions {
61 top_p: Option<f32>,
62 top_k: Option<u32>,
63}
64
65#[derive(Serialize)]
67struct OllamaChatMessage<'a> {
68 role: &'a str,
69 content: &'a str,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 images: Option<Vec<String>>,
72}
73
74impl<'a> From<&'a ChatMessage> for OllamaChatMessage<'a> {
75 fn from(msg: &'a ChatMessage) -> Self {
76 Self {
77 role: match msg.role {
78 ChatRole::User => "user",
79 ChatRole::Assistant => "assistant",
80 },
81 content: &msg.content,
82 images: match &msg.message_type {
83 MessageType::Image((_mime, data)) => {
84 Some(vec![base64::engine::general_purpose::STANDARD.encode(data)])
85 }
86 _ => None,
87 },
88 }
89 }
90}
91
92#[derive(Deserialize, Debug)]
94struct OllamaResponse {
95 content: Option<String>,
96 response: Option<String>,
97 message: Option<OllamaChatResponseMessage>,
98}
99
100impl std::fmt::Display for OllamaResponse {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 let empty = String::new();
103 let text = self
104 .content
105 .as_ref()
106 .or(self.response.as_ref())
107 .or(self.message.as_ref().map(|m| &m.content))
108 .unwrap_or(&empty);
109
110 if let Some(message) = &self.message {
112 if let Some(tool_calls) = &message.tool_calls {
113 for tc in tool_calls {
114 writeln!(
115 f,
116 "{{\"name\": \"{}\", \"arguments\": {}}}",
117 tc.function.name,
118 serde_json::to_string_pretty(&tc.function.arguments).unwrap_or_default()
119 )?;
120 }
121 }
122 }
123
124 write!(f, "{}", text)
125 }
126}
127
128impl ChatResponse for OllamaResponse {
129 fn text(&self) -> Option<String> {
130 self.content
131 .as_ref()
132 .or(self.response.as_ref())
133 .or(self.message.as_ref().map(|m| &m.content))
134 .map(|s| s.to_string())
135 }
136
137 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
138 self.message.as_ref().and_then(|msg| {
139 msg.tool_calls.as_ref().map(|tcs| {
140 tcs.iter()
141 .map(|tc| ToolCall {
142 id: format!("call_{}", tc.function.name),
143 call_type: "function".to_string(),
144 function: FunctionCall {
145 name: tc.function.name.clone(),
146 arguments: serde_json::to_string(&tc.function.arguments)
147 .unwrap_or_default(),
148 },
149 })
150 .collect()
151 })
152 })
153 }
154}
155
156#[derive(Deserialize, Debug)]
158struct OllamaChatResponseMessage {
159 content: String,
160 tool_calls: Option<Vec<OllamaToolCall>>,
161}
162
163#[derive(Deserialize, Debug)]
164struct OllamaChatStreamResponse {
165 message: OllamaChatStreamMessage,
166}
167
168#[derive(Deserialize, Debug)]
169struct OllamaChatStreamMessage {
170 content: String,
171}
172
173#[derive(Serialize)]
175struct OllamaGenerateRequest<'a> {
176 model: String,
177 prompt: &'a str,
178 raw: bool,
179 stream: bool,
180}
181
182#[derive(Serialize)]
183struct OllamaEmbeddingRequest {
184 model: String,
185 input: Vec<String>,
186}
187
188#[derive(Deserialize, Debug)]
189struct OllamaEmbeddingResponse {
190 embeddings: Vec<Vec<f32>>,
191}
192
193#[derive(Deserialize, Debug, Serialize)]
194#[serde(untagged)]
195enum OllamaResponseType {
196 #[serde(rename = "json")]
197 Json,
198 StructuredOutput(Value),
199}
200
201#[derive(Deserialize, Debug, Serialize)]
202struct OllamaResponseFormat {
203 #[serde(flatten)]
204 format: OllamaResponseType,
205}
206
207#[derive(Serialize, Debug)]
209struct OllamaTool {
210 #[serde(rename = "type")]
211 pub tool_type: String,
212
213 pub function: OllamaFunctionTool,
214}
215
216#[derive(Serialize, Debug)]
217struct OllamaFunctionTool {
218 name: String,
220 description: String,
222 parameters: OllamaParameters,
224}
225
226impl From<&crate::chat::Tool> for OllamaTool {
227 fn from(tool: &crate::chat::Tool) -> Self {
228 let properties_value = tool
229 .function
230 .parameters
231 .get("properties")
232 .cloned()
233 .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
234
235 let required_fields = tool
236 .function
237 .parameters
238 .get("required")
239 .and_then(|v| v.as_array())
240 .map(|arr| {
241 arr.iter()
242 .filter_map(|v| v.as_str().map(|s| s.to_string()))
243 .collect::<Vec<String>>()
244 })
245 .unwrap_or_default();
246
247 OllamaTool {
248 tool_type: "function".to_owned(),
249 function: OllamaFunctionTool {
250 name: tool.function.name.clone(),
251 description: tool.function.description.clone(),
252 parameters: OllamaParameters {
253 schema_type: "object".to_string(),
254 properties: properties_value,
255 required: required_fields,
256 },
257 },
258 }
259 }
260}
261
262#[derive(Serialize, Debug)]
264struct OllamaParameters {
265 #[serde(rename = "type")]
267 schema_type: String,
268 properties: Value,
270 required: Vec<String>,
272}
273
274#[derive(Deserialize, Debug)]
276struct OllamaToolCall {
277 function: OllamaFunctionCall,
278}
279
280#[derive(Deserialize, Debug)]
281struct OllamaFunctionCall {
282 name: String,
284 arguments: Value,
286}
287
288impl Ollama {
289 #[allow(clippy::too_many_arguments)]
304 pub fn new(
305 base_url: impl Into<String>,
306 api_key: Option<String>,
307 model: Option<String>,
308 max_tokens: Option<u32>,
309 temperature: Option<f32>,
310 timeout_seconds: Option<u64>,
311 system: Option<String>,
312 stream: Option<bool>,
313 top_p: Option<f32>,
314 top_k: Option<u32>,
315 json_schema: Option<StructuredOutputFormat>,
316 tools: Option<Vec<Tool>>,
317 ) -> Self {
318 let mut builder = Client::builder();
319 if let Some(sec) = timeout_seconds {
320 builder = builder.timeout(std::time::Duration::from_secs(sec));
321 }
322 Self {
323 base_url: base_url.into(),
324 api_key,
325 model: model.unwrap_or("llama3.1".to_string()),
326 temperature,
327 max_tokens,
328 timeout_seconds,
329 system,
330 top_p,
331 top_k,
332 json_schema,
333 tools,
334 client: builder.build().expect("Failed to build reqwest Client"),
335 }
336 }
337
338 fn make_chat_request<'a>(
339 &'a self,
340 messages: &'a [ChatMessage],
341 tools: Option<&'a [Tool]>,
342 stream: bool,
343 ) -> OllamaChatRequest<'a> {
344 let mut chat_messages: Vec<OllamaChatMessage> =
345 messages.iter().map(OllamaChatMessage::from).collect();
346
347 if let Some(system) = &self.system {
348 chat_messages.insert(
349 0,
350 OllamaChatMessage {
351 role: "system",
352 content: system,
353 images: None,
354 },
355 );
356 }
357
358 let ollama_tools = tools.map(|t| t.iter().map(OllamaTool::from).collect());
360
361 let format = if let Some(schema) = &self.json_schema {
363 schema.schema.as_ref().map(|schema| OllamaResponseFormat {
364 format: OllamaResponseType::StructuredOutput(schema.clone()),
365 })
366 } else {
367 None
368 };
369
370 OllamaChatRequest {
371 model: self.model.clone(),
372 messages: chat_messages,
373 stream: stream,
374 options: Some(OllamaOptions {
375 top_p: self.top_p,
376 top_k: self.top_k,
377 }),
378 format,
379 tools: ollama_tools,
380 }
381 }
382}
383
384#[async_trait]
385impl ChatProvider for Ollama {
386 async fn chat_with_tools(
387 &self,
388 messages: &[ChatMessage],
389 tools: Option<&[Tool]>,
390 ) -> Result<Box<dyn ChatResponse>, LLMError> {
391 if self.base_url.is_empty() {
392 return Err(LLMError::InvalidRequest("Missing base_url".to_string()));
393 }
394
395 let req_body = self.make_chat_request(messages, tools, false);
396
397 if log::log_enabled!(log::Level::Trace) {
398 if let Ok(json) = serde_json::to_string(&req_body) {
399 log::trace!("Ollama request payload (tools): {}", json);
400 }
401 }
402
403 let url = format!("{}/api/chat", self.base_url);
404
405 let mut request = self.client.post(&url).json(&req_body);
406
407 if let Some(timeout) = self.timeout_seconds {
408 request = request.timeout(std::time::Duration::from_secs(timeout));
409 }
410
411 let resp = request.send().await?;
412
413 log::debug!("Ollama HTTP status (tools): {}", resp.status());
414
415 let resp = resp.error_for_status()?;
416 let json_resp = resp.json::<OllamaResponse>().await?;
417
418 Ok(Box::new(json_resp))
419 }
420
421 async fn chat_stream(
422 &self,
423 messages: &[ChatMessage],
424 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError> {
425 let req_body = self.make_chat_request(messages, None, true);
426
427 let url = format!("{}/api/chat", self.base_url);
428 let mut request = self.client.post(&url).json(&req_body);
429
430 if let Some(timeout) = self.timeout_seconds {
431 request = request.timeout(std::time::Duration::from_secs(timeout));
432 }
433
434 let resp = request.send().await?;
435 log::debug!("Ollama HTTP status: {}", resp.status());
436
437 let resp = resp.error_for_status()?;
438
439 Ok(crate::chat::create_sse_stream(resp, parse_ollama_sse))
440 }
441}
442
443#[async_trait]
444impl CompletionProvider for Ollama {
445 async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
455 if self.base_url.is_empty() {
456 return Err(LLMError::InvalidRequest("Missing base_url".to_string()));
457 }
458 let url = format!("{}/api/generate", self.base_url);
459
460 let req_body = OllamaGenerateRequest {
461 model: self.model.clone(),
462 prompt: &req.prompt,
463 raw: true,
464 stream: false,
465 };
466
467 let resp = self
468 .client
469 .post(&url)
470 .json(&req_body)
471 .send()
472 .await?
473 .error_for_status()?;
474 let json_resp: OllamaResponse = resp.json().await?;
475
476 if let Some(answer) = json_resp.response.or(json_resp.content) {
477 Ok(CompletionResponse { text: answer })
478 } else {
479 Err(LLMError::ProviderError(
480 "No answer returned by Ollama".to_string(),
481 ))
482 }
483 }
484}
485
486#[async_trait]
487impl EmbeddingProvider for Ollama {
488 async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
489 if self.base_url.is_empty() {
490 return Err(LLMError::InvalidRequest("Missing base_url".to_string()));
491 }
492 let url = format!("{}/api/embed", self.base_url);
493
494 let body = OllamaEmbeddingRequest {
495 model: self.model.clone(),
496 input: text,
497 };
498
499 let resp = self
500 .client
501 .post(&url)
502 .json(&body)
503 .send()
504 .await?
505 .error_for_status()?;
506
507 let json_resp: OllamaEmbeddingResponse = resp.json().await?;
508 Ok(json_resp.embeddings)
509 }
510}
511
512#[async_trait]
513impl SpeechToTextProvider for Ollama {
514 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
515 Err(LLMError::ProviderError(
516 "Ollama does not implement speech to text endpoint yet.".into(),
517 ))
518 }
519}
520
521#[async_trait]
522impl ModelsProvider for Ollama {}
523
524impl crate::LLMProvider for Ollama {
525 fn tools(&self) -> Option<&[Tool]> {
526 self.tools.as_deref()
527 }
528}
529
530#[async_trait]
531impl TextToSpeechProvider for Ollama {}
532
533fn parse_ollama_sse(chunk: &str) -> Result<Option<String>, LLMError> {
545 let mut collected_content = String::new();
546
547 for line in chunk.lines() {
548 let line = line.trim();
549
550 match serde_json::from_str::<OllamaChatStreamResponse>(line) {
551 Ok(data) => {
552 collected_content.push_str(&data.message.content);
553 }
554 Err(e) => return Err(LLMError::JsonError(e.to_string())),
555 }
556 }
557
558 if collected_content.is_empty() {
559 Ok(None)
560 } else {
561 Ok(Some(collected_content))
562 }
563}