1use crate::apis::api_client::{
2 ApiClient, CompletionOptions, Message, ToolCall, ToolDefinition, ToolResult,
3};
4use crate::app::logger::{format_log_with_color, LogLevel};
5use crate::errors::AppError;
6use anyhow::Result;
7use async_trait::async_trait;
8use rand;
9use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
10use reqwest::Client as ReqwestClient;
11use serde::{Deserialize, Serialize};
12use serde_json::{self, json, Value};
13use std::time::Duration;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17struct OllamaMessage {
18 role: String,
19 #[serde(default)]
20 #[serde(with = "content_string_or_object")]
21 content: String,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 tool_calls: Option<Vec<OllamaToolCall>>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 tool_call_id: Option<String>,
26}
27
28mod content_string_or_object {
30 use serde::{self, Deserialize, Deserializer, Serializer};
31 use serde_json::Value;
32
33 pub fn serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 serializer.serialize_str(content)
38 }
39
40 pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
41 where
42 D: Deserializer<'de>,
43 {
44 let value = Value::deserialize(deserializer)?;
45 match value {
46 Value::String(s) => Ok(s),
47 _ => Ok(value.to_string()),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53struct OllamaToolCall {
54 #[serde(default)]
55 id: String,
56 function: OllamaFunctionCall,
57 #[serde(rename = "type")]
58 #[serde(skip_serializing_if = "Option::is_none")]
59 #[serde(default)]
60 tool_type: Option<String>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64struct OllamaFunctionCall {
65 name: String,
66 #[serde(with = "arguments_as_string_or_object")]
67 arguments: String,
68}
69
70mod arguments_as_string_or_object {
72 use serde::{self, Deserialize, Deserializer, Serializer};
73 use serde_json::Value;
74
75 pub fn serialize<S>(arguments: &str, serializer: S) -> Result<S::Ok, S::Error>
76 where
77 S: Serializer,
78 {
79 serializer.serialize_str(arguments)
80 }
81
82 pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
83 where
84 D: Deserializer<'de>,
85 {
86 let value = Value::deserialize(deserializer)?;
87
88 match value {
89 Value::String(s) => Ok(s),
90 _ => {
91 let json_str = serde_json::to_string(&value).unwrap_or_else(|_| "{}".to_string());
93 Ok(json_str)
94 }
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100struct OllamaTool {
101 #[serde(rename = "type")]
102 tool_type: String,
103 function: OllamaFunction,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107struct OllamaFunction {
108 name: String,
109 description: String,
110 parameters: Value,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114struct OllamaRequest {
115 model: String,
116 messages: Vec<OllamaMessage>,
117 stream: bool,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 temperature: Option<f32>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 top_p: Option<f32>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 options: Option<Value>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 format: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 tools: Option<Vec<OllamaTool>>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131struct OllamaResponse {
132 model: String,
133 created_at: String,
134 message: OllamaMessage,
135 done: bool,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 total_duration: Option<u64>,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 load_duration: Option<u64>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 prompt_eval_duration: Option<u64>,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 eval_count: Option<u64>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 eval_duration: Option<u64>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149struct OllamaListModelsResponse {
150 models: Vec<OllamaModelInfo>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct OllamaModelInfo {
155 pub name: String,
156 pub modified_at: String,
157 pub size: u64,
158 pub digest: String,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub details: Option<OllamaModelDetails>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct OllamaModelDetails {
165 pub parameter_size: Option<String>,
166 pub quantization_level: Option<String>,
167 pub format: Option<String>,
168 pub families: Option<Vec<String>>,
169 pub description: Option<String>,
170}
171
172pub struct OllamaClient {
173 client: ReqwestClient,
174 model: String,
175 api_base: String,
176}
177
178impl OllamaClient {
179 pub fn new(model: Option<String>) -> Result<Self> {
180 let model_name = match model {
182 Some(m) if !m.trim().is_empty() => m,
183 _ => "qwen2.5-coder:14b".to_string(),
184 };
185
186 Self::with_base_url(model_name, "http://localhost:11434".to_string())
187 }
188
189 pub fn with_base_url(model: String, api_base: String) -> Result<Self> {
190 let mut headers = HeaderMap::new();
191 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
192
193 let client = ReqwestClient::builder()
194 .default_headers(headers)
195 .timeout(Duration::from_secs(300)) .build()?;
197
198 Ok(Self {
199 client,
200 model,
201 api_base,
202 })
203 }
204
205 fn convert_messages(&self, messages: Vec<Message>) -> Vec<OllamaMessage> {
206 messages
207 .into_iter()
208 .map(|msg| {
209 OllamaMessage {
211 role: msg.role,
212 content: msg.content,
213 tool_calls: None,
214 tool_call_id: None,
215 }
216 })
217 .collect()
218 }
219
220 fn convert_tool_definitions(&self, tools: Vec<ToolDefinition>) -> Vec<OllamaTool> {
221 tools
222 .into_iter()
223 .map(|tool| OllamaTool {
224 tool_type: "function".to_string(),
225 function: OllamaFunction {
226 name: tool.name,
227 description: tool.description,
228 parameters: tool.parameters,
229 },
230 })
231 .collect()
232 }
233
234 pub async fn list_models(&self) -> Result<Vec<OllamaModelInfo>> {
235 let url = format!("{}/api/tags", self.api_base);
236
237 eprintln!(
238 "{}",
239 format_log_with_color(
240 LogLevel::Debug,
241 &format!("Listing Ollama models from: {}", url)
242 )
243 );
244
245 let response = self.client.get(&url).send().await.map_err(|e| {
246 let error_msg = if e.is_connect() {
247 "Failed to connect to Ollama server. Make sure 'ollama serve' is running."
249 .to_string()
250 } else {
251 format!("Failed to send request to Ollama: {}", e)
252 };
253 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
254 AppError::NetworkError(error_msg)
255 })?;
256
257 if !response.status().is_success() {
258 let status = response.status();
259 let error_text = response
260 .text()
261 .await
262 .unwrap_or_else(|_| "Unknown error".to_string());
263 return Err(AppError::NetworkError(format!(
264 "Ollama API error: {} - {}",
265 status, error_text
266 ))
267 .into());
268 }
269
270 let response_text = response.text().await.map_err(|e| {
272 let error_msg = format!("Failed to get response text: {}", e);
273 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
274 AppError::NetworkError(error_msg)
275 })?;
276
277 eprintln!(
278 "{}",
279 format_log_with_color(
280 LogLevel::Debug,
281 &format!(
282 "Ollama API response received: {} bytes",
283 response_text.len()
284 )
285 )
286 );
287
288 let models_response: OllamaListModelsResponse = serde_json::from_str(&response_text)
289 .map_err(|e| {
290 let error_msg = format!("Failed to parse Ollama response: {}", e);
291 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
292 AppError::LLMError(error_msg)
293 })?;
294
295 Ok(models_response.models)
296 }
297}
298
299#[async_trait]
300impl ApiClient for OllamaClient {
301 async fn complete(&self, messages: Vec<Message>, options: CompletionOptions) -> Result<String> {
302 let ollama_messages = self.convert_messages(messages);
303
304 let model_name = if self.model.is_empty() {
306 "qwen2.5-coder:14b".to_string() } else {
308 self.model.clone()
309 };
310
311 let request = OllamaRequest {
312 model: model_name,
313 messages: ollama_messages,
314 stream: false,
315 temperature: options.temperature,
316 top_p: options.top_p,
317 options: None,
318 format: if options.json_schema.is_some() {
319 Some("json".to_string())
320 } else {
321 None
322 },
323 tools: None,
324 };
325
326 let url = format!("{}/api/chat", self.api_base);
327
328 eprintln!(
329 "{}",
330 format_log_with_color(
331 LogLevel::Debug,
332 &format!("Sending request to Ollama API with model: {}", self.model)
333 )
334 );
335
336 let response = self
337 .client
338 .post(&url)
339 .json(&request)
340 .send()
341 .await
342 .map_err(|e| {
343 if e.is_connect() {
344 AppError::NetworkError(
346 "Failed to connect to Ollama server. Make sure 'ollama serve' is running."
347 .to_string(),
348 )
349 } else {
350 AppError::NetworkError(format!("Failed to send request to Ollama: {}", e))
351 }
352 })?;
353
354 if !response.status().is_success() {
355 let status = response.status();
356 let error_text = response
357 .text()
358 .await
359 .unwrap_or_else(|_| "Unknown error".to_string());
360 return Err(AppError::NetworkError(format!(
361 "Ollama API error: {} - {}",
362 status, error_text
363 ))
364 .into());
365 }
366
367 let response_text = response.text().await.map_err(|e| {
369 let error_msg = format!("Failed to get response text: {}", e);
370 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
371 AppError::NetworkError(error_msg)
372 })?;
373
374 eprintln!(
375 "{}",
376 format_log_with_color(
377 LogLevel::Debug,
378 &format!(
379 "Ollama API response received: {} bytes",
380 response_text.len()
381 )
382 )
383 );
384
385 let ollama_response = match serde_json::from_str::<OllamaResponse>(&response_text) {
387 Ok(resp) => resp,
388 Err(e) => {
389 eprintln!(
391 "{}",
392 format_log_with_color(
393 LogLevel::Warning,
394 &format!("Failed to parse standard Ollama response: {}, attempting alternate parsing", e)
395 )
396 );
397
398 let json_value: Result<serde_json::Value, _> = serde_json::from_str(&response_text);
400 if let Ok(value) = json_value {
401 if let Some(message) = value.get("message") {
402 let role = message
403 .get("role")
404 .and_then(|r| r.as_str())
405 .unwrap_or("assistant")
406 .to_string();
407
408 let content = match message.get("content") {
410 Some(c) if c.is_string() => c.as_str().unwrap_or("").to_string(),
411 Some(c) => c.to_string(),
412 None => "".to_string(),
413 };
414
415 OllamaResponse {
417 model: value
418 .get("model")
419 .and_then(|m| m.as_str())
420 .unwrap_or("unknown")
421 .to_string(),
422 created_at: value
423 .get("created_at")
424 .and_then(|t| t.as_str())
425 .unwrap_or("")
426 .to_string(),
427 message: OllamaMessage {
428 role,
429 content,
430 tool_calls: None,
431 tool_call_id: None,
432 },
433 done: value.get("done").and_then(|d| d.as_bool()).unwrap_or(true),
434 total_duration: None,
435 load_duration: None,
436 prompt_eval_duration: None,
437 eval_count: None,
438 eval_duration: None,
439 }
440 } else {
441 return Err(AppError::Other(format!(
442 "Failed to parse Ollama response: {}",
443 e
444 ))
445 .into());
446 }
447 } else {
448 return Err(
449 AppError::Other(format!("Failed to parse Ollama response: {}", e)).into(),
450 );
451 }
452 }
453 };
454
455 Ok(ollama_response.message.content)
456 }
457
458 async fn complete_with_tools(
459 &self,
460 messages: Vec<Message>,
461 options: CompletionOptions,
462 tool_results: Option<Vec<ToolResult>>,
463 ) -> Result<(String, Option<Vec<ToolCall>>)> {
464 if self.model.is_empty() {
466 return Err(anyhow::anyhow!(
467 "Model name is empty. Please select a valid Ollama model."
468 ));
469 }
470
471 let mut ollama_messages = self.convert_messages(messages);
473
474 if let Some(results) = tool_results {
476 for result in results {
477 ollama_messages.push(OllamaMessage {
478 role: "tool".to_string(),
479 content: result.output,
480 tool_calls: None,
481 tool_call_id: Some(result.tool_call_id),
482 });
483 }
484 }
485
486 let model_name = if self.model.is_empty() {
488 "qwen2.5-coder:14b".to_string() } else {
490 self.model.clone()
491 };
492
493 let mut request = OllamaRequest {
495 model: model_name,
496 messages: ollama_messages,
497 stream: false,
498 temperature: options.temperature,
499 top_p: options.top_p,
500 options: None,
501 format: if options.json_schema.is_some() {
502 Some("json".to_string())
503 } else {
504 None
505 },
506 tools: None,
507 };
508
509 if let Some(tools) = options.tools {
511 let converted_tools = self.convert_tool_definitions(tools);
512 request.tools = Some(converted_tools);
513 }
514
515 let url = format!("{}/api/chat", self.api_base);
516
517 eprintln!(
518 "{}",
519 format_log_with_color(
520 LogLevel::Debug,
521 &format!("Sending request to Ollama API with model: {}", self.model)
522 )
523 );
524
525 let response = self
526 .client
527 .post(&url)
528 .json(&request)
529 .send()
530 .await
531 .map_err(|e| {
532 if e.is_connect() {
533 AppError::NetworkError(
535 "Failed to connect to Ollama server. Make sure 'ollama serve' is running."
536 .to_string(),
537 )
538 } else {
539 AppError::NetworkError(format!("Failed to send request to Ollama: {}", e))
540 }
541 })?;
542
543 if !response.status().is_success() {
544 let status = response.status();
545 let error_text = response
546 .text()
547 .await
548 .unwrap_or_else(|_| "Unknown error".to_string());
549 return Err(AppError::NetworkError(format!(
550 "Ollama API error: {} - {}",
551 status, error_text
552 ))
553 .into());
554 }
555
556 let response_text = response.text().await.map_err(|e| {
558 let error_msg = format!("Failed to get response text: {}", e);
559 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
560 AppError::NetworkError(error_msg)
561 })?;
562
563 eprintln!(
564 "{}",
565 format_log_with_color(
566 LogLevel::Debug,
567 &format!(
568 "Ollama API response received: {} bytes",
569 response_text.len()
570 )
571 )
572 );
573
574 let ollama_response = match serde_json::from_str::<OllamaResponse>(&response_text) {
576 Ok(resp) => resp,
577 Err(e) => {
578 eprintln!(
580 "{}",
581 format_log_with_color(
582 LogLevel::Warning,
583 &format!("Failed to parse standard Ollama response: {}, attempting alternate parsing", e)
584 )
585 );
586
587 let json_value: Result<serde_json::Value, _> = serde_json::from_str(&response_text);
589 if let Ok(value) = json_value {
590 if let Some(message) = value.get("message") {
591 let role = message
592 .get("role")
593 .and_then(|r| r.as_str())
594 .unwrap_or("assistant")
595 .to_string();
596
597 let content = match message.get("content") {
599 Some(c) if c.is_string() => c.as_str().unwrap_or("").to_string(),
600 Some(c) => c.to_string(),
601 None => "".to_string(),
602 };
603
604 OllamaResponse {
606 model: value
607 .get("model")
608 .and_then(|m| m.as_str())
609 .unwrap_or("unknown")
610 .to_string(),
611 created_at: value
612 .get("created_at")
613 .and_then(|t| t.as_str())
614 .unwrap_or("")
615 .to_string(),
616 message: OllamaMessage {
617 role,
618 content,
619 tool_calls: None,
620 tool_call_id: None,
621 },
622 done: value.get("done").and_then(|d| d.as_bool()).unwrap_or(true),
623 total_duration: None,
624 load_duration: None,
625 prompt_eval_duration: None,
626 eval_count: None,
627 eval_duration: None,
628 }
629 } else {
630 return Err(AppError::Other(format!(
631 "Failed to parse Ollama response: {}",
632 e
633 ))
634 .into());
635 }
636 } else {
637 return Err(
638 AppError::Other(format!("Failed to parse Ollama response: {}", e)).into(),
639 );
640 }
641 }
642 };
643
644 let content = ollama_response.message.content.clone();
646
647 if let Some(ollama_tool_calls) = ollama_response.message.tool_calls {
649 if !ollama_tool_calls.is_empty() {
650 let tool_calls = ollama_tool_calls
651 .iter()
652 .map(|call| {
653 let arguments_result =
655 serde_json::from_str::<Value>(&call.function.arguments);
656 let arguments = match arguments_result {
657 Ok(args) => args,
658 Err(_) => json!({}),
659 };
660
661 let id = if call.id.is_empty() {
663 format!("ollama-tool-{}", rand::random::<u64>())
664 } else {
665 call.id.clone()
666 };
667
668 ToolCall {
670 id: Some(id),
671 name: call.function.name.clone(),
672 arguments,
673 }
674 })
675 .collect::<Vec<_>>();
676
677 return Ok((String::new(), Some(tool_calls)));
678 }
679 }
680
681 let content_str = content.trim();
685 if content_str.starts_with('{') && content_str.ends_with('}') {
686 if let Ok(json_value) = serde_json::from_str::<Value>(content_str) {
687 if let Some(tool_calls) = json_value.get("tool_calls").and_then(|tc| tc.as_array())
689 {
690 if !tool_calls.is_empty() {
691 let calls = tool_calls
692 .iter()
693 .filter_map(|call| {
694 let id = call.get("id").and_then(|id| id.as_str()).unwrap_or("");
695 let function = call.get("function")?;
696 let name = function.get("name")?.as_str()?;
697 let arguments = function.get("arguments")?;
698
699 let args_str = arguments.as_str().unwrap_or("{}");
700 let args: Value =
701 serde_json::from_str(args_str).unwrap_or(json!({}));
702
703 Some(ToolCall {
704 id: Some(id.to_string()),
705 name: name.to_string(),
706 arguments: args,
707 })
708 })
709 .collect::<Vec<_>>();
710
711 if !calls.is_empty() {
712 return Ok((String::new(), Some(calls)));
713 }
714 }
715 }
716
717 if let (Some(tool_name), Some(tool_args)) = (
719 json_value.get("tool").and_then(|t| t.as_str()),
720 json_value.get("args"),
721 ) {
722 let tool_call = ToolCall {
723 id: Some(format!("ollama-tool-{}", rand::random::<u64>())),
724 name: tool_name.to_string(),
725 arguments: tool_args.clone(),
726 };
727
728 return Ok((String::new(), Some(vec![tool_call])));
729 }
730 }
731 }
732
733 Ok((content, None))
735 }
736}