Skip to main content

ai_memory/inference/
mod.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! Pluggable inference backend trait — issue #651 (RFC pulled forward
5//! from v0.8 per operator directive `28860423-d12c-4959-bc8b-8fa9a94a33d9`,
6//! 2026-05-18).
7//!
8//! ## Goal
9//!
10//! Provide a single trait surface that unifies the substrate's two
11//! inference paths today (`embeddings::Embedder` for vector embedding,
12//! `llm::OllamaClient` for chat / auto-tag / detect-contradiction)
13//! AND provides a forward-compatible hook for the v0.8 GPU / MTP
14//! distilled hot-path backend (issues #651 / #654 / Gap #10 of #846).
15//!
16//! ## Surface
17//!
18//! ```ignore
19//! pub trait InferenceBackend: Send + Sync {
20//!     fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>>;
21//!     fn chat(&self, prompt: &str) -> anyhow::Result<String>;
22//!     fn attested_weights(&self) -> Option<AttestedWeights>;
23//! }
24//! ```
25//!
26//! ## Backends shipped at v0.7.0
27//!
28//! - [`CpuBackend`] — wraps the existing CPU pipeline
29//!   (`embeddings::Embedder` + `llm::OllamaClient`). This is what
30//!   v0.7.0 actually uses on the recall hot-path.
31//! - [`GpuBackend`] — stub returning `not implemented`. Lands as a
32//!   trait-conformant placeholder so the v0.8 work (issue #651 Phase 1
33//!   — mistralrs or candle in-process GPU backend) can drop in without
34//!   any caller-side refactor.
35//!
36//! ## Attested weights (issue #654)
37//!
38//! `attested_weights()` returns the loaded model's SHA-256 + an
39//! optional Ed25519 signature over the weight bytes. The CPU backend
40//! implements MVP supply-chain attestation by hashing the on-disk
41//! model file at load time; the GPU backend stub returns `None`.
42//! Documentation for the full v0.8 attested weight chain lives at
43//! `docs/v0.7.0/inference-attestation.md`.
44//!
45//! ## Regression test
46//!
47//! `cpu_backend_round_trips_embed` (in this module) and
48//! `gpu_backend_returns_not_implemented` pin the contract.
49
50use anyhow::{Result, anyhow};
51use std::sync::Arc;
52
53/// Attested model-weight provenance returned by
54/// [`InferenceBackend::attested_weights`]. MVP supply-chain attestation
55/// per issue #654 — SHA-256 of the on-disk weight file, plus an
56/// optional Ed25519 signature attested by the operator key.
57///
58/// v0.8 will extend this with a full Sigstore-style chain (cosign
59/// bundle, transparency log entry, key-rotation reference). Today the
60/// MVP shape is enough to refuse to serve from a tampered weight file
61/// at load time.
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct AttestedWeights {
64    /// Hex-encoded SHA-256 of the model weight bytes.
65    pub sha256: String,
66    /// Optional base64-encoded Ed25519 signature over `sha256`.
67    /// `None` for backends that have not been signed yet.
68    pub signature: Option<String>,
69    /// Operator-readable label identifying the model
70    /// (e.g. `"all-MiniLM-L6-v2"` or `"distilled-hot-path-v0.8"`).
71    pub label: String,
72}
73
74/// The unified inference surface. v0.8 callers will hold an
75/// `Arc<dyn InferenceBackend>` instead of separate embedder + llm
76/// handles. At v0.7.0 the recall hot-path still uses the legacy
77/// types directly (no callsite churn during the v0.7.0 ship window);
78/// the trait is the seam through which the v0.8 GPU/MTP backend will
79/// be threaded.
80pub trait InferenceBackend: Send + Sync {
81    /// Produce a single embedding vector for `text`.
82    ///
83    /// # Errors
84    ///
85    /// Implementor-specific (model load failure, tokenisation error,
86    /// device OOM, etc.). The GPU stub backend returns a
87    /// `not implemented` error.
88    fn embed(&self, text: &str) -> Result<Vec<f32>>;
89
90    /// Generate a chat completion for `prompt`. Default system prompt
91    /// is `None` (implementor decides); use a concrete backend's API
92    /// for system-prompt support.
93    ///
94    /// # Errors
95    ///
96    /// Implementor-specific (transport error, model unavailable,
97    /// safety refusal, etc.).
98    fn chat(&self, prompt: &str) -> Result<String>;
99
100    /// Return the loaded model's SHA-256 + optional signature for
101    /// issue #654 supply-chain attestation. `None` if the backend
102    /// has no on-disk weights to attest (e.g. a network-only client).
103    fn attested_weights(&self) -> Option<AttestedWeights> {
104        None
105    }
106}
107
108/// CPU backend — wraps the existing v0.7.0 inference path
109/// (`embeddings::Embedder` + `llm::OllamaClient`). This is a thin
110/// adapter; the underlying types are unchanged.
111pub struct CpuBackend {
112    embedder: Arc<dyn crate::embeddings::Embed>,
113    llm: Option<Arc<crate::llm::OllamaClient>>,
114    /// Optional pre-computed attested-weights record. Construct via
115    /// [`CpuBackend::with_attested_weights`] when the operator has
116    /// pinned the model file's SHA-256.
117    attested: Option<AttestedWeights>,
118}
119
120impl CpuBackend {
121    /// Construct a CPU backend from existing handles.
122    #[must_use]
123    pub fn new(
124        embedder: Arc<dyn crate::embeddings::Embed>,
125        llm: Option<Arc<crate::llm::OllamaClient>>,
126    ) -> Self {
127        Self {
128            embedder,
129            llm,
130            attested: None,
131        }
132    }
133
134    /// Pin an attested-weights record (issue #654). Returns a new
135    /// backend wrapping the same handles. The hash is NOT recomputed
136    /// here — the caller pre-computes it via
137    /// [`compute_attested_weights`] at model-load time so the
138    /// `verify_attested_weights` gate can refuse to serve from a
139    /// tampered file.
140    #[must_use]
141    pub fn with_attested_weights(mut self, attested: AttestedWeights) -> Self {
142        self.attested = Some(attested);
143        self
144    }
145}
146
147impl InferenceBackend for CpuBackend {
148    fn embed(&self, text: &str) -> Result<Vec<f32>> {
149        self.embedder.embed(text)
150    }
151
152    fn chat(&self, prompt: &str) -> Result<String> {
153        let llm = self
154            .llm
155            .as_ref()
156            .ok_or_else(|| anyhow!("CpuBackend: chat unavailable (no OllamaClient configured)"))?;
157        llm.generate(prompt, None)
158    }
159
160    fn attested_weights(&self) -> Option<AttestedWeights> {
161        self.attested.clone()
162    }
163}
164
165/// GPU backend stub — issue #651 Phase 1 placeholder. Returns
166/// `not implemented` from every call. Lands as a trait-conformant
167/// type so the v0.8 GPU/MTP backend (mistralrs or candle in-process)
168/// can drop in without a single caller-side refactor.
169#[derive(Default)]
170pub struct GpuBackend {
171    /// Operator-readable label (e.g. `"distilled-hot-path-v0.8"`).
172    /// Stored even on the stub so attestation plumbing can be
173    /// exercised end-to-end during the v0.8 work.
174    pub label: String,
175}
176
177impl GpuBackend {
178    /// Construct a GPU backend stub with the given operator-readable
179    /// label.
180    #[must_use]
181    pub fn new(label: impl Into<String>) -> Self {
182        Self {
183            label: label.into(),
184        }
185    }
186}
187
188impl InferenceBackend for GpuBackend {
189    fn embed(&self, _text: &str) -> Result<Vec<f32>> {
190        Err(anyhow!(
191            "GpuBackend::embed not implemented (v0.8 work — issue #651 Phase 1; \
192             see docs/v0.7.0/inference-attestation.md for the rollout plan)"
193        ))
194    }
195
196    fn chat(&self, _prompt: &str) -> Result<String> {
197        Err(anyhow!(
198            "GpuBackend::chat not implemented (v0.8 work — issue #651 Phase 1)"
199        ))
200    }
201}
202
203/// Compute the SHA-256 of a model-weight file on disk and assemble an
204/// [`AttestedWeights`] record. Issue #654 MVP supply-chain attestation.
205///
206/// # Errors
207///
208/// Returns an error if the file cannot be read.
209pub fn compute_attested_weights(
210    path: &std::path::Path,
211    label: impl Into<String>,
212    signature: Option<String>,
213) -> Result<AttestedWeights> {
214    use sha2::{Digest, Sha256};
215    let bytes = std::fs::read(path)
216        .map_err(|e| anyhow!("compute_attested_weights: read {}: {e}", path.display()))?;
217    let mut hasher = Sha256::new();
218    hasher.update(&bytes);
219    let digest = hasher.finalize();
220    Ok(AttestedWeights {
221        sha256: hex::encode(digest),
222        signature,
223        label: label.into(),
224    })
225}
226
227/// Verify an in-flight [`AttestedWeights`] record against the file at
228/// `path`. Issue #654 MVP gate — call before binding the backend if
229/// the operator has pinned a known-good hash.
230///
231/// Two checks run, both fail-CLOSED:
232///
233/// 1. **Hash** — the recomputed SHA-256 of the on-disk file MUST equal
234///    `expected.sha256`.
235/// 2. **Signature** — when `expected.signature` is `Some`, the Ed25519
236///    signature MUST verify against the operator's resolved public key
237///    ([`crate::governance::rules_store::resolve_operator_pubkey`]) over
238///    the recomputed SHA-256 hex string's bytes. A signature that is
239///    present but cannot be verified — malformed base64, wrong length,
240///    bad signature, OR no operator key resolvable — is a hard refusal.
241///    Pre-fix the signature field was stored but NEVER checked, so a
242///    record carrying a forged or stale signature passed the gate on the
243///    hash alone (silent unverified-signature gap, issue #654).
244///
245/// # Errors
246///
247/// Returns an error if the file cannot be read, the recomputed hash does
248/// not match `expected.sha256`, or a present signature fails to verify.
249pub fn verify_attested_weights(path: &std::path::Path, expected: &AttestedWeights) -> Result<()> {
250    let operator_pubkey = crate::governance::rules_store::resolve_operator_pubkey();
251    verify_attested_weights_with_key(path, expected, operator_pubkey.as_ref())
252}
253
254/// Key-injecting core of [`verify_attested_weights`]. Production callers
255/// use the wrapper (which resolves the operator key from disk/env);
256/// tests pass an explicit `operator_pubkey` so the signature gate can be
257/// exercised hermetically without touching the operator key directory.
258///
259/// # Errors
260///
261/// See [`verify_attested_weights`].
262pub fn verify_attested_weights_with_key(
263    path: &std::path::Path,
264    expected: &AttestedWeights,
265    operator_pubkey: Option<&ed25519_dalek::VerifyingKey>,
266) -> Result<()> {
267    let recomputed = compute_attested_weights(path, &expected.label, None)?;
268    if recomputed.sha256 != expected.sha256 {
269        return Err(anyhow!(
270            "verify_attested_weights: hash mismatch for {} (expected {}, got {}) — \
271             refusing to serve from a tampered weight file (issue #654)",
272            path.display(),
273            expected.sha256,
274            recomputed.sha256,
275        ));
276    }
277
278    // Signature gate (issue #654). The canonical signed message is the
279    // SHA-256 hex string's ASCII bytes — the same value the operator
280    // signs when pinning the weights. A present-but-unverifiable
281    // signature fails CLOSED.
282    if let Some(sig_b64) = expected.signature.as_deref() {
283        let Some(verifying_key) = operator_pubkey else {
284            return Err(anyhow!(
285                "verify_attested_weights: record for {} carries a signature but no operator \
286                 public key could be resolved — refusing to serve (fail-CLOSED, issue #654)",
287                path.display(),
288            ));
289        };
290        verify_attested_weights_signature(&recomputed.sha256, sig_b64, verifying_key).map_err(
291            |e| {
292                anyhow!(
293                    "verify_attested_weights: signature verification failed for {} ({e}) — \
294                     refusing to serve (issue #654)",
295                    path.display(),
296                )
297            },
298        )?;
299    }
300
301    Ok(())
302}
303
304/// Verify a base64-encoded Ed25519 `signature` over the `sha256` hex
305/// string's bytes against `verifying_key`. Accepts both URL-safe-no-pad
306/// and standard base64 (mirrors
307/// [`crate::governance::rules_store::resolve_operator_pubkey`]).
308fn verify_attested_weights_signature(
309    sha256: &str,
310    signature: &str,
311    verifying_key: &ed25519_dalek::VerifyingKey,
312) -> Result<(), ed25519_dalek::SignatureError> {
313    use base64::Engine;
314    use ed25519_dalek::{Signature, Verifier};
315
316    let trimmed = signature.trim();
317    let sig_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
318        .decode(trimmed)
319        .or_else(|_| base64::engine::general_purpose::STANDARD.decode(trimmed))
320        .map_err(|_| ed25519_dalek::SignatureError::new())?;
321    if sig_bytes.len() != ed25519_dalek::SIGNATURE_LENGTH {
322        return Err(ed25519_dalek::SignatureError::new());
323    }
324    let mut sig_arr = [0u8; ed25519_dalek::SIGNATURE_LENGTH];
325    sig_arr.copy_from_slice(&sig_bytes);
326    let sig = Signature::from_bytes(&sig_arr);
327    verifying_key.verify(sha256.as_bytes(), &sig)
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use std::io::Write;
334
335    struct MockEmbedder;
336    impl crate::embeddings::Embed for MockEmbedder {
337        fn embed(&self, text: &str) -> Result<Vec<f32>> {
338            Ok(vec![text.len() as f32; 4])
339        }
340    }
341
342    #[test]
343    fn cpu_backend_round_trips_embed() {
344        let be: Arc<dyn InferenceBackend> = Arc::new(CpuBackend::new(Arc::new(MockEmbedder), None));
345        let v = be.embed("hello").expect("embed ok");
346        assert_eq!(v, vec![5.0_f32; 4]);
347    }
348
349    #[test]
350    fn cpu_backend_chat_without_llm_errors() {
351        let be = CpuBackend::new(Arc::new(MockEmbedder), None);
352        let err = be.chat("anything").expect_err("must err");
353        assert!(err.to_string().contains("chat unavailable"));
354    }
355
356    #[test]
357    fn gpu_backend_returns_not_implemented() {
358        let be: Arc<dyn InferenceBackend> = Arc::new(GpuBackend::new("test-gpu"));
359        let err = be.embed("x").expect_err("gpu embed must err");
360        assert!(err.to_string().contains("not implemented"));
361        let err = be.chat("x").expect_err("gpu chat must err");
362        assert!(err.to_string().contains("not implemented"));
363        assert!(be.attested_weights().is_none());
364    }
365
366    #[test]
367    fn compute_and_verify_attested_weights_round_trip() {
368        // Write a tiny fixture file to .local-runs/ so we honor the
369        // no-/tmp HARD RULE in CLAUDE.md.
370        let dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".local-runs");
371        std::fs::create_dir_all(&dir).expect("mkdir .local-runs");
372        let path = dir.join(format!(
373            "inference-attest-fixture-{}.bin",
374            uuid::Uuid::new_v4()
375        ));
376        let mut f = std::fs::File::create(&path).expect("create fixture");
377        f.write_all(b"a tiny attested model weight blob")
378            .expect("write fixture");
379        f.sync_all().expect("sync fixture");
380        drop(f);
381
382        let attested =
383            compute_attested_weights(&path, "fixture", None).expect("compute_attested_weights ok");
384        assert_eq!(attested.sha256.len(), 64, "sha256 hex must be 64 chars");
385
386        verify_attested_weights(&path, &attested).expect("verify ok");
387
388        // Tamper the file; verify must now refuse.
389        let mut f = std::fs::OpenOptions::new()
390            .append(true)
391            .open(&path)
392            .expect("open append");
393        f.write_all(b"--tampered--").expect("tamper write");
394        f.sync_all().expect("sync tamper");
395        drop(f);
396        let err = verify_attested_weights(&path, &attested)
397            .expect_err("verify must refuse tampered file");
398        assert!(err.to_string().contains("hash mismatch"));
399
400        let _ = std::fs::remove_file(&path);
401    }
402
403    #[test]
404    fn cpu_backend_with_attested_weights_round_trip() {
405        let attested = AttestedWeights {
406            sha256: "0".repeat(64),
407            signature: None,
408            label: "test".into(),
409        };
410        let be =
411            CpuBackend::new(Arc::new(MockEmbedder), None).with_attested_weights(attested.clone());
412        assert_eq!(be.attested_weights(), Some(attested));
413    }
414
415    // ---- issue #654 — Ed25519 signature gate on attested weights ----
416    //
417    // Pre-fix `verify_attested_weights` stored the `signature` field but
418    // never checked it, so a record could pass the gate on the hash
419    // alone while carrying a forged / stale / absent-key signature. These
420    // tests pin the fail-CLOSED signature semantics via the key-injecting
421    // core (`verify_attested_weights_with_key`) so the gate is exercised
422    // hermetically, independent of the operator key directory.
423
424    fn write_attest_fixture(content: &[u8]) -> std::path::PathBuf {
425        let dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".local-runs");
426        std::fs::create_dir_all(&dir).expect("mkdir .local-runs");
427        let path = dir.join(format!("inference-attest-sig-{}.bin", uuid::Uuid::new_v4()));
428        let mut f = std::fs::File::create(&path).expect("create fixture");
429        f.write_all(content).expect("write fixture");
430        f.sync_all().expect("sync fixture");
431        path
432    }
433
434    fn sign_b64(signing_key: &ed25519_dalek::SigningKey, message: &[u8]) -> String {
435        use base64::Engine;
436        use ed25519_dalek::Signer;
437        base64::engine::general_purpose::STANDARD.encode(signing_key.sign(message).to_bytes())
438    }
439
440    #[test]
441    fn verify_attested_weights_accepts_valid_operator_signature() {
442        let mut csprng = rand_core::OsRng;
443        let signing_key = ed25519_dalek::SigningKey::generate(&mut csprng);
444        let verifying_key = signing_key.verifying_key();
445
446        let path = write_attest_fixture(b"signed weight blob");
447        let unsigned = compute_attested_weights(&path, "fixture", None).expect("compute ok");
448        // Operator signs the sha256 hex string's bytes — the canonical
449        // signed message.
450        let signature = sign_b64(&signing_key, unsigned.sha256.as_bytes());
451        let attested = AttestedWeights {
452            signature: Some(signature),
453            ..unsigned
454        };
455
456        verify_attested_weights_with_key(&path, &attested, Some(&verifying_key))
457            .expect("valid signature must verify");
458
459        let _ = std::fs::remove_file(&path);
460    }
461
462    #[test]
463    fn verify_attested_weights_rejects_forged_signature() {
464        let mut csprng = rand_core::OsRng;
465        let operator_key = ed25519_dalek::SigningKey::generate(&mut csprng);
466        let attacker_key = ed25519_dalek::SigningKey::generate(&mut csprng);
467
468        let path = write_attest_fixture(b"forged weight blob");
469        let unsigned = compute_attested_weights(&path, "fixture", None).expect("compute ok");
470        // Signed by the attacker, verified against the operator key → fail.
471        let signature = sign_b64(&attacker_key, unsigned.sha256.as_bytes());
472        let attested = AttestedWeights {
473            signature: Some(signature),
474            ..unsigned
475        };
476
477        let err =
478            verify_attested_weights_with_key(&path, &attested, Some(&operator_key.verifying_key()))
479                .expect_err("forged signature must be refused");
480        assert!(err.to_string().contains("signature verification failed"));
481
482        let _ = std::fs::remove_file(&path);
483    }
484
485    #[test]
486    fn verify_attested_weights_fails_closed_when_signed_but_no_key() {
487        let mut csprng = rand_core::OsRng;
488        let signing_key = ed25519_dalek::SigningKey::generate(&mut csprng);
489
490        let path = write_attest_fixture(b"orphan-signature weight blob");
491        let unsigned = compute_attested_weights(&path, "fixture", None).expect("compute ok");
492        let signature = sign_b64(&signing_key, unsigned.sha256.as_bytes());
493        let attested = AttestedWeights {
494            signature: Some(signature),
495            ..unsigned
496        };
497
498        // Signature present but NO operator key resolvable → fail CLOSED.
499        let err = verify_attested_weights_with_key(&path, &attested, None)
500            .expect_err("present signature with no key must fail closed");
501        assert!(err.to_string().contains("no operator public key"));
502
503        let _ = std::fs::remove_file(&path);
504    }
505
506    #[test]
507    fn verify_attested_weights_unsigned_record_skips_signature_gate() {
508        // No signature → only the hash gate runs; no operator key needed.
509        let path = write_attest_fixture(b"unsigned weight blob");
510        let attested = compute_attested_weights(&path, "fixture", None).expect("compute ok");
511        assert!(attested.signature.is_none());
512        verify_attested_weights_with_key(&path, &attested, None)
513            .expect("unsigned record must verify on hash alone");
514        let _ = std::fs::remove_file(&path);
515    }
516}