crtx-llm 0.1.1

Claude, Ollama, and replay adapters behind a shared trait.
Documentation
//! `LlmAdapter` trait and the request / response / error types it exchanges.
//!
//! This module is the single shape that every LLM backend in Cortex implements
//! — Claude, Ollama, and the deterministic [`crate::replay::ReplayAdapter`]
//! used in CI. The contract is the one frozen in
//! [BUILD_SPEC §12](../../docs/BUILD_SPEC.md): a single async `complete`
//! entry point, a request struct that fully describes the call, and a
//! response struct that returns text, optionally-parsed JSON, the model name
//! that actually answered, token usage, and a stable byte-hash of the raw
//! response (`raw_hash`).
//!
//! `LlmAdapter` is `Send + Sync` so it can live behind an
//! `Arc<dyn LlmAdapter>` shared across the agent runtime.
//!
//! ## Example
//!
//! Implementations live in adapter-specific modules; here is the trait and
//! the request / response surface in their canonical form:
//!
//! ```rust
//! use async_trait::async_trait;
//! use cortex_llm::adapter::{LlmAdapter, LlmError, LlmRequest, LlmResponse};
//!
//! struct EchoAdapter;
//!
//! #[async_trait]
//! impl LlmAdapter for EchoAdapter {
//!     fn adapter_id(&self) -> &'static str {
//!         "echo"
//!     }
//!
//!     async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
//!         let text = req.messages.last().map(|m| m.content.clone()).unwrap_or_default();
//!         Ok(LlmResponse {
//!             text: text.clone(),
//!             parsed_json: None,
//!             model: req.model,
//!             usage: None,
//!             raw_hash: cortex_llm::adapter::blake3_hex(text.as_bytes()),
//!         })
//!     }
//! }
//! ```

use std::pin::Pin;

use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use thiserror::Error;

/// Role of a chat message in an [`LlmRequest`].
///
/// Mirrors the OpenAI / Anthropic role taxonomy so adapters can pass values
/// through to upstream APIs without translation.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LlmRole {
    /// Plain user prompt.
    User,
    /// Model reply (used to seed multi-turn replay fixtures).
    Assistant,
    /// Tool / function call result fed back to the model.
    Tool,
}

/// One message in the request's conversation history.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LlmMessage {
    /// Speaker role.
    pub role: LlmRole,
    /// UTF-8 text body. The adapter MAY reject non-UTF-8 upstream payloads
    /// before they reach this surface.
    pub content: String,
}

/// Optional token usage echoed back from the provider.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TokenUsage {
    /// Tokens consumed by the prompt (system + messages).
    pub prompt_tokens: u32,
    /// Tokens emitted in the completion text.
    pub completion_tokens: u32,
}

/// A single LLM call.
///
/// Field shape is frozen by BUILD_SPEC §12. Adapters MUST NOT reorder, rename,
/// or hide fields without bumping
/// [`cortex_core::SCHEMA_VERSION`] — the request
/// is part of the audit envelope downstream.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
    /// Provider-specific model identifier, e.g. `claude-3-5-sonnet-20240620`.
    pub model: String,
    /// System prompt; may be empty.
    pub system: String,
    /// Conversation history in chronological order.
    pub messages: Vec<LlmMessage>,
    /// Sampling temperature; passed through verbatim.
    pub temperature: f32,
    /// Hard cap on completion length.
    pub max_tokens: u32,
    /// Optional JSON Schema constraint applied to the response (provider may
    /// enforce or just inject as guidance).
    pub json_schema: Option<serde_json::Value>,
    /// Wallclock budget in milliseconds.
    pub timeout_ms: u64,
}

impl LlmRequest {
    /// Stable BLAKE3 hash over the canonical fields used for fixture matching:
    /// `(system, messages, temperature, max_tokens, json_schema)`.
    ///
    /// **Excludes `model` and `timeout_ms`** so the same prompt can be
    /// replayed against multiple models (matching is by `(model, prompt_hash)`
    /// pair) and so transient budget tweaks do not invalidate fixtures.
    ///
    /// The hash is over the serialized JSON of a tagged struct so that
    /// `serde_json` field ordering is the deterministic input.
    pub fn prompt_hash(&self) -> String {
        let canonical = CanonicalPrompt {
            system: &self.system,
            messages: &self.messages,
            temperature: self.temperature,
            max_tokens: self.max_tokens,
            json_schema: self.json_schema.as_ref(),
        };
        // serde_json writes object fields in struct-declaration order, which
        // is what we want — no `BTreeMap` shuffling.
        let bytes = serde_json::to_vec(&canonical).expect("CanonicalPrompt is always serializable");
        blake3_hex(&bytes)
    }
}

