Skip to main content

cortex_llm/
adapter.rs

1//! `LlmAdapter` trait and the request / response / error types it exchanges.
2//!
3//! This module is the single shape that every LLM backend in Cortex implements
4//! — Claude, Ollama, and the deterministic [`crate::replay::ReplayAdapter`]
5//! used in CI. The contract is the one frozen in
6//! [BUILD_SPEC §12](../../docs/BUILD_SPEC.md): a single async `complete`
7//! entry point, a request struct that fully describes the call, and a
8//! response struct that returns text, optionally-parsed JSON, the model name
9//! that actually answered, token usage, and a stable byte-hash of the raw
10//! response (`raw_hash`).
11//!
12//! `LlmAdapter` is `Send + Sync` so it can live behind an
13//! `Arc<dyn LlmAdapter>` shared across the agent runtime.
14//!
15//! ## Example
16//!
17//! Implementations live in adapter-specific modules; here is the trait and
18//! the request / response surface in their canonical form:
19//!
20//! ```rust
21//! use async_trait::async_trait;
22//! use cortex_llm::adapter::{LlmAdapter, LlmError, LlmRequest, LlmResponse};
23//!
24//! struct EchoAdapter;
25//!
26//! #[async_trait]
27//! impl LlmAdapter for EchoAdapter {
28//!     fn adapter_id(&self) -> &'static str {
29//!         "echo"
30//!     }
31//!
32//!     async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
33//!         let text = req.messages.last().map(|m| m.content.clone()).unwrap_or_default();
34//!         Ok(LlmResponse {
35//!             text: text.clone(),
36//!             parsed_json: None,
37//!             model: req.model,
38//!             usage: None,
39//!             raw_hash: cortex_llm::adapter::blake3_hex(text.as_bytes()),
40//!         })
41//!     }
42//! }
43//! ```
44
45use std::pin::Pin;
46
47use async_trait::async_trait;
48use futures::Stream;
49use serde::{Deserialize, Serialize};
50use thiserror::Error;
51
52/// Role of a chat message in an [`LlmRequest`].
53///
54/// Mirrors the OpenAI / Anthropic role taxonomy so adapters can pass values
55/// through to upstream APIs without translation.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum LlmRole {
59    /// Plain user prompt.
60    User,
61    /// Model reply (used to seed multi-turn replay fixtures).
62    Assistant,
63    /// Tool / function call result fed back to the model.
64    Tool,
65}
66
67/// One message in the request's conversation history.
68#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
69pub struct LlmMessage {
70    /// Speaker role.
71    pub role: LlmRole,
72    /// UTF-8 text body. The adapter MAY reject non-UTF-8 upstream payloads
73    /// before they reach this surface.
74    pub content: String,
75}
76
77/// Optional token usage echoed back from the provider.
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
79pub struct TokenUsage {
80    /// Tokens consumed by the prompt (system + messages).
81    pub prompt_tokens: u32,
82    /// Tokens emitted in the completion text.
83    pub completion_tokens: u32,
84}
85
86/// A single LLM call.
87///
88/// Field shape is frozen by BUILD_SPEC §12. Adapters MUST NOT reorder, rename,
89/// or hide fields without bumping
90/// [`cortex_core::SCHEMA_VERSION`] — the request
91/// is part of the audit envelope downstream.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct LlmRequest {
94    /// Provider-specific model identifier, e.g. `claude-3-5-sonnet-20240620`.
95    pub model: String,
96    /// System prompt; may be empty.
97    pub system: String,
98    /// Conversation history in chronological order.
99    pub messages: Vec<LlmMessage>,
100    /// Sampling temperature; passed through verbatim.
101    pub temperature: f32,
102    /// Hard cap on completion length.
103    pub max_tokens: u32,
104    /// Optional JSON Schema constraint applied to the response (provider may
105    /// enforce or just inject as guidance).
106    pub json_schema: Option<serde_json::Value>,
107    /// Wallclock budget in milliseconds.
108    pub timeout_ms: u64,
109}
110
111impl LlmRequest {
112    /// Stable BLAKE3 hash over the canonical fields used for fixture matching:
113    /// `(system, messages, temperature, max_tokens, json_schema)`.
114    ///
115    /// **Excludes `model` and `timeout_ms`** so the same prompt can be
116    /// replayed against multiple models (matching is by `(model, prompt_hash)`
117    /// pair) and so transient budget tweaks do not invalidate fixtures.
118    ///
119    /// The hash is over the serialized JSON of a tagged struct so that
120    /// `serde_json` field ordering is the deterministic input.
121    pub fn prompt_hash(&self) -> String {
122        let canonical = CanonicalPrompt {
123            system: &self.system,
124            messages: &self.messages,
125            temperature: self.temperature,
126            max_tokens: self.max_tokens,
127            json_schema: self.json_schema.as_ref(),
128        };
129        // serde_json writes object fields in struct-declaration order, which
130        // is what we want — no `BTreeMap` shuffling.
131        let bytes = serde_json::to_vec(&canonical).expect("CanonicalPrompt is always serializable");
132        blake3_hex(&bytes)
133    }
134}
135
136/// Internal helper: the subset of [`LlmRequest`] that participates in
137/// `prompt_hash()`. Borrowed-references-only so we never copy the payload.
138#[derive(Serialize)]
139struct CanonicalPrompt<'a> {
140    system: &'a str,
141    messages: &'a [LlmMessage],
142    temperature: f32,
143    max_tokens: u32,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    json_schema: Option<&'a serde_json::Value>,
146}
147
148/// Result of a successful LLM call.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct LlmResponse {
151    /// Free-form text reply; equals `parsed_json.to_string()` for JSON-only
152    /// adapters that do not echo a separate stringified form.
153    pub text: String,
154    /// Parsed JSON, when the request supplied a `json_schema` and the
155    /// provider returned a parseable body.
156    pub parsed_json: Option<serde_json::Value>,
157    /// Model that actually produced the response (may differ from
158    /// `LlmRequest::model` when the provider transparently routes).
159    pub model: String,
160    /// Token accounting from the provider, if reported.
161    pub usage: Option<TokenUsage>,
162    /// Lowercase hex BLAKE3 of the raw provider bytes (or the
163    /// fixture's response payload, in the replay case). Fed straight into
164    /// the audit envelope downstream.
165    pub raw_hash: String,
166}
167
168/// Errors returned by any [`LlmAdapter`] implementation.
169#[derive(Debug, Error)]
170pub enum LlmError {
171    /// Network or transport error talking to the provider.
172    #[error("transport: {0}")]
173    Transport(String),
174
175    /// Upstream provider returned a non-success status with the given message.
176    #[error("upstream: {0}")]
177    Upstream(String),
178
179    /// Request violated provider or local validation (bad shape, banned
180    /// content, schema mismatch).
181    #[error("invalid request: {0}")]
182    InvalidRequest(String),
183
184    /// The request exceeded its `timeout_ms` budget.
185    #[error("timeout after {timeout_ms} ms")]
186    Timeout {
187        /// The original budget the caller supplied.
188        timeout_ms: u64,
189    },
190
191    /// The response could not be parsed as the requested JSON schema.
192    #[error("response parse: {0}")]
193    Parse(String),
194
195    /// Replay adapter could not find a fixture matching `(model, prompt_hash)`.
196    #[error("no replay fixture for model={model} prompt_hash={prompt_hash}")]
197    NoFixture {
198        /// Model the request asked for.
199        model: String,
200        /// `LlmRequest::prompt_hash` value the lookup used.
201        prompt_hash: String,
202    },
203
204    /// Replay adapter found a fixture but its on-disk bytes did not match the
205    /// hash recorded in `INDEX.toml`. Maps to the CLI's
206    /// `Exit::QuarantinedInput(5)` exit code in lane 1.C.
207    ///
208    /// See `THREATS.md` row T-RM-1 for the threat model rationale.
209    #[error("fixture integrity failed: {0}")]
210    FixtureIntegrityFailed(String),
211
212    /// I/O failure reading fixtures, INDEX.toml, or any other adapter-local
213    /// resource.
214    #[error("io: {0}")]
215    Io(String),
216}
217
218/// A single token delta emitted by a streaming LLM call.
219#[derive(Debug, Clone)]
220pub struct StreamChunk {
221    /// Token delta — may be empty on the final chunk.
222    pub delta: String,
223    /// Set on the terminal chunk: `"stop"`, `"max_tokens"`, etc.
224    pub finish_reason: Option<String>,
225}
226
227/// Object-safe boxed stream of [`StreamChunk`] items.
228pub type BoxStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send + 'a>>;
229
230/// The shared LLM surface.
231///
232/// **Send + Sync is required** so the runtime can hold an
233/// `Arc<dyn LlmAdapter>` and dispatch from any async task.
234///
235/// Streaming support is provided through two entry points:
236///
237/// - [`LlmAdapter::stream`] — ergonomic `impl Stream` variant; only callable on
238///   a concrete type (`where Self: Sized`), not through `dyn LlmAdapter`.
239/// - [`LlmAdapter::stream_boxed`] — object-safe `Pin<Box<dyn Stream>>` variant,
240///   callable on `&dyn LlmAdapter`. Default implementation delegates to
241///   [`LlmAdapter::complete`] and yields the full response as a single chunk;
242///   adapters override it for true line-by-line streaming.
243#[async_trait]
244pub trait LlmAdapter: Send + Sync {
245    /// Stable, lowercase identifier used in audit envelopes (e.g. `"claude"`,
246    /// `"ollama"`, `"replay"`). Constants — implementations MUST NOT vary
247    /// this per call.
248    fn adapter_id(&self) -> &'static str;
249
250    /// Issue a completion call. The adapter is responsible for honouring
251    /// `req.timeout_ms` and returning [`LlmError::Timeout`] when exceeded.
252    async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError>;
253
254    /// Stream tokens incrementally. Default implementation calls [`complete`]
255    /// and yields the full response as a single chunk — adapters override this
256    /// for true streaming.
257    ///
258    /// Bound to `where Self: Sized` so this method is not required to be
259    /// object-safe. Use [`stream_boxed`] when dispatching through
260    /// `&dyn LlmAdapter`.
261    ///
262    /// [`complete`]: LlmAdapter::complete
263    /// [`stream_boxed`]: LlmAdapter::stream_boxed
264    fn stream(&self, req: LlmRequest) -> impl Stream<Item = Result<StreamChunk, LlmError>> + Send
265    where
266        Self: Sized,
267    {
268        self.stream_boxed(req)
269    }
270
271    /// Object-safe streaming entry point. Returns a heap-allocated
272    /// `Pin<Box<dyn Stream>>` so callers with only a `&dyn LlmAdapter`
273    /// reference can drive streaming without knowing the concrete adapter type.
274    ///
275    /// Default: calls [`complete`] and emits the full response as a single
276    /// [`StreamChunk`] with `finish_reason = Some("stop".into())`.
277    ///
278    /// [`complete`]: LlmAdapter::complete
279    fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
280        // Default implementation: wrap complete() result as one chunk.
281        // Adapters that can stream natively override this method.
282        let fut = self.complete(req);
283        Box::pin(async_stream::stream! {
284            match fut.await {
285                Ok(resp) => {
286                    yield Ok(StreamChunk {
287                        delta: resp.text,
288                        finish_reason: Some("stop".into()),
289                    });
290                }
291                Err(e) => yield Err(e),
292            }
293        })
294    }
295}
296
297/// Lowercase hex BLAKE3 of the given bytes.
298///
299/// Re-exported so adapters and tests can derive `raw_hash` consistently.
300#[must_use]
301pub fn blake3_hex(bytes: &[u8]) -> String {
302    blake3::hash(bytes).to_hex().to_string()
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    fn req_for(messages: &[(LlmRole, &str)]) -> LlmRequest {
310        LlmRequest {
311            model: "test-model".into(),
312            system: "be precise".into(),
313            messages: messages
314                .iter()
315                .map(|(r, c)| LlmMessage {
316                    role: *r,
317                    content: (*c).to_string(),
318                })
319                .collect(),
320            temperature: 0.0,
321            max_tokens: 256,
322            json_schema: None,
323            timeout_ms: 30_000,
324        }
325    }
326
327    #[test]
328    fn prompt_hash_is_stable_across_calls() {
329        let r = req_for(&[(LlmRole::User, "hello")]);
330        assert_eq!(r.prompt_hash(), r.prompt_hash());
331    }
332
333    #[test]
334    fn prompt_hash_ignores_model_and_timeout() {
335        let mut a = req_for(&[(LlmRole::User, "hello")]);
336        let mut b = a.clone();
337        b.model = "other-model".into();
338        b.timeout_ms = 1;
339        assert_eq!(a.prompt_hash(), b.prompt_hash());
340
341        // Sanity: changing a participating field DOES change the hash.
342        a.temperature = 0.5;
343        assert_ne!(a.prompt_hash(), b.prompt_hash());
344    }
345
346    #[test]
347    fn prompt_hash_changes_with_message_content() {
348        let a = req_for(&[(LlmRole::User, "hello")]);
349        let b = req_for(&[(LlmRole::User, "world")]);
350        assert_ne!(a.prompt_hash(), b.prompt_hash());
351    }
352
353    #[test]
354    fn blake3_hex_is_64_chars() {
355        assert_eq!(blake3_hex(b"abc").len(), 64);
356    }
357}