Skip to main content

inferd_engine/
mock.rs

1//! Deterministic mock backend used by tests and by the daemon's M1 echo
2//! milestone.
3//!
4//! Configurable knobs cover the failure modes adapters must support:
5//! - `ready` flag toggles `Backend::ready()` for testing the listener-gate
6//!   invariant (`THREAT_MODEL.md` F-13).
7//! - `pre_stream_error` causes `generate()` to return `GenerateError`
8//!   before yielding any tokens.
9//! - `mid_stream_drop_after` truncates the stream after N tokens (no
10//!   `Done` event) to exercise the mid-stream failure path.
11
12use crate::backend::{
13    Backend, BackendCapabilities, EmbedError, EmbedResult, GenerateError, TokenEvent, TokenEventV2,
14    TokenStream, TokenStreamV2,
15};
16use async_trait::async_trait;
17use inferd_proto::embed::{EmbedResolved, EmbedUsage};
18use inferd_proto::v2::{ResolvedV2, StopReasonV2, UsageV2};
19use inferd_proto::{Resolved, StopReason, Usage};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicBool, Ordering};
22use tokio_stream::wrappers::ReceiverStream;
23
24/// Configuration for `Mock` failure-mode injection.
25#[derive(Debug, Clone, Default)]
26pub struct MockConfig {
27    /// If `Some`, `generate()` returns this error immediately. Defaults to
28    /// `None` (success).
29    pub pre_stream_error: Option<MockError>,
30    /// If `Some(N)`, the stream yields N tokens then ends without a `Done`
31    /// event, simulating a mid-stream backend failure.
32    pub mid_stream_drop_after: Option<usize>,
33    /// Tokens to emit (if `mid_stream_drop_after` is `None` they all stream
34    /// followed by a `Done`). Default: a single canned response so callers
35    /// without a config still get something useful.
36    pub tokens: Vec<String>,
37    /// Optional sleep between emitted tokens, in milliseconds. Used by
38    /// the concurrency stress harness to make per-request work
39    /// observable so admission queueing actually engages. `None` means
40    /// no delay (the historical behaviour).
41    pub token_delay_ms: Option<u64>,
42}
43
44/// Variants for `MockConfig::pre_stream_error`.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum MockError {
47    /// Backend reports not ready.
48    NotReady,
49    /// Backend reports invalid request.
50    InvalidRequest,
51    /// Backend reports unavailable.
52    Unavailable,
53}
54
55impl From<MockError> for GenerateError {
56    fn from(e: MockError) -> Self {
57        match e {
58            MockError::NotReady => GenerateError::NotReady,
59            MockError::InvalidRequest => GenerateError::InvalidRequest("mock".into()),
60            MockError::Unavailable => GenerateError::Unavailable("mock".into()),
61        }
62    }
63}
64
65/// Deterministic test backend.
66pub struct Mock {
67    name: &'static str,
68    ready: Arc<AtomicBool>,
69    config: MockConfig,
70}
71
72impl Mock {
73    /// Build a `Mock` that reports ready immediately and emits a single canned
74    /// token followed by `Done`.
75    pub fn new() -> Self {
76        Self::with_config(MockConfig {
77            tokens: vec!["mock-response".into()],
78            ..Default::default()
79        })
80    }
81
82    /// Build a `Mock` with custom failure-mode configuration.
83    pub fn with_config(config: MockConfig) -> Self {
84        Self {
85            name: "mock",
86            ready: Arc::new(AtomicBool::new(true)),
87            config,
88        }
89    }
90
91    /// Toggle the backend's reported readiness. Used by tests of the
92    /// listener-gate invariant.
93    pub fn set_ready(&self, ready: bool) {
94        self.ready.store(ready, Ordering::SeqCst);
95    }
96}
97
98impl Default for Mock {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104#[async_trait]
105impl Backend for Mock {
106    fn name(&self) -> &str {
107        self.name
108    }
109
110    fn ready(&self) -> bool {
111        self.ready.load(Ordering::SeqCst)
112    }
113
114    /// Mock advertises v2 + thinking + embed so daemon-side dispatch
115    /// across all three sockets can be exercised end-to-end without a
116    /// real engine. Multimodal / tool flags stay `false` — Mock
117    /// doesn't pretend to ingest images or parse tool calls.
118    fn capabilities(&self) -> BackendCapabilities {
119        BackendCapabilities {
120            v2: true,
121            thinking: true,
122            embed: true,
123            ..BackendCapabilities::default()
124        }
125    }
126
127    async fn generate(&self, _req: Resolved) -> Result<TokenStream, GenerateError> {
128        if let Some(err) = self.config.pre_stream_error {
129            return Err(err.into());
130        }
131        if !self.ready() {
132            return Err(GenerateError::NotReady);
133        }
134
135        let tokens = self.config.tokens.clone();
136        let drop_after = self.config.mid_stream_drop_after;
137        let token_delay = self
138            .config
139            .token_delay_ms
140            .map(std::time::Duration::from_millis);
141        let (tx, rx) = tokio::sync::mpsc::channel(8);
142
143        // Spawned so dropping the stream (which drops `rx`) cancels by
144        // closing the channel — `tx.send` then returns Err and we exit.
145        tokio::spawn(async move {
146            let mut completion_tokens: u32 = 0;
147            for (emitted, tok) in tokens.into_iter().enumerate() {
148                if let Some(n) = drop_after
149                    && emitted >= n
150                {
151                    // Simulate mid-stream failure: stop without Done.
152                    return;
153                }
154                if let Some(d) = token_delay {
155                    tokio::time::sleep(d).await;
156                }
157                if tx.send(TokenEvent::Token(tok)).await.is_err() {
158                    return; // receiver dropped → cancellation
159                }
160                completion_tokens = completion_tokens.saturating_add(1);
161            }
162            let _ = tx
163                .send(TokenEvent::Done {
164                    stop_reason: StopReason::End,
165                    usage: Usage {
166                        prompt_tokens: 0,
167                        completion_tokens,
168                    },
169                })
170                .await;
171        });
172
173        Ok(Box::pin(ReceiverStream::new(rx)))
174    }
175
176    /// v2 generation. Same token tape + delays as `generate` but
177    /// emits `TokenEventV2::Text(...)` and a v2 `Done` frame with
178    /// `StopReasonV2::EndTurn` and `UsageV2` field names. Mid-stream
179    /// drop and pre-stream error knobs apply identically.
180    async fn generate_v2(&self, _req: ResolvedV2) -> Result<TokenStreamV2, GenerateError> {
181        if let Some(err) = self.config.pre_stream_error {
182            return Err(err.into());
183        }
184        if !self.ready() {
185            return Err(GenerateError::NotReady);
186        }
187
188        let tokens = self.config.tokens.clone();
189        let drop_after = self.config.mid_stream_drop_after;
190        let token_delay = self
191            .config
192            .token_delay_ms
193            .map(std::time::Duration::from_millis);
194        let (tx, rx) = tokio::sync::mpsc::channel(8);
195
196        tokio::spawn(async move {
197            let mut output_tokens: u32 = 0;
198            for (emitted, tok) in tokens.into_iter().enumerate() {
199                if let Some(n) = drop_after
200                    && emitted >= n
201                {
202                    return;
203                }
204                if let Some(d) = token_delay {
205                    tokio::time::sleep(d).await;
206                }
207                if tx.send(TokenEventV2::Text(tok)).await.is_err() {
208                    return;
209                }
210                output_tokens = output_tokens.saturating_add(1);
211            }
212            let _ = tx
213                .send(TokenEventV2::Done {
214                    stop_reason: StopReasonV2::EndTurn,
215                    usage: UsageV2 {
216                        input_tokens: 0,
217                        output_tokens,
218                    },
219                })
220                .await;
221        });
222
223        Ok(Box::pin(ReceiverStream::new(rx)))
224    }
225
226    /// Deterministic mock embedding. Emits one fixed-length vector per
227    /// input string; entries are derived from the input length so
228    /// tests can assert correlation. `dimensions` defaults to 8 when
229    /// the request doesn't supply one; otherwise the request's value
230    /// is honoured (no model-specific MRL set to validate against).
231    /// Pre-stream-error knob (`MockError::Unavailable` /
232    /// `InvalidRequest`) is reused on the embed path to exercise
233    /// daemon-side error mapping; `NotReady` mode short-circuits to
234    /// `EmbedError::NotReady` to mirror the v1/v2 paths.
235    async fn embed(&self, req: EmbedResolved) -> Result<EmbedResult, EmbedError> {
236        if let Some(err) = self.config.pre_stream_error {
237            return Err(match err {
238                MockError::NotReady => EmbedError::NotReady,
239                MockError::InvalidRequest => EmbedError::InvalidRequest("mock".into()),
240                MockError::Unavailable => EmbedError::Unavailable("mock".into()),
241            });
242        }
243        if !self.ready() {
244            return Err(EmbedError::NotReady);
245        }
246
247        let dimensions = req.dimensions.unwrap_or(8);
248        let mut input_tokens: u32 = 0;
249        let embeddings = req
250            .input
251            .iter()
252            .map(|s| {
253                input_tokens = input_tokens.saturating_add(s.len() as u32);
254                let len_f = s.len() as f32;
255                (0..dimensions)
256                    .map(|i| (i as f32 + 1.0) / (len_f + 1.0))
257                    .collect()
258            })
259            .collect();
260
261        Ok(EmbedResult {
262            embeddings,
263            dimensions,
264            model: "mock".into(),
265            usage: EmbedUsage { input_tokens },
266        })
267    }
268}