Skip to main content

mermaid_cli/providers/model/
ollama.rs

1//! Ollama provider wrapping `OllamaAdapter`.
2//!
3//! The adapter owns the wire format (NDJSON framing, gpt-oss
4//! reasoning dispatch, truncation marker, retry). The wrapper
5//! translates `ChatRequest` ↔ `ModelConfig` and bridges the
6//! adapter's legacy `StreamCallback` to the typed `StreamEvent`
7//! sink. Adapter-internals stay where they are; the architecture
8//! boundary is at `ModelProvider::chat`.
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::domain::ChatRequest;
15use crate::models::adapters::ollama::OllamaAdapter;
16use crate::models::{
17    BackendConfig, Model, ModelConfig, ModelError, ReasoningChunk, Result, StreamCallback,
18    StreamEvent as ModelStreamEvent,
19};
20
21use super::super::capabilities::Capabilities;
22use super::super::ctx::{FinalResponse, StreamContext, StreamEvent};
23use super::ModelProvider;
24
25/// Ollama adapter fronted by `ModelProvider`.
26pub struct OllamaProvider {
27    adapter: OllamaAdapter,
28    capabilities: Capabilities,
29    /// Shared app `Config` so `build_model_config` can read Ollama
30    /// hardware options (`num_ctx`, `num_gpu`, `num_thread`, `numa`) at
31    /// call time. Before F11 these were silently dropped because the
32    /// wrapper built `ModelConfig` only from `ChatRequest` fields.
33    config: Arc<crate::app::Config>,
34}
35
36impl OllamaProvider {
37    /// Backward-compatible constructor that uses a default app config.
38    /// Call `with_app_config` instead when you have one available so
39    /// Ollama hardware options actually reach the adapter.
40    pub async fn new(model_name: &str, backend: Arc<BackendConfig>) -> Result<Self> {
41        Self::with_app_config(model_name, backend, Arc::new(crate::app::Config::default())).await
42    }
43
44    /// Construct with an explicit `app::Config` reference. Used by
45    /// `ProviderFactory::build_provider` so `config.ollama.{num_gpu,
46    /// num_ctx, num_thread, numa}` make it into the Ollama request's
47    /// `options` block.
48    pub async fn with_app_config(
49        model_name: &str,
50        backend: Arc<BackendConfig>,
51        config: Arc<crate::app::Config>,
52    ) -> Result<Self> {
53        let adapter = OllamaAdapter::new(model_name, backend).await?;
54        let capabilities = Capabilities::from_legacy(adapter.capabilities());
55        Ok(Self {
56            adapter,
57            capabilities,
58            config,
59        })
60    }
61}
62
63#[async_trait]
64impl ModelProvider for OllamaProvider {
65    fn capabilities(&self) -> &Capabilities {
66        &self.capabilities
67    }
68
69    async fn chat(&self, request: ChatRequest, ctx: StreamContext) -> Result<FinalResponse> {
70        let config = build_model_config(&request, &self.config);
71        // Ordered relay (F2): the adapter's sync callback pushes into an
72        // `UnboundedSender` (synchronous, FIFO). A single relay task drains
73        // into the bounded sink in order, avoiding the per-event `tokio::
74        // spawn` race that could deliver `Done` before prior tool calls.
75        let relay_tx = super::stream_bridge::ordered_relay(ctx.sink.clone());
76        let callback = stream_callback_for(relay_tx);
77
78        // Race adapter.chat against the cancellation token. When
79        // cancelled, the adapter's stream loop observes the sink
80        // closing (we drop `callback`) and exits at its next await.
81        // This is the crucial structural win vs. the old
82        // `check_interrupt` polling: the adapter doesn't need to
83        // know anything about turn IDs — the sink either drains or
84        // doesn't, and the tokens handle everything else.
85        let chat_fut = self
86            .adapter
87            .chat(&request.messages, &config, Some(callback));
88
89        let response = tokio::select! {
90            biased;
91            _ = ctx.token.cancelled() => {
92                // Terminal event for a cancelled turn comes from the
93                // runner's `drop_scope` once the `TurnScope` drains, so
94                // we neither emit `StreamEvent::Done` here nor surface
95                // an `UpstreamError`. `ModelError::Cancelled` is the
96                // sentinel the runner swallows.
97                return Err(ModelError::Cancelled);
98            },
99            r = chat_fut => r?,
100        };
101
102        // F3: the wrapper's `Done` is now the sole terminal event —
103        // the adapter no longer emits one from the callback. Carrying
104        // `thinking_signature` out of `ModelResponse` here is what
105        // lets multi-turn extended thinking round-trip.
106        let usage = response.usage.clone();
107        let thinking_signature = response.thinking_signature.clone();
108        let _ = ctx
109            .sink
110            .send(StreamEvent::Done {
111                usage: usage.clone(),
112                thinking_signature: thinking_signature.clone(),
113            })
114            .await;
115
116        Ok(FinalResponse {
117            usage,
118            thinking_signature,
119            tool_calls: response.tool_calls.unwrap_or_default(),
120        })
121    }
122}
123
124// ─── helpers ────────────────────────────────────────────────────────
125
126fn build_model_config(request: &ChatRequest, app_config: &crate::app::Config) -> ModelConfig {
127    let mut mc = ModelConfig {
128        model: request.model_id.clone(),
129        temperature: request.temperature,
130        max_tokens: request.max_tokens,
131        reasoning: request.reasoning,
132        system_prompt: Some(request.system_prompt.clone()),
133        dynamic_system_suffix: request.instructions.clone(),
134        tools: request.tools.iter().map(|t| t.to_openai_json()).collect(),
135        ..Default::default()
136    };
137    // F11: forward Ollama hardware options from the user's app config.
138    // Previously these fields were configured + persisted but silently
139    // ignored because this wrapper built ModelConfig in isolation.
140    if let Some(v) = app_config.ollama.num_gpu {
141        mc.set_backend_option("ollama".into(), "num_gpu".into(), v.to_string());
142    }
143    if let Some(v) = app_config.ollama.num_ctx {
144        mc.set_backend_option("ollama".into(), "num_ctx".into(), v.to_string());
145    }
146    if let Some(v) = app_config.ollama.num_thread {
147        mc.set_backend_option("ollama".into(), "num_thread".into(), v.to_string());
148    }
149    if let Some(v) = app_config.ollama.numa {
150        mc.set_backend_option("ollama".into(), "numa".into(), v.to_string());
151    }
152    mc
153}
154
155/// Build a `StreamCallback` that forwards `ModelStreamEvent`s from the
156/// adapter into an `UnboundedSender<StreamEvent>` (ordered relay). The
157/// caller wires that sender to a bounded sink via
158/// `stream_bridge::ordered_relay`; this keeps event delivery FIFO even
159/// though the adapter calls us from a sync context.
160fn stream_callback_for(sink: tokio::sync::mpsc::UnboundedSender<StreamEvent>) -> StreamCallback {
161    Arc::new(move |event: ModelStreamEvent| {
162        let mapped = match event {
163            ModelStreamEvent::Text(s) => StreamEvent::Text(s),
164            ModelStreamEvent::Reasoning(chunk) => StreamEvent::Reasoning(ReasoningChunk {
165                text: chunk.text,
166                signature: chunk.signature,
167            }),
168            ModelStreamEvent::ToolCall(tc) => StreamEvent::ToolCall(tc),
169            ModelStreamEvent::Done { tokens } => StreamEvent::Done {
170                usage: if tokens > 0 {
171                    Some(crate::models::TokenUsage::provider(0, tokens, tokens))
172                } else {
173                    None
174                },
175                thinking_signature: None,
176            },
177        };
178        // Synchronous send preserves ordering. Ignore errors — the
179        // receiver has closed means the turn is already gone.
180        let _ = sink.send(mapped);
181    })
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn build_model_config_maps_request_fields() {
190        let req = ChatRequest {
191            model_id: "ollama/test".to_string(),
192            messages: vec![],
193            system_prompt: "sys".to_string(),
194            instructions: Some("instructions text".to_string()),
195            reasoning: crate::models::ReasoningLevel::High,
196            temperature: 0.3,
197            max_tokens: 2048,
198            tools: vec![],
199        };
200        let app_cfg = crate::app::Config::default();
201        let cfg = build_model_config(&req, &app_cfg);
202        assert_eq!(cfg.model, "ollama/test");
203        assert_eq!(cfg.temperature, 0.3);
204        assert_eq!(cfg.max_tokens, 2048);
205        assert_eq!(cfg.reasoning, crate::models::ReasoningLevel::High);
206        assert_eq!(cfg.system_prompt.as_deref(), Some("sys"));
207        assert_eq!(
208            cfg.dynamic_system_suffix.as_deref(),
209            Some("instructions text")
210        );
211    }
212
213    /// F11 regression guard: Ollama hardware options in the user's
214    /// app config must land in the ModelConfig's backend_options so
215    /// the adapter's `build_request_body` emits them under `options`.
216    /// Before F11 these were configured + persisted but never reached
217    /// the wire — `num_ctx = 8192` in config.toml was a silent no-op.
218    #[test]
219    fn build_model_config_forwards_ollama_hardware_options() {
220        let req = ChatRequest {
221            model_id: "ollama/test".to_string(),
222            messages: vec![],
223            system_prompt: "sys".to_string(),
224            instructions: None,
225            reasoning: crate::models::ReasoningLevel::Medium,
226            temperature: 0.7,
227            max_tokens: 4096,
228            tools: vec![],
229        };
230        let mut app_cfg = crate::app::Config::default();
231        app_cfg.ollama.num_ctx = Some(8192);
232        app_cfg.ollama.num_gpu = Some(10);
233        app_cfg.ollama.num_thread = Some(8);
234        app_cfg.ollama.numa = Some(true);
235
236        let cfg = build_model_config(&req, &app_cfg);
237        let opts = cfg.ollama_options();
238        assert_eq!(opts.num_ctx, Some(8192));
239        assert_eq!(opts.num_gpu, Some(10));
240        assert_eq!(opts.num_thread, Some(8));
241        assert_eq!(opts.numa, Some(true));
242    }
243
244    #[tokio::test]
245    async fn stream_callback_forwards_text_event() {
246        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
247        let cb = stream_callback_for(tx);
248        cb(ModelStreamEvent::Text("hello".to_string()));
249        let recv = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
250            .await
251            .expect("recv")
252            .expect("sender alive");
253        match recv {
254            StreamEvent::Text(s) => assert_eq!(s, "hello"),
255            _ => panic!("wrong variant"),
256        }
257    }
258
259    #[tokio::test]
260    async fn stream_callback_forwards_done_with_tokens() {
261        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
262        let cb = stream_callback_for(tx);
263        cb(ModelStreamEvent::Done { tokens: 42 });
264        let recv = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
265            .await
266            .expect("recv")
267            .expect("sender");
268        match recv {
269            StreamEvent::Done { usage, .. } => {
270                let u = usage.expect("tokens > 0 → Some");
271                assert_eq!(u.total_tokens, 42);
272            },
273            _ => panic!("wrong variant"),
274        }
275    }
276
277    #[tokio::test]
278    async fn stream_callback_done_zero_tokens_is_none_usage() {
279        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
280        let cb = stream_callback_for(tx);
281        cb(ModelStreamEvent::Done { tokens: 0 });
282        let recv = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
283            .await
284            .expect("recv")
285            .expect("sender");
286        match recv {
287            StreamEvent::Done { usage, .. } => assert!(usage.is_none()),
288            _ => panic!("wrong variant"),
289        }
290    }
291}