crtx-llm 0.1.1

Claude, Ollama, and replay adapters behind a shared trait.
Documentation
//! HTTP adapter that posts to a local Ollama `/api/chat` endpoint.
//!
//! [`OllamaHttpAdapter`] implements [`LlmAdapter`] by forwarding requests to
//! the Ollama REST API. Because `ureq` is synchronous, the blocking I/O is
//! wrapped with `tokio::task::spawn_blocking` so the adapter can satisfy the
//! async trait contract without blocking the async executor.
//!
//! The adapter enforces the same loopback-only and digest-pinned-model
//! invariants as [`crate::ollama::validate_config`]: construction fails for
//! non-loopback endpoints; calls fail for model refs that are not pinned.

use std::time::Duration;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};

use crate::adapter::{
    blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, StreamChunk,
};
use crate::ollama::{validate_endpoint_url, validate_model_ref, OllamaConfig};

/// HTTP adapter that routes to a local Ollama instance via `/api/chat`.
#[derive(Debug, Clone)]
pub struct OllamaHttpAdapter {
    config: OllamaConfig,
}

impl OllamaHttpAdapter {
    /// Build an adapter from `config`.
    ///
    /// The endpoint URL is validated immediately; returns
    /// [`LlmError::InvalidRequest`] if the endpoint does not satisfy the
    /// loopback-only constraint. The model reference is validated per-call
    /// (inside [`LlmAdapter::complete`]) because the model is also present on
    /// the request.
    pub fn new(config: OllamaConfig) -> Result<Self, LlmError> {
        validate_endpoint_url(&config.endpoint_url).map_err(|e| match e {
            LlmError::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
            other => other,
        })?;
        Ok(Self { config })
    }
}

// ---------------------------------------------------------------------------
// Wire types
// ---------------------------------------------------------------------------

/// Outgoing payload for `POST /api/chat`.
#[derive(Debug, Serialize)]
struct ChatRequest<'a> {
    model: &'a str,
    messages: Vec<OllamaMessage<'a>>,
    stream: bool,
}

/// One message in the Ollama chat format.
#[derive(Debug, Serialize)]
struct OllamaMessage<'a> {
    role: &'a str,
    content: &'a str,
}

/// Top-level Ollama `/api/chat` response envelope (non-streaming).
#[derive(Debug, Deserialize)]
struct ChatResponse {
    #[serde(default)]
    message: MessageField,
}

/// The `message` field inside a chat response.
#[derive(Debug, Default, Deserialize)]
struct MessageField {
    #[serde(default)]
    content: String,
}

/// One newline-delimited JSON line emitted by Ollama's streaming `/api/chat`.
///
/// Ollama sends objects of the form:
/// ```json
/// {"message":{"role":"assistant","content":"Hello"},"done":false}
/// {"message":{"role":"assistant","content":""},"done":true,"done_reason":"stop"}
/// ```
#[derive(Debug, Deserialize)]
struct StreamLine {
    #[serde(default)]
    message: MessageField,
    #[serde(default)]
    done: bool,
    /// Present on the terminal line when `done` is `true`.
    done_reason: Option<String>,
}

// ---------------------------------------------------------------------------
// LlmAdapter implementation
// ---------------------------------------------------------------------------

#[async_trait]
impl LlmAdapter for OllamaHttpAdapter {
    fn adapter_id(&self) -> &'static str {
        "ollama"
    }

    async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
        // Use the adapter's configured model, not req.model. req.model is set
        // by the caller (e.g. cortex_reflect uses DEFAULT_REFLECTION_MODEL =
        // "replay-reflection-v1" so the ReplayAdapter can look up fixtures).
        // The Ollama adapter always drives its own pinned model.
        validate_model_ref(&self.config.model)?;
        let req = LlmRequest { model: self.config.model.clone(), ..req };

        let config = self.config.clone();
        let timeout_ms = req.timeout_ms;

        let result = tokio::task::spawn_blocking(move || call_ollama(&config, &req, timeout_ms))
            .await
            .map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;

        result
    }

    /// Override with true Ollama streaming via newline-delimited JSON.
    ///
    /// Uses `ureq` (synchronous) inside `spawn_blocking`. Because
    /// `spawn_blocking` requires the entire blocking work to complete before
    /// returning, all stream lines are collected into a `Vec` before being
    /// yielded. This means backpressure and incremental display require the
    /// full response to arrive first.
    ///
    /// TODO: replace `ureq` with an async HTTP client (e.g. `reqwest`) and
    /// drive the response body with `tokio::io::AsyncBufReadExt` to achieve
    /// true line-by-line streaming without buffering the entire response.
    fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
        // Same model-override pattern as complete(): use the adapter's pinned
        // model, not whatever placeholder the caller put in req.model.
        let req = LlmRequest { model: self.config.model.clone(), ..req };
        validate_model_ref_and_stream(self.config.clone(), req)
    }
}

