oli_server/apis/
ollama.rs

1use crate::apis::api_client::{
2    ApiClient, CompletionOptions, Message, ToolCall, ToolDefinition, ToolResult,
3};
4use crate::app::logger::{format_log_with_color, LogLevel};
5use crate::errors::AppError;
6use anyhow::Result;
7use async_trait::async_trait;
8use rand;
9use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
10use reqwest::Client as ReqwestClient;
11use serde::{Deserialize, Serialize};
12use serde_json::{self, json, Value};
13use std::time::Duration;
14
15// Ollama API Types
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct OllamaMessage {
18    role: String,
19    #[serde(default)]
20    #[serde(with = "content_string_or_object")]
21    content: String,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    tool_calls: Option<Vec<OllamaToolCall>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    tool_call_id: Option<String>,
26}
27
28// Custom serializer to handle content that might be a string or a complex object
29mod content_string_or_object {
30    use serde::{self, Deserialize, Deserializer, Serializer};
31    use serde_json::Value;
32
33    pub fn serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
34    where
35        S: Serializer,
36    {
37        serializer.serialize_str(content)
38    }
39
40    pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
41    where
42        D: Deserializer<'de>,
43    {
44        let value = Value::deserialize(deserializer)?;
45        match value {
46            Value::String(s) => Ok(s),
47            _ => Ok(value.to_string()),
48        }
49    }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53struct OllamaToolCall {
54    #[serde(default)]
55    id: String,
56    function: OllamaFunctionCall,
57    #[serde(rename = "type")]
58    #[serde(skip_serializing_if = "Option::is_none")]
59    #[serde(default)]
60    tool_type: Option<String>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64struct OllamaFunctionCall {
65    name: String,
66    #[serde(with = "arguments_as_string_or_object")]
67    arguments: String,
68}
69
70// Custom serde module for function arguments
71mod arguments_as_string_or_object {
72    use serde::{self, Deserialize, Deserializer, Serializer};
73    use serde_json::Value;
74
75    pub fn serialize<S>(arguments: &str, serializer: S) -> Result<S::Ok, S::Error>
76    where
77        S: Serializer,
78    {
79        serializer.serialize_str(arguments)
80    }
81
82    pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
83    where
84        D: Deserializer<'de>,
85    {
86        let value = Value::deserialize(deserializer)?;
87
88        match value {
89            Value::String(s) => Ok(s),
90            _ => {
91                // If it's not a string, convert the JSON to a string
92                let json_str = serde_json::to_string(&value).unwrap_or_else(|_| "{}".to_string());
93                Ok(json_str)
94            }
95        }
96    }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100struct OllamaTool {
101    #[serde(rename = "type")]
102    tool_type: String,
103    function: OllamaFunction,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107struct OllamaFunction {
108    name: String,
109    description: String,
110    parameters: Value,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114struct OllamaRequest {
115    model: String,
116    messages: Vec<OllamaMessage>,
117    stream: bool,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    temperature: Option<f32>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    top_p: Option<f32>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    options: Option<Value>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    format: Option<String>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    tools: Option<Vec<OllamaTool>>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131struct OllamaResponse {
132    model: String,
133    created_at: String,
134    message: OllamaMessage,
135    done: bool,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    total_duration: Option<u64>,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    load_duration: Option<u64>,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    prompt_eval_duration: Option<u64>,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    eval_count: Option<u64>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    eval_duration: Option<u64>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149struct OllamaListModelsResponse {
150    models: Vec<OllamaModelInfo>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct OllamaModelInfo {
155    pub name: String,
156    pub modified_at: String,
157    pub size: u64,
158    pub digest: String,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub details: Option<OllamaModelDetails>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct OllamaModelDetails {
165    pub parameter_size: Option<String>,
166    pub quantization_level: Option<String>,
167    pub format: Option<String>,
168    pub families: Option<Vec<String>>,
169    pub description: Option<String>,
170}
171
172pub struct OllamaClient {
173    client: ReqwestClient,
174    model: String,
175    api_base: String,
176}
177
178impl OllamaClient {
179    pub fn new(model: Option<String>) -> Result<Self> {
180        // Default to qwen2.5-coder:14b model if None or empty string
181        let model_name = match model {
182            Some(m) if !m.trim().is_empty() => m,
183            _ => "qwen2.5-coder:14b".to_string(),
184        };
185
186        Self::with_base_url(model_name, "http://localhost:11434".to_string())
187    }
188
189    pub fn with_base_url(model: String, api_base: String) -> Result<Self> {
190        let mut headers = HeaderMap::new();
191        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
192
193        let client = ReqwestClient::builder()
194            .default_headers(headers)
195            .timeout(Duration::from_secs(300)) // 5 minutes timeout for operations
196            .build()?;
197
198        Ok(Self {
199            client,
200            model,
201            api_base,
202        })
203    }
204
205    fn convert_messages(&self, messages: Vec<Message>) -> Vec<OllamaMessage> {
206        messages
207            .into_iter()
208            .map(|msg| {
209                // Convert standard messages to Ollama format
210                OllamaMessage {
211                    role: msg.role,
212                    content: msg.content,
213                    tool_calls: None,
214                    tool_call_id: None,
215                }
216            })
217            .collect()
218    }
219
220    fn convert_tool_definitions(&self, tools: Vec<ToolDefinition>) -> Vec<OllamaTool> {
221        tools
222            .into_iter()
223            .map(|tool| OllamaTool {
224                tool_type: "function".to_string(),
225                function: OllamaFunction {
226                    name: tool.name,
227                    description: tool.description,
228                    parameters: tool.parameters,
229                },
230            })
231            .collect()
232    }
233
234    pub async fn list_models(&self) -> Result<Vec<OllamaModelInfo>> {
235        let url = format!("{}/api/tags", self.api_base);
236
237        eprintln!(
238            "{}",
239            format_log_with_color(
240                LogLevel::Debug,
241                &format!("Listing Ollama models from: {}", url)
242            )
243        );
244
245        let response = self.client.get(&url).send().await.map_err(|e| {
246            let error_msg = if e.is_connect() {
247                // Connection failed - likely Ollama is not running
248                "Failed to connect to Ollama server. Make sure 'ollama serve' is running."
249                    .to_string()
250            } else {
251                format!("Failed to send request to Ollama: {}", e)
252            };
253            eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
254            AppError::NetworkError(error_msg)
255        })?;
256
257        if !response.status().is_success() {
258            let status = response.status();
259            let error_text = response
260                .text()
261                .await
262                .unwrap_or_else(|_| "Unknown error".to_string());
263            return Err(AppError::NetworkError(format!(
264                "Ollama API error: {} - {}",
265                status, error_text
266            ))
267            .into());
268        }
269
270        // Parse response
271        let response_text = response.text().await.map_err(|e| {
272            let error_msg = format!("Failed to get response text: {}", e);
273            eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
274            AppError::NetworkError(error_msg)
275        })?;
276
277        eprintln!(
278            "{}",
279            format_log_with_color(
280                LogLevel::Debug,
281                &format!(
282                    "Ollama API response received: {} bytes",
283                    response_text.len()
284                )
285            )
286        );
287
288        let models_response: OllamaListModelsResponse = serde_json::from_str(&response_text)
289            .map_err(|e| {
290                let error_msg = format!("Failed to parse Ollama response: {}", e);
291                eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
292                AppError::LLMError(error_msg)
293            })?;
294
295        Ok(models_response.models)
296    }
297}
298
299#[async_trait]
300impl ApiClient for OllamaClient {
301    async fn complete(&self, messages: Vec<Message>, options: CompletionOptions) -> Result<String> {
302        let ollama_messages = self.convert_messages(messages);
303
304        // Make sure we have a valid model name
305        let model_name = if self.model.is_empty() {
306            "qwen2.5-coder:14b".to_string() // Fallback to the default model
307        } else {
308            self.model.clone()
309        };
310
311        let request = OllamaRequest {
312            model: model_name,
313            messages: ollama_messages,
314            stream: false,
315            temperature: options.temperature,
316            top_p: options.top_p,
317            options: None,
318            format: if options.json_schema.is_some() {
319                Some("json".to_string())
320            } else {
321                None
322            },
323            tools: None,
324        };
325
326        let url = format!("{}/api/chat", self.api_base);
327
328        eprintln!(
329            "{}",
330            format_log_with_color(
331                LogLevel::Debug,
332                &format!("Sending request to Ollama API with model: {}", self.model)
333            )
334        );
335
336        let response = self
337            .client
338            .post(&url)
339            .json(&request)
340            .send()
341            .await
342            .map_err(|e| {
343                if e.is_connect() {
344                    // Connection failed - likely Ollama is not running
345                    AppError::NetworkError(
346                        "Failed to connect to Ollama server. Make sure 'ollama serve' is running."
347                            .to_string(),
348                    )
349                } else {
350                    AppError::NetworkError(format!("Failed to send request to Ollama: {}", e))
351                }
352            })?;
353
354        if !response.status().is_success() {
355            let status = response.status();
356            let error_text = response
357                .text()
358                .await
359                .unwrap_or_else(|_| "Unknown error".to_string());
360            return Err(AppError::NetworkError(format!(
361                "Ollama API error: {} - {}",
362                status, error_text
363            ))
364            .into());
365        }
366
367        // Parse response
368        let response_text = response.text().await.map_err(|e| {
369            let error_msg = format!("Failed to get response text: {}", e);
370            eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
371            AppError::NetworkError(error_msg)
372        })?;
373
374        eprintln!(
375            "{}",
376            format_log_with_color(
377                LogLevel::Debug,
378                &format!(
379                    "Ollama API response received: {} bytes",
380                    response_text.len()
381                )
382            )
383        );
384
385        // Try to parse as a direct response
386        let ollama_response = match serde_json::from_str::<OllamaResponse>(&response_text) {
387            Ok(resp) => resp,
388            Err(e) => {
389                // Log errors when parsing Ollama API response
390                eprintln!(
391                    "{}",
392                    format_log_with_color(
393                        LogLevel::Warning,
394                        &format!("Failed to parse standard Ollama response: {}, attempting alternate parsing", e)
395                    )
396                );
397
398                // Try to parse as a generic JSON value to extract what we need
399                let json_value: Result<serde_json::Value, _> = serde_json::from_str(&response_text);
400                if let Ok(value) = json_value {
401                    if let Some(message) = value.get("message") {
402                        let role = message
403                            .get("role")
404                            .and_then(|r| r.as_str())
405                            .unwrap_or("assistant")
406                            .to_string();
407
408                        // Extract content, which might be a string or object
409                        let content = match message.get("content") {
410                            Some(c) if c.is_string() => c.as_str().unwrap_or("").to_string(),
411                            Some(c) => c.to_string(),
412                            None => "".to_string(),
413                        };
414
415                        // Construct a valid OllamaResponse with the extracted data
416                        OllamaResponse {
417                            model: value
418                                .get("model")
419                                .and_then(|m| m.as_str())
420                                .unwrap_or("unknown")
421                                .to_string(),
422                            created_at: value
423                                .get("created_at")
424                                .and_then(|t| t.as_str())
425                                .unwrap_or("")
426                                .to_string(),
427                            message: OllamaMessage {
428                                role,
429                                content,
430                                tool_calls: None,
431                                tool_call_id: None,
432                            },
433                            done: value.get("done").and_then(|d| d.as_bool()).unwrap_or(true),
434                            total_duration: None,
435                            load_duration: None,
436                            prompt_eval_duration: None,
437                            eval_count: None,
438                            eval_duration: None,
439                        }
440                    } else {
441                        return Err(AppError::Other(format!(
442                            "Failed to parse Ollama response: {}",
443                            e
444                        ))
445                        .into());
446                    }
447                } else {
448                    return Err(
449                        AppError::Other(format!("Failed to parse Ollama response: {}", e)).into(),
450                    );
451                }
452            }
453        };
454
455        Ok(ollama_response.message.content)
456    }
457
458    async fn complete_with_tools(
459        &self,
460        messages: Vec<Message>,
461        options: CompletionOptions,
462        tool_results: Option<Vec<ToolResult>>,
463    ) -> Result<(String, Option<Vec<ToolCall>>)> {
464        // Ensure we have a valid model
465        if self.model.is_empty() {
466            return Err(anyhow::anyhow!(
467                "Model name is empty. Please select a valid Ollama model."
468            ));
469        }
470
471        // Convert messages to Ollama format
472        let mut ollama_messages = self.convert_messages(messages);
473
474        // Add tool results if provided
475        if let Some(results) = tool_results {
476            for result in results {
477                ollama_messages.push(OllamaMessage {
478                    role: "tool".to_string(),
479                    content: result.output,
480                    tool_calls: None,
481                    tool_call_id: Some(result.tool_call_id),
482                });
483            }
484        }
485
486        // Make sure we have a valid model name
487        let model_name = if self.model.is_empty() {
488            "qwen2.5-coder:14b".to_string() // Fallback to the default model
489        } else {
490            self.model.clone()
491        };
492
493        // Create the request payload
494        let mut request = OllamaRequest {
495            model: model_name,
496            messages: ollama_messages,
497            stream: false,
498            temperature: options.temperature,
499            top_p: options.top_p,
500            options: None,
501            format: if options.json_schema.is_some() {
502                Some("json".to_string())
503            } else {
504                None
505            },
506            tools: None,
507        };
508
509        // Add tools if provided
510        if let Some(tools) = options.tools {
511            let converted_tools = self.convert_tool_definitions(tools);
512            request.tools = Some(converted_tools);
513        }
514
515        let url = format!("{}/api/chat", self.api_base);
516
517        eprintln!(
518            "{}",
519            format_log_with_color(
520                LogLevel::Debug,
521                &format!("Sending request to Ollama API with model: {}", self.model)
522            )
523        );
524
525        let response = self
526            .client
527            .post(&url)
528            .json(&request)
529            .send()
530            .await
531            .map_err(|e| {
532                if e.is_connect() {
533                    // Connection failed - likely Ollama is not running
534                    AppError::NetworkError(
535                        "Failed to connect to Ollama server. Make sure 'ollama serve' is running."
536                            .to_string(),
537                    )
538                } else {
539                    AppError::NetworkError(format!("Failed to send request to Ollama: {}", e))
540                }
541            })?;
542
543        if !response.status().is_success() {
544            let status = response.status();
545            let error_text = response
546                .text()
547                .await
548                .unwrap_or_else(|_| "Unknown error".to_string());
549            return Err(AppError::NetworkError(format!(
550                "Ollama API error: {} - {}",
551                status, error_text
552            ))
553            .into());
554        }
555
556        // Parse response
557        let response_text = response.text().await.map_err(|e| {
558            let error_msg = format!("Failed to get response text: {}", e);
559            eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
560            AppError::NetworkError(error_msg)
561        })?;
562
563        eprintln!(
564            "{}",
565            format_log_with_color(
566                LogLevel::Debug,
567                &format!(
568                    "Ollama API response received: {} bytes",
569                    response_text.len()
570                )
571            )
572        );
573
574        // Try to parse as a direct response
575        let ollama_response = match serde_json::from_str::<OllamaResponse>(&response_text) {
576            Ok(resp) => resp,
577            Err(e) => {
578                // Log errors when parsing Ollama API response
579                eprintln!(
580                    "{}",
581                    format_log_with_color(
582                        LogLevel::Warning,
583                        &format!("Failed to parse standard Ollama response: {}, attempting alternate parsing", e)
584                    )
585                );
586
587                // Try to parse as a generic JSON value to extract what we need
588                let json_value: Result<serde_json::Value, _> = serde_json::from_str(&response_text);
589                if let Ok(value) = json_value {
590                    if let Some(message) = value.get("message") {
591                        let role = message
592                            .get("role")
593                            .and_then(|r| r.as_str())
594                            .unwrap_or("assistant")
595                            .to_string();
596
597                        // Extract content, which might be a string or object
598                        let content = match message.get("content") {
599                            Some(c) if c.is_string() => c.as_str().unwrap_or("").to_string(),
600                            Some(c) => c.to_string(),
601                            None => "".to_string(),
602                        };
603
604                        // Construct a valid OllamaResponse with the extracted data
605                        OllamaResponse {
606                            model: value
607                                .get("model")
608                                .and_then(|m| m.as_str())
609                                .unwrap_or("unknown")
610                                .to_string(),
611                            created_at: value
612                                .get("created_at")
613                                .and_then(|t| t.as_str())
614                                .unwrap_or("")
615                                .to_string(),
616                            message: OllamaMessage {
617                                role,
618                                content,
619                                tool_calls: None,
620                                tool_call_id: None,
621                            },
622                            done: value.get("done").and_then(|d| d.as_bool()).unwrap_or(true),
623                            total_duration: None,
624                            load_duration: None,
625                            prompt_eval_duration: None,
626                            eval_count: None,
627                            eval_duration: None,
628                        }
629                    } else {
630                        return Err(AppError::Other(format!(
631                            "Failed to parse Ollama response: {}",
632                            e
633                        ))
634                        .into());
635                    }
636                } else {
637                    return Err(
638                        AppError::Other(format!("Failed to parse Ollama response: {}", e)).into(),
639                    );
640                }
641            }
642        };
643
644        // Extract the content and tool calls from the response
645        let content = ollama_response.message.content.clone();
646
647        // Check for tool calls in the response
648        if let Some(ollama_tool_calls) = ollama_response.message.tool_calls {
649            if !ollama_tool_calls.is_empty() {
650                let tool_calls = ollama_tool_calls
651                    .iter()
652                    .map(|call| {
653                        // Parse arguments as JSON
654                        let arguments_result =
655                            serde_json::from_str::<Value>(&call.function.arguments);
656                        let arguments = match arguments_result {
657                            Ok(args) => args,
658                            Err(_) => json!({}),
659                        };
660
661                        // Generate a random ID if one wasn't provided
662                        let id = if call.id.is_empty() {
663                            format!("ollama-tool-{}", rand::random::<u64>())
664                        } else {
665                            call.id.clone()
666                        };
667
668                        // Create a tool call
669                        ToolCall {
670                            id: Some(id),
671                            name: call.function.name.clone(),
672                            arguments,
673                        }
674                    })
675                    .collect::<Vec<_>>();
676
677                return Ok((String::new(), Some(tool_calls)));
678            }
679        }
680
681        // Also try to check if the content itself contains a tool call in JSON format
682        // This handles cases where Ollama doesn't properly format its tool_calls field
683        // but still returns JSON in the content field that looks like a tool call
684        let content_str = content.trim();
685        if content_str.starts_with('{') && content_str.ends_with('}') {
686            if let Ok(json_value) = serde_json::from_str::<Value>(content_str) {
687                // Check for OpenAI style tool calls
688                if let Some(tool_calls) = json_value.get("tool_calls").and_then(|tc| tc.as_array())
689                {
690                    if !tool_calls.is_empty() {
691                        let calls = tool_calls
692                            .iter()
693                            .filter_map(|call| {
694                                let id = call.get("id").and_then(|id| id.as_str()).unwrap_or("");
695                                let function = call.get("function")?;
696                                let name = function.get("name")?.as_str()?;
697                                let arguments = function.get("arguments")?;
698
699                                let args_str = arguments.as_str().unwrap_or("{}");
700                                let args: Value =
701                                    serde_json::from_str(args_str).unwrap_or(json!({}));
702
703                                Some(ToolCall {
704                                    id: Some(id.to_string()),
705                                    name: name.to_string(),
706                                    arguments: args,
707                                })
708                            })
709                            .collect::<Vec<_>>();
710
711                        if !calls.is_empty() {
712                            return Ok((String::new(), Some(calls)));
713                        }
714                    }
715                }
716
717                // Check for the simpler/custom format that our old implementation expected
718                if let (Some(tool_name), Some(tool_args)) = (
719                    json_value.get("tool").and_then(|t| t.as_str()),
720                    json_value.get("args"),
721                ) {
722                    let tool_call = ToolCall {
723                        id: Some(format!("ollama-tool-{}", rand::random::<u64>())),
724                        name: tool_name.to_string(),
725                        arguments: tool_args.clone(),
726                    };
727
728                    return Ok((String::new(), Some(vec![tool_call])));
729                }
730            }
731        }
732
733        // If no tool calls were found, just return the content
734        Ok((content, None))
735    }
736}