Skip to main content

smos_application/testkit/
providers.rs

1//! Provider test doubles: scripted LLM extractor, constant/recording
2//! embedders, and the dual-mode scripted NLI classifier.
3
4use std::sync::{Arc, Mutex};
5
6use smos_domain::NliResult;
7use smos_domain::chat::ToolCall;
8
9use crate::errors::ProviderError;
10use crate::ports::{EmbeddingProvider, LlmExtractor, NliClassifier};
11
12/// LLM extractor that pops pre-scripted results in FIFO order and counts
13/// invocations. When the script is exhausted, subsequent calls return an empty
14/// `Vec` (mirroring a provider that simply finds no facts) rather than
15/// erroring, so tests that do not care about the Nth call still pass.
16pub struct ScriptedExtractor {
17    results: Mutex<Vec<Result<Vec<String>, ProviderError>>>,
18    calls: Mutex<u32>,
19}
20
21impl ScriptedExtractor {
22    pub fn new(results: Vec<Result<Vec<String>, ProviderError>>) -> Self {
23        Self {
24            results: Mutex::new(results),
25            calls: Mutex::new(0),
26        }
27    }
28
29    pub fn call_count(&self) -> u32 {
30        *self.calls.lock().unwrap()
31    }
32}
33
34impl LlmExtractor for ScriptedExtractor {
35    async fn extract_facts(
36        &self,
37        _content: &str,
38        _tool_calls: &[ToolCall],
39    ) -> Result<Vec<String>, ProviderError> {
40        *self.calls.lock().unwrap() += 1;
41        let mut guard = self.results.lock().unwrap();
42        if guard.is_empty() {
43            Ok(Vec::new())
44        } else {
45            guard.remove(0)
46        }
47    }
48}
49
50/// Embedding provider that always returns the same vector regardless of input.
51pub struct ConstantEmbedder(pub Vec<f32>);
52
53impl EmbeddingProvider for ConstantEmbedder {
54    async fn embed(&self, _text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
55        Ok(Some(self.0.clone()))
56    }
57}
58
59/// Embedding provider that records every `embed` call and returns a
60/// deterministic content-derived vector unique to the input text. Used to
61/// verify the extraction pipeline hands distinct embeddings to distinct facts
62/// (so Layer 2 dedup makes the right call). `new` returns the double together
63/// with the shared call-log handle so the test body can assert on it.
64pub struct RecordingEmbedder {
65    calls: Arc<Mutex<Vec<String>>>,
66}
67
68impl RecordingEmbedder {
69    pub fn new() -> (Self, Arc<Mutex<Vec<String>>>) {
70        let calls = Arc::new(Mutex::new(Vec::new()));
71        (
72            Self {
73                calls: calls.clone(),
74            },
75            calls,
76        )
77    }
78
79    fn vector_for(text: &str) -> Vec<f32> {
80        // Stable, content-derived 1024-dim one-hot-ish vector: hash the text
81        // into a single u64 and use it as the index of the single non-zero
82        // dimension. Distinct inputs land on distinct indices, so the cosine
83        // similarity across different hashes is 0.
84        let hash = text
85            .bytes()
86            .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
87        let mut vec = vec![0.0; 1024];
88        vec[(hash as usize) % 1024] = 1.0;
89        vec
90    }
91}
92
93impl EmbeddingProvider for RecordingEmbedder {
94    async fn embed(&self, text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
95        self.calls.lock().unwrap().push(text.to_string());
96        Ok(Some(Self::vector_for(text)))
97    }
98}
99
100/// Closure type used by the matcher variant of [`ScriptedNliClassifier`].
101type NliResolver = Box<dyn Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync>;
102
103/// Scripted NLI classifier with two modes:
104/// - [`ScriptedNliClassifier::new`] (FIFO): each call pops the next verdict
105///   from the queue. Use when the test controls call order.
106/// - [`ScriptedNliClassifier::matching`] (Match): each call dispatches to the
107///   supplied closure. Use when pending iteration order is not deterministic
108///   (`HashMap` order) and the test keys verdicts on the candidate text.
109///
110/// Both modes record every (premise, hypothesis) pair so tests can assert on
111/// the exact set of pairs the use case asked about.
112pub enum ScriptedNliClassifier {
113    Fifo {
114        verdicts: Mutex<Vec<Result<NliResult, ProviderError>>>,
115        calls: Mutex<Vec<(String, String)>>,
116    },
117    Match {
118        resolver: NliResolver,
119        calls: Mutex<Vec<(String, String)>>,
120    },
121}
122
123impl ScriptedNliClassifier {
124    pub fn new(verdicts: Vec<Result<NliResult, ProviderError>>) -> Self {
125        Self::Fifo {
126            verdicts: Mutex::new(verdicts),
127            calls: Mutex::new(Vec::new()),
128        }
129    }
130
131    pub fn matching<F>(resolver: F) -> Self
132    where
133        F: Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync + 'static,
134    {
135        Self::Match {
136            resolver: Box::new(resolver),
137            calls: Mutex::new(Vec::new()),
138        }
139    }
140
141    pub fn calls(&self) -> Vec<(String, String)> {
142        match self {
143            Self::Fifo { calls, .. } | Self::Match { calls, .. } => calls.lock().unwrap().clone(),
144        }
145    }
146}
147
148impl NliClassifier for ScriptedNliClassifier {
149    async fn classify(&self, premise: &str, hypothesis: &str) -> Result<NliResult, ProviderError> {
150        match self {
151            Self::Fifo { verdicts, calls } => {
152                calls
153                    .lock()
154                    .unwrap()
155                    .push((premise.to_string(), hypothesis.to_string()));
156                let mut queue = verdicts.lock().unwrap();
157                if queue.is_empty() {
158                    Err(ProviderError::Unavailable("scripted queue empty".into()))
159                } else {
160                    queue.remove(0)
161                }
162            }
163            Self::Match { resolver, calls } => {
164                calls
165                    .lock()
166                    .unwrap()
167                    .push((premise.to_string(), hypothesis.to_string()));
168                resolver(premise, hypothesis)
169            }
170        }
171    }
172}