/// Internal helper: the subset of [`LlmRequest`] that participates in
/// `prompt_hash()`. Borrowed-references-only so we never copy the payload.
#[derive(Serialize)]
struct CanonicalPrompt<'a> {
    system: &'a str,
    messages: &'a [LlmMessage],
    temperature: f32,
    max_tokens: u32,
    #[serde(skip_serializing_if = "Option::is_none")]
    json_schema: Option<&'a serde_json::Value>,
}

/// Result of a successful LLM call.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
    /// Free-form text reply; equals `parsed_json.to_string()` for JSON-only
    /// adapters that do not echo a separate stringified form.
    pub text: String,
    /// Parsed JSON, when the request supplied a `json_schema` and the
    /// provider returned a parseable body.
    pub parsed_json: Option<serde_json::Value>,
    /// Model that actually produced the response (may differ from
    /// `LlmRequest::model` when the provider transparently routes).
    pub model: String,
    /// Token accounting from the provider, if reported.
    pub usage: Option<TokenUsage>,
    /// Lowercase hex BLAKE3 of the raw provider bytes (or the
    /// fixture's response payload, in the replay case). Fed straight into
    /// the audit envelope downstream.
    pub raw_hash: String,
}

/// Errors returned by any [`LlmAdapter`] implementation.
#[derive(Debug, Error)]
pub enum LlmError {
    /// Network or transport error talking to the provider.
    #[error("transport: {0}")]
    Transport(String),

    /// Upstream provider returned a non-success status with the given message.
    #[error("upstream: {0}")]
    Upstream(String),

    /// Request violated provider or local validation (bad shape, banned
    /// content, schema mismatch).
    #[error("invalid request: {0}")]
    InvalidRequest(String),

    /// The request exceeded its `timeout_ms` budget.
    #[error("timeout after {timeout_ms} ms")]
    Timeout {
        /// The original budget the caller supplied.
        timeout_ms: u64,
    },

    /// The response could not be parsed as the requested JSON schema.
    #[error("response parse: {0}")]
    Parse(String),

    /// Replay adapter could not find a fixture matching `(model, prompt_hash)`.
    #[error("no replay fixture for model={model} prompt_hash={prompt_hash}")]
    NoFixture {
        /// Model the request asked for.
        model: String,
        /// `LlmRequest::prompt_hash` value the lookup used.
        prompt_hash: String,
    },

    /// Replay adapter found a fixture but its on-disk bytes did not match the
    /// hash recorded in `INDEX.toml`. Maps to the CLI's
    /// `Exit::QuarantinedInput(5)` exit code in lane 1.C.
    ///
    /// See `THREATS.md` row T-RM-1 for the threat model rationale.
    #[error("fixture integrity failed: {0}")]
    FixtureIntegrityFailed(String),

    /// I/O failure reading fixtures, INDEX.toml, or any other adapter-local
    /// resource.
    #[error("io: {0}")]
    Io(String),
}

/// A single token delta emitted by a streaming LLM call.
#[derive(Debug, Clone)]
pub struct StreamChunk {
    /// Token delta — may be empty on the final chunk.
    pub delta: String,
    /// Set on the terminal chunk: `"stop"`, `"max_tokens"`, etc.
    pub finish_reason: Option<String>,
}

/// Object-safe boxed stream of [`StreamChunk`] items.
pub type BoxStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send + 'a>>;