/// Synchronous Ollama HTTP call, executed inside `spawn_blocking`.
fn call_ollama(
    config: &OllamaConfig,
    req: &LlmRequest,
    timeout_ms: u64,
) -> Result<LlmResponse, LlmError> {
    let url = format!("{}/api/chat", config.endpoint_url);

    let messages: Vec<OllamaMessage<'_>> = req
        .messages
        .iter()
        .map(|m| OllamaMessage {
            role: m.role.as_str(),
            content: &m.content,
        })
        .collect();

    // Ollama API uses "name:tag" format; strip "@sha256:<digest>" if present.
    let ollama_model = req.model.split('@').next().unwrap_or(&req.model);

    let body = ChatRequest {
        model: ollama_model,
        messages,
        stream: false,
    };

    let timeout = Duration::from_millis(timeout_ms);

    let agent = ureq::AgentBuilder::new().timeout(timeout).build();

    let raw_response = agent
        .post(&url)
        .send_json(
            serde_json::to_value(&body)
                .map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?,
        )
        .map_err(|err| map_ureq_error(err, timeout_ms))?;

    let status = raw_response.status();
    if status != 200 {
        return Err(LlmError::Transport(format!("HTTP {status}")));
    }

    let response_text = raw_response
        .into_string()
        .map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;

    const MAX_RESPONSE_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
    if response_text.len() > MAX_RESPONSE_BYTES {
        return Err(LlmError::Transport(format!(
            "ollama response body exceeds 16 MiB limit ({} bytes); refusing to store",
            response_text.len()
        )));
    }

    let parsed: ChatResponse = serde_json::from_str(&response_text)
        .map_err(|e| LlmError::Parse(format!("ollama response parse: {e}")))?;

    let text = parsed.message.content;
    let raw_hash = blake3_hex(response_text.as_bytes());

    Ok(LlmResponse {
        text,
        parsed_json: None,
        model: config.model.clone(),
        usage: None,
        raw_hash,
    })
}

/// Build a `BoxStream` after validating the model ref.
///
/// Extracted as a free function so the `stream_boxed` method body stays short
/// and the `async_stream::stream!` macro is not inside an `impl` block.
fn validate_model_ref_and_stream(config: OllamaConfig, req: LlmRequest) -> BoxStream<'static> {
    Box::pin(async_stream::stream! {
        // req.model was already overridden to config.model by stream_boxed;
        // validate against the config model (not a caller placeholder).
        if let Err(e) = validate_model_ref(&config.model) {
            yield Err(e);
            return;
        }

        let timeout_ms = req.timeout_ms;
        let result = tokio::task::spawn_blocking(move || {
            call_ollama_streaming(&config, &req, timeout_ms)
        })
        .await;

        match result {
            Ok(chunks) => {
                for chunk in chunks {
                    yield chunk;
                }
            }
            Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
        }
    })
}

/// Synchronous Ollama streaming call, executed inside `spawn_blocking`.
///
/// Posts to `/api/chat` with `stream: true`, then reads the response body
/// line by line. Each non-empty line is parsed as a [`StreamLine`] and
/// converted to a [`StreamChunk`]. The complete `Vec` is returned so the
/// caller's `async_stream::stream!` block can yield items incrementally.
fn call_ollama_streaming(
    config: &OllamaConfig,
    req: &LlmRequest,
    timeout_ms: u64,
) -> Vec<Result<StreamChunk, LlmError>> {
    let url = format!("{}/api/chat", config.endpoint_url);

    let messages: Vec<OllamaMessage<'_>> = req
        .messages
        .iter()
        .map(|m| OllamaMessage {
            role: m.role.as_str(),
            content: &m.content,
        })
        .collect();

    // Strip "@sha256:<digest>" for Ollama's "name:tag" API format.
    let ollama_model = req.model.split('@').next().unwrap_or(&req.model);

    let body = ChatRequest {
        model: ollama_model,
        messages,
        stream: true,
    };

    let timeout = Duration::from_millis(timeout_ms);
    let agent = ureq::AgentBuilder::new().timeout(timeout).build();

    let body_value = match serde_json::to_value(&body) {
        Ok(v) => v,
        Err(e) => {
            return vec![Err(LlmError::Transport(format!(
                "request serialization failed: {e}"
            )))]
        }
    };

    let raw_response = match agent.post(&url).send_json(body_value) {
        Ok(r) => r,
        Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
    };

    let status = raw_response.status();
    if status != 200 {
        return vec![Err(LlmError::Transport(format!("HTTP {status}")))];
    }

    let body_text = match raw_response.into_string() {
        Ok(s) => s,
        Err(e) => {
            return vec![Err(LlmError::Transport(format!(
                "reading streaming response body: {e}"
            )))]
        }
    };

    body_text
        .lines()
        .filter(|line| !line.trim().is_empty())
        .map(|line| {
            let parsed: StreamLine = serde_json::from_str(line)
                .map_err(|e| LlmError::Parse(format!("ollama stream line parse: {e}")))?;
            Ok(StreamChunk {
                delta: parsed.message.content,
                finish_reason: if parsed.done {
                    parsed.done_reason
                } else {
                    None
                },
            })
        })
        .collect()
}

/// Map a `ureq` error to an [`LlmError`] variant.
fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
    match err {
        ureq::Error::Transport(t) => {
            // ureq surfaces timeout as a Transport error whose `kind()` is
            // `Io` and whose inner source is a `TimedOut` OS error.
            let msg = t.to_string();
            if is_timeout_message(&msg) {
                LlmError::Timeout { timeout_ms }
            } else {
                LlmError::Transport(msg)
            }
        }
        ureq::Error::Status(code, _) => LlmError::Transport(format!("HTTP {code}")),
    }
}

/// Heuristic: does the transport error message look like a timeout?
fn is_timeout_message(msg: &str) -> bool {
    let lower = msg.to_ascii_lowercase();
    lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
}

// ---------------------------------------------------------------------------
// Role serialization helper
// ---------------------------------------------------------------------------

use crate::adapter::LlmRole;

impl LlmRole {
    /// Return the lowercase string representation used by Ollama's API.
    fn as_str(self) -> &'static str {
        match self {
            LlmRole::User => "user",
            LlmRole::Assistant => "assistant",
            LlmRole::Tool => "tool",
        }
    }
}