mermaid_cli/providers/model/
ollama.rs1use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::domain::ChatRequest;
15use crate::models::adapters::ollama::OllamaAdapter;
16use crate::models::{
17 BackendConfig, Model, ModelConfig, ModelError, ReasoningChunk, Result, StreamCallback,
18 StreamEvent as ModelStreamEvent,
19};
20
21use super::super::capabilities::Capabilities;
22use super::super::ctx::{FinalResponse, StreamContext, StreamEvent};
23use super::ModelProvider;
24
25pub struct OllamaProvider {
27 adapter: OllamaAdapter,
28 capabilities: Capabilities,
29 config: Arc<crate::app::Config>,
34}
35
36impl OllamaProvider {
37 pub async fn new(model_name: &str, backend: Arc<BackendConfig>) -> Result<Self> {
41 Self::with_app_config(model_name, backend, Arc::new(crate::app::Config::default())).await
42 }
43
44 pub async fn with_app_config(
49 model_name: &str,
50 backend: Arc<BackendConfig>,
51 config: Arc<crate::app::Config>,
52 ) -> Result<Self> {
53 let adapter = OllamaAdapter::new(model_name, backend).await?;
54 let capabilities = Capabilities::from_legacy(adapter.capabilities());
55 Ok(Self {
56 adapter,
57 capabilities,
58 config,
59 })
60 }
61}
62
63#[async_trait]
64impl ModelProvider for OllamaProvider {
65 fn capabilities(&self) -> &Capabilities {
66 &self.capabilities
67 }
68
69 async fn chat(&self, request: ChatRequest, ctx: StreamContext) -> Result<FinalResponse> {
70 let config = build_model_config(&request, &self.config);
71 let relay_tx = super::stream_bridge::ordered_relay(ctx.sink.clone());
76 let callback = stream_callback_for(relay_tx);
77
78 let chat_fut = self
86 .adapter
87 .chat(&request.messages, &config, Some(callback));
88
89 let response = tokio::select! {
90 biased;
91 _ = ctx.token.cancelled() => {
92 return Err(ModelError::Cancelled);
98 },
99 r = chat_fut => r?,
100 };
101
102 let usage = response.usage.clone();
107 let thinking_signature = response.thinking_signature.clone();
108 let _ = ctx
109 .sink
110 .send(StreamEvent::Done {
111 usage: usage.clone(),
112 thinking_signature: thinking_signature.clone(),
113 })
114 .await;
115
116 Ok(FinalResponse {
117 usage,
118 thinking_signature,
119 tool_calls: response.tool_calls.unwrap_or_default(),
120 })
121 }
122}
123
124fn build_model_config(request: &ChatRequest, app_config: &crate::app::Config) -> ModelConfig {
127 let mut mc = ModelConfig {
128 model: request.model_id.clone(),
129 temperature: request.temperature,
130 max_tokens: request.max_tokens,
131 reasoning: request.reasoning,
132 system_prompt: Some(request.system_prompt.clone()),
133 dynamic_system_suffix: request.instructions.clone(),
134 tools: request.tools.iter().map(|t| t.to_openai_json()).collect(),
135 ..Default::default()
136 };
137 if let Some(v) = app_config.ollama.num_gpu {
141 mc.set_backend_option("ollama".into(), "num_gpu".into(), v.to_string());
142 }
143 if let Some(v) = app_config.ollama.num_ctx {
144 mc.set_backend_option("ollama".into(), "num_ctx".into(), v.to_string());
145 }
146 if let Some(v) = app_config.ollama.num_thread {
147 mc.set_backend_option("ollama".into(), "num_thread".into(), v.to_string());
148 }
149 if let Some(v) = app_config.ollama.numa {
150 mc.set_backend_option("ollama".into(), "numa".into(), v.to_string());
151 }
152 mc
153}
154
155fn stream_callback_for(sink: tokio::sync::mpsc::UnboundedSender<StreamEvent>) -> StreamCallback {
161 Arc::new(move |event: ModelStreamEvent| {
162 let mapped = match event {
163 ModelStreamEvent::Text(s) => StreamEvent::Text(s),
164 ModelStreamEvent::Reasoning(chunk) => StreamEvent::Reasoning(ReasoningChunk {
165 text: chunk.text,
166 signature: chunk.signature,
167 }),
168 ModelStreamEvent::ToolCall(tc) => StreamEvent::ToolCall(tc),
169 ModelStreamEvent::Done { tokens } => StreamEvent::Done {
170 usage: if tokens > 0 {
171 Some(crate::models::TokenUsage::provider(0, tokens, tokens))
172 } else {
173 None
174 },
175 thinking_signature: None,
176 },
177 };
178 let _ = sink.send(mapped);
181 })
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn build_model_config_maps_request_fields() {
190 let req = ChatRequest {
191 model_id: "ollama/test".to_string(),
192 messages: vec![],
193 system_prompt: "sys".to_string(),
194 instructions: Some("instructions text".to_string()),
195 reasoning: crate::models::ReasoningLevel::High,
196 temperature: 0.3,
197 max_tokens: 2048,
198 tools: vec![],
199 };
200 let app_cfg = crate::app::Config::default();
201 let cfg = build_model_config(&req, &app_cfg);
202 assert_eq!(cfg.model, "ollama/test");
203 assert_eq!(cfg.temperature, 0.3);
204 assert_eq!(cfg.max_tokens, 2048);
205 assert_eq!(cfg.reasoning, crate::models::ReasoningLevel::High);
206 assert_eq!(cfg.system_prompt.as_deref(), Some("sys"));
207 assert_eq!(
208 cfg.dynamic_system_suffix.as_deref(),
209 Some("instructions text")
210 );
211 }
212
213 #[test]
219 fn build_model_config_forwards_ollama_hardware_options() {
220 let req = ChatRequest {
221 model_id: "ollama/test".to_string(),
222 messages: vec![],
223 system_prompt: "sys".to_string(),
224 instructions: None,
225 reasoning: crate::models::ReasoningLevel::Medium,
226 temperature: 0.7,
227 max_tokens: 4096,
228 tools: vec![],
229 };
230 let mut app_cfg = crate::app::Config::default();
231 app_cfg.ollama.num_ctx = Some(8192);
232 app_cfg.ollama.num_gpu = Some(10);
233 app_cfg.ollama.num_thread = Some(8);
234 app_cfg.ollama.numa = Some(true);
235
236 let cfg = build_model_config(&req, &app_cfg);
237 let opts = cfg.ollama_options();
238 assert_eq!(opts.num_ctx, Some(8192));
239 assert_eq!(opts.num_gpu, Some(10));
240 assert_eq!(opts.num_thread, Some(8));
241 assert_eq!(opts.numa, Some(true));
242 }
243
244 #[tokio::test]
245 async fn stream_callback_forwards_text_event() {
246 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
247 let cb = stream_callback_for(tx);
248 cb(ModelStreamEvent::Text("hello".to_string()));
249 let recv = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
250 .await
251 .expect("recv")
252 .expect("sender alive");
253 match recv {
254 StreamEvent::Text(s) => assert_eq!(s, "hello"),
255 _ => panic!("wrong variant"),
256 }
257 }
258
259 #[tokio::test]
260 async fn stream_callback_forwards_done_with_tokens() {
261 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
262 let cb = stream_callback_for(tx);
263 cb(ModelStreamEvent::Done { tokens: 42 });
264 let recv = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
265 .await
266 .expect("recv")
267 .expect("sender");
268 match recv {
269 StreamEvent::Done { usage, .. } => {
270 let u = usage.expect("tokens > 0 → Some");
271 assert_eq!(u.total_tokens, 42);
272 },
273 _ => panic!("wrong variant"),
274 }
275 }
276
277 #[tokio::test]
278 async fn stream_callback_done_zero_tokens_is_none_usage() {
279 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
280 let cb = stream_callback_for(tx);
281 cb(ModelStreamEvent::Done { tokens: 0 });
282 let recv = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
283 .await
284 .expect("recv")
285 .expect("sender");
286 match recv {
287 StreamEvent::Done { usage, .. } => assert!(usage.is_none()),
288 _ => panic!("wrong variant"),
289 }
290 }
291}