Skip to main content

cortex_llm/
ollama_http.rs

1//! HTTP adapter that posts to a local Ollama `/api/chat` endpoint.
2//!
3//! [`OllamaHttpAdapter`] implements [`LlmAdapter`] by forwarding requests to
4//! the Ollama REST API. Because `ureq` is synchronous, the blocking I/O is
5//! wrapped with `tokio::task::spawn_blocking` so the adapter can satisfy the
6//! async trait contract without blocking the async executor.
7//!
8//! The adapter enforces the same loopback-only and digest-pinned-model
9//! invariants as [`crate::ollama::validate_config`]: construction fails for
10//! non-loopback endpoints; calls fail for model refs that are not pinned.
11
12use std::time::Duration;
13
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16
17use crate::adapter::{
18    blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, StreamChunk,
19};
20use crate::ollama::{validate_endpoint_url, validate_model_ref, OllamaConfig};
21
22/// HTTP adapter that routes to a local Ollama instance via `/api/chat`.
23#[derive(Debug, Clone)]
24pub struct OllamaHttpAdapter {
25    config: OllamaConfig,
26}
27
28impl OllamaHttpAdapter {
29    /// Build an adapter from `config`.
30    ///
31    /// The endpoint URL is validated immediately; returns
32    /// [`LlmError::InvalidRequest`] if the endpoint does not satisfy the
33    /// loopback-only constraint. The model reference is validated per-call
34    /// (inside [`LlmAdapter::complete`]) because the model is also present on
35    /// the request.
36    pub fn new(config: OllamaConfig) -> Result<Self, LlmError> {
37        validate_endpoint_url(&config.endpoint_url).map_err(|e| match e {
38            LlmError::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
39            other => other,
40        })?;
41        Ok(Self { config })
42    }
43}
44
45// ---------------------------------------------------------------------------
46// Wire types
47// ---------------------------------------------------------------------------
48
49/// Outgoing payload for `POST /api/chat`.
50#[derive(Debug, Serialize)]
51struct ChatRequest<'a> {
52    model: &'a str,
53    messages: Vec<OllamaMessage<'a>>,
54    stream: bool,
55}
56
57/// One message in the Ollama chat format.
58#[derive(Debug, Serialize)]
59struct OllamaMessage<'a> {
60    role: &'a str,
61    content: &'a str,
62}
63
64/// Top-level Ollama `/api/chat` response envelope (non-streaming).
65#[derive(Debug, Deserialize)]
66struct ChatResponse {
67    #[serde(default)]
68    message: MessageField,
69}
70
71/// The `message` field inside a chat response.
72#[derive(Debug, Default, Deserialize)]
73struct MessageField {
74    #[serde(default)]
75    content: String,
76}
77
78/// One newline-delimited JSON line emitted by Ollama's streaming `/api/chat`.
79///
80/// Ollama sends objects of the form:
81/// ```json
82/// {"message":{"role":"assistant","content":"Hello"},"done":false}
83/// {"message":{"role":"assistant","content":""},"done":true,"done_reason":"stop"}
84/// ```
85#[derive(Debug, Deserialize)]
86struct StreamLine {
87    #[serde(default)]
88    message: MessageField,
89    #[serde(default)]
90    done: bool,
91    /// Present on the terminal line when `done` is `true`.
92    done_reason: Option<String>,
93}
94
95// ---------------------------------------------------------------------------
96// LlmAdapter implementation
97// ---------------------------------------------------------------------------
98
99#[async_trait]
100impl LlmAdapter for OllamaHttpAdapter {
101    fn adapter_id(&self) -> &'static str {
102        "ollama"
103    }
104
105    async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
106        // Use the adapter's configured model, not req.model. req.model is set
107        // by the caller (e.g. cortex_reflect uses DEFAULT_REFLECTION_MODEL =
108        // "replay-reflection-v1" so the ReplayAdapter can look up fixtures).
109        // The Ollama adapter always drives its own pinned model.
110        validate_model_ref(&self.config.model)?;
111        let req = LlmRequest { model: self.config.model.clone(), ..req };
112
113        let config = self.config.clone();
114        let timeout_ms = req.timeout_ms;
115
116        let result = tokio::task::spawn_blocking(move || call_ollama(&config, &req, timeout_ms))
117            .await
118            .map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;
119
120        result
121    }
122
123    /// Override with true Ollama streaming via newline-delimited JSON.
124    ///
125    /// Uses `ureq` (synchronous) inside `spawn_blocking`. Because
126    /// `spawn_blocking` requires the entire blocking work to complete before
127    /// returning, all stream lines are collected into a `Vec` before being
128    /// yielded. This means backpressure and incremental display require the
129    /// full response to arrive first.
130    ///
131    /// TODO: replace `ureq` with an async HTTP client (e.g. `reqwest`) and
132    /// drive the response body with `tokio::io::AsyncBufReadExt` to achieve
133    /// true line-by-line streaming without buffering the entire response.
134    fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
135        // Same model-override pattern as complete(): use the adapter's pinned
136        // model, not whatever placeholder the caller put in req.model.
137        let req = LlmRequest { model: self.config.model.clone(), ..req };
138        validate_model_ref_and_stream(self.config.clone(), req)
139    }
140}
141
142/// Synchronous Ollama HTTP call, executed inside `spawn_blocking`.
143fn call_ollama(
144    config: &OllamaConfig,
145    req: &LlmRequest,
146    timeout_ms: u64,
147) -> Result<LlmResponse, LlmError> {
148    let url = format!("{}/api/chat", config.endpoint_url);
149
150    let messages: Vec<OllamaMessage<'_>> = req
151        .messages
152        .iter()
153        .map(|m| OllamaMessage {
154            role: m.role.as_str(),
155            content: &m.content,
156        })
157        .collect();
158
159    // Ollama API uses "name:tag" format; strip "@sha256:<digest>" if present.
160    let ollama_model = req.model.split('@').next().unwrap_or(&req.model);
161
162    let body = ChatRequest {
163        model: ollama_model,
164        messages,
165        stream: false,
166    };
167
168    let timeout = Duration::from_millis(timeout_ms);
169
170    let agent = ureq::AgentBuilder::new().timeout(timeout).build();
171
172    let raw_response = agent
173        .post(&url)
174        .send_json(
175            serde_json::to_value(&body)
176                .map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?,
177        )
178        .map_err(|err| map_ureq_error(err, timeout_ms))?;
179
180    let status = raw_response.status();
181    if status != 200 {
182        return Err(LlmError::Transport(format!("HTTP {status}")));
183    }
184
185    let response_text = raw_response
186        .into_string()
187        .map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;
188
189    const MAX_RESPONSE_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
190    if response_text.len() > MAX_RESPONSE_BYTES {
191        return Err(LlmError::Transport(format!(
192            "ollama response body exceeds 16 MiB limit ({} bytes); refusing to store",
193            response_text.len()
194        )));
195    }
196
197    let parsed: ChatResponse = serde_json::from_str(&response_text)
198        .map_err(|e| LlmError::Parse(format!("ollama response parse: {e}")))?;
199
200    let text = parsed.message.content;
201    let raw_hash = blake3_hex(response_text.as_bytes());
202
203    Ok(LlmResponse {
204        text,
205        parsed_json: None,
206        model: config.model.clone(),
207        usage: None,
208        raw_hash,
209    })
210}
211
212/// Build a `BoxStream` after validating the model ref.
213///
214/// Extracted as a free function so the `stream_boxed` method body stays short
215/// and the `async_stream::stream!` macro is not inside an `impl` block.
216fn validate_model_ref_and_stream(config: OllamaConfig, req: LlmRequest) -> BoxStream<'static> {
217    Box::pin(async_stream::stream! {
218        // req.model was already overridden to config.model by stream_boxed;
219        // validate against the config model (not a caller placeholder).
220        if let Err(e) = validate_model_ref(&config.model) {
221            yield Err(e);
222            return;
223        }
224
225        let timeout_ms = req.timeout_ms;
226        let result = tokio::task::spawn_blocking(move || {
227            call_ollama_streaming(&config, &req, timeout_ms)
228        })
229        .await;
230
231        match result {
232            Ok(chunks) => {
233                for chunk in chunks {
234                    yield chunk;
235                }
236            }
237            Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
238        }
239    })
240}
241
242/// Synchronous Ollama streaming call, executed inside `spawn_blocking`.
243///
244/// Posts to `/api/chat` with `stream: true`, then reads the response body
245/// line by line. Each non-empty line is parsed as a [`StreamLine`] and
246/// converted to a [`StreamChunk`]. The complete `Vec` is returned so the
247/// caller's `async_stream::stream!` block can yield items incrementally.
248fn call_ollama_streaming(
249    config: &OllamaConfig,
250    req: &LlmRequest,
251    timeout_ms: u64,
252) -> Vec<Result<StreamChunk, LlmError>> {
253    let url = format!("{}/api/chat", config.endpoint_url);
254
255    let messages: Vec<OllamaMessage<'_>> = req
256        .messages
257        .iter()
258        .map(|m| OllamaMessage {
259            role: m.role.as_str(),
260            content: &m.content,
261        })
262        .collect();
263
264    // Strip "@sha256:<digest>" for Ollama's "name:tag" API format.
265    let ollama_model = req.model.split('@').next().unwrap_or(&req.model);
266
267    let body = ChatRequest {
268        model: ollama_model,
269        messages,
270        stream: true,
271    };
272
273    let timeout = Duration::from_millis(timeout_ms);
274    let agent = ureq::AgentBuilder::new().timeout(timeout).build();
275
276    let body_value = match serde_json::to_value(&body) {
277        Ok(v) => v,
278        Err(e) => {
279            return vec![Err(LlmError::Transport(format!(
280                "request serialization failed: {e}"
281            )))]
282        }
283    };
284
285    let raw_response = match agent.post(&url).send_json(body_value) {
286        Ok(r) => r,
287        Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
288    };
289
290    let status = raw_response.status();
291    if status != 200 {
292        return vec![Err(LlmError::Transport(format!("HTTP {status}")))];
293    }
294
295    let body_text = match raw_response.into_string() {
296        Ok(s) => s,
297        Err(e) => {
298            return vec![Err(LlmError::Transport(format!(
299                "reading streaming response body: {e}"
300            )))]
301        }
302    };
303
304    body_text
305        .lines()
306        .filter(|line| !line.trim().is_empty())
307        .map(|line| {
308            let parsed: StreamLine = serde_json::from_str(line)
309                .map_err(|e| LlmError::Parse(format!("ollama stream line parse: {e}")))?;
310            Ok(StreamChunk {
311                delta: parsed.message.content,
312                finish_reason: if parsed.done {
313                    parsed.done_reason
314                } else {
315                    None
316                },
317            })
318        })
319        .collect()
320}
321
322/// Map a `ureq` error to an [`LlmError`] variant.
323fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
324    match err {
325        ureq::Error::Transport(t) => {
326            // ureq surfaces timeout as a Transport error whose `kind()` is
327            // `Io` and whose inner source is a `TimedOut` OS error.
328            let msg = t.to_string();
329            if is_timeout_message(&msg) {
330                LlmError::Timeout { timeout_ms }
331            } else {
332                LlmError::Transport(msg)
333            }
334        }
335        ureq::Error::Status(code, _) => LlmError::Transport(format!("HTTP {code}")),
336    }
337}
338
339/// Heuristic: does the transport error message look like a timeout?
340fn is_timeout_message(msg: &str) -> bool {
341    let lower = msg.to_ascii_lowercase();
342    lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
343}
344
345// ---------------------------------------------------------------------------
346// Role serialization helper
347// ---------------------------------------------------------------------------
348
349use crate::adapter::LlmRole;
350
351impl LlmRole {
352    /// Return the lowercase string representation used by Ollama's API.
353    fn as_str(self) -> &'static str {
354        match self {
355            LlmRole::User => "user",
356            LlmRole::Assistant => "assistant",
357            LlmRole::Tool => "tool",
358        }
359    }
360}