chie_crypto/
formal_verify.rs

1//! Formal Verification Helpers
2//!
3//! This module provides utilities and helpers to support formal verification of
4//! cryptographic implementations, including property-based testing, invariant checking,
5//! and verification condition generation.
6//!
7//! # Features
8//!
9//! - **Property-based testing**: Automatic generation of test cases
10//! - **Invariant checking**: Runtime verification of cryptographic properties
11//! - **Pre/post-condition checking**: Function contract verification
12//! - **State machine verification**: Verify state transitions are valid
13//! - **Symbolic execution helpers**: Support for symbolic analysis
14//! - **Proof obligations**: Generate verification conditions
15//!
16//! # Example
17//!
18//! ```
19//! use chie_crypto::formal_verify::{Invariant, PropertyChecker, check_invariant};
20//!
21//! // Define an invariant
22//! let inv = Invariant::new("key_length", |state: &[u8]| {
23//!     state.len() == 32
24//! });
25//!
26//! // Check the invariant
27//! let key = [0u8; 32];
28//! assert!(inv.check(&key));
29//!
30//! // Property-based testing
31//! let mut checker = PropertyChecker::new();
32//! checker.add_property("encryption_decryption_roundtrip", |data: &[u8]| {
33//!     // Property: encrypt(decrypt(x)) == x
34//!     true // simplified example
35//! });
36//! ```
37
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40
41/// Type alias for post-condition predicates
42type PostConditionFn<T, U> = Box<dyn Fn(&T, &U) -> bool>;
43
44/// Type alias for property check functions
45type PropertyFn = Box<dyn Fn(&[u8]) -> bool>;
46
47/// Type alias for state machine transition functions
48type TransitionFn<S> = Box<dyn Fn(&S, &str) -> Option<S>>;
49
50/// Invariant that must hold for a cryptographic operation
51pub struct Invariant<T: ?Sized> {
52    /// Name of the invariant
53    name: String,
54    /// Predicate function that checks the invariant
55    predicate: Box<dyn Fn(&T) -> bool>,
56}
57
58impl<T: ?Sized> Invariant<T> {
59    /// Create a new invariant
60    pub fn new<F>(name: &str, predicate: F) -> Self
61    where
62        F: Fn(&T) -> bool + 'static,
63    {
64        Self {
65            name: name.to_string(),
66            predicate: Box::new(predicate),
67        }
68    }
69
70    /// Check if the invariant holds for the given state
71    pub fn check(&self, state: &T) -> bool {
72        (self.predicate)(state)
73    }
74
75    /// Get invariant name
76    pub fn name(&self) -> &str {
77        &self.name
78    }
79}
80
81/// Pre-condition for a function
82pub struct PreCondition<T: ?Sized> {
83    /// Name of the pre-condition
84    name: String,
85    /// Predicate that must be true before function execution
86    predicate: Box<dyn Fn(&T) -> bool>,
87}
88
89impl<T: ?Sized> PreCondition<T> {
90    /// Create a new pre-condition
91    pub fn new<F>(name: &str, predicate: F) -> Self
92    where
93        F: Fn(&T) -> bool + 'static,
94    {
95        Self {
96            name: name.to_string(),
97            predicate: Box::new(predicate),
98        }
99    }
100
101    /// Check if the pre-condition holds
102    pub fn check(&self, input: &T) -> bool {
103        (self.predicate)(input)
104    }
105
106    /// Get pre-condition name
107    pub fn name(&self) -> &str {
108        &self.name
109    }
110}
111
112/// Post-condition for a function
113pub struct PostCondition<T: ?Sized, U: ?Sized> {
114    /// Name of the post-condition
115    name: String,
116    /// Predicate that must be true after function execution
117    predicate: PostConditionFn<T, U>,
118}
119
120impl<T: ?Sized, U: ?Sized> PostCondition<T, U> {
121    /// Create a new post-condition
122    pub fn new<F>(name: &str, predicate: F) -> Self
123    where
124        F: Fn(&T, &U) -> bool + 'static,
125    {
126        Self {
127            name: name.to_string(),
128            predicate: Box::new(predicate),
129        }
130    }
131
132    /// Check if the post-condition holds
133    pub fn check(&self, input: &T, output: &U) -> bool {
134        (self.predicate)(input, output)
135    }
136
137    /// Get post-condition name
138    pub fn name(&self) -> &str {
139        &self.name
140    }
141}
142
143/// Property-based test checker
144pub struct PropertyChecker {
145    /// Properties to check
146    properties: HashMap<String, PropertyFn>,
147    /// Number of test cases per property
148    num_cases: usize,
149}
150
151impl Default for PropertyChecker {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl PropertyChecker {
158    /// Create a new property checker
159    pub fn new() -> Self {
160        Self {
161            properties: HashMap::new(),
162            num_cases: 100,
163        }
164    }
165
166    /// Set number of test cases
167    pub fn with_num_cases(mut self, num: usize) -> Self {
168        self.num_cases = num;
169        self
170    }
171
172    /// Add a property to check
173    pub fn add_property<F>(&mut self, name: &str, property: F)
174    where
175        F: Fn(&[u8]) -> bool + 'static,
176    {
177        self.properties.insert(name.to_string(), Box::new(property));
178    }
179
180    /// Check all properties with random inputs
181    pub fn check_all(&self) -> PropertyCheckResult {
182        use rand::RngCore;
183        let mut rng = rand::thread_rng();
184        let mut results = HashMap::new();
185
186        for (name, property) in &self.properties {
187            let mut passed = 0;
188            let mut failed = 0;
189
190            for _ in 0..self.num_cases {
191                // Generate random input
192                let mut data = vec![0u8; 32];
193                rng.fill_bytes(&mut data);
194
195                if property(&data) {
196                    passed += 1;
197                } else {
198                    failed += 1;
199                }
200            }
201
202            results.insert(
203                name.clone(),
204                PropertyResult {
205                    passed,
206                    failed,
207                    total: self.num_cases,
208                },
209            );
210        }
211
212        PropertyCheckResult { results }
213    }
214
215    /// Check a specific property
216    pub fn check_property(&self, name: &str) -> Option<PropertyResult> {
217        use rand::RngCore;
218        let property = self.properties.get(name)?;
219        let mut rng = rand::thread_rng();
220
221        let mut passed = 0;
222        let mut failed = 0;
223
224        for _ in 0..self.num_cases {
225            let mut data = vec![0u8; 32];
226            rng.fill_bytes(&mut data);
227
228            if property(&data) {
229                passed += 1;
230            } else {
231                failed += 1;
232            }
233        }
234
235        Some(PropertyResult {
236            passed,
237            failed,
238            total: self.num_cases,
239        })
240    }
241}
242
243/// Result of property checking
244#[derive(Debug, Clone)]
245pub struct PropertyCheckResult {
246    /// Results for each property
247    pub results: HashMap<String, PropertyResult>,
248}
249
250impl PropertyCheckResult {
251    /// Check if all properties passed
252    pub fn all_passed(&self) -> bool {
253        self.results.values().all(|r| r.failed == 0)
254    }
255
256    /// Get failed properties
257    pub fn failed_properties(&self) -> Vec<String> {
258        self.results
259            .iter()
260            .filter(|(_, r)| r.failed > 0)
261            .map(|(name, _)| name.clone())
262            .collect()
263    }
264}
265
266/// Result for a single property
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct PropertyResult {
269    /// Number of test cases that passed
270    pub passed: usize,
271    /// Number of test cases that failed
272    pub failed: usize,
273    /// Total number of test cases
274    pub total: usize,
275}
276
277impl PropertyResult {
278    /// Get success rate (0.0 - 1.0)
279    pub fn success_rate(&self) -> f64 {
280        if self.total == 0 {
281            return 0.0;
282        }
283        self.passed as f64 / self.total as f64
284    }
285
286    /// Check if all tests passed
287    pub fn all_passed(&self) -> bool {
288        self.failed == 0
289    }
290}
291
292/// State machine verifier
293pub struct StateMachine<S> {
294    /// Current state
295    current_state: S,
296    /// Valid transitions
297    transitions: Vec<TransitionFn<S>>,
298    /// State invariants
299    invariants: Vec<Invariant<S>>,
300}
301
302impl<S: Clone> StateMachine<S> {
303    /// Create a new state machine with initial state
304    pub fn new(initial_state: S) -> Self {
305        Self {
306            current_state: initial_state,
307            transitions: Vec::new(),
308            invariants: Vec::new(),
309        }
310    }
311
312    /// Add a transition function
313    pub fn add_transition<F>(&mut self, transition: F)
314    where
315        F: Fn(&S, &str) -> Option<S> + 'static,
316    {
317        self.transitions.push(Box::new(transition));
318    }
319
320    /// Add a state invariant
321    pub fn add_invariant(&mut self, invariant: Invariant<S>) {
322        self.invariants.push(invariant);
323    }
324
325    /// Check if current state satisfies all invariants
326    pub fn check_invariants(&self) -> Vec<String> {
327        self.invariants
328            .iter()
329            .filter(|inv| !inv.check(&self.current_state))
330            .map(|inv| inv.name().to_string())
331            .collect()
332    }
333
334    /// Attempt a transition
335    pub fn transition(&mut self, event: &str) -> Result<(), String> {
336        // Try each transition function
337        for trans in &self.transitions {
338            if let Some(new_state) = trans(&self.current_state, event) {
339                // Check invariants before transitioning
340                let old_state = self.current_state.clone();
341                self.current_state = new_state;
342
343                let violations = self.check_invariants();
344                if !violations.is_empty() {
345                    // Rollback if invariants violated
346                    self.current_state = old_state;
347                    return Err(format!("Invariant violations: {:?}", violations));
348                }
349
350                return Ok(());
351            }
352        }
353
354        Err(format!("No valid transition for event: {}", event))
355    }
356
357    /// Get current state
358    pub fn current_state(&self) -> &S {
359        &self.current_state
360    }
361}
362
363/// Verification condition
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct VerificationCondition {
366    /// Name of the condition
367    pub name: String,
368    /// Description
369    pub description: String,
370    /// Formula (in informal notation)
371    pub formula: String,
372}
373
374impl VerificationCondition {
375    /// Create a new verification condition
376    pub fn new(name: &str, description: &str, formula: &str) -> Self {
377        Self {
378            name: name.to_string(),
379            description: description.to_string(),
380            formula: formula.to_string(),
381        }
382    }
383}
384
385/// Helper function to check an invariant and panic if it doesn't hold
386pub fn check_invariant<T: ?Sized>(name: &str, state: &T, predicate: impl Fn(&T) -> bool) {
387    if !predicate(state) {
388        panic!("Invariant '{}' violated", name);
389    }
390}
391
392/// Helper function to check a pre-condition
393pub fn check_precondition<T: ?Sized>(name: &str, input: &T, predicate: impl Fn(&T) -> bool) {
394    if !predicate(input) {
395        panic!("Pre-condition '{}' violated", name);
396    }
397}
398
399/// Helper function to check a post-condition
400pub fn check_postcondition<T: ?Sized, U: ?Sized>(
401    name: &str,
402    input: &T,
403    output: &U,
404    predicate: impl Fn(&T, &U) -> bool,
405) {
406    if !predicate(input, output) {
407        panic!("Post-condition '{}' violated", name);
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_invariant_creation() {
417        let inv = Invariant::new("test", |x: &i32| *x > 0);
418        assert_eq!(inv.name(), "test");
419        assert!(inv.check(&5));
420        assert!(!inv.check(&-5));
421    }
422
423    #[test]
424    fn test_invariant_key_length() {
425        let inv = Invariant::new("key_length_32", |key: &[u8]| key.len() == 32);
426        assert!(inv.check(&[0u8; 32]));
427        assert!(!inv.check(&[0u8; 16]));
428    }
429
430    #[test]
431    fn test_precondition() {
432        let pre = PreCondition::new("non_empty", |data: &[u8]| !data.is_empty());
433        assert!(pre.check(&[1, 2, 3]));
434        assert!(!pre.check(&[]));
435    }
436
437    #[test]
438    fn test_postcondition() {
439        let post = PostCondition::new("output_not_empty", |_input: &[u8], output: &[u8]| {
440            !output.is_empty()
441        });
442        assert!(post.check(&[1, 2], &[3, 4]));
443        assert!(!post.check(&[1, 2], &[]));
444    }
445
446    #[test]
447    fn test_property_checker() {
448        let mut checker = PropertyChecker::new().with_num_cases(10);
449        checker.add_property("always_true", |_| true);
450        checker.add_property("always_false", |_| false);
451
452        let results = checker.check_all();
453        assert!(results.results["always_true"].all_passed());
454        assert!(!results.results["always_false"].all_passed());
455        assert!(!results.all_passed());
456    }
457
458    #[test]
459    fn test_property_result_success_rate() {
460        let result = PropertyResult {
461            passed: 75,
462            failed: 25,
463            total: 100,
464        };
465        assert_eq!(result.success_rate(), 0.75);
466    }
467
468    #[test]
469    fn test_property_checker_single_property() {
470        let mut checker = PropertyChecker::new().with_num_cases(20);
471        checker.add_property("test_prop", |_| true);
472
473        let result = checker.check_property("test_prop").unwrap();
474        assert_eq!(result.passed, 20);
475        assert_eq!(result.failed, 0);
476        assert!(result.all_passed());
477    }
478
479    #[test]
480    fn test_state_machine_basic() {
481        #[derive(Clone, PartialEq, Debug)]
482        enum State {
483            Init,
484            Ready,
485            Running,
486        }
487
488        let mut sm = StateMachine::new(State::Init);
489
490        // Add transition: Init -> Ready on "start"
491        sm.add_transition(|state, event| match (state, event) {
492            (State::Init, "start") => Some(State::Ready),
493            (State::Ready, "run") => Some(State::Running),
494            _ => None,
495        });
496
497        // Transition to Ready
498        assert!(sm.transition("start").is_ok());
499        assert_eq!(*sm.current_state(), State::Ready);
500
501        // Transition to Running
502        assert!(sm.transition("run").is_ok());
503        assert_eq!(*sm.current_state(), State::Running);
504
505        // Invalid transition
506        assert!(sm.transition("start").is_err());
507    }
508
509    #[test]
510    fn test_state_machine_with_invariant() {
511        let mut sm = StateMachine::new(0i32);
512
513        // Add invariant: state must be non-negative
514        sm.add_invariant(Invariant::new("non_negative", |s: &i32| *s >= 0));
515
516        // Add transition that increments state
517        sm.add_transition(|state, event| {
518            if event == "increment" {
519                Some(state + 1)
520            } else {
521                None
522            }
523        });
524
525        // Valid transition
526        assert!(sm.transition("increment").is_ok());
527        assert_eq!(*sm.current_state(), 1);
528        assert!(sm.check_invariants().is_empty());
529    }
530
531    #[test]
532    fn test_state_machine_invariant_violation() {
533        let mut sm = StateMachine::new(5i32);
534
535        // Add invariant: state must be <= 10
536        sm.add_invariant(Invariant::new("max_10", |s: &i32| *s <= 10));
537
538        // Add transition that adds 10
539        sm.add_transition(|state, event| {
540            if event == "add_10" {
541                Some(state + 10)
542            } else {
543                None
544            }
545        });
546
547        // This transition would violate the invariant
548        assert!(sm.transition("add_10").is_err());
549        // State should be unchanged
550        assert_eq!(*sm.current_state(), 5);
551    }
552
553    #[test]
554    fn test_check_invariant_helper() {
555        let state = vec![1, 2, 3];
556        check_invariant("non_empty", &state, |s| !s.is_empty());
557    }
558
559    #[test]
560    #[should_panic(expected = "Invariant 'empty' violated")]
561    fn test_check_invariant_helper_panic() {
562        let state = vec![1, 2, 3];
563        check_invariant("empty", &state, |s| s.is_empty());
564    }
565
566    #[test]
567    fn test_verification_condition() {
568        let vc = VerificationCondition::new(
569            "encryption_correctness",
570            "Decryption of encrypted data returns original",
571            "forall m, k: decrypt(encrypt(m, k), k) = m",
572        );
573
574        assert_eq!(vc.name, "encryption_correctness");
575        assert!(vc.formula.contains("decrypt"));
576    }
577
578    #[test]
579    fn test_failed_properties() {
580        let mut checker = PropertyChecker::new().with_num_cases(10);
581        checker.add_property("pass1", |_| true);
582        checker.add_property("fail1", |_| false);
583        checker.add_property("pass2", |_| true);
584
585        let results = checker.check_all();
586        let failed = results.failed_properties();
587        assert_eq!(failed.len(), 1);
588        assert!(failed.contains(&"fail1".to_string()));
589    }
590}