/// The shared LLM surface.
///
/// **Send + Sync is required** so the runtime can hold an
/// `Arc<dyn LlmAdapter>` and dispatch from any async task.
///
/// Streaming support is provided through two entry points:
///
/// - [`LlmAdapter::stream`] — ergonomic `impl Stream` variant; only callable on
///   a concrete type (`where Self: Sized`), not through `dyn LlmAdapter`.
/// - [`LlmAdapter::stream_boxed`] — object-safe `Pin<Box<dyn Stream>>` variant,
///   callable on `&dyn LlmAdapter`. Default implementation delegates to
///   [`LlmAdapter::complete`] and yields the full response as a single chunk;
///   adapters override it for true line-by-line streaming.
#[async_trait]
pub trait LlmAdapter: Send + Sync {
    /// Stable, lowercase identifier used in audit envelopes (e.g. `"claude"`,
    /// `"ollama"`, `"replay"`). Constants — implementations MUST NOT vary
    /// this per call.
    fn adapter_id(&self) -> &'static str;

    /// Issue a completion call. The adapter is responsible for honouring
    /// `req.timeout_ms` and returning [`LlmError::Timeout`] when exceeded.
    async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError>;

    /// Stream tokens incrementally. Default implementation calls [`complete`]
    /// and yields the full response as a single chunk — adapters override this
    /// for true streaming.
    ///
    /// Bound to `where Self: Sized` so this method is not required to be
    /// object-safe. Use [`stream_boxed`] when dispatching through
    /// `&dyn LlmAdapter`.
    ///
    /// [`complete`]: LlmAdapter::complete
    /// [`stream_boxed`]: LlmAdapter::stream_boxed
    fn stream(&self, req: LlmRequest) -> impl Stream<Item = Result<StreamChunk, LlmError>> + Send
    where
        Self: Sized,
    {
        self.stream_boxed(req)
    }

    /// Object-safe streaming entry point. Returns a heap-allocated
    /// `Pin<Box<dyn Stream>>` so callers with only a `&dyn LlmAdapter`
    /// reference can drive streaming without knowing the concrete adapter type.
    ///
    /// Default: calls [`complete`] and emits the full response as a single
    /// [`StreamChunk`] with `finish_reason = Some("stop".into())`.
    ///
    /// [`complete`]: LlmAdapter::complete
    fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
        // Default implementation: wrap complete() result as one chunk.
        // Adapters that can stream natively override this method.
        let fut = self.complete(req);
        Box::pin(async_stream::stream! {
            match fut.await {
                Ok(resp) => {
                    yield Ok(StreamChunk {
                        delta: resp.text,
                        finish_reason: Some("stop".into()),
                    });
                }
                Err(e) => yield Err(e),
            }
        })
    }
}

/// Lowercase hex BLAKE3 of the given bytes.
///
/// Re-exported so adapters and tests can derive `raw_hash` consistently.
#[must_use]
pub fn blake3_hex(bytes: &[u8]) -> String {
    blake3::hash(bytes).to_hex().to_string()
}

#[cfg(test)]
mod tests {
    use super::*;

    fn req_for(messages: &[(LlmRole, &str)]) -> LlmRequest {
        LlmRequest {
            model: "test-model".into(),
            system: "be precise".into(),
            messages: messages
                .iter()
                .map(|(r, c)| LlmMessage {
                    role: *r,
                    content: (*c).to_string(),
                })
                .collect(),
            temperature: 0.0,
            max_tokens: 256,
            json_schema: None,
            timeout_ms: 30_000,
        }
    }

    #[test]
    fn prompt_hash_is_stable_across_calls() {
        let r = req_for(&[(LlmRole::User, "hello")]);
        assert_eq!(r.prompt_hash(), r.prompt_hash());
    }

    #[test]
    fn prompt_hash_ignores_model_and_timeout() {
        let mut a = req_for(&[(LlmRole::User, "hello")]);
        let mut b = a.clone();
        b.model = "other-model".into();
        b.timeout_ms = 1;
        assert_eq!(a.prompt_hash(), b.prompt_hash());

        // Sanity: changing a participating field DOES change the hash.
        a.temperature = 0.5;
        assert_ne!(a.prompt_hash(), b.prompt_hash());
    }

    #[test]
    fn prompt_hash_changes_with_message_content() {
        let a = req_for(&[(LlmRole::User, "hello")]);
        let b = req_for(&[(LlmRole::User, "world")]);
        assert_ne!(a.prompt_hash(), b.prompt_hash());
    }

    #[test]
    fn blake3_hex_is_64_chars() {
        assert_eq!(blake3_hex(b"abc").len(), 64);
    }
}