mini_langchain/llm/
ollama.rs

1
2use std::sync::Arc;
3use serde_json::error::Error as SerdeJsonError;
4use serde::de::Error as SerdeDeError;
5use async_stream::stream as async_stream;
6use futures::{
7    FutureExt,
8    future::BoxFuture,
9    stream::BoxStream
10};
11
12
13use crate::message::Message;
14use crate::tools::stream::StreamData;
15use crate::message::MessageRole as MsgRole;
16
17use crate::llm::{
18    traits::LLM,
19    tokens::TokenUsage,
20    error::LLMError,
21    GenerateResult,
22    LLMResult,
23};
24
25/// Default model name used when no model is specified.
26/// Adjust this to match the model name you have installed in your local Ollama.
27/// Common names: "llama3.2", "llama3", "llama2", or custom names from `ollama list`.
28pub const DEFAULT_MODEL: &str = "llama3.2";
29
30pub use ollama_rs::{
31    error::OllamaError,
32    Ollama as OllamaClient,
33    models::ModelOptions,
34    generation::{
35        chat::{request::ChatMessageRequest,ChatMessage, MessageRole},
36        completion::request::GenerationRequest,
37    }
38};
39
40
41#[derive(Debug, Clone)]
42pub struct Ollama {
43    pub(crate) client: Arc<OllamaClient>,
44    pub(crate) model: String,
45    pub(crate) options: Option<ModelOptions>,
46}
47impl Ollama {
48    /// Create an `Ollama` wrapper using the provided client and the default model.
49    ///
50    /// If your local Ollama uses a different default model name, change
51    /// `DEFAULT_MODEL` or call `Ollama::with_model`.
52    pub fn new(client: Arc<OllamaClient>) -> Self {
53        Self {
54            client,
55            model: DEFAULT_MODEL.to_string(),
56            options: None,
57        }
58    }
59
60    /// Create an `Ollama` wrapper with an explicit model name.
61    pub fn with_model(mut self, model: impl Into<String>) -> Self {
62        self.model = model.into();
63        self
64    }
65
66    /// Create an `Ollama` wrapper with additional generation options.
67    pub fn with_options(mut self, options: ModelOptions) -> Self {
68        self.options = Some(options);
69        self
70    }
71
72    fn generate_request(&self, messages: &[Message]) -> ChatMessageRequest {
73        let mapped_messages = messages.iter().map(|message| message.into()).collect();
74        ChatMessageRequest::new(self.model.clone(), mapped_messages).think(true)
75    }
76
77
78}
79
80impl Default for Ollama {
81    fn default() -> Self {
82        let client = Arc::new(OllamaClient::default());
83        Ollama::new(client)
84    }
85}
86
87
88
89
90impl From<&Message> for ChatMessage {
91    fn from(message: &Message) -> Self {
92        let role = match message.role {
93            MsgRole::System => MessageRole::System,
94            MsgRole::User => MessageRole::User,
95            MsgRole::Assistant => MessageRole::Assistant,
96            MsgRole::ToolResponce => MessageRole::Tool,
97            MsgRole::Tool | MsgRole::Developer => MessageRole::System,
98
99        };
100        ChatMessage::new(role, message.content.clone())
101    }
102}
103
104
105impl LLM for Ollama {
106    fn generate<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, LLMResult<GenerateResult>> {
107        async move {
108            // build request (this clones/moves as generate_request does)
109            let request = self.generate_request(messages);
110
111            // perform async call and map errors into our LLMError
112            let response = self
113                .client
114                .send_chat_messages(request)
115                .await
116                .map_err(|e| LLMError::InvalidResponse(format!("{:?}", e)))?;
117            let generation = response.message.content.clone();
118
119            let tokens = if let Some(final_data) = response.final_data {
120                let prompt_tokens = final_data.prompt_eval_count as u32;
121                let completion_tokens = final_data.eval_count as u32;
122
123                TokenUsage {
124                    prompt_tokens,
125                    completion_tokens,
126                    total_tokens: prompt_tokens + completion_tokens,
127                }
128            } else {
129                TokenUsage::default()
130            };
131            // Robustly extract tool_calls: [{name, args}] from generation text
132            let mut call_tools: Vec<crate::llm::CallInfo> = Vec::new();
133            let parsed_json_res = serde_json::from_str::<serde_json::Value>(&generation)
134                .or_else(|_err| {
135                    if let (Some(start), Some(end)) = (generation.find('{'), generation.rfind('}')) {
136                        let sub = &generation[start..=end];
137                        serde_json::from_str::<serde_json::Value>(sub)
138                    } else {
139                        Err(SerdeJsonError::custom("no json substring"))
140                    }
141                });
142            if let Ok(parsed) = parsed_json_res {
143                if let Some(arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
144                    for entry in arr.iter() {
145                        if let Some(obj) = entry.as_object() {
146                            if let Some(name_val) = obj.get("name").and_then(|v| v.as_str()) {
147                                let name = name_val.to_string();
148                                let args = obj.get("args").cloned().unwrap_or_else(|| serde_json::json!({}));
149                                call_tools.push(crate::llm::CallInfo { name, args });
150                            }
151                        }
152                    }
153                }
154            }
155
156            Ok(GenerateResult { tokens, generation, call_tools })
157        }
158        .boxed()
159    }
160
161    fn stream<'a>(&'a self, messages: &'a [Message]) -> BoxStream<'a, LLMResult<StreamData>> {
162        // Keep borrowed references `self` and `messages` in scope for the async generator.
163        let this = self;
164        let msgs = messages;
165
166        let s = async_stream! {
167            // Prefer upstream streaming if feature enabled
168            #[cfg(feature = "ollama_stream")]
169            {
170                let request = this.generate_request(msgs);
171                // get upstream stream (awaitable)
172                let upstream = match this.client.send_chat_messages_stream(request).await {
173                    Ok(s) => s,
174                    Err(e) => {
175                        yield Err(LLMError::InvalidResponse(format!("{:?}", e)));
176                        return;
177                    }
178                };
179
180                futures::pin_mut!(upstream);
181                while let Some(item_res) = upstream.next().await {
182                    match item_res {
183                        Ok(item) => {
184                            let value = serde_json::to_value(&item).unwrap_or_default();
185                            let content = item.message.content.clone();
186                            let tokens = item.final_data.map(|final_data| crate::llm::tokens::TokenUsage {
187                                prompt_tokens: final_data.prompt_eval_count as u32,
188                                completion_tokens: final_data.eval_count as u32,
189                                total_tokens: final_data.prompt_eval_count as u32 + final_data.eval_count as u32,
190                            });
191                            yield Ok(StreamData::new(value, tokens, content));
192                        }
193                        Err(e) => {
194                            // map upstream error to LLMError
195                            yield Err(LLMError::InvalidResponse(format!("{:?}", e)));
196                        }
197                    }
198                }
199            }
200
201            // Fallback: call non-streaming endpoint and yield single item
202            #[cfg(not(feature = "ollama_stream"))]
203            {
204                let request = this.generate_request(msgs);
205                match this.client.send_chat_messages(request).await {
206                    Ok(response) => {
207                        let content = response.message.content.clone();
208                        let value = serde_json::to_value(response.message).unwrap_or_default();
209
210                        let tokens = response.final_data.map(|final_data| {
211                            let prompt_tokens = final_data.prompt_eval_count as u32;
212                            let completion_tokens = final_data.eval_count as u32;
213                            TokenUsage {
214                                prompt_tokens,
215                                completion_tokens,
216                                total_tokens: prompt_tokens + completion_tokens,
217                            }
218                        });
219
220                        let sd = StreamData::new(value, tokens, content);
221                        yield Ok(sd);
222                    }
223                    Err(e) => {
224                        yield Err(LLMError::InvalidResponse(format!("{:?}", e)));
225                    }
226                }
227            }
228        };
229
230        Box::pin(s)
231    }
232}