1use super::api_client::{ApiClient, AuthMethod};
2use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
3use super::errors::ProviderError;
4use super::retry::ProviderRetry;
5use super::utils::{
6 get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
7 RequestLog,
8};
9use crate::config::declarative_providers::DeclarativeProviderConfig;
10use crate::config::AsterMode;
11use crate::conversation::message::Message;
12use crate::conversation::Conversation;
13
14use crate::model::ModelConfig;
15use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
16use crate::utils::safe_truncate;
17use anyhow::Result;
18use async_trait::async_trait;
19use regex::Regex;
20use rmcp::model::Tool;
21use serde_json::Value;
22use std::time::Duration;
23use url::Url;
24
25pub const OLLAMA_HOST: &str = "localhost";
26pub const OLLAMA_TIMEOUT: u64 = 600;
27pub const OLLAMA_DEFAULT_PORT: u16 = 11434;
28pub const OLLAMA_DEFAULT_MODEL: &str = "qwen3";
29pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
30 OLLAMA_DEFAULT_MODEL,
31 "qwen3-coder:30b",
32 "qwen3-coder:480b-cloud",
33];
34pub const OLLAMA_DOC_URL: &str = "https://ollama.com/library";
35
36#[derive(serde::Serialize)]
37pub struct OllamaProvider {
38 #[serde(skip)]
39 api_client: ApiClient,
40 model: ModelConfig,
41 supports_streaming: bool,
42 name: String,
43}
44
45impl OllamaProvider {
46 pub async fn from_env(model: ModelConfig) -> Result<Self> {
47 let config = crate::config::Config::global();
48 let host: String = config
49 .get_param("OLLAMA_HOST")
50 .unwrap_or_else(|_| OLLAMA_HOST.to_string());
51
52 let timeout: Duration =
53 Duration::from_secs(config.get_param("OLLAMA_TIMEOUT").unwrap_or(OLLAMA_TIMEOUT));
54
55 let base = if host.starts_with("http://") || host.starts_with("https://") {
56 host.clone()
57 } else {
58 format!("http://{}", host)
59 };
60
61 let mut base_url =
62 Url::parse(&base).map_err(|e| anyhow::anyhow!("Invalid base URL: {e}"))?;
63
64 let explicit_port = host.contains(':');
65 let is_localhost = host == "localhost" || host == "127.0.0.1" || host == "::1";
66
67 if base_url.port().is_none() && !explicit_port && !host.starts_with("http") && is_localhost
68 {
69 base_url
70 .set_port(Some(OLLAMA_DEFAULT_PORT))
71 .map_err(|_| anyhow::anyhow!("Failed to set default port"))?;
72 }
73
74 let auth = AuthMethod::Custom(Box::new(NoAuth));
75 let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?;
76
77 Ok(Self {
78 api_client,
79 model,
80 supports_streaming: true,
81 name: Self::metadata().name,
82 })
83 }
84
85 pub fn from_custom_config(
86 model: ModelConfig,
87 config: DeclarativeProviderConfig,
88 ) -> Result<Self> {
89 let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT));
90
91 let base =
92 if config.base_url.starts_with("http://") || config.base_url.starts_with("https://") {
93 config.base_url.clone()
94 } else {
95 format!("http://{}", config.base_url)
96 };
97
98 let mut base_url = Url::parse(&base)
99 .map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?;
100
101 let explicit_default_port =
102 config.base_url.ends_with(":80") || config.base_url.ends_with(":443");
103 let is_https = base_url.scheme() == "https";
104
105 if base_url.port().is_none() && !explicit_default_port && !is_https {
106 base_url
107 .set_port(Some(OLLAMA_DEFAULT_PORT))
108 .map_err(|_| anyhow::anyhow!("Failed to set default port"))?;
109 }
110
111 let auth = AuthMethod::Custom(Box::new(NoAuth));
112 let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?;
113
114 Ok(Self {
115 api_client,
116 model,
117 supports_streaming: config.supports_streaming.unwrap_or(true),
118 name: config.name.clone(),
119 })
120 }
121
122 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
123 let response = self
124 .api_client
125 .response_post("v1/chat/completions", payload)
126 .await?;
127 handle_response_openai_compat(response).await
128 }
129}
130
131struct NoAuth;
132
133#[async_trait]
134impl super::api_client::AuthProvider for NoAuth {
135 async fn get_auth_header(&self) -> Result<(String, String)> {
136 Ok(("X-No-Auth".to_string(), "true".to_string()))
137 }
138}
139
140#[async_trait]
141impl Provider for OllamaProvider {
142 fn metadata() -> ProviderMetadata {
143 ProviderMetadata::new(
144 "ollama",
145 "Ollama",
146 "Local open source models",
147 OLLAMA_DEFAULT_MODEL,
148 OLLAMA_KNOWN_MODELS.to_vec(),
149 OLLAMA_DOC_URL,
150 vec![
151 ConfigKey::new("OLLAMA_HOST", true, false, Some(OLLAMA_HOST)),
152 ConfigKey::new(
153 "OLLAMA_TIMEOUT",
154 false,
155 false,
156 Some(&(OLLAMA_TIMEOUT.to_string())),
157 ),
158 ],
159 )
160 }
161
162 fn get_name(&self) -> &str {
163 &self.name
164 }
165
166 fn get_model_config(&self) -> ModelConfig {
167 self.model.clone()
168 }
169
170 #[tracing::instrument(
171 skip(self, model_config, system, messages, tools),
172 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
173 )]
174 async fn complete_with_model(
175 &self,
176 model_config: &ModelConfig,
177 system: &str,
178 messages: &[Message],
179 tools: &[Tool],
180 ) -> Result<(Message, ProviderUsage), ProviderError> {
181 let config = crate::config::Config::global();
182 let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
183 let filtered_tools = if aster_mode == AsterMode::Chat {
184 &[]
185 } else {
186 tools
187 };
188
189 let payload = create_request(
190 model_config,
191 system,
192 messages,
193 filtered_tools,
194 &super::utils::ImageFormat::OpenAi,
195 false,
196 )?;
197
198 let mut log = RequestLog::start(model_config, &payload)?;
199 let response = self
200 .with_retry(|| async {
201 let payload_clone = payload.clone();
202 self.post(&payload_clone).await
203 })
204 .await
205 .inspect_err(|e| {
206 let _ = log.error(e);
207 })?;
208
209 let message = response_to_message(&response)?;
210
211 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
212 tracing::debug!("Failed to get usage data");
213 Usage::default()
214 });
215 let response_model = get_model(&response);
216 log.write(&response, Some(&usage))?;
217 Ok((message, ProviderUsage::new(response_model, usage)))
218 }
219
220 async fn generate_session_name(
221 &self,
222 messages: &Conversation,
223 ) -> Result<String, ProviderError> {
224 let context = self.get_initial_user_messages(messages);
225 let message = Message::user().with_text(self.create_session_name_prompt(&context));
226 let result = self
227 .complete(
228 "You are a title generator. Output only the requested title of 4 words or less, with no additional text, reasoning, or explanations.",
229 &[message],
230 &[],
231 )
232 .await?;
233
234 let mut description = result.0.as_concat_text();
235 description = Self::filter_reasoning_tokens(&description);
236
237 Ok(safe_truncate(&description, 100))
238 }
239
240 fn supports_streaming(&self) -> bool {
241 self.supports_streaming
242 }
243
244 async fn stream(
245 &self,
246 system: &str,
247 messages: &[Message],
248 tools: &[Tool],
249 ) -> Result<MessageStream, ProviderError> {
250 let config = crate::config::Config::global();
251 let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
252 let filtered_tools = if aster_mode == AsterMode::Chat {
253 &[]
254 } else {
255 tools
256 };
257
258 let payload = create_request(
259 &self.model,
260 system,
261 messages,
262 filtered_tools,
263 &super::utils::ImageFormat::OpenAi,
264 true,
265 )?;
266 let mut log = RequestLog::start(&self.model, &payload)?;
267
268 let response = self
269 .with_retry(|| async {
270 let resp = self
271 .api_client
272 .response_post("v1/chat/completions", &payload)
273 .await?;
274 handle_status_openai_compat(resp).await
275 })
276 .await
277 .inspect_err(|e| {
278 let _ = log.error(e);
279 })?;
280 stream_openai_compat(response, log)
281 }
282
283 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
284 let response = self
285 .api_client
286 .response_get("api/tags")
287 .await
288 .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?;
289
290 if !response.status().is_success() {
291 return Err(ProviderError::RequestFailed(format!(
292 "Failed to fetch models: HTTP {}",
293 response.status()
294 )));
295 }
296
297 let json_response = response.json::<Value>().await.map_err(|e| {
298 ProviderError::RequestFailed(format!("Failed to parse response: {}", e))
299 })?;
300
301 let models = json_response
302 .get("models")
303 .and_then(|m| m.as_array())
304 .ok_or_else(|| {
305 ProviderError::RequestFailed("No models array in response".to_string())
306 })?;
307
308 let mut model_names: Vec<String> = models
309 .iter()
310 .filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
311 .collect();
312
313 model_names.sort();
314
315 Ok(Some(model_names))
316 }
317}
318
319impl OllamaProvider {
320 fn filter_reasoning_tokens(text: &str) -> String {
321 let mut filtered = text.to_string();
322
323 let reasoning_patterns = [
324 r"<think>.*?</think>",
325 r"<thinking>.*?</thinking>",
326 r"Let me think.*?\n",
327 r"I need to.*?\n",
328 r"First, I.*?\n",
329 r"Okay, .*?\n",
330 r"So, .*?\n",
331 r"Well, .*?\n",
332 r"Hmm, .*?\n",
333 r"Actually, .*?\n",
334 r"Based on.*?I think",
335 r"Looking at.*?I would say",
336 ];
337
338 for pattern in reasoning_patterns {
339 if let Ok(re) = Regex::new(pattern) {
340 filtered = re.replace_all(&filtered, "").to_string();
341 }
342 }
343 filtered = filtered
344 .replace("<think>", "")
345 .replace("</think>", "")
346 .replace("<thinking>", "")
347 .replace("</thinking>", "");
348 filtered = filtered
349 .lines()
350 .map(|line| line.trim())
351 .filter(|line| !line.is_empty())
352 .collect::<Vec<_>>()
353 .join(" ");
354
355 filtered
356 }
357}