llm_connector/providers/
ollama.rs

1//! Ollama服务提供商实现
2//!
3//! Ollama是一个本地LLM服务,具有特殊的模型管理功能,因此需要自定义Provider实现。
4
5use crate::core::{HttpClient, Provider};
6use crate::error::LlmConnectorError;
7use crate::types::{ChatRequest, ChatResponse, Choice, Message, Role};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::any::Any;
11
12#[cfg(feature = "streaming")]
13use crate::types::ChatStream;
14
15/// Ollama服务提供商
16///
17/// 由于Ollama具有特殊的模型管理功能,我们使用自定义Provider实现
18/// 而不是GenericProvider模式。
19#[derive(Clone, Debug)]
20pub struct OllamaProvider {
21    client: HttpClient,
22    base_url: String,
23}
24
25impl OllamaProvider {
26    /// 创建新的Ollama提供商
27    ///
28    /// # 参数
29    /// - `base_url`: Ollama服务的URL (默认: http://localhost:11434)
30    ///
31    /// # 示例
32    /// ```rust,no_run
33    /// use llm_connector::providers::OllamaProvider;
34    ///
35    /// let provider = OllamaProvider::new("http://localhost:11434").unwrap();
36    /// ```
37    pub fn new(base_url: &str) -> Result<Self, LlmConnectorError> {
38        // Content-Type 由 HttpClient::post() 的 .json() 方法自动设置
39        let client = HttpClient::new(base_url)?;
40
41        Ok(Self {
42            client,
43            base_url: base_url.to_string(),
44        })
45    }
46
47    /// 创建带有自定义配置的Ollama提供商
48    pub fn with_config(
49        base_url: &str,
50        timeout_secs: Option<u64>,
51        proxy: Option<&str>,
52    ) -> Result<Self, LlmConnectorError> {
53        // Content-Type 由 HttpClient::post() 的 .json() 方法自动设置
54        let client = HttpClient::with_config(base_url, timeout_secs, proxy)?;
55
56        Ok(Self {
57            client,
58            base_url: base_url.to_string(),
59        })
60    }
61
62    /// 拉取模型
63    ///
64    /// # 参数
65    /// - `model_name`: 要拉取的模型名称 (如 "llama2", "codellama")
66    ///
67    /// # 示例
68    /// ```rust,no_run
69    /// # use llm_connector::providers::OllamaProvider;
70    /// # #[tokio::main]
71    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
72    /// let provider = OllamaProvider::new("http://localhost:11434")?;
73    /// provider.pull_model("llama2").await?;
74    /// # Ok(())
75    /// # }
76    /// ```
77    pub async fn pull_model(&self, model_name: &str) -> Result<(), LlmConnectorError> {
78        let request = OllamaPullRequest {
79            name: model_name.to_string(),
80            stream: Some(false),
81        };
82
83        let url = format!("{}/api/pull", self.base_url);
84        let response = self.client.post(&url, &request).await?;
85
86        if !response.status().is_success() {
87            let text = response
88                .text()
89                .await
90                .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
91            return Err(LlmConnectorError::ApiError(format!(
92                "Failed to pull model: {}",
93                text
94            )));
95        }
96
97        Ok(())
98    }
99
100    /// 删除模型
101    ///
102    /// # 参数
103    /// - `model_name`: 要删除的模型名称
104    ///
105    /// # 示例
106    /// ```rust,no_run
107    /// # use llm_connector::providers::OllamaProvider;
108    /// # #[tokio::main]
109    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
110    /// let provider = OllamaProvider::new("http://localhost:11434")?;
111    /// provider.delete_model("llama2").await?;
112    /// # Ok(())
113    /// # }
114    /// ```
115    pub async fn delete_model(&self, model_name: &str) -> Result<(), LlmConnectorError> {
116        let request = OllamaDeleteRequest {
117            name: model_name.to_string(),
118        };
119
120        let url = format!("{}/api/delete", self.base_url);
121        let response = self.client.post(&url, &request).await?;
122
123        if !response.status().is_success() {
124            let text = response
125                .text()
126                .await
127                .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
128            return Err(LlmConnectorError::ApiError(format!(
129                "Failed to delete model: {}",
130                text
131            )));
132        }
133
134        Ok(())
135    }
136
137    /// 获取模型信息
138    ///
139    /// # 参数
140    /// - `model_name`: 模型名称
141    ///
142    /// # 返回
143    /// 模型的详细信息
144    pub async fn show_model(&self, model_name: &str) -> Result<OllamaModelInfo, LlmConnectorError> {
145        let request = OllamaShowRequest {
146            name: model_name.to_string(),
147        };
148
149        let url = format!("{}/api/show", self.base_url);
150        let response = self.client.post(&url, &request).await?;
151        let status = response.status();
152        let text = response
153            .text()
154            .await
155            .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
156
157        if !status.is_success() {
158            return Err(LlmConnectorError::ApiError(format!(
159                "Failed to show model: {}",
160                text
161            )));
162        }
163
164        serde_json::from_str(&text).map_err(|e| {
165            LlmConnectorError::ParseError(format!("Failed to parse model info: {}", e))
166        })
167    }
168
169    /// 检查模型是否存在
170    pub async fn model_exists(&self, model_name: &str) -> Result<bool, LlmConnectorError> {
171        match self.show_model(model_name).await {
172            Ok(_) => Ok(true),
173            Err(LlmConnectorError::ApiError(_)) => Ok(false),
174            Err(e) => Err(e),
175        }
176    }
177}
178
179#[async_trait]
180impl Provider for OllamaProvider {
181    fn name(&self) -> &str {
182        "ollama"
183    }
184
185    async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
186        let url = format!("{}/api/tags", self.base_url);
187        let response = self.client.get(&url).await?;
188        let status = response.status();
189        let text = response
190            .text()
191            .await
192            .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
193
194        if !status.is_success() {
195            return Err(LlmConnectorError::ApiError(format!(
196                "Failed to get models: {}",
197                text
198            )));
199        }
200
201        let models_response: OllamaModelsResponse = serde_json::from_str(&text)
202            .map_err(|e| LlmConnectorError::ParseError(format!("Failed to parse models: {}", e)))?;
203
204        Ok(models_response.models.into_iter().map(|m| m.name).collect())
205    }
206
207    async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
208        let ollama_request = OllamaChatRequest {
209            model: request.model.clone(),
210            messages: request
211                .messages
212                .iter()
213                .map(|msg| OllamaMessage {
214                    role: match msg.role {
215                        Role::User => "user".to_string(),
216                        Role::Assistant => "assistant".to_string(),
217                        Role::System => "system".to_string(),
218                        Role::Tool => "user".to_string(), // Ollama不支持tool角色
219                    },
220                    // Ollama 使用纯文本格式
221                    content: msg.content_as_text(),
222                })
223                .collect(),
224            stream: Some(false),
225            options: Some(OllamaOptions {
226                temperature: request.temperature,
227                num_predict: request.max_tokens.map(|t| t as i32),
228                top_p: request.top_p,
229            }),
230        };
231
232        let url = format!("{}/api/chat", self.base_url);
233        let response = self.client.post(&url, &ollama_request).await?;
234        let status = response.status();
235        let text = response
236            .text()
237            .await
238            .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
239
240        if !status.is_success() {
241            return Err(LlmConnectorError::ApiError(format!(
242                "Ollama chat failed: {}",
243                text
244            )));
245        }
246
247        let ollama_response: OllamaChatResponse = serde_json::from_str(&text).map_err(|e| {
248            LlmConnectorError::ParseError(format!("Failed to parse Ollama response: {}", e))
249        })?;
250
251        let content = ollama_response.message.content.clone();
252
253        let choices = vec![Choice {
254            index: 0,
255            message: Message {
256                role: Role::Assistant,
257                content: vec![crate::types::MessageBlock::text(&content)],
258                name: None,
259                tool_calls: None,
260                tool_call_id: None,
261                reasoning_content: None,
262                reasoning: None,
263                thought: None,
264                thinking: None,
265            },
266            finish_reason: Some("stop".to_string()),
267            logprobs: None,
268        }];
269
270        Ok(ChatResponse {
271            id: "ollama-response".to_string(),
272            object: "chat.completion".to_string(),
273            created: std::time::SystemTime::now()
274                .duration_since(std::time::UNIX_EPOCH)
275                .unwrap_or_default()
276                .as_secs(),
277            model: ollama_response.model,
278            choices,
279            content,
280            reasoning_content: None,
281            usage: None, // Ollama不返回token使用信息
282            system_fingerprint: None,
283        })
284    }
285
286    #[cfg(feature = "streaming")]
287    async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
288        let ollama_request = OllamaChatRequest {
289            model: request.model.clone(),
290            messages: request
291                .messages
292                .iter()
293                .map(|msg| OllamaMessage {
294                    role: match msg.role {
295                        Role::User => "user".to_string(),
296                        Role::Assistant => "assistant".to_string(),
297                        Role::System => "system".to_string(),
298                        Role::Tool => "user".to_string(),
299                    },
300                    content: msg.content_as_text(),
301                })
302                .collect(),
303            stream: Some(true),
304            options: Some(OllamaOptions {
305                temperature: request.temperature,
306                num_predict: request.max_tokens.map(|t| t as i32),
307                top_p: request.top_p,
308            }),
309        };
310
311        let url = format!("{}/api/chat", self.base_url);
312        let response = self.client.stream(&url, &ollama_request).await?;
313        let status = response.status();
314
315        if !status.is_success() {
316            let text = response
317                .text()
318                .await
319                .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
320            return Err(LlmConnectorError::ApiError(format!(
321                "Ollama stream failed: {}",
322                text
323            )));
324        }
325
326        // Ollama使用JSONL格式而不是SSE
327        Ok(crate::sse::sse_to_streaming_response(response))
328    }
329
330    fn as_any(&self) -> &dyn Any {
331        self
332    }
333}
334
335// Ollama请求/响应类型
336#[derive(Serialize, Debug)]
337struct OllamaChatRequest {
338    model: String,
339    messages: Vec<OllamaMessage>,
340    stream: Option<bool>,
341    options: Option<OllamaOptions>,
342}
343
344#[derive(Serialize, Debug)]
345struct OllamaMessage {
346    role: String,
347    content: String,
348}
349
350#[derive(Serialize, Debug)]
351struct OllamaOptions {
352    #[serde(skip_serializing_if = "Option::is_none")]
353    temperature: Option<f32>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    num_predict: Option<i32>,
356    #[serde(skip_serializing_if = "Option::is_none")]
357    top_p: Option<f32>,
358}
359
360#[derive(Deserialize, Debug)]
361struct OllamaChatResponse {
362    model: String,
363    message: OllamaResponseMessage,
364    #[allow(dead_code)]
365    done: bool,
366}
367
368#[derive(Deserialize, Debug)]
369struct OllamaResponseMessage {
370    #[allow(dead_code)]
371    role: String,
372    content: String,
373}
374
375#[derive(Serialize, Debug)]
376struct OllamaPullRequest {
377    name: String,
378    stream: Option<bool>,
379}
380
381#[derive(Serialize, Debug)]
382struct OllamaDeleteRequest {
383    name: String,
384}
385
386#[derive(Serialize, Debug)]
387struct OllamaShowRequest {
388    name: String,
389}
390
391#[derive(Deserialize, Debug)]
392pub struct OllamaModelInfo {
393    pub modelfile: String,
394    pub parameters: String,
395    pub template: String,
396    pub details: OllamaModelDetails,
397}
398
399#[derive(Deserialize, Debug)]
400pub struct OllamaModelDetails {
401    pub format: String,
402    pub family: String,
403    pub families: Option<Vec<String>>,
404    pub parameter_size: String,
405    pub quantization_level: String,
406}
407
408#[derive(Deserialize, Debug)]
409struct OllamaModelsResponse {
410    models: Vec<OllamaModel>,
411}
412
413#[derive(Deserialize, Debug)]
414struct OllamaModel {
415    name: String,
416    #[allow(dead_code)]
417    modified_at: String,
418    #[allow(dead_code)]
419    size: u64,
420}
421
422/// 创建Ollama服务提供商 (默认本地地址)
423///
424/// # 示例
425/// ```rust,no_run
426/// use llm_connector::providers::ollama;
427///
428/// let provider = ollama().unwrap();
429/// ```
430pub fn ollama() -> Result<OllamaProvider, LlmConnectorError> {
431    OllamaProvider::new("http://localhost:11434")
432}
433
434/// 创建带有自定义URL的Ollama服务提供商
435///
436/// # 参数
437/// - `base_url`: Ollama服务的URL
438///
439/// # 示例
440/// ```rust,no_run
441/// use llm_connector::providers::ollama_with_base_url;
442///
443/// let provider = ollama_with_base_url("http://192.168.1.100:11434").unwrap();
444/// ```
445pub fn ollama_with_base_url(base_url: &str) -> Result<OllamaProvider, LlmConnectorError> {
446    OllamaProvider::new(base_url)
447}
448
449/// 创建带有自定义配置的Ollama服务提供商
450///
451/// # 参数
452/// - `base_url`: Ollama服务的URL
453/// - `timeout_secs`: 超时时间(秒)
454/// - `proxy`: 代理URL (可选)
455///
456/// # 示例
457/// ```rust,no_run
458/// use llm_connector::providers::ollama_with_config;
459///
460/// let provider = ollama_with_config(
461///     "http://localhost:11434",
462///     Some(120), // 2分钟超时
463///     None
464/// ).unwrap();
465/// ```
466pub fn ollama_with_config(
467    base_url: &str,
468    timeout_secs: Option<u64>,
469    proxy: Option<&str>,
470) -> Result<OllamaProvider, LlmConnectorError> {
471    OllamaProvider::with_config(base_url, timeout_secs, proxy)
472}