Skip to main content

aster/providers/
ollama.rs

1use super::api_client::{ApiClient, AuthMethod};
2use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
3use super::errors::ProviderError;
4use super::retry::ProviderRetry;
5use super::utils::{
6    get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
7    RequestLog,
8};
9use crate::config::declarative_providers::DeclarativeProviderConfig;
10use crate::config::AsterMode;
11use crate::conversation::message::Message;
12use crate::conversation::Conversation;
13
14use crate::model::ModelConfig;
15use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
16use crate::utils::safe_truncate;
17use anyhow::Result;
18use async_trait::async_trait;
19use regex::Regex;
20use rmcp::model::Tool;
21use serde_json::Value;
22use std::time::Duration;
23use url::Url;
24
25pub const OLLAMA_HOST: &str = "localhost";
26pub const OLLAMA_TIMEOUT: u64 = 600;
27pub const OLLAMA_DEFAULT_PORT: u16 = 11434;
28pub const OLLAMA_DEFAULT_MODEL: &str = "qwen3";
29pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
30    OLLAMA_DEFAULT_MODEL,
31    "qwen3-coder:30b",
32    "qwen3-coder:480b-cloud",
33];
34pub const OLLAMA_DOC_URL: &str = "https://ollama.com/library";
35
36#[derive(serde::Serialize)]
37pub struct OllamaProvider {
38    #[serde(skip)]
39    api_client: ApiClient,
40    model: ModelConfig,
41    supports_streaming: bool,
42    name: String,
43}
44
45impl OllamaProvider {
46    pub async fn from_env(model: ModelConfig) -> Result<Self> {
47        let config = crate::config::Config::global();
48        let host: String = config
49            .get_param("OLLAMA_HOST")
50            .unwrap_or_else(|_| OLLAMA_HOST.to_string());
51
52        let timeout: Duration =
53            Duration::from_secs(config.get_param("OLLAMA_TIMEOUT").unwrap_or(OLLAMA_TIMEOUT));
54
55        let base = if host.starts_with("http://") || host.starts_with("https://") {
56            host.clone()
57        } else {
58            format!("http://{}", host)
59        };
60
61        let mut base_url =
62            Url::parse(&base).map_err(|e| anyhow::anyhow!("Invalid base URL: {e}"))?;
63
64        let explicit_port = host.contains(':');
65        let is_localhost = host == "localhost" || host == "127.0.0.1" || host == "::1";
66
67        if base_url.port().is_none() && !explicit_port && !host.starts_with("http") && is_localhost
68        {
69            base_url
70                .set_port(Some(OLLAMA_DEFAULT_PORT))
71                .map_err(|_| anyhow::anyhow!("Failed to set default port"))?;
72        }
73
74        let auth = AuthMethod::Custom(Box::new(NoAuth));
75        let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?;
76
77        Ok(Self {
78            api_client,
79            model,
80            supports_streaming: true,
81            name: Self::metadata().name,
82        })
83    }
84
85    pub fn from_custom_config(
86        model: ModelConfig,
87        config: DeclarativeProviderConfig,
88    ) -> Result<Self> {
89        let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT));
90
91        let base =
92            if config.base_url.starts_with("http://") || config.base_url.starts_with("https://") {
93                config.base_url.clone()
94            } else {
95                format!("http://{}", config.base_url)
96            };
97
98        let mut base_url = Url::parse(&base)
99            .map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?;
100
101        let explicit_default_port =
102            config.base_url.ends_with(":80") || config.base_url.ends_with(":443");
103        let is_https = base_url.scheme() == "https";
104
105        if base_url.port().is_none() && !explicit_default_port && !is_https {
106            base_url
107                .set_port(Some(OLLAMA_DEFAULT_PORT))
108                .map_err(|_| anyhow::anyhow!("Failed to set default port"))?;
109        }
110
111        let auth = AuthMethod::Custom(Box::new(NoAuth));
112        let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?;
113
114        Ok(Self {
115            api_client,
116            model,
117            supports_streaming: config.supports_streaming.unwrap_or(true),
118            name: config.name.clone(),
119        })
120    }
121
122    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
123        let response = self
124            .api_client
125            .response_post("v1/chat/completions", payload)
126            .await?;
127        handle_response_openai_compat(response).await
128    }
129}
130
131struct NoAuth;
132
133#[async_trait]
134impl super::api_client::AuthProvider for NoAuth {
135    async fn get_auth_header(&self) -> Result<(String, String)> {
136        Ok(("X-No-Auth".to_string(), "true".to_string()))
137    }
138}
139
140#[async_trait]
141impl Provider for OllamaProvider {
142    fn metadata() -> ProviderMetadata {
143        ProviderMetadata::new(
144            "ollama",
145            "Ollama",
146            "Local open source models",
147            OLLAMA_DEFAULT_MODEL,
148            OLLAMA_KNOWN_MODELS.to_vec(),
149            OLLAMA_DOC_URL,
150            vec![
151                ConfigKey::new("OLLAMA_HOST", true, false, Some(OLLAMA_HOST)),
152                ConfigKey::new(
153                    "OLLAMA_TIMEOUT",
154                    false,
155                    false,
156                    Some(&(OLLAMA_TIMEOUT.to_string())),
157                ),
158            ],
159        )
160    }
161
162    fn get_name(&self) -> &str {
163        &self.name
164    }
165
166    fn get_model_config(&self) -> ModelConfig {
167        self.model.clone()
168    }
169
170    #[tracing::instrument(
171        skip(self, model_config, system, messages, tools),
172        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
173    )]
174    async fn complete_with_model(
175        &self,
176        model_config: &ModelConfig,
177        system: &str,
178        messages: &[Message],
179        tools: &[Tool],
180    ) -> Result<(Message, ProviderUsage), ProviderError> {
181        let config = crate::config::Config::global();
182        let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
183        let filtered_tools = if aster_mode == AsterMode::Chat {
184            &[]
185        } else {
186            tools
187        };
188
189        let payload = create_request(
190            model_config,
191            system,
192            messages,
193            filtered_tools,
194            &super::utils::ImageFormat::OpenAi,
195            false,
196        )?;
197
198        let mut log = RequestLog::start(model_config, &payload)?;
199        let response = self
200            .with_retry(|| async {
201                let payload_clone = payload.clone();
202                self.post(&payload_clone).await
203            })
204            .await
205            .inspect_err(|e| {
206                let _ = log.error(e);
207            })?;
208
209        let message = response_to_message(&response)?;
210
211        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
212            tracing::debug!("Failed to get usage data");
213            Usage::default()
214        });
215        let response_model = get_model(&response);
216        log.write(&response, Some(&usage))?;
217        Ok((message, ProviderUsage::new(response_model, usage)))
218    }
219
220    async fn generate_session_name(
221        &self,
222        messages: &Conversation,
223    ) -> Result<String, ProviderError> {
224        let context = self.get_initial_user_messages(messages);
225        let message = Message::user().with_text(self.create_session_name_prompt(&context));
226        let result = self
227            .complete(
228                "You are a title generator. Output only the requested title of 4 words or less, with no additional text, reasoning, or explanations.",
229                &[message],
230                &[],
231            )
232            .await?;
233
234        let mut description = result.0.as_concat_text();
235        description = Self::filter_reasoning_tokens(&description);
236
237        Ok(safe_truncate(&description, 100))
238    }
239
240    fn supports_streaming(&self) -> bool {
241        self.supports_streaming
242    }
243
244    async fn stream(
245        &self,
246        system: &str,
247        messages: &[Message],
248        tools: &[Tool],
249    ) -> Result<MessageStream, ProviderError> {
250        let config = crate::config::Config::global();
251        let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
252        let filtered_tools = if aster_mode == AsterMode::Chat {
253            &[]
254        } else {
255            tools
256        };
257
258        let payload = create_request(
259            &self.model,
260            system,
261            messages,
262            filtered_tools,
263            &super::utils::ImageFormat::OpenAi,
264            true,
265        )?;
266        let mut log = RequestLog::start(&self.model, &payload)?;
267
268        let response = self
269            .with_retry(|| async {
270                let resp = self
271                    .api_client
272                    .response_post("v1/chat/completions", &payload)
273                    .await?;
274                handle_status_openai_compat(resp).await
275            })
276            .await
277            .inspect_err(|e| {
278                let _ = log.error(e);
279            })?;
280        stream_openai_compat(response, log)
281    }
282
283    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
284        let response = self
285            .api_client
286            .response_get("api/tags")
287            .await
288            .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?;
289
290        if !response.status().is_success() {
291            return Err(ProviderError::RequestFailed(format!(
292                "Failed to fetch models: HTTP {}",
293                response.status()
294            )));
295        }
296
297        let json_response = response.json::<Value>().await.map_err(|e| {
298            ProviderError::RequestFailed(format!("Failed to parse response: {}", e))
299        })?;
300
301        let models = json_response
302            .get("models")
303            .and_then(|m| m.as_array())
304            .ok_or_else(|| {
305                ProviderError::RequestFailed("No models array in response".to_string())
306            })?;
307
308        let mut model_names: Vec<String> = models
309            .iter()
310            .filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
311            .collect();
312
313        model_names.sort();
314
315        Ok(Some(model_names))
316    }
317}
318
319impl OllamaProvider {
320    fn filter_reasoning_tokens(text: &str) -> String {
321        let mut filtered = text.to_string();
322
323        let reasoning_patterns = [
324            r"<think>.*?</think>",
325            r"<thinking>.*?</thinking>",
326            r"Let me think.*?\n",
327            r"I need to.*?\n",
328            r"First, I.*?\n",
329            r"Okay, .*?\n",
330            r"So, .*?\n",
331            r"Well, .*?\n",
332            r"Hmm, .*?\n",
333            r"Actually, .*?\n",
334            r"Based on.*?I think",
335            r"Looking at.*?I would say",
336        ];
337
338        for pattern in reasoning_patterns {
339            if let Ok(re) = Regex::new(pattern) {
340                filtered = re.replace_all(&filtered, "").to_string();
341            }
342        }
343        filtered = filtered
344            .replace("<think>", "")
345            .replace("</think>", "")
346            .replace("<thinking>", "")
347            .replace("</thinking>", "");
348        filtered = filtered
349            .lines()
350            .map(|line| line.trim())
351            .filter(|line| !line.is_empty())
352            .collect::<Vec<_>>()
353            .join(" ");
354
355        filtered
356    }
357}