mini_langchain/llm/
ollama.rs1
2use std::sync::Arc;
3use serde_json::error::Error as SerdeJsonError;
4use serde::de::Error as SerdeDeError;
5use async_stream::stream as async_stream;
6use futures::{
7 FutureExt,
8 future::BoxFuture,
9 stream::BoxStream
10};
11
12
13use crate::message::Message;
14use crate::tools::stream::StreamData;
15use crate::message::MessageRole as MsgRole;
16
17use crate::llm::{
18 traits::LLM,
19 tokens::TokenUsage,
20 error::LLMError,
21 GenerateResult,
22 LLMResult,
23};
24
25pub const DEFAULT_MODEL: &str = "llama3.2";
29
30pub use ollama_rs::{
31 error::OllamaError,
32 Ollama as OllamaClient,
33 models::ModelOptions,
34 generation::{
35 chat::{request::ChatMessageRequest,ChatMessage, MessageRole},
36 completion::request::GenerationRequest,
37 }
38};
39
40
41#[derive(Debug, Clone)]
42pub struct Ollama {
43 pub(crate) client: Arc<OllamaClient>,
44 pub(crate) model: String,
45 pub(crate) options: Option<ModelOptions>,
46}
47impl Ollama {
48 pub fn new(client: Arc<OllamaClient>) -> Self {
53 Self {
54 client,
55 model: DEFAULT_MODEL.to_string(),
56 options: None,
57 }
58 }
59
60 pub fn with_model(mut self, model: impl Into<String>) -> Self {
62 self.model = model.into();
63 self
64 }
65
66 pub fn with_options(mut self, options: ModelOptions) -> Self {
68 self.options = Some(options);
69 self
70 }
71
72 fn generate_request(&self, messages: &[Message]) -> ChatMessageRequest {
73 let mapped_messages = messages.iter().map(|message| message.into()).collect();
74 ChatMessageRequest::new(self.model.clone(), mapped_messages).think(true)
75 }
76
77
78}
79
80impl Default for Ollama {
81 fn default() -> Self {
82 let client = Arc::new(OllamaClient::default());
83 Ollama::new(client)
84 }
85}
86
87
88
89
90impl From<&Message> for ChatMessage {
91 fn from(message: &Message) -> Self {
92 let role = match message.role {
93 MsgRole::System => MessageRole::System,
94 MsgRole::User => MessageRole::User,
95 MsgRole::Assistant => MessageRole::Assistant,
96 MsgRole::ToolResponce => MessageRole::Tool,
97 MsgRole::Tool | MsgRole::Developer => MessageRole::System,
98
99 };
100 ChatMessage::new(role, message.content.clone())
101 }
102}
103
104
105impl LLM for Ollama {
106 fn generate<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, LLMResult<GenerateResult>> {
107 async move {
108 let request = self.generate_request(messages);
110
111 let response = self
113 .client
114 .send_chat_messages(request)
115 .await
116 .map_err(|e| LLMError::InvalidResponse(format!("{:?}", e)))?;
117 let generation = response.message.content.clone();
118
119 let tokens = if let Some(final_data) = response.final_data {
120 let prompt_tokens = final_data.prompt_eval_count as u32;
121 let completion_tokens = final_data.eval_count as u32;
122
123 TokenUsage {
124 prompt_tokens,
125 completion_tokens,
126 total_tokens: prompt_tokens + completion_tokens,
127 }
128 } else {
129 TokenUsage::default()
130 };
131 let mut call_tools: Vec<crate::llm::CallInfo> = Vec::new();
133 let parsed_json_res = serde_json::from_str::<serde_json::Value>(&generation)
134 .or_else(|_err| {
135 if let (Some(start), Some(end)) = (generation.find('{'), generation.rfind('}')) {
136 let sub = &generation[start..=end];
137 serde_json::from_str::<serde_json::Value>(sub)
138 } else {
139 Err(SerdeJsonError::custom("no json substring"))
140 }
141 });
142 if let Ok(parsed) = parsed_json_res {
143 if let Some(arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
144 for entry in arr.iter() {
145 if let Some(obj) = entry.as_object() {
146 if let Some(name_val) = obj.get("name").and_then(|v| v.as_str()) {
147 let name = name_val.to_string();
148 let args = obj.get("args").cloned().unwrap_or_else(|| serde_json::json!({}));
149 call_tools.push(crate::llm::CallInfo { name, args });
150 }
151 }
152 }
153 }
154 }
155
156 Ok(GenerateResult { tokens, generation, call_tools })
157 }
158 .boxed()
159 }
160
161 fn stream<'a>(&'a self, messages: &'a [Message]) -> BoxStream<'a, LLMResult<StreamData>> {
162 let this = self;
164 let msgs = messages;
165
166 let s = async_stream! {
167 #[cfg(feature = "ollama_stream")]
169 {
170 let request = this.generate_request(msgs);
171 let upstream = match this.client.send_chat_messages_stream(request).await {
173 Ok(s) => s,
174 Err(e) => {
175 yield Err(LLMError::InvalidResponse(format!("{:?}", e)));
176 return;
177 }
178 };
179
180 futures::pin_mut!(upstream);
181 while let Some(item_res) = upstream.next().await {
182 match item_res {
183 Ok(item) => {
184 let value = serde_json::to_value(&item).unwrap_or_default();
185 let content = item.message.content.clone();
186 let tokens = item.final_data.map(|final_data| crate::llm::tokens::TokenUsage {
187 prompt_tokens: final_data.prompt_eval_count as u32,
188 completion_tokens: final_data.eval_count as u32,
189 total_tokens: final_data.prompt_eval_count as u32 + final_data.eval_count as u32,
190 });
191 yield Ok(StreamData::new(value, tokens, content));
192 }
193 Err(e) => {
194 yield Err(LLMError::InvalidResponse(format!("{:?}", e)));
196 }
197 }
198 }
199 }
200
201 #[cfg(not(feature = "ollama_stream"))]
203 {
204 let request = this.generate_request(msgs);
205 match this.client.send_chat_messages(request).await {
206 Ok(response) => {
207 let content = response.message.content.clone();
208 let value = serde_json::to_value(response.message).unwrap_or_default();
209
210 let tokens = response.final_data.map(|final_data| {
211 let prompt_tokens = final_data.prompt_eval_count as u32;
212 let completion_tokens = final_data.eval_count as u32;
213 TokenUsage {
214 prompt_tokens,
215 completion_tokens,
216 total_tokens: prompt_tokens + completion_tokens,
217 }
218 });
219
220 let sd = StreamData::new(value, tokens, content);
221 yield Ok(sd);
222 }
223 Err(e) => {
224 yield Err(LLMError::InvalidResponse(format!("{:?}", e)));
225 }
226 }
227 }
228 };
229
230 Box::pin(s)
231 }
232}