Skip to main content

batuta/oracle/rag/
validator.rs

1//! Jidoka Index Validator - Stop-on-Error Guarantees
2//!
3//! Implements Toyota Way Jidoka (自働化) principle for automatic defect detection.
4//! Addresses common failure points in RAG engineering per Barnett et al. (2024).
5
6use super::types::JidokaHalt;
7use std::collections::HashMap;
8
9/// Jidoka index validator for stop-on-error guarantees
10///
11/// Validates:
12/// - Embedding dimensions match model
13/// - No NaN/Inf in embeddings (Poka-Yoke)
14/// - Document hashes match content (integrity)
15#[derive(Debug)]
16pub struct JidokaIndexValidator {
17    /// Expected embedding dimensions
18    expected_dims: usize,
19    /// Model hash for verification
20    model_hash: Option<[u8; 32]>,
21    /// Validation statistics
22    stats: ValidationStats,
23}
24
25/// Validation statistics
26#[derive(Debug, Default, Clone)]
27pub struct ValidationStats {
28    /// Total validations performed
29    pub total_validations: u64,
30    /// Successful validations
31    pub successful: u64,
32    /// Failed validations
33    pub failed: u64,
34    /// Halts triggered
35    pub halts: u64,
36}
37
38impl JidokaIndexValidator {
39    /// Create a new validator with expected embedding dimensions
40    pub fn new(expected_dims: usize) -> Self {
41        Self { expected_dims, model_hash: None, stats: ValidationStats::default() }
42    }
43
44    /// Set expected model hash
45    pub fn with_model_hash(mut self, hash: [u8; 32]) -> Self {
46        self.model_hash = Some(hash);
47        self
48    }
49
50    /// Run a validation check, updating stats based on the result.
51    ///
52    /// Increments `total_validations` unconditionally, then `failed`+`halts`
53    /// on `Err` or `successful` on `Ok`.
54    fn run_check(
55        &mut self,
56        check: impl FnOnce() -> Result<(), JidokaHalt>,
57    ) -> Result<(), JidokaHalt> {
58        self.stats.total_validations += 1;
59        match check() {
60            Ok(()) => {
61                self.stats.successful += 1;
62                Ok(())
63            }
64            Err(e) => {
65                self.stats.failed += 1;
66                self.stats.halts += 1;
67                Err(e)
68            }
69        }
70    }
71
72    /// Validate an embedding vector
73    pub fn validate_embedding(
74        &mut self,
75        doc_id: &str,
76        embedding: &[f32],
77    ) -> Result<(), JidokaHalt> {
78        let expected_dims = self.expected_dims;
79        self.run_check(|| {
80            // Check dimensions
81            if embedding.len() != expected_dims {
82                return Err(JidokaHalt::DimensionMismatch {
83                    expected: expected_dims,
84                    actual: embedding.len(),
85                });
86            }
87
88            // Check for NaN/Inf (Poka-Yoke)
89            if embedding.iter().any(|v| v.is_nan() || v.is_infinite()) {
90                return Err(JidokaHalt::CorruptedEmbedding { doc_id: doc_id.to_string() });
91            }
92
93            Ok(())
94        })
95    }
96
97    /// Validate document content integrity
98    pub fn validate_integrity(
99        &mut self,
100        doc_id: &str,
101        content: &[u8],
102        stored_hash: [u8; 32],
103    ) -> Result<(), JidokaHalt> {
104        self.run_check(|| {
105            let computed_hash = compute_hash(content);
106            if computed_hash != stored_hash {
107                return Err(JidokaHalt::IntegrityViolation { doc_id: doc_id.to_string() });
108            }
109            Ok(())
110        })
111    }
112
113    /// Validate model hash matches expected
114    pub fn validate_model_hash(&mut self, actual_hash: [u8; 32]) -> Result<(), JidokaHalt> {
115        let model_hash = self.model_hash;
116        self.run_check(|| {
117            if let Some(expected) = model_hash {
118                if expected != actual_hash {
119                    return Err(JidokaHalt::ModelMismatch {
120                        expected: hex_encode(&expected),
121                        actual: hex_encode(&actual_hash),
122                    });
123                }
124            }
125            Ok(())
126        })
127    }
128
129    /// Validate a batch of embeddings
130    pub fn validate_batch(
131        &mut self,
132        embeddings: &HashMap<String, Vec<f32>>,
133    ) -> Result<(), JidokaHalt> {
134        for (doc_id, embedding) in embeddings {
135            self.validate_embedding(doc_id, embedding)?;
136        }
137        Ok(())
138    }
139
140    /// Get validation statistics
141    pub fn stats(&self) -> &ValidationStats {
142        &self.stats
143    }
144
145    /// Reset statistics
146    pub fn reset_stats(&mut self) {
147        self.stats = ValidationStats::default();
148    }
149
150    /// Get expected dimensions
151    pub fn expected_dims(&self) -> usize {
152        self.expected_dims
153    }
154}
155
156/// Fallback strategy when Jidoka halts occur
157#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
158pub enum FallbackStrategy {
159    /// Serve from last validated index
160    #[default]
161    LastKnownGood,
162    /// Serve from in-memory cache
163    CacheOnly,
164    /// Return "index unavailable" error
165    Unavailable,
166}
167
168/// Jidoka halt handler
169#[derive(Debug)]
170pub struct JidokaHaltHandler {
171    /// Fallback strategy
172    strategy: FallbackStrategy,
173    /// Halt history for debugging
174    halt_history: Vec<HaltRecord>,
175    /// Maximum history size
176    max_history: usize,
177}
178
179/// Record of a Jidoka halt
180#[derive(Debug, Clone)]
181pub struct HaltRecord {
182    /// Timestamp (Unix epoch ms)
183    pub timestamp_ms: u64,
184    /// Halt reason
185    pub halt: JidokaHalt,
186    /// Recovery action taken
187    pub recovery_action: String,
188}
189
190impl JidokaHaltHandler {
191    /// Create a new halt handler
192    pub fn new(strategy: FallbackStrategy) -> Self {
193        Self { strategy, halt_history: Vec::new(), max_history: 100 }
194    }
195
196    /// Handle a Jidoka halt
197    pub fn handle_halt(&mut self, halt: JidokaHalt) -> FallbackStrategy {
198        let timestamp_ms = std::time::SystemTime::now()
199            .duration_since(std::time::UNIX_EPOCH)
200            .map(|d| d.as_millis() as u64)
201            .unwrap_or(0);
202
203        let recovery_action = match self.strategy {
204            FallbackStrategy::LastKnownGood => "Rolling back to last validated index".to_string(),
205            FallbackStrategy::CacheOnly => "Serving from in-memory cache".to_string(),
206            FallbackStrategy::Unavailable => "Index marked unavailable".to_string(),
207        };
208
209        self.halt_history.push(HaltRecord { timestamp_ms, halt, recovery_action });
210
211        // Trim history
212        if self.halt_history.len() > self.max_history {
213            self.halt_history.remove(0);
214        }
215
216        self.strategy
217    }
218
219    /// Get recent halts
220    pub fn recent_halts(&self, count: usize) -> &[HaltRecord] {
221        let start = self.halt_history.len().saturating_sub(count);
222        &self.halt_history[start..]
223    }
224
225    /// Get halt count
226    pub fn halt_count(&self) -> usize {
227        self.halt_history.len()
228    }
229
230    /// Clear history
231    pub fn clear_history(&mut self) {
232        self.halt_history.clear();
233    }
234}
235
236impl Default for JidokaHaltHandler {
237    fn default() -> Self {
238        Self::new(FallbackStrategy::default())
239    }
240}
241
242/// Compute hash for content (same algorithm as fingerprint)
243fn compute_hash(data: &[u8]) -> [u8; 32] {
244    let mut hash = [0u8; 32];
245    let mut state: u64 = 0xcbf2_9ce4_8422_2325;
246    for &byte in data {
247        state ^= byte as u64;
248        state = state.wrapping_mul(0x0100_0000_01b3);
249    }
250    for i in 0..4 {
251        let chunk = state.wrapping_add(i as u64).to_le_bytes();
252        hash[i * 8..(i + 1) * 8].copy_from_slice(&chunk);
253    }
254    hash
255}
256
257/// Encode hash as hex string
258fn hex_encode(hash: &[u8; 32]) -> String {
259    hash.iter().map(|b| format!("{:02x}", b)).collect()
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    /// Standard test dimension used across unit tests.
267    const TEST_DIM: usize = 4;
268
269    /// A valid embedding of `TEST_DIM` elements.
270    const VALID_EMBEDDING: [f32; TEST_DIM] = [0.1, 0.2, 0.3, 0.4];
271
272    /// Create a fresh validator with `TEST_DIM` expected dimensions.
273    fn test_validator() -> JidokaIndexValidator {
274        JidokaIndexValidator::new(TEST_DIM)
275    }
276
277    /// Build a `TEST_DIM`-length embedding with a single poisoned value at position 1.
278    fn poisoned_embedding(bad_value: f32) -> Vec<f32> {
279        let mut v = VALID_EMBEDDING.to_vec();
280        v[1] = bad_value;
281        v
282    }
283
284    /// Assert that validating an embedding with a single non-finite value
285    /// at position 1 yields `CorruptedEmbedding`.
286    fn assert_corrupted_embedding(bad_value: f32) {
287        let mut validator = test_validator();
288        let embedding = poisoned_embedding(bad_value);
289        let result = validator.validate_embedding("doc1", &embedding);
290        assert!(matches!(result, Err(JidokaHalt::CorruptedEmbedding { .. })));
291    }
292
293    #[test]
294    fn test_validator_creation() {
295        let validator = JidokaIndexValidator::new(384);
296        assert_eq!(validator.expected_dims(), 384);
297    }
298
299    #[test]
300    fn test_validate_correct_embedding() {
301        let mut validator = test_validator();
302        let result = validator.validate_embedding("doc1", &VALID_EMBEDDING);
303        assert!(result.is_ok());
304        assert_eq!(validator.stats().successful, 1);
305    }
306
307    #[test]
308    fn test_validate_wrong_dimensions() {
309        let mut validator = test_validator();
310        let embedding = vec![0.1, 0.2]; // Wrong size
311
312        let result = validator.validate_embedding("doc1", &embedding);
313        assert!(matches!(result, Err(JidokaHalt::DimensionMismatch { expected: 4, actual: 2 })));
314        assert_eq!(validator.stats().halts, 1);
315    }
316
317    #[test]
318    fn test_validate_nan_embedding() {
319        assert_corrupted_embedding(f32::NAN);
320    }
321
322    #[test]
323    fn test_validate_inf_embedding() {
324        assert_corrupted_embedding(f32::INFINITY);
325    }
326
327    #[test]
328    fn test_validate_neg_inf_embedding() {
329        assert_corrupted_embedding(f32::NEG_INFINITY);
330    }
331
332    #[test]
333    fn test_validate_integrity_correct() {
334        let mut validator = test_validator();
335        let content = b"test content";
336        let hash = compute_hash(content);
337
338        let result = validator.validate_integrity("doc1", content, hash);
339        assert!(result.is_ok());
340    }
341
342    #[test]
343    fn test_validate_integrity_mismatch() {
344        let mut validator = test_validator();
345        let content = b"test content";
346        let wrong_hash = [0u8; 32];
347
348        let result = validator.validate_integrity("doc1", content, wrong_hash);
349        assert!(matches!(result, Err(JidokaHalt::IntegrityViolation { .. })));
350    }
351
352    #[test]
353    fn test_validate_model_hash() {
354        let expected_hash = [1u8; 32];
355        let mut validator = test_validator().with_model_hash(expected_hash);
356
357        // Correct hash
358        let result = validator.validate_model_hash(expected_hash);
359        assert!(result.is_ok());
360
361        // Wrong hash
362        let result = validator.validate_model_hash([2u8; 32]);
363        assert!(matches!(result, Err(JidokaHalt::ModelMismatch { .. })));
364    }
365
366    #[test]
367    fn test_validate_batch() {
368        let mut validator = test_validator();
369        let mut embeddings = HashMap::new();
370        embeddings.insert("doc1".to_string(), VALID_EMBEDDING.to_vec());
371        embeddings.insert("doc2".to_string(), vec![0.5, 0.6, 0.7, 0.8]);
372
373        let result = validator.validate_batch(&embeddings);
374        assert!(result.is_ok());
375        assert_eq!(validator.stats().successful, 2);
376    }
377
378    #[test]
379    fn test_validate_batch_with_error() {
380        let mut validator = test_validator();
381        let mut embeddings = HashMap::new();
382        embeddings.insert("doc1".to_string(), VALID_EMBEDDING.to_vec());
383        embeddings.insert("doc2".to_string(), poisoned_embedding(f32::NAN));
384
385        let result = validator.validate_batch(&embeddings);
386        assert!(result.is_err());
387    }
388
389    #[test]
390    fn test_halt_handler() {
391        let mut handler = JidokaHaltHandler::new(FallbackStrategy::LastKnownGood);
392
393        let halt = JidokaHalt::CorruptedEmbedding { doc_id: "doc1".to_string() };
394        let strategy = handler.handle_halt(halt);
395
396        assert_eq!(strategy, FallbackStrategy::LastKnownGood);
397        assert_eq!(handler.halt_count(), 1);
398    }
399
400    #[test]
401    fn test_halt_handler_history() {
402        let mut handler = JidokaHaltHandler::new(FallbackStrategy::CacheOnly);
403
404        for i in 0..5 {
405            handler.handle_halt(JidokaHalt::CorruptedEmbedding { doc_id: format!("doc{}", i) });
406        }
407
408        let recent = handler.recent_halts(3);
409        assert_eq!(recent.len(), 3);
410    }
411
412    #[test]
413    fn test_fallback_strategy_default() {
414        assert_eq!(FallbackStrategy::default(), FallbackStrategy::LastKnownGood);
415    }
416
417    #[test]
418    fn test_reset_stats() {
419        let mut validator = test_validator();
420        validator.validate_embedding("doc1", &VALID_EMBEDDING).expect("unexpected failure");
421
422        assert_eq!(validator.stats().successful, 1);
423
424        validator.reset_stats();
425        assert_eq!(validator.stats().successful, 0);
426    }
427
428    #[test]
429    fn test_hex_encode() {
430        let hash = [
431            0x12, 0x34, 0xab, 0xcd, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
432            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
433            0x00, 0x00, 0x00, 0x00,
434        ];
435        let hex = hex_encode(&hash);
436        assert!(hex.starts_with("1234abcd00ff"));
437    }
438
439    // Property-based tests for Jidoka validator
440    mod proptests {
441        use super::*;
442        use proptest::prelude::*;
443
444        /// Build a sequential embedding of the given dimension and inject a
445        /// poison value at position `pos % dim`.
446        fn embedding_with_poison(dim: usize, pos: usize, poison: f32) -> Vec<f32> {
447            let mut v: Vec<f32> = (0..dim).map(|i| i as f32 / 100.0).collect();
448            v[pos % dim] = poison;
449            v
450        }
451
452        /// Assert that a poisoned embedding always fails validation.
453        fn assert_poison_fails(
454            dim: usize,
455            pos: usize,
456            poison: f32,
457        ) -> Result<(), proptest::test_runner::TestCaseError> {
458            let mut validator = JidokaIndexValidator::new(dim);
459            let embedding = embedding_with_poison(dim, pos, poison);
460            let result = validator.validate_embedding("test_doc", &embedding);
461            prop_assert!(result.is_err());
462            Ok(())
463        }
464
465        proptest! {
466            #![proptest_config(ProptestConfig::with_cases(50))]
467
468            /// Property: Valid embeddings always pass validation
469            #[test]
470            fn prop_valid_embeddings_pass(
471                values in prop::collection::vec(-1.0f32..1.0, 4..64)
472            ) {
473                let mut validator = JidokaIndexValidator::new(values.len());
474                let result = validator.validate_embedding("test_doc", &values);
475                prop_assert!(result.is_ok());
476            }
477
478            /// Property: Wrong dimension always fails
479            #[test]
480            fn prop_wrong_dim_fails(
481                expected_dim in 64usize..128,
482                actual_dim in 1usize..32
483            ) {
484                let mut validator = JidokaIndexValidator::new(expected_dim);
485                let embedding: Vec<f32> = (0..actual_dim).map(|i| i as f32 / 100.0).collect();
486                let result = validator.validate_embedding("test_doc", &embedding);
487                prop_assert!(result.is_err());
488            }
489
490            /// Property: NaN values always fail
491            #[test]
492            fn prop_nan_fails(dim in 4usize..64, nan_pos in 0usize..4) {
493                assert_poison_fails(dim, nan_pos, f32::NAN)?;
494            }
495
496            /// Property: Infinite values always fail
497            #[test]
498            fn prop_inf_fails(dim in 4usize..64, inf_pos in 0usize..4) {
499                assert_poison_fails(dim, inf_pos, f32::INFINITY)?;
500            }
501
502            /// Property: Stats correctly count validations
503            #[test]
504            fn prop_stats_count_validations(
505                valid_count in 0u64..10,
506                invalid_count in 0u64..10
507            ) {
508                let mut validator = JidokaIndexValidator::new(TEST_DIM);
509
510                for i in 0..valid_count {
511                    validator.validate_embedding(&format!("valid_{}", i), &VALID_EMBEDDING).ok();
512                }
513                for i in 0..invalid_count {
514                    validator.validate_embedding(&format!("invalid_{}", i), &[0.1]).ok();
515                }
516
517                let stats = validator.stats();
518                prop_assert_eq!(stats.total_validations, valid_count + invalid_count);
519                prop_assert_eq!(stats.successful, valid_count);
520                prop_assert_eq!(stats.failed, invalid_count);
521            }
522
523            /// Property: Reset clears all stats
524            #[test]
525            fn prop_reset_clears_stats(count in 1u64..20) {
526                let mut validator = JidokaIndexValidator::new(TEST_DIM);
527
528                for i in 0..count {
529                    validator.validate_embedding(&format!("doc_{}", i), &VALID_EMBEDDING).ok();
530                }
531
532                validator.reset_stats();
533                let stats = validator.stats();
534                prop_assert_eq!(stats.total_validations, 0);
535                prop_assert_eq!(stats.successful, 0);
536                prop_assert_eq!(stats.failed, 0);
537            }
538        }
539    }
540}