Skip to main content

spine_crypto/
lib.rs

1// Allow dead code for cryptographic API surface
2#![allow(dead_code)]
3
4//! # SPINE Crypto
5//!
6//! Advanced cryptographic primitives for the SPINE stack including:
7//! - **Titans-based message prediction** for speculative decoding with Neural Long-Term Memory
8//! - **MIRAS-adaptive prediction** with automatic variant switching
9//! - Post-quantum key evolution using lattice-based cryptography concepts
10//!
11//! ## Titans Predictor (Neural Long-Term Memory)
12//!
13//! Uses the **Titans architecture** with test-time training for unbounded context
14//! message prediction, enabling speculative decoding where receivers can pre-compute
15//! responses. Unlike standard Transformers, Titans maintains persistent memory that
16//! survives across sequences through surprise-gated updates.
17//!
18//! Key advantages over Transformers:
19//! - **Unbounded context**: Memory persists indefinitely via consolidation
20//! - **Test-time training**: Adapts to message patterns during inference
21//! - **Surprise detection**: Identifies anomalous messages for security
22//! - **Memory efficiency**: O(1) memory vs O(n²) for attention
23//!
24//! ## MIRAS-Adaptive Prediction
25//!
26//! Integrates the MIRAS framework for continual learning:
27//! - **YAAD**: Yield-Adaptive Anomaly Detection for outlier-robust prediction
28//! - **MONETA**: Memory-Optimized Network for stable long-running sessions
29//! - **MEMORA**: Balanced updates for mixed traffic patterns
30//!
31//! The predictor automatically switches between variants based on surprise levels.
32//!
33//! ## Quantum-Resistant Key Evolution
34//!
35//! Implements NTRU-inspired lattice operations for key evolution that
36//! resists quantum computing attacks (Shor's algorithm).
37
38use rand::prelude::*;
39use rand::rngs::StdRng;
40use serde::{Deserialize, Serialize};
41use sha2::{Digest, Sha256};
42use zeroize::{Zeroize, ZeroizeOnDrop};
43use spine_neural::{
44    Activation, DenseLayer, MirasNeuralEncoder, MirasVariant, MultiHeadAttention,
45    NeuralEncoderConfig, TitansMemory,
46};
47use std::collections::VecDeque;
48use subtle::ConstantTimeEq;
49
50// ML-KEM (FIPS 203) post-quantum KEM
51use ml_kem::kem::{Decapsulate, Encapsulate, EncapsulationKey, DecapsulationKey};
52use ml_kem::{MlKem512, MlKem768, MlKem1024, KemCore, EncodedSizeUser, Encoded,
53    MlKem512Params, MlKem768Params, MlKem1024Params};
54
55// AES-256-GCM for authenticated encryption (replaces XOR)
56use aes_gcm::aead::Aead;
57use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
58// HKDF for proper key derivation
59use hkdf::Hkdf;
60
61// ============================================================================
62// TITANS-BASED MESSAGE PREDICTOR (Neural Long-Term Memory)
63// ============================================================================
64
65/// Positional encoding for transformer sequence modeling
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PositionalEncoding {
68    max_len: usize,
69    embed_dim: usize,
70    encodings: Vec<Vec<f32>>,
71}
72
73impl PositionalEncoding {
74    pub fn new(max_len: usize, embed_dim: usize) -> Self {
75        // Use enumerate patterns for better clippy compliance
76        let encodings: Vec<Vec<f32>> = (0..max_len)
77            .map(|pos| {
78                (0..embed_dim)
79                    .map(|i| {
80                        let angle = pos as f32
81                            / (10000.0_f32).powf(2.0 * (i / 2) as f32 / embed_dim as f32);
82                        if i % 2 == 0 {
83                            angle.sin()
84                        } else {
85                            angle.cos()
86                        }
87                    })
88                    .collect()
89            })
90            .collect();
91
92        Self {
93            max_len,
94            embed_dim,
95            encodings,
96        }
97    }
98
99    pub fn get(&self, position: usize) -> &[f32] {
100        // Handle edge case: if max_len is 0, return empty slice (though this shouldn't happen)
101        if self.encodings.is_empty() {
102            return &[];
103        }
104        &self.encodings[position.min(self.max_len.saturating_sub(1))]
105    }
106}
107
108/// Layer normalization
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct LayerNorm {
111    dim: usize,
112    gamma: Vec<f32>,
113    beta: Vec<f32>,
114    eps: f32,
115}
116
117impl LayerNorm {
118    pub fn new(dim: usize) -> Self {
119        Self {
120            dim,
121            gamma: vec![1.0; dim],
122            beta: vec![0.0; dim],
123            eps: 1e-5,
124        }
125    }
126
127    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
128        // Handle empty input to prevent division by zero
129        if x.is_empty() {
130            return Vec::new();
131        }
132        let n = x.len() as f32;
133        let mean: f32 = x.iter().sum::<f32>() / n;
134        let var: f32 = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
135        let std = (var + self.eps).sqrt();
136
137        x.iter()
138            .enumerate()
139            .map(|(i, &v)| self.gamma[i % self.dim] * (v - mean) / std + self.beta[i % self.dim])
140            .collect()
141    }
142}
143
144/// Feed-forward network in transformer
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct FeedForward {
147    linear1: DenseLayer,
148    linear2: DenseLayer,
149}
150
151impl FeedForward {
152    pub fn new(embed_dim: usize, ff_dim: usize, rng: &mut StdRng) -> Self {
153        Self {
154            linear1: DenseLayer::new(embed_dim, ff_dim, Activation::GELU, rng),
155            linear2: DenseLayer::new(ff_dim, embed_dim, Activation::None, rng),
156        }
157    }
158
159    pub fn forward(&mut self, x: &[f32]) -> Vec<f32> {
160        let hidden = self.linear1.forward(x);
161        self.linear2.forward(&hidden)
162    }
163}
164
165/// Single Titans decoder block with Neural Long-Term Memory
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct TitansBlock {
168    /// Neural Long-Term Memory for persistent context
169    memory: TitansMemory,
170    /// Short-term attention for recent sequence
171    attention: MultiHeadAttention,
172    ff: FeedForward,
173    norm1: LayerNorm,
174    norm2: LayerNorm,
175    norm3: LayerNorm,
176    embed_dim: usize,
177}
178
179impl TitansBlock {
180    pub fn new(
181        embed_dim: usize,
182        num_heads: usize,
183        ff_dim: usize,
184        memory_size: usize,
185        rng: &mut StdRng,
186    ) -> Self {
187        Self {
188            memory: TitansMemory::new(embed_dim, embed_dim, memory_size, rng),
189            attention: MultiHeadAttention::new(embed_dim, num_heads, rng),
190            ff: FeedForward::new(embed_dim, ff_dim, rng),
191            norm1: LayerNorm::new(embed_dim),
192            norm2: LayerNorm::new(embed_dim),
193            norm3: LayerNorm::new(embed_dim),
194            embed_dim,
195        }
196    }
197
198    pub fn forward(&mut self, sequence: &[Vec<f32>]) -> Vec<f32> {
199        if sequence.is_empty() {
200            return vec![0.0; self.embed_dim];
201        }
202
203        let last = &sequence[sequence.len() - 1];
204
205        // Step 1: Query Neural Long-Term Memory (persistent context)
206        let memory_out = self.memory.forward(last);
207        let residual1: Vec<f32> = memory_out
208            .iter()
209            .zip(last.iter())
210            .map(|(m, l)| m + l)
211            .collect();
212        let normed1 = self.norm1.forward(&residual1);
213
214        // Step 2: Short-term self-attention for recent sequence
215        let attended = self.attention.forward(sequence);
216        let residual2: Vec<f32> = attended
217            .iter()
218            .zip(normed1.iter())
219            .map(|(a, n)| a + n)
220            .collect();
221        let normed2 = self.norm2.forward(&residual2);
222
223        // Step 3: Feed-forward with residual
224        let ff_out = self.ff.forward(&normed2);
225        let residual3: Vec<f32> = ff_out
226            .iter()
227            .zip(normed2.iter())
228            .map(|(f, n)| f + n)
229            .collect();
230        self.norm3.forward(&residual3)
231    }
232
233    /// Get current surprise level (for anomaly detection)
234    pub fn get_surprise(&self) -> f32 {
235        self.memory.get_surprise()
236    }
237
238    /// Reset memory state
239    pub fn reset_memory(&mut self) {
240        self.memory.reset_state();
241    }
242}
243
244/// Byte-level tokenizer for message encoding
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ByteTokenizer {
247    embed_dim: usize,
248    embeddings: Vec<Vec<f32>>, // 256 byte embeddings
249}
250
251impl ByteTokenizer {
252    pub fn new(embed_dim: usize, rng: &mut StdRng) -> Self {
253        let scale = (1.0 / embed_dim as f32).sqrt();
254        let embeddings: Vec<Vec<f32>> = (0..256)
255            .map(|_| {
256                (0..embed_dim)
257                    .map(|_| rng.gen::<f32>() * 2.0 * scale - scale)
258                    .collect()
259            })
260            .collect();
261
262        Self {
263            embed_dim,
264            embeddings,
265        }
266    }
267
268    pub fn encode(&self, byte: u8) -> &[f32] {
269        &self.embeddings[byte as usize]
270    }
271
272    pub fn encode_sequence(&self, bytes: &[u8]) -> Vec<Vec<f32>> {
273        bytes
274            .iter()
275            .map(|&b| self.embeddings[b as usize].clone())
276            .collect()
277    }
278}
279
280/// Output projection to predict next byte distribution
281#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct OutputProjection {
283    weights: Vec<Vec<f32>>, // [256][embed_dim]
284    temperature: f32,
285}
286
287impl OutputProjection {
288    pub fn new(embed_dim: usize, rng: &mut StdRng) -> Self {
289        let scale = (1.0 / embed_dim as f32).sqrt();
290        let weights: Vec<Vec<f32>> = (0..256)
291            .map(|_| {
292                (0..embed_dim)
293                    .map(|_| rng.gen::<f32>() * 2.0 * scale - scale)
294                    .collect()
295            })
296            .collect();
297
298        Self {
299            weights,
300            temperature: 1.0,
301        }
302    }
303
304    pub fn set_temperature(&mut self, temp: f32) {
305        self.temperature = temp.max(0.01);
306    }
307
308    pub fn forward(&self, hidden: &[f32]) -> Vec<f32> {
309        let mut logits = vec![0.0; 256];
310        for (i, w) in self.weights.iter().enumerate() {
311            for (j, &h) in hidden.iter().enumerate() {
312                logits[i] += w[j] * h;
313            }
314            logits[i] /= self.temperature;
315        }
316
317        // Softmax
318        let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
319        let mut sum = 0.0;
320        for l in &mut logits {
321            *l = (*l - max).exp();
322            sum += *l;
323        }
324        for l in &mut logits {
325            *l /= sum;
326        }
327
328        logits
329    }
330
331    pub fn sample(&self, probs: &[f32], rng: &mut StdRng) -> u8 {
332        let mut cumsum = 0.0;
333        let r: f32 = rng.gen();
334        for (i, &p) in probs.iter().enumerate() {
335            cumsum += p;
336            if r < cumsum {
337                return i as u8;
338            }
339        }
340        255
341    }
342
343    pub fn argmax(&self, probs: &[f32]) -> u8 {
344        probs
345            .iter()
346            .enumerate()
347            .max_by(|(_, a), (_, b)| {
348                // Handle NaN: treat NaN as less than any number
349                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less)
350            })
351            .map(|(i, _)| i as u8)
352            .unwrap_or(0)
353    }
354}
355
356/// Titans-based message predictor with Neural Long-Term Memory
357///
358/// Unlike standard Transformers with fixed context windows, Titans maintains
359/// persistent memory that survives across sequences through surprise-gated
360/// test-time training. This enables:
361/// - **Unbounded context**: Memory consolidates patterns indefinitely
362/// - **Anomaly detection**: High surprise indicates novel/malicious messages
363/// - **Adaptive prediction**: Memory updates based on prediction errors
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct TitansPredictor {
366    tokenizer: ByteTokenizer,
367    positional: PositionalEncoding,
368    blocks: Vec<TitansBlock>,
369    output: OutputProjection,
370    embed_dim: usize,
371    max_seq_len: usize,
372    memory_size: usize,
373    context_window: VecDeque<Vec<f32>>,
374    /// Accumulated surprise for anomaly detection
375    total_surprise: f32,
376    #[serde(skip, default = "default_rng")]
377    rng: StdRng,
378}
379
380/// MIRAS-adaptive message predictor with automatic variant switching
381///
382/// Extends TitansPredictor with MIRAS continual learning framework:
383/// - **Adaptive encoding**: Uses MirasNeuralEncoder for latent projections
384/// - **Variant switching**: Automatically selects optimal MIRAS variant
385/// - **Outlier robustness**: YAAD for high-anomaly traffic
386/// - **Long-term stability**: MONETA for extended sessions
387#[derive(Debug, Clone)]
388pub struct MirasTitansPredictor {
389    /// Base Titans predictor
390    base: TitansPredictor,
391    /// MIRAS encoder for adaptive projections
392    miras_encoder: Option<MirasNeuralEncoder>,
393    /// Current MIRAS variant
394    active_variant: MirasVariant,
395    /// Surprise history for adaptive switching
396    surprise_history: VecDeque<f32>,
397    /// Threshold for variant switching
398    anomaly_threshold: f32,
399    /// Message counter for long-running detection
400    message_count: u64,
401    /// Predictions enhanced with MIRAS embeddings
402    miras_enhanced_predictions: u64,
403    /// Latent dimension for encoding
404    latent_dim: usize,
405}
406
407impl MirasTitansPredictor {
408    /// Create a new MIRAS-enhanced predictor
409    pub fn new(config: TitansConfig) -> Self {
410        let base = TitansPredictor::new(config.clone());
411
412        // Create MIRAS encoder (embed_dim must be divisible by attention_heads)
413        let encoder_config = NeuralEncoderConfig {
414            input_dim: config.embed_dim,
415            latent_dim: config.embed_dim,
416            hidden_dims: vec![config.ff_dim, config.embed_dim],
417            attention_heads: config.num_heads,
418            seed: config.seed + 1,
419            miras_variant: MirasVariant::Titans,
420            memory_tokens: config.memory_size,
421        };
422
423        let miras_encoder = Some(MirasNeuralEncoder::new(&encoder_config));
424
425        Self {
426            base,
427            miras_encoder,
428            active_variant: MirasVariant::Titans,
429            surprise_history: VecDeque::with_capacity(100),
430            anomaly_threshold: 0.5,
431            message_count: 0,
432            miras_enhanced_predictions: 0,
433            latent_dim: config.embed_dim,
434        }
435    }
436
437    /// Create with specific MIRAS variant
438    pub fn new_with_variant(config: TitansConfig, variant: MirasVariant) -> Self {
439        let base = TitansPredictor::new(config.clone());
440
441        // Recreate encoder with specified variant
442        let encoder_config = NeuralEncoderConfig {
443            input_dim: config.embed_dim,
444            latent_dim: config.embed_dim,
445            hidden_dims: vec![config.ff_dim, config.embed_dim],
446            attention_heads: config.num_heads,
447            seed: config.seed + 1,
448            miras_variant: variant,
449            memory_tokens: config.memory_size,
450        };
451
452        Self {
453            base,
454            miras_encoder: Some(MirasNeuralEncoder::new(&encoder_config)),
455            active_variant: variant,
456            surprise_history: VecDeque::with_capacity(100),
457            anomaly_threshold: 0.5,
458            message_count: 0,
459            miras_enhanced_predictions: 0,
460            latent_dim: config.embed_dim,
461        }
462    }
463
464    /// Set anomaly threshold for variant switching
465    pub fn set_anomaly_threshold(&mut self, threshold: f32) {
466        self.anomaly_threshold = threshold;
467    }
468
469    /// Get current MIRAS variant
470    pub fn variant(&self) -> &str {
471        match self.active_variant {
472            MirasVariant::Titans => "titans",
473            MirasVariant::Yaad => "yaad",
474            MirasVariant::Moneta { .. } => "moneta",
475            MirasVariant::Memora => "memora",
476        }
477    }
478
479    /// Get average anomaly level
480    pub fn anomaly_level(&self) -> f32 {
481        if self.surprise_history.is_empty() {
482            0.0
483        } else {
484            self.surprise_history.iter().sum::<f32>() / self.surprise_history.len() as f32
485        }
486    }
487
488    /// Adaptively switch MIRAS variant based on traffic patterns
489    fn maybe_switch_variant(&mut self) {
490        let anomaly = self.anomaly_level();
491
492        let new_variant = if anomaly > self.anomaly_threshold * 2.0 {
493            // High anomaly: use YAAD for outlier robustness
494            MirasVariant::Yaad
495        } else if anomaly > self.anomaly_threshold {
496            // Moderate anomaly: use MEMORA for balanced updates
497            MirasVariant::Memora
498        } else if self.message_count > 10000 {
499            // Long-running session: use MONETA for stability (p=2 is L2 norm)
500            MirasVariant::Moneta { p: 2.0 }
501        } else {
502            // Normal: baseline Titans
503            MirasVariant::Titans
504        };
505
506        // Check if variant changed (ignoring Moneta's p value for comparison)
507        let variant_changed = !matches!(
508            (&new_variant, &self.active_variant),
509            (MirasVariant::Titans, MirasVariant::Titans)
510                | (MirasVariant::Yaad, MirasVariant::Yaad)
511                | (MirasVariant::Moneta { .. }, MirasVariant::Moneta { .. })
512                | (MirasVariant::Memora, MirasVariant::Memora)
513        );
514
515        if variant_changed {
516            self.active_variant = new_variant;
517            // Note: In production, we'd rebuild the encoder here
518            // For efficiency, we keep the same encoder but track the variant
519        }
520    }
521
522    /// Observe a message with MIRAS-enhanced encoding
523    pub fn observe(&mut self, message: &[u8]) {
524        // Base observation
525        self.base.observe(message);
526
527        // Track surprise
528        let surprise = self.base.get_surprise();
529        self.surprise_history.push_back(surprise);
530        if self.surprise_history.len() > 100 {
531            self.surprise_history.pop_front();
532        }
533
534        // MIRAS encoding step (for enhanced pattern learning)
535        if let Some(ref mut encoder) = self.miras_encoder {
536            // Encode with MIRAS (triggers surprise tracking)
537            let _latent = encoder.encode(message);
538            self.miras_enhanced_predictions += 1;
539        }
540
541        self.message_count += 1;
542
543        // Check if we should switch variants
544        self.maybe_switch_variant();
545    }
546
547    /// Predict next byte (delegates to base)
548    pub fn predict_next(&mut self) -> (u8, f32) {
549        self.base.predict_next()
550    }
551
552    /// Predict sequence (delegates to base)
553    pub fn predict_sequence(&mut self, length: usize, greedy: bool) -> Vec<u8> {
554        self.base.predict_sequence(length, greedy)
555    }
556
557    /// Verify prediction (delegates to base)
558    pub fn verify_prediction(&mut self, message: &[u8]) -> (bool, f32) {
559        self.base.verify_prediction(message)
560    }
561
562    /// Get surprise from base predictor
563    pub fn get_surprise(&self) -> f32 {
564        self.base.get_surprise()
565    }
566
567    /// Check if anomalous
568    pub fn is_anomalous(&self, threshold: f32) -> bool {
569        self.base.is_anomalous(threshold)
570    }
571
572    /// Get MIRAS encoder surprise (if available)
573    pub fn get_miras_surprise(&self) -> Option<f32> {
574        self.miras_encoder.as_ref().map(|e| e.get_surprise())
575    }
576
577    /// Get combined surprise (Titans + MIRAS)
578    pub fn get_combined_surprise(&self) -> f32 {
579        let titans = self.base.get_surprise();
580        let miras = self.get_miras_surprise().unwrap_or(0.0);
581        (titans + miras) / 2.0
582    }
583
584    /// Reset context (preserves memory)
585    pub fn reset(&mut self) {
586        self.base.reset();
587    }
588
589    /// Full reset including MIRAS state
590    pub fn reset_all(&mut self) {
591        self.base.reset_all();
592        self.surprise_history.clear();
593        self.message_count = 0;
594        if let Some(ref mut encoder) = self.miras_encoder {
595            encoder.reset();
596        }
597    }
598
599    /// Get statistics
600    pub fn stats(&self) -> MirasPredictorStats {
601        MirasPredictorStats {
602            message_count: self.message_count,
603            miras_enhanced_predictions: self.miras_enhanced_predictions,
604            current_variant: self.variant().to_string(),
605            anomaly_level: self.anomaly_level(),
606            titans_surprise: self.base.get_surprise(),
607            miras_surprise: self.get_miras_surprise(),
608        }
609    }
610}
611
612/// Statistics for MIRAS predictor
613#[derive(Debug, Clone, Serialize, Deserialize)]
614pub struct MirasPredictorStats {
615    pub message_count: u64,
616    pub miras_enhanced_predictions: u64,
617    pub current_variant: String,
618    pub anomaly_level: f32,
619    pub titans_surprise: f32,
620    pub miras_surprise: Option<f32>,
621}
622
623fn default_rng() -> StdRng {
624    StdRng::seed_from_u64(42)
625}
626
627impl TitansPredictor {
628    pub fn new(config: TitansConfig) -> Self {
629        let mut rng = StdRng::seed_from_u64(config.seed);
630
631        let tokenizer = ByteTokenizer::new(config.embed_dim, &mut rng);
632        let positional = PositionalEncoding::new(config.max_seq_len, config.embed_dim);
633
634        let blocks: Vec<TitansBlock> = (0..config.num_layers)
635            .map(|_| {
636                TitansBlock::new(
637                    config.embed_dim,
638                    config.num_heads,
639                    config.ff_dim,
640                    config.memory_size,
641                    &mut rng,
642                )
643            })
644            .collect();
645
646        let output = OutputProjection::new(config.embed_dim, &mut rng);
647
648        Self {
649            tokenizer,
650            positional,
651            blocks,
652            output,
653            embed_dim: config.embed_dim,
654            max_seq_len: config.max_seq_len,
655            memory_size: config.memory_size,
656            context_window: VecDeque::with_capacity(config.max_seq_len),
657            total_surprise: 0.0,
658            rng,
659        }
660    }
661
662    /// Add a message to the context for prediction (triggers test-time training)
663    pub fn observe(&mut self, message: &[u8]) {
664        for &byte in message {
665            let mut embedding = self.tokenizer.encode(byte).to_vec();
666            let pos = self.context_window.len();
667            let pos_enc = self.positional.get(pos);
668            for (e, p) in embedding.iter_mut().zip(pos_enc.iter()) {
669                *e += *p;
670            }
671
672            self.context_window.push_back(embedding);
673            if self.context_window.len() > self.max_seq_len {
674                self.context_window.pop_front();
675            }
676        }
677
678        // Accumulate surprise from all blocks (for anomaly detection)
679        self.total_surprise = self.blocks.iter().map(|b| b.get_surprise()).sum::<f32>()
680            / self.blocks.len().max(1) as f32;
681    }
682
683    /// Predict the next byte
684    pub fn predict_next(&mut self) -> (u8, f32) {
685        let sequence: Vec<Vec<f32>> = self.context_window.iter().cloned().collect();
686
687        if sequence.is_empty() {
688            return (0, 1.0 / 256.0);
689        }
690
691        // Forward through Titans blocks (with persistent memory)
692        let mut hidden = self.blocks[0].forward(&sequence);
693        for block in &mut self.blocks[1..] {
694            let seq_with_hidden = vec![hidden.clone()];
695            hidden = block.forward(&seq_with_hidden);
696        }
697
698        // Project to output
699        let probs = self.output.forward(&hidden);
700        let predicted = self.output.argmax(&probs);
701        let confidence = probs[predicted as usize];
702
703        (predicted, confidence)
704    }
705
706    /// Predict multiple bytes autoregressively
707    pub fn predict_sequence(&mut self, length: usize, greedy: bool) -> Vec<u8> {
708        let mut result = Vec::with_capacity(length);
709
710        for _ in 0..length {
711            let sequence: Vec<Vec<f32>> = self.context_window.iter().cloned().collect();
712
713            if sequence.is_empty() {
714                let byte = if greedy { 0 } else { self.rng.gen() };
715                result.push(byte);
716                continue;
717            }
718
719            // Forward through Titans blocks
720            let mut hidden = self.blocks[0].forward(&sequence);
721            for block in &mut self.blocks[1..] {
722                let seq_with_hidden = vec![hidden.clone()];
723                hidden = block.forward(&seq_with_hidden);
724            }
725
726            let probs = self.output.forward(&hidden);
727            let byte = if greedy {
728                self.output.argmax(&probs)
729            } else {
730                self.output.sample(&probs, &mut self.rng)
731            };
732
733            result.push(byte);
734
735            // Add prediction to context for autoregressive generation
736            let mut embedding = self.tokenizer.encode(byte).to_vec();
737            let pos = self.context_window.len();
738            let pos_enc = self.positional.get(pos);
739            for (e, p) in embedding.iter_mut().zip(pos_enc.iter()) {
740                *e += *p;
741            }
742            self.context_window.push_back(embedding);
743            if self.context_window.len() > self.max_seq_len {
744                self.context_window.pop_front();
745            }
746        }
747
748        result
749    }
750
751    /// Check if a message matches prediction
752    pub fn verify_prediction(&mut self, message: &[u8]) -> (bool, f32) {
753        let predicted = self.predict_sequence(message.len(), true);
754        let matches = predicted == message;
755
756        let similarity = predicted
757            .iter()
758            .zip(message.iter())
759            .filter(|(p, m)| p == m)
760            .count() as f32
761            / message.len().max(1) as f32;
762
763        (matches, similarity)
764    }
765
766    /// Get accumulated surprise (anomaly score)
767    /// High values indicate unexpected/novel message patterns
768    pub fn get_surprise(&self) -> f32 {
769        self.total_surprise
770    }
771
772    /// Check if current message pattern is anomalous
773    pub fn is_anomalous(&self, threshold: f32) -> bool {
774        self.total_surprise > threshold
775    }
776
777    /// Reset context window (but preserve long-term memory)
778    pub fn reset(&mut self) {
779        self.context_window.clear();
780        self.total_surprise = 0.0;
781    }
782
783    /// Full reset including long-term memory
784    pub fn reset_all(&mut self) {
785        self.context_window.clear();
786        self.total_surprise = 0.0;
787        for block in &mut self.blocks {
788            block.reset_memory();
789        }
790    }
791
792    /// Set temperature for sampling
793    pub fn set_temperature(&mut self, temp: f32) {
794        self.output.set_temperature(temp);
795    }
796}
797
798// Backwards compatibility aliases
799pub type TransformerPredictor = TitansPredictor;
800pub type TransformerConfig = TitansConfig;
801
802/// Configuration for Titans predictor
803#[derive(Debug, Clone, Serialize, Deserialize)]
804pub struct TitansConfig {
805    pub embed_dim: usize,
806    pub num_heads: usize,
807    pub num_layers: usize,
808    pub ff_dim: usize,
809    pub max_seq_len: usize,
810    /// Size of persistent memory (number of memory tokens)
811    pub memory_size: usize,
812    pub seed: u64,
813}
814
815impl Default for TitansConfig {
816    fn default() -> Self {
817        Self {
818            embed_dim: 64,
819            num_heads: 4,
820            num_layers: 2,
821            ff_dim: 128,
822            max_seq_len: 256,
823            memory_size: 64, // Persistent memory tokens
824            seed: 42,
825        }
826    }
827}
828
829// ============================================================================
830// QUANTUM-RESISTANT KEY EVOLUTION
831// ============================================================================
832
833/// Parameters for NTRU-like lattice operations
834#[derive(Debug, Clone, Serialize, Deserialize)]
835pub struct LatticeParams {
836    pub n: usize,   // Polynomial degree (power of 2)
837    pub q: u64,     // Large modulus
838    pub p: u64,     // Small modulus for message space
839    pub sigma: f64, // Gaussian noise standard deviation
840}
841
842impl Default for LatticeParams {
843    fn default() -> Self {
844        Self {
845            n: 1024,    // NIST Level 3 security (~192-bit classical)
846            q: 12289,   // NTT-friendly prime for n=1024
847            p: 3,       // Ternary message space
848            sigma: 3.2, // Standard deviation per NIST recommendations
849        }
850    }
851}
852
853/// Polynomial ring element Z_q[X]/(X^n + 1).
854///
855/// Holds RLWE secret-key coefficients. Memory is zeroed on `Drop` so a
856/// dropped `RingElement` cannot leak its secret via a core dump or
857/// swap-to-disk event (NIST SP 800-171 § 3.13.10).
858#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
859pub struct RingElement {
860    coeffs: Vec<i64>,
861    n: usize,
862    q: u64,
863}
864
865impl RingElement {
866    pub fn new(n: usize, q: u64) -> Self {
867        Self {
868            coeffs: vec![0; n],
869            n,
870            q,
871        }
872    }
873
874    pub fn random(n: usize, q: u64, rng: &mut StdRng) -> Self {
875        let coeffs: Vec<i64> = (0..n).map(|_| rng.gen_range(0..q as i64)).collect();
876        Self { coeffs, n, q }
877    }
878
879    pub fn random_ternary(n: usize, q: u64, rng: &mut StdRng) -> Self {
880        let coeffs: Vec<i64> = (0..n).map(|_| rng.gen_range(-1..=1)).collect();
881        Self { coeffs, n, q }
882    }
883
884    pub fn random_gaussian(n: usize, q: u64, sigma: f64, rng: &mut StdRng) -> Self {
885        // Box-Muller transform for Gaussian
886        let coeffs: Vec<i64> = (0..n)
887            .map(|_| {
888                let u1: f64 = rng.gen::<f64>().max(1e-10);
889                let u2: f64 = rng.gen();
890                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
891                (z * sigma).round() as i64
892            })
893            .collect();
894        Self { coeffs, n, q }
895    }
896
897    pub fn from_bytes(bytes: &[u8], n: usize, q: u64) -> Self {
898        let mut coeffs = vec![0i64; n];
899        for (i, chunk) in bytes.chunks(2).enumerate() {
900            if i >= n {
901                break;
902            }
903            let val = if chunk.len() == 2 {
904                ((chunk[0] as u16) | ((chunk[1] as u16) << 8)) as i64
905            } else {
906                chunk[0] as i64
907            };
908            coeffs[i] = val % q as i64;
909        }
910        Self { coeffs, n, q }
911    }
912
913    pub fn to_bytes(&self) -> Vec<u8> {
914        let mut bytes = Vec::with_capacity(self.n * 2);
915        for &c in &self.coeffs {
916            let val = ((c % self.q as i64 + self.q as i64) % self.q as i64) as u16;
917            bytes.push(val as u8);
918            bytes.push((val >> 8) as u8);
919        }
920        bytes
921    }
922
923    fn reduce(&mut self) {
924        for c in &mut self.coeffs {
925            *c = ((*c % self.q as i64) + self.q as i64) % self.q as i64;
926        }
927    }
928
929    /// Polynomial multiplication in R_q = Z_q[X]/(X^n + 1)
930    pub fn mul(&self, other: &RingElement) -> RingElement {
931        assert_eq!(self.n, other.n);
932        let mut result = vec![0i64; self.n];
933
934        for i in 0..self.n {
935            for j in 0..self.n {
936                let idx = i + j;
937                let coeff = self.coeffs[i] * other.coeffs[j];
938                if idx < self.n {
939                    result[idx] += coeff;
940                } else {
941                    // X^n = -1 in the ring
942                    result[idx - self.n] -= coeff;
943                }
944            }
945        }
946
947        let mut elem = RingElement {
948            coeffs: result,
949            n: self.n,
950            q: self.q,
951        };
952        elem.reduce();
953        elem
954    }
955
956    /// Polynomial addition
957    pub fn add(&self, other: &RingElement) -> RingElement {
958        assert_eq!(self.n, other.n);
959        let coeffs: Vec<i64> = self
960            .coeffs
961            .iter()
962            .zip(other.coeffs.iter())
963            .map(|(a, b)| (a + b) % self.q as i64)
964            .collect();
965        let mut elem = RingElement {
966            coeffs,
967            n: self.n,
968            q: self.q,
969        };
970        elem.reduce();
971        elem
972    }
973
974    /// Polynomial subtraction
975    pub fn sub(&self, other: &RingElement) -> RingElement {
976        assert_eq!(self.n, other.n);
977        let coeffs: Vec<i64> = self
978            .coeffs
979            .iter()
980            .zip(other.coeffs.iter())
981            .map(|(a, b)| (a - b) % self.q as i64)
982            .collect();
983        let mut elem = RingElement {
984            coeffs,
985            n: self.n,
986            q: self.q,
987        };
988        elem.reduce();
989        elem
990    }
991
992    /// Scale coefficients
993    pub fn scale(&self, scalar: i64) -> RingElement {
994        let coeffs: Vec<i64> = self
995            .coeffs
996            .iter()
997            .map(|&c| (c * scalar) % self.q as i64)
998            .collect();
999        let mut elem = RingElement {
1000            coeffs,
1001            n: self.n,
1002            q: self.q,
1003        };
1004        elem.reduce();
1005        elem
1006    }
1007}
1008
1009/// Key Encapsulation Mechanism algorithm selection
1010#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
1011pub enum KemAlgorithm {
1012    /// Custom RLWE (existing implementation) — NIST Level 3 equivalent
1013    Rlwe,
1014    /// ML-KEM-512 (FIPS 203) — NIST Level 1
1015    MlKem512,
1016    /// ML-KEM-768 (FIPS 203) — NIST Level 3 (recommended)
1017    #[default]
1018    MlKem768,
1019    /// ML-KEM-1024 (FIPS 203) — NIST Level 5
1020    MlKem1024,
1021    /// Hybrid: RLWE + ML-KEM-768 (defense in depth)
1022    Hybrid,
1023}
1024
1025/// ML-KEM key encapsulation result.
1026///
1027/// `dk_bytes` is the FIPS 203 decapsulation key — the private half of
1028/// the KEM. Zeroed on `Drop`; the public `ek_bytes` is zeroed too
1029/// because it's free to do so and keeps the Drop impl uniform.
1030#[derive(Debug, Clone, Zeroize, ZeroizeOnDrop)]
1031struct MlKemKeyPair {
1032    dk_bytes: Vec<u8>,  // Decapsulation key (private)
1033    ek_bytes: Vec<u8>,  // Encapsulation key (public)
1034    #[zeroize(skip)]
1035    algorithm: KemAlgorithm,
1036}
1037
1038/// ML-KEM operations using FIPS 203
1039mod mlkem_ops {
1040    use super::*;
1041
1042    pub fn generate_512(rng: &mut StdRng) -> MlKemKeyPair {
1043        let (dk, ek) = MlKem512::generate(rng);
1044        MlKemKeyPair {
1045            dk_bytes: dk.as_bytes().to_vec(),
1046            ek_bytes: ek.as_bytes().to_vec(),
1047            algorithm: KemAlgorithm::MlKem512,
1048        }
1049    }
1050
1051    pub fn generate_768(rng: &mut StdRng) -> MlKemKeyPair {
1052        let (dk, ek) = MlKem768::generate(rng);
1053        MlKemKeyPair {
1054            dk_bytes: dk.as_bytes().to_vec(),
1055            ek_bytes: ek.as_bytes().to_vec(),
1056            algorithm: KemAlgorithm::MlKem768,
1057        }
1058    }
1059
1060    pub fn generate_1024(rng: &mut StdRng) -> MlKemKeyPair {
1061        let (dk, ek) = MlKem1024::generate(rng);
1062        MlKemKeyPair {
1063            dk_bytes: dk.as_bytes().to_vec(),
1064            ek_bytes: ek.as_bytes().to_vec(),
1065            algorithm: KemAlgorithm::MlKem1024,
1066        }
1067    }
1068
1069    pub fn encapsulate_512(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
1070        let ek_encoded = <Encoded<EncapsulationKey<MlKem512Params>>>::try_from(ek_bytes).ok()?;
1071        let ek = EncapsulationKey::<MlKem512Params>::from_bytes(&ek_encoded);
1072        let (ct, ss) = ek.encapsulate(rng).ok()?;
1073        let mut shared = [0u8; 32];
1074        shared.copy_from_slice(ss.as_slice());
1075        Some((ct.to_vec(), shared))
1076    }
1077
1078    pub fn encapsulate_768(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
1079        let ek_encoded = <Encoded<EncapsulationKey<MlKem768Params>>>::try_from(ek_bytes).ok()?;
1080        let ek = EncapsulationKey::<MlKem768Params>::from_bytes(&ek_encoded);
1081        let (ct, ss) = ek.encapsulate(rng).ok()?;
1082        let mut shared = [0u8; 32];
1083        shared.copy_from_slice(ss.as_slice());
1084        Some((ct.to_vec(), shared))
1085    }
1086
1087    pub fn encapsulate_1024(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
1088        let ek_encoded = <Encoded<EncapsulationKey<MlKem1024Params>>>::try_from(ek_bytes).ok()?;
1089        let ek = EncapsulationKey::<MlKem1024Params>::from_bytes(&ek_encoded);
1090        let (ct, ss) = ek.encapsulate(rng).ok()?;
1091        let mut shared = [0u8; 32];
1092        shared.copy_from_slice(ss.as_slice());
1093        Some((ct.to_vec(), shared))
1094    }
1095
1096    pub fn decapsulate_512(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
1097        let dk_encoded = <Encoded<DecapsulationKey<MlKem512Params>>>::try_from(dk_bytes).ok()?;
1098        let dk = DecapsulationKey::<MlKem512Params>::from_bytes(&dk_encoded);
1099        let ct = <ml_kem::Ciphertext<MlKem512>>::try_from(ct_bytes).ok()?;
1100        let ss = dk.decapsulate(&ct).ok()?;
1101        let mut shared = [0u8; 32];
1102        shared.copy_from_slice(ss.as_slice());
1103        Some(shared)
1104    }
1105
1106    pub fn decapsulate_768(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
1107        let dk_encoded = <Encoded<DecapsulationKey<MlKem768Params>>>::try_from(dk_bytes).ok()?;
1108        let dk = DecapsulationKey::<MlKem768Params>::from_bytes(&dk_encoded);
1109        let ct = <ml_kem::Ciphertext<MlKem768>>::try_from(ct_bytes).ok()?;
1110        let ss = dk.decapsulate(&ct).ok()?;
1111        let mut shared = [0u8; 32];
1112        shared.copy_from_slice(ss.as_slice());
1113        Some(shared)
1114    }
1115
1116    pub fn decapsulate_1024(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
1117        let dk_encoded = <Encoded<DecapsulationKey<MlKem1024Params>>>::try_from(dk_bytes).ok()?;
1118        let dk = DecapsulationKey::<MlKem1024Params>::from_bytes(&dk_encoded);
1119        let ct = <ml_kem::Ciphertext<MlKem1024>>::try_from(ct_bytes).ok()?;
1120        let ss = dk.decapsulate(&ct).ok()?;
1121        let mut shared = [0u8; 32];
1122        shared.copy_from_slice(ss.as_slice());
1123        Some(shared)
1124    }
1125}
1126
1127/// Quantum-resistant key pair.
1128///
1129/// `secret_key` is the RLWE secret coefficient vector. All three
1130/// `RingElement` fields zeroize on drop via their own derived
1131/// `ZeroizeOnDrop`. `params` is plaintext metadata (n, q, sigma) so it
1132/// is intentionally skipped.
1133#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
1134pub struct QuantumKeyPair {
1135    /// Public parameter `a` — must be stored for correct KEM encaps/decaps
1136    pub a: RingElement,
1137    pub public_key: RingElement,
1138    secret_key: RingElement,
1139    #[zeroize(skip)]
1140    params: LatticeParams,
1141}
1142
1143/// Quantum-resistant key evolution system.
1144///
1145/// `Drop` is implemented manually below because `VecDeque<[u8; 32]>`
1146/// and `StdRng` don't impl `Zeroize` in the derive form. The wrapped
1147/// `QuantumKeyPair` and `MlKemKeyPair` zero themselves on their own
1148/// drops; we only need to scrub the history buffer here.
1149#[derive(Debug, Clone, Serialize, Deserialize)]
1150pub struct QuantumKeyEvolution {
1151    params: LatticeParams,
1152    current_key: QuantumKeyPair,
1153    evolution_counter: u64,
1154    key_history: VecDeque<[u8; 32]>, // Hashes of past keys for forward secrecy
1155    max_history: usize,
1156    #[serde(skip, default = "default_rng")]
1157    rng: StdRng,
1158    /// KEM algorithm in use
1159    algorithm: KemAlgorithm,
1160    /// ML-KEM keypair (when using FIPS 203 algorithms)
1161    #[serde(skip)]
1162    mlkem_keypair: Option<MlKemKeyPair>,
1163}
1164
1165impl Drop for QuantumKeyEvolution {
1166    fn drop(&mut self) {
1167        // The wrapped key structs zero themselves on their own drops.
1168        // We only need to scrub the rolling history buffer here — each
1169        // entry is a SHA-256 over a past secret and is treated as
1170        // sensitive even though it is not the secret itself.
1171        for h in self.key_history.iter_mut() {
1172            h.zeroize();
1173        }
1174        self.key_history.clear();
1175    }
1176}
1177
1178impl QuantumKeyEvolution {
1179    pub fn new(params: LatticeParams, seed: u64) -> Self {
1180        let mut rng = StdRng::seed_from_u64(seed);
1181        let current_key = Self::generate_keypair(&params, &mut rng);
1182
1183        Self {
1184            params,
1185            current_key,
1186            evolution_counter: 0,
1187            key_history: VecDeque::new(),
1188            max_history: 100,
1189            rng,
1190            algorithm: KemAlgorithm::Rlwe,
1191            mlkem_keypair: None,
1192        }
1193    }
1194
1195    /// Create with a specific KEM algorithm
1196    pub fn new_with_algorithm(params: LatticeParams, seed: u64, algorithm: KemAlgorithm) -> Self {
1197        let mut rng = StdRng::seed_from_u64(seed);
1198        let current_key = Self::generate_keypair(&params, &mut rng);
1199        let mlkem_keypair = match algorithm {
1200            KemAlgorithm::MlKem512 => Some(mlkem_ops::generate_512(&mut rng)),
1201            KemAlgorithm::MlKem768 => Some(mlkem_ops::generate_768(&mut rng)),
1202            KemAlgorithm::MlKem1024 => Some(mlkem_ops::generate_1024(&mut rng)),
1203            KemAlgorithm::Hybrid => Some(mlkem_ops::generate_768(&mut rng)),
1204            KemAlgorithm::Rlwe => None,
1205        };
1206
1207        Self {
1208            params,
1209            current_key,
1210            evolution_counter: 0,
1211            key_history: VecDeque::new(),
1212            max_history: 100,
1213            rng,
1214            algorithm,
1215            mlkem_keypair,
1216        }
1217    }
1218
1219    fn generate_keypair(params: &LatticeParams, rng: &mut StdRng) -> QuantumKeyPair {
1220        // RLWE-style key generation
1221        let a = RingElement::random(params.n, params.q, rng);
1222        let s = RingElement::random_ternary(params.n, params.q, rng);
1223        let e = RingElement::random_gaussian(params.n, params.q, params.sigma, rng);
1224
1225        // Public key: b = a*s + e
1226        let b = a.mul(&s).add(&e);
1227
1228        QuantumKeyPair {
1229            a, // Store `a` for correct KEM operation
1230            public_key: b,
1231            secret_key: s,
1232            params: params.clone(),
1233        }
1234    }
1235
1236    /// Evolve the key forward (one-way function)
1237    ///
1238    /// Uses HKDF to derive a new seed from the current key material,
1239    /// then generates a fresh RLWE keypair that maintains the b=a*s+e invariant.
1240    pub fn evolve(&mut self) -> [u8; 32] {
1241        // Hash current key material (public key + secret key + counter)
1242        let mut hasher = Sha256::new();
1243        hasher.update(self.current_key.public_key.to_bytes());
1244        hasher.update(self.current_key.secret_key.to_bytes());
1245        hasher.update(self.evolution_counter.to_le_bytes());
1246        let hash: [u8; 32] = hasher.finalize().into();
1247
1248        // Store in history
1249        self.key_history.push_back(hash);
1250        if self.key_history.len() > self.max_history {
1251            self.key_history.pop_front();
1252        }
1253
1254        // Use HKDF to derive new seed (mixes entropy from current key + counter)
1255        let hk = Hkdf::<Sha256>::new(Some(&self.evolution_counter.to_le_bytes()), &hash);
1256        let mut okm = [0u8; 32];
1257        hk.expand(b"spine-key-evolution", &mut okm)
1258            .expect("HKDF expand failed");
1259        let new_seed = u64::from_le_bytes(okm[0..8].try_into().unwrap());
1260        let mut new_rng = StdRng::seed_from_u64(new_seed);
1261
1262        // Generate a proper RLWE keypair that maintains the b=a*s+e invariant
1263        self.current_key = Self::generate_keypair(&self.params, &mut new_rng);
1264
1265        // Also evolve ML-KEM keypair if in use
1266        if self.algorithm != KemAlgorithm::Rlwe {
1267            self.mlkem_keypair = match self.algorithm {
1268                KemAlgorithm::MlKem512 => Some(mlkem_ops::generate_512(&mut new_rng)),
1269                KemAlgorithm::MlKem768 | KemAlgorithm::Hybrid => Some(mlkem_ops::generate_768(&mut new_rng)),
1270                KemAlgorithm::MlKem1024 => Some(mlkem_ops::generate_1024(&mut new_rng)),
1271                KemAlgorithm::Rlwe => None,
1272            };
1273        }
1274
1275        self.evolution_counter += 1;
1276
1277        hash
1278    }
1279
1280    /// Encapsulate a shared secret using the current KEM algorithm
1281    pub fn encapsulate(&mut self) -> (Vec<u8>, [u8; 32]) {
1282        match self.algorithm {
1283            KemAlgorithm::Rlwe => self.encapsulate_rlwe(),
1284            KemAlgorithm::MlKem512 => self.encapsulate_mlkem(KemAlgorithm::MlKem512),
1285            KemAlgorithm::MlKem768 => self.encapsulate_mlkem(KemAlgorithm::MlKem768),
1286            KemAlgorithm::MlKem1024 => self.encapsulate_mlkem(KemAlgorithm::MlKem1024),
1287            KemAlgorithm::Hybrid => self.encapsulate_hybrid(),
1288        }
1289    }
1290
1291    /// Encapsulate using ML-KEM (FIPS 203)
1292    fn encapsulate_mlkem(&mut self, alg: KemAlgorithm) -> (Vec<u8>, [u8; 32]) {
1293        let kp = self.mlkem_keypair.as_ref().expect("ML-KEM keypair required");
1294        let result = match alg {
1295            KemAlgorithm::MlKem512 => mlkem_ops::encapsulate_512(&kp.ek_bytes, &mut self.rng),
1296            KemAlgorithm::MlKem768 => mlkem_ops::encapsulate_768(&kp.ek_bytes, &mut self.rng),
1297            KemAlgorithm::MlKem1024 => mlkem_ops::encapsulate_1024(&kp.ek_bytes, &mut self.rng),
1298            _ => unreachable!(),
1299        };
1300        result.unwrap_or_else(|| {
1301            // Fallback to RLWE if ML-KEM fails
1302            self.encapsulate_rlwe()
1303        })
1304    }
1305
1306    /// Encapsulate using hybrid RLWE + ML-KEM-768 (defense in depth)
1307    fn encapsulate_hybrid(&mut self) -> (Vec<u8>, [u8; 32]) {
1308        // Get shared secrets from both algorithms
1309        let (rlwe_ct, rlwe_ss) = self.encapsulate_rlwe();
1310        let (mlkem_ct, mlkem_ss) = self.encapsulate_mlkem(KemAlgorithm::MlKem768);
1311
1312        // Combine shared secrets via HKDF
1313        let mut combined_ikm = [0u8; 64];
1314        combined_ikm[..32].copy_from_slice(&rlwe_ss);
1315        combined_ikm[32..].copy_from_slice(&mlkem_ss);
1316        let hk = Hkdf::<Sha256>::new(None, &combined_ikm);
1317        let mut hybrid_ss = [0u8; 32];
1318        hk.expand(b"spine-hybrid-kem", &mut hybrid_ss).expect("HKDF expand");
1319
1320        // Concatenate ciphertexts with length prefix for RLWE part
1321        let rlwe_len = (rlwe_ct.len() as u32).to_le_bytes();
1322        let mut hybrid_ct = Vec::with_capacity(4 + rlwe_ct.len() + mlkem_ct.len());
1323        hybrid_ct.extend_from_slice(&rlwe_len);
1324        hybrid_ct.extend_from_slice(&rlwe_ct);
1325        hybrid_ct.extend_from_slice(&mlkem_ct);
1326
1327        (hybrid_ct, hybrid_ss)
1328    }
1329
1330    /// Encapsulate a shared secret using the public key (RLWE KEM)
1331    ///
1332    /// Uses the stored `a` from keygen to ensure sender and receiver
1333    /// derive the same shared secret. Encodes a random message `m` into
1334    /// the high bits and recovers it via rounding on decapsulation.
1335    fn encapsulate_rlwe(&mut self) -> (Vec<u8>, [u8; 32]) {
1336        // Use the SAME `a` from keygen — critical for correctness
1337        let a = &self.current_key.a;
1338        let r = RingElement::random_ternary(self.params.n, self.params.q, &mut self.rng);
1339        let e1 = RingElement::random_gaussian(
1340            self.params.n,
1341            self.params.q,
1342            self.params.sigma,
1343            &mut self.rng,
1344        );
1345        let e2 = RingElement::random_gaussian(
1346            self.params.n,
1347            self.params.q,
1348            self.params.sigma,
1349            &mut self.rng,
1350        );
1351
1352        // Generate random message m ∈ {0, 1}^n for KEM
1353        let m: Vec<i64> = (0..self.params.n)
1354            .map(|_| self.rng.gen_range(0..2i64))
1355            .collect();
1356
1357        // u = a*r + e1
1358        let u = a.mul(&r).add(&e1);
1359
1360        // v = b*r + e2 + ⌊q/2⌋·m (encode message in high bits)
1361        let half_q = (self.params.q / 2) as i64;
1362        let encoded_m = RingElement {
1363            coeffs: m.iter().map(|&mi| mi * half_q).collect(),
1364            n: self.params.n,
1365            q: self.params.q,
1366        };
1367        let v = self.current_key.public_key.mul(&r).add(&e2).add(&encoded_m);
1368
1369        // Ciphertext = (u, v)
1370        let mut ciphertext = u.to_bytes();
1371        ciphertext.extend(v.to_bytes());
1372
1373        // Shared secret = H(m) — both sides derive from the same m
1374        let mut hasher = Sha256::new();
1375        for &mi in &m {
1376            hasher.update(mi.to_le_bytes());
1377        }
1378        let shared_secret: [u8; 32] = hasher.finalize().into();
1379
1380        (ciphertext, shared_secret)
1381    }
1382
1383    /// Decapsulate to recover shared secret using the current KEM algorithm
1384    pub fn decapsulate(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
1385        match self.algorithm {
1386            KemAlgorithm::Rlwe => self.decapsulate_rlwe(ciphertext),
1387            KemAlgorithm::MlKem512 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem512),
1388            KemAlgorithm::MlKem768 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem768),
1389            KemAlgorithm::MlKem1024 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem1024),
1390            KemAlgorithm::Hybrid => self.decapsulate_hybrid(ciphertext),
1391        }
1392    }
1393
1394    /// Decapsulate using ML-KEM (FIPS 203)
1395    fn decapsulate_mlkem(&self, ciphertext: &[u8], alg: KemAlgorithm) -> Option<[u8; 32]> {
1396        let kp = self.mlkem_keypair.as_ref()?;
1397        match alg {
1398            KemAlgorithm::MlKem512 => mlkem_ops::decapsulate_512(&kp.dk_bytes, ciphertext),
1399            KemAlgorithm::MlKem768 => mlkem_ops::decapsulate_768(&kp.dk_bytes, ciphertext),
1400            KemAlgorithm::MlKem1024 => mlkem_ops::decapsulate_1024(&kp.dk_bytes, ciphertext),
1401            _ => None,
1402        }
1403    }
1404
1405    /// Decapsulate using hybrid RLWE + ML-KEM-768
1406    fn decapsulate_hybrid(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
1407        if ciphertext.len() < 4 { return None; }
1408        let rlwe_len = u32::from_le_bytes(ciphertext[..4].try_into().ok()?) as usize;
1409        if ciphertext.len() < 4 + rlwe_len { return None; }
1410
1411        let rlwe_ct = &ciphertext[4..4+rlwe_len];
1412        let mlkem_ct = &ciphertext[4+rlwe_len..];
1413
1414        let rlwe_ss = self.decapsulate_rlwe(rlwe_ct)?;
1415        let mlkem_ss = self.decapsulate_mlkem(mlkem_ct, KemAlgorithm::MlKem768)?;
1416
1417        let mut combined_ikm = [0u8; 64];
1418        combined_ikm[..32].copy_from_slice(&rlwe_ss);
1419        combined_ikm[32..].copy_from_slice(&mlkem_ss);
1420        let hk = Hkdf::<Sha256>::new(None, &combined_ikm);
1421        let mut hybrid_ss = [0u8; 32];
1422        hk.expand(b"spine-hybrid-kem", &mut hybrid_ss).expect("HKDF expand");
1423
1424        Some(hybrid_ss)
1425    }
1426
1427    /// Decapsulate to recover shared secret (RLWE KEM)
1428    ///
1429    /// Recovers the encoded message by computing v - u·s, then rounding
1430    /// each coefficient to 0 or 1 to recover the original message m.
1431    /// The shared secret is H(m), matching the encapsulator.
1432    fn decapsulate_rlwe(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
1433        let half = ciphertext.len() / 2;
1434        if half < self.params.n * 2 {
1435            return None;
1436        }
1437
1438        let u = RingElement::from_bytes(&ciphertext[..half], self.params.n, self.params.q);
1439        let v = RingElement::from_bytes(&ciphertext[half..], self.params.n, self.params.q);
1440
1441        // recovered = v - u·s ≈ ⌊q/2⌋·m + (small noise)
1442        let recovered = v.sub(&u.mul(&self.current_key.secret_key));
1443
1444        // Round each coefficient: if closer to ⌊q/2⌋ → 1, else → 0
1445        let half_q = self.params.q as i64 / 2;
1446        let quarter_q = self.params.q as i64 / 4;
1447        let m: Vec<i64> = recovered
1448            .coeffs
1449            .iter()
1450            .map(|&c| {
1451                // Normalize to [0, q)
1452                let c_pos =
1453                    ((c % self.params.q as i64) + self.params.q as i64) % self.params.q as i64;
1454                // If |c - q/2| < q/4, round to 1; otherwise 0
1455                if (c_pos - half_q).abs() < quarter_q {
1456                    1i64
1457                } else {
1458                    0i64
1459                }
1460            })
1461            .collect();
1462
1463        // Shared secret = H(m) — matches encapsulate
1464        let mut hasher = Sha256::new();
1465        for &mi in &m {
1466            hasher.update(mi.to_le_bytes());
1467        }
1468        Some(hasher.finalize().into())
1469    }
1470
1471    /// Get current key hash for synchronization
1472    pub fn get_key_hash(&self) -> [u8; 32] {
1473        let mut hasher = Sha256::new();
1474        hasher.update(self.current_key.public_key.to_bytes());
1475        hasher.finalize().into()
1476    }
1477
1478    /// Verify key chain integrity (constant-time comparison)
1479    pub fn verify_evolution(&self, expected_hash: &[u8; 32]) -> bool {
1480        self.key_history
1481            .iter()
1482            .any(|h| h.ct_eq(expected_hash).into())
1483    }
1484
1485    /// Get evolution counter for synchronization
1486    pub fn get_evolution_counter(&self) -> u64 {
1487        self.evolution_counter
1488    }
1489
1490    /// Export public key for key exchange
1491    pub fn export_public_key(&self) -> Vec<u8> {
1492        self.current_key.public_key.to_bytes()
1493    }
1494}
1495
1496/// Combined quantum-resistant speculative protocol
1497#[derive(Debug, Clone, Serialize, Deserialize)]
1498pub struct QuantumSpeculativeProtocol {
1499    predictor: TransformerPredictor,
1500    key_evolution: QuantumKeyEvolution,
1501    prediction_threshold: f32,
1502    evolution_interval: u64,
1503    message_count: u64,
1504}
1505
1506impl QuantumSpeculativeProtocol {
1507    pub fn new(
1508        transformer_config: TransformerConfig,
1509        lattice_params: LatticeParams,
1510        seed: u64,
1511    ) -> Self {
1512        Self {
1513            predictor: TransformerPredictor::new(transformer_config),
1514            key_evolution: QuantumKeyEvolution::new(lattice_params, seed),
1515            prediction_threshold: 0.8,
1516            evolution_interval: 10,
1517            message_count: 0,
1518        }
1519    }
1520
1521    /// Create with a specific KEM algorithm
1522    pub fn new_with_algorithm(
1523        transformer_config: TransformerConfig,
1524        lattice_params: LatticeParams,
1525        seed: u64,
1526        algorithm: KemAlgorithm,
1527    ) -> Self {
1528        Self {
1529            predictor: TransformerPredictor::new(transformer_config),
1530            key_evolution: QuantumKeyEvolution::new_with_algorithm(lattice_params, seed, algorithm),
1531            prediction_threshold: 0.8,
1532            evolution_interval: 10,
1533            message_count: 0,
1534        }
1535    }
1536
1537    /// Get the KEM algorithm in use
1538    pub fn algorithm(&self) -> KemAlgorithm {
1539        self.key_evolution.algorithm
1540    }
1541
1542    /// Process an outgoing message with prediction and encryption
1543    pub fn send(&mut self, message: &[u8]) -> QuantumMessage {
1544        // Check if receiver could predict this
1545        let (matches, similarity) = self.predictor.verify_prediction(message);
1546
1547        let payload = if matches && similarity >= self.prediction_threshold {
1548            // Send confirmation only
1549            MessagePayload::Confirmation {
1550                hash: Self::hash_message(message),
1551                length: message.len(),
1552            }
1553        } else {
1554            // Encapsulate shared secret via RLWE KEM
1555            let (ciphertext, shared_secret) = self.key_evolution.encapsulate();
1556
1557            // Derive AES-256-GCM key from KEM shared secret via HKDF
1558            let hk = Hkdf::<Sha256>::new(None, &shared_secret);
1559            let mut aes_key = [0u8; 32];
1560            hk.expand(b"spine-aead-key", &mut aes_key)
1561                .expect("HKDF expand failed");
1562
1563            // Create nonce from message count (unique per message)
1564            let mut nonce_bytes = [0u8; 12];
1565            nonce_bytes[..8].copy_from_slice(&self.message_count.to_le_bytes());
1566            let nonce = Nonce::from_slice(&nonce_bytes);
1567
1568            // Encrypt with AES-256-GCM (authenticated encryption)
1569            let cipher = Aes256Gcm::new_from_slice(&aes_key).expect("AES key length");
1570            let encrypted = cipher.encrypt(nonce, message).expect("AES-GCM encrypt");
1571
1572            // Prepend 12-byte nonce to ciphertext for receiver
1573            let mut encrypted_message = nonce_bytes.to_vec();
1574            encrypted_message.extend(encrypted);
1575
1576            MessagePayload::Full {
1577                ciphertext,
1578                encrypted_message,
1579            }
1580        };
1581
1582        // Evolve key periodically
1583        self.message_count += 1;
1584        let key_evolution = if self.message_count.is_multiple_of(self.evolution_interval) {
1585            Some(self.key_evolution.evolve())
1586        } else {
1587            None
1588        };
1589
1590        QuantumMessage {
1591            payload,
1592            evolution_counter: self.key_evolution.get_evolution_counter(),
1593            key_evolution,
1594        }
1595    }
1596
1597    /// Get a seed for protocol morphing based on current quantum key state
1598    pub fn get_morph_seed(&self) -> u64 {
1599        let key_hash = self.key_evolution.get_key_hash();
1600        u64::from_le_bytes(key_hash[0..8].try_into().unwrap())
1601    }
1602
1603    /// Process an incoming message
1604    pub fn receive(&mut self, quantum_msg: &QuantumMessage) -> Option<Vec<u8>> {
1605        // Sync key evolution if needed
1606        while self.key_evolution.get_evolution_counter() < quantum_msg.evolution_counter {
1607            self.key_evolution.evolve();
1608        }
1609
1610        let message = match &quantum_msg.payload {
1611            MessagePayload::Confirmation { hash, length } => {
1612                // Use prediction
1613                let predicted = self.predictor.predict_sequence(*length, true);
1614
1615                // Verify hash
1616                let predicted_hash = Self::hash_message(&predicted);
1617                if &predicted_hash == hash {
1618                    Some(predicted)
1619                } else {
1620                    None // Prediction mismatch, need retransmission
1621                }
1622            }
1623            MessagePayload::Full {
1624                ciphertext,
1625                encrypted_message,
1626            } => {
1627                // Decapsulate KEM ciphertext to recover shared secret
1628                let shared_secret = self.key_evolution.decapsulate(ciphertext)?;
1629
1630                // Derive AES-256-GCM key from KEM shared secret via HKDF
1631                let hk = Hkdf::<Sha256>::new(None, &shared_secret);
1632                let mut aes_key = [0u8; 32];
1633                hk.expand(b"spine-aead-key", &mut aes_key)
1634                    .expect("HKDF expand failed");
1635
1636                // Extract nonce (first 12 bytes) and ciphertext
1637                if encrypted_message.len() < 12 {
1638                    return None;
1639                }
1640                let nonce = Nonce::from_slice(&encrypted_message[..12]);
1641                let ciphertext_data = &encrypted_message[12..];
1642
1643                // Decrypt with AES-256-GCM (authenticated — rejects tampered data)
1644                let cipher = Aes256Gcm::new_from_slice(&aes_key).expect("AES key length");
1645                cipher.decrypt(nonce, ciphertext_data).ok()
1646            }
1647        };
1648
1649        // Update predictor
1650        if let Some(ref msg) = message {
1651            self.predictor.observe(msg);
1652        }
1653
1654        message
1655    }
1656
1657    fn hash_message(message: &[u8]) -> [u8; 32] {
1658        let mut hasher = Sha256::new();
1659        hasher.update(message);
1660        hasher.finalize().into()
1661    }
1662
1663    /// Set prediction confidence threshold
1664    pub fn set_threshold(&mut self, threshold: f32) {
1665        self.prediction_threshold = threshold.clamp(0.0, 1.0);
1666    }
1667
1668    /// Set key evolution interval
1669    pub fn set_evolution_interval(&mut self, interval: u64) {
1670        self.evolution_interval = interval.max(1);
1671    }
1672
1673    /// Reset protocol state
1674    pub fn reset(&mut self) {
1675        self.predictor.reset();
1676        self.message_count = 0;
1677    }
1678}
1679
1680/// Wire format for quantum-protected messages
1681#[derive(Debug, Clone, Serialize, Deserialize)]
1682pub struct QuantumMessage {
1683    pub payload: MessagePayload,
1684    pub evolution_counter: u64,
1685    pub key_evolution: Option<[u8; 32]>,
1686}
1687
1688#[derive(Debug, Clone, Serialize, Deserialize)]
1689pub enum MessagePayload {
1690    /// Prediction matched - send only confirmation
1691    Confirmation { hash: [u8; 32], length: usize },
1692    /// Full encrypted message
1693    Full {
1694        ciphertext: Vec<u8>,
1695        encrypted_message: Vec<u8>,
1696    },
1697}
1698
1699#[cfg(test)]
1700mod tests {
1701    use super::*;
1702
1703    // -----------------------------------------------------------------
1704    // Zeroization regression guards (NIST SP 800-171 § 3.13.10).
1705    // -----------------------------------------------------------------
1706
1707    /// Compile-time proof that the type implements `ZeroizeOnDrop`.
1708    /// If the derive ever falls off the struct definition, this stops
1709    /// compiling — a louder failure than a silently-leaking secret.
1710    fn assert_zeroize_on_drop<T: ZeroizeOnDrop>() {}
1711
1712    #[test]
1713    fn ringelement_implements_zeroize_on_drop() {
1714        assert_zeroize_on_drop::<RingElement>();
1715    }
1716
1717    #[test]
1718    fn mlkemkeypair_implements_zeroize_on_drop() {
1719        assert_zeroize_on_drop::<MlKemKeyPair>();
1720    }
1721
1722    #[test]
1723    fn quantumkeypair_implements_zeroize_on_drop() {
1724        assert_zeroize_on_drop::<QuantumKeyPair>();
1725    }
1726
1727    /// Runtime verification: filling a `RingElement` with non-zero
1728    /// coefficients and then calling `zeroize()` (the same path the
1729    /// `ZeroizeOnDrop` derive takes on drop) leaves every coefficient
1730    /// at exactly 0. If this ever fails, the `Zeroize` derive is no
1731    /// longer covering `coeffs`.
1732    #[test]
1733    fn ringelement_zeroize_clears_all_coefficients() {
1734        let mut rng = StdRng::seed_from_u64(0xAB_CD);
1735        let mut r = RingElement::random(64, 8_192, &mut rng);
1736        assert!(
1737            r.coeffs.iter().any(|&c| c != 0),
1738            "test precondition: random RingElement should have non-zero coeffs"
1739        );
1740        r.zeroize();
1741        assert!(
1742            r.coeffs.iter().all(|&c| c == 0),
1743            "RingElement::zeroize did not clear every coefficient"
1744        );
1745    }
1746
1747    #[test]
1748    fn mlkemkeypair_zeroize_clears_dk_bytes() {
1749        let mut rng = StdRng::seed_from_u64(0x12_34);
1750        let mut kp = mlkem_ops::generate_768(&mut rng);
1751        assert!(
1752            kp.dk_bytes.iter().any(|&b| b != 0),
1753            "test precondition: fresh ML-KEM dk should be non-zero"
1754        );
1755        kp.zeroize();
1756        // Vec is zeroed (length becomes 0 OR bytes are 0 in place
1757        // — zeroize 1.8 does an in-place clear and then keeps the
1758        // allocation). Either is fine; the contract is "no surviving
1759        // secret".
1760        assert!(
1761            kp.dk_bytes.iter().all(|&b| b == 0),
1762            "MlKemKeyPair::zeroize left non-zero bytes in dk_bytes"
1763        );
1764    }
1765
1766    #[test]
1767    fn quantumkeyevolution_drop_clears_key_history() {
1768        let mut ev = QuantumKeyEvolution::new(LatticeParams::default(), 0xCAFE);
1769        // Push two fake history entries so we have something to scrub.
1770        ev.key_history.push_back([0x11u8; 32]);
1771        ev.key_history.push_back([0x22u8; 32]);
1772        assert_eq!(ev.key_history.len(), 2);
1773        // Manually trigger Drop semantics by replacing with a fresh
1774        // instance — the old ev is dropped, which runs our impl.
1775        // After drop, we can't observe the old buffer's bytes safely,
1776        // but we CAN verify the impl exists by calling drop() directly
1777        // on a still-borrowable target via a helper.
1778        ev.key_history.iter_mut().for_each(|h| h.zeroize());
1779        assert!(
1780            ev.key_history.iter().all(|h| h.iter().all(|&b| b == 0)),
1781            "QuantumKeyEvolution key_history not zeroed"
1782        );
1783    }
1784
1785    #[test]
1786    fn test_positional_encoding() {
1787        let pe = PositionalEncoding::new(100, 64);
1788        let enc0 = pe.get(0);
1789        let enc50 = pe.get(50);
1790        assert_eq!(enc0.len(), 64);
1791        assert_ne!(enc0, enc50);
1792    }
1793
1794    #[test]
1795    fn test_layer_norm() {
1796        let ln = LayerNorm::new(8);
1797        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1798        let output = ln.forward(&input);
1799        assert_eq!(output.len(), 8);
1800
1801        // Mean should be ~0
1802        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
1803        assert!(mean.abs() < 0.01);
1804    }
1805
1806    #[test]
1807    fn test_titans_predictor() {
1808        let config = TitansConfig {
1809            embed_dim: 32,
1810            num_heads: 2,
1811            num_layers: 1,
1812            ff_dim: 64,
1813            max_seq_len: 64,
1814            memory_size: 16,
1815            seed: 42,
1816        };
1817        let mut predictor = TitansPredictor::new(config);
1818
1819        // Observe some data
1820        predictor.observe(b"Hello ");
1821        predictor.observe(b"World");
1822
1823        // Predict next
1824        let (next, conf) = predictor.predict_next();
1825        assert!(conf > 0.0 && conf <= 1.0);
1826        // next is u8, already valid in range 0-255
1827        let _ = next;
1828
1829        // Check surprise is tracked
1830        let surprise = predictor.get_surprise();
1831        assert!(surprise >= 0.0);
1832    }
1833
1834    #[test]
1835    fn test_titans_anomaly_detection() {
1836        let config = TitansConfig {
1837            embed_dim: 32,
1838            num_heads: 2,
1839            num_layers: 1,
1840            ff_dim: 64,
1841            max_seq_len: 64,
1842            memory_size: 16,
1843            seed: 42,
1844        };
1845        let mut predictor = TitansPredictor::new(config);
1846
1847        // Train on normal pattern
1848        for _ in 0..10 {
1849            predictor.observe(b"GET /api/status\n");
1850        }
1851        let _normal_surprise = predictor.get_surprise();
1852
1853        // Introduce anomalous pattern
1854        predictor.observe(b"MALICIOUS_PAYLOAD_XYZ!!!");
1855        let anomaly_surprise = predictor.get_surprise();
1856
1857        // Anomaly should have higher surprise (or at least be detected)
1858        assert!(anomaly_surprise >= 0.0);
1859    }
1860
1861    #[test]
1862    fn test_ring_operations() {
1863        let mut rng = StdRng::seed_from_u64(42);
1864        let params = LatticeParams {
1865            n: 16,
1866            q: 97,
1867            p: 3,
1868            sigma: 2.0,
1869        };
1870
1871        let a = RingElement::random(params.n, params.q, &mut rng);
1872        let b = RingElement::random(params.n, params.q, &mut rng);
1873
1874        let sum = a.add(&b);
1875        let product = a.mul(&b);
1876
1877        assert_eq!(sum.coeffs.len(), params.n);
1878        assert_eq!(product.coeffs.len(), params.n);
1879
1880        // Check coefficients are in range
1881        for &c in &sum.coeffs {
1882            assert!(c >= 0 && c < params.q as i64);
1883        }
1884    }
1885
1886    #[test]
1887    fn test_key_evolution() {
1888        let params = LatticeParams {
1889            n: 32,
1890            q: 257,
1891            p: 3,
1892            sigma: 2.0,
1893        };
1894        let mut ke = QuantumKeyEvolution::new(params, 42);
1895
1896        let hash1 = ke.get_key_hash();
1897        ke.evolve();
1898        let hash2 = ke.get_key_hash();
1899
1900        // Keys should be different after evolution
1901        assert_ne!(hash1, hash2);
1902
1903        // Evolution should be tracked
1904        assert_eq!(ke.get_evolution_counter(), 1);
1905    }
1906
1907    #[test]
1908    fn test_encapsulation() {
1909        let params = LatticeParams {
1910            n: 32,
1911            q: 257,
1912            p: 3,
1913            sigma: 2.0,
1914        };
1915        let mut ke = QuantumKeyEvolution::new(params, 42);
1916
1917        let (ciphertext, _shared_secret1) = ke.encapsulate();
1918        assert!(!ciphertext.is_empty());
1919
1920        let shared_secret2 = ke.decapsulate(&ciphertext);
1921        assert!(shared_secret2.is_some());
1922
1923        // Note: Due to noise, shared secrets may not match exactly in RLWE
1924        // This is a simplified demo - production would use error correction
1925    }
1926
1927    #[test]
1928    fn test_quantum_speculative_protocol() {
1929        let config = TitansConfig {
1930            embed_dim: 16,
1931            num_heads: 2,
1932            num_layers: 1,
1933            ff_dim: 32,
1934            max_seq_len: 32,
1935            memory_size: 8,
1936            seed: 42,
1937        };
1938        let params = LatticeParams {
1939            n: 16,
1940            q: 97,
1941            p: 3,
1942            sigma: 2.0,
1943        };
1944
1945        let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
1946        let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
1947
1948        // Alice sends to Bob
1949        let msg = b"Hello Bob!";
1950        let quantum_msg = alice.send(msg);
1951
1952        // Bob receives
1953        let received = bob.receive(&quantum_msg);
1954        assert!(received.is_some());
1955        assert_eq!(received.unwrap(), msg.to_vec());
1956    }
1957
1958    #[test]
1959    fn test_prediction_efficiency() {
1960        let config = TransformerConfig::default();
1961        let params = LatticeParams::default();
1962
1963        let mut sender = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
1964        let mut receiver = QuantumSpeculativeProtocol::new(config, params, 42);
1965
1966        // Send same pattern multiple times to train predictor
1967        for _ in 0..5 {
1968            let msg1 = sender.send(b"GET /api/status");
1969            receiver.receive(&msg1);
1970
1971            let msg2 = sender.send(b"200 OK");
1972            receiver.receive(&msg2);
1973        }
1974
1975        // After training, check if prediction kicks in
1976        let msg = sender.send(b"GET /api/status");
1977
1978        // Even if not confirmed (training takes longer), protocol should work
1979        let received = receiver.receive(&msg);
1980        assert!(received.is_some());
1981    }
1982
1983    // =========================================================================
1984    // MIRAS-ADAPTIVE PREDICTOR TESTS
1985    // =========================================================================
1986
1987    #[test]
1988    fn test_miras_predictor_basic() {
1989        let config = TitansConfig {
1990            embed_dim: 32,
1991            num_heads: 2,
1992            num_layers: 1,
1993            ff_dim: 64,
1994            max_seq_len: 64,
1995            memory_size: 16,
1996            seed: 42,
1997        };
1998        let mut predictor = MirasTitansPredictor::new(config);
1999
2000        // Observe data
2001        predictor.observe(b"Hello World");
2002
2003        // Check variant starts as Titans
2004        assert_eq!(predictor.variant(), "titans");
2005
2006        // Predict should work
2007        let (next, conf) = predictor.predict_next();
2008        assert!(conf > 0.0 && conf <= 1.0);
2009        // next is u8, already valid in range 0-255
2010        let _ = next;
2011
2012        // Stats should be populated
2013        let stats = predictor.stats();
2014        assert_eq!(stats.message_count, 1);
2015        assert!(stats.miras_enhanced_predictions > 0);
2016    }
2017
2018    #[test]
2019    fn test_miras_predictor_variants() {
2020        let config = TitansConfig {
2021            embed_dim: 32,
2022            num_heads: 2,
2023            num_layers: 1,
2024            ff_dim: 64,
2025            max_seq_len: 64,
2026            memory_size: 16,
2027            seed: 42,
2028        };
2029
2030        // Test each variant - check initial state before observe
2031        for variant in [
2032            MirasVariant::Titans,
2033            MirasVariant::Yaad,
2034            MirasVariant::Moneta { p: 2.0 },
2035            MirasVariant::Memora,
2036        ] {
2037            let predictor = MirasTitansPredictor::new_with_variant(config.clone(), variant);
2038
2039            // Check variant matches what was requested (before any adaptive switching)
2040            match variant {
2041                MirasVariant::Titans => assert_eq!(predictor.variant(), "titans"),
2042                MirasVariant::Yaad => assert_eq!(predictor.variant(), "yaad"),
2043                MirasVariant::Moneta { .. } => assert_eq!(predictor.variant(), "moneta"),
2044                MirasVariant::Memora => assert_eq!(predictor.variant(), "memora"),
2045            }
2046        }
2047
2048        // Test that variants can be used after observation (adaptive switching may occur)
2049        let mut predictor =
2050            MirasTitansPredictor::new_with_variant(config.clone(), MirasVariant::Yaad);
2051        assert_eq!(predictor.variant(), "yaad");
2052
2053        // After observe, low anomaly may switch to Titans (adaptive behavior)
2054        predictor.observe(b"test");
2055        // Variant may have changed due to adaptive switching - this is expected behavior
2056    }
2057
2058    #[test]
2059    fn test_miras_predictor_combined_surprise() {
2060        let config = TitansConfig {
2061            embed_dim: 32,
2062            num_heads: 2,
2063            num_layers: 1,
2064            ff_dim: 64,
2065            max_seq_len: 64,
2066            memory_size: 16,
2067            seed: 42,
2068        };
2069        let mut predictor = MirasTitansPredictor::new(config);
2070
2071        // Train on normal pattern
2072        for _ in 0..5 {
2073            predictor.observe(b"normal message pattern");
2074        }
2075
2076        // Get combined surprise
2077        let combined = predictor.get_combined_surprise();
2078        assert!(combined >= 0.0);
2079
2080        // Get individual surprises
2081        let titans_surprise = predictor.get_surprise();
2082        let miras_surprise = predictor.get_miras_surprise();
2083
2084        assert!(titans_surprise >= 0.0);
2085        assert!(miras_surprise.is_some());
2086    }
2087
2088    #[test]
2089    fn test_miras_predictor_anomaly_level() {
2090        let config = TitansConfig {
2091            embed_dim: 32,
2092            num_heads: 2,
2093            num_layers: 1,
2094            ff_dim: 64,
2095            max_seq_len: 64,
2096            memory_size: 16,
2097            seed: 42,
2098        };
2099        let mut predictor = MirasTitansPredictor::new(config);
2100
2101        // Initially no anomaly
2102        assert_eq!(predictor.anomaly_level(), 0.0);
2103
2104        // After observation, anomaly level is tracked
2105        predictor.observe(b"test");
2106        let level = predictor.anomaly_level();
2107        assert!(level >= 0.0); // Some level is tracked
2108    }
2109
2110    #[test]
2111    fn test_miras_predictor_reset() {
2112        let config = TitansConfig {
2113            embed_dim: 32,
2114            num_heads: 2,
2115            num_layers: 1,
2116            ff_dim: 64,
2117            max_seq_len: 64,
2118            memory_size: 16,
2119            seed: 42,
2120        };
2121        let mut predictor = MirasTitansPredictor::new(config);
2122
2123        // Add some state
2124        for _ in 0..10 {
2125            predictor.observe(b"data");
2126        }
2127        assert!(predictor.stats().message_count > 0);
2128
2129        // Reset
2130        predictor.reset_all();
2131        let stats = predictor.stats();
2132        assert_eq!(stats.message_count, 0);
2133    }
2134
2135    // =========================================================================
2136    // COMPREHENSIVE SECURITY TESTS (W3: Security Validation)
2137    // =========================================================================
2138
2139    #[test]
2140    fn test_rlwe_ring_arithmetic_correctness() {
2141        // Test ring arithmetic properties: associativity, commutativity, distributivity
2142        let mut rng = StdRng::seed_from_u64(12345);
2143        let params = LatticeParams {
2144            n: 32,
2145            q: 257,
2146            p: 3,
2147            sigma: 2.0,
2148        };
2149
2150        let a = RingElement::random(params.n, params.q, &mut rng);
2151        let b = RingElement::random(params.n, params.q, &mut rng);
2152        let c = RingElement::random(params.n, params.q, &mut rng);
2153
2154        // Commutativity of addition: a + b = b + a
2155        let ab = a.add(&b);
2156        let ba = b.add(&a);
2157        assert_eq!(ab.coeffs, ba.coeffs, "Addition should be commutative");
2158
2159        // Associativity of addition: (a + b) + c = a + (b + c)
2160        let ab_c = a.add(&b).add(&c);
2161        let a_bc = a.add(&b.add(&c));
2162        assert_eq!(ab_c.coeffs, a_bc.coeffs, "Addition should be associative");
2163
2164        // Distributivity: a * (b + c) = a*b + a*c
2165        let a_times_bplusc = a.mul(&b.add(&c));
2166        let ab_plus_ac = a.mul(&b).add(&a.mul(&c));
2167        assert_eq!(
2168            a_times_bplusc.coeffs, ab_plus_ac.coeffs,
2169            "Multiplication should distribute over addition"
2170        );
2171    }
2172
2173    #[test]
2174    fn test_rlwe_gaussian_distribution() {
2175        // Verify Gaussian noise has expected statistical properties
2176        let mut rng = StdRng::seed_from_u64(54321);
2177        let params = LatticeParams {
2178            n: 1024,
2179            q: 12289, // NIST-like parameter
2180            p: 3,
2181            sigma: 3.2,
2182        };
2183
2184        let e = RingElement::random_gaussian(params.n, params.q, params.sigma, &mut rng);
2185
2186        // Calculate mean and variance
2187        let mean: f64 = e.coeffs.iter().map(|&c| c as f64).sum::<f64>() / params.n as f64;
2188        let variance: f64 = e
2189            .coeffs
2190            .iter()
2191            .map(|&c| (c as f64 - mean).powi(2))
2192            .sum::<f64>()
2193            / params.n as f64;
2194
2195        // Mean should be close to 0 (centered)
2196        assert!(
2197            mean.abs() < params.sigma,
2198            "Gaussian mean should be near 0, got {}",
2199            mean
2200        );
2201
2202        // Variance should be close to sigma^2
2203        let expected_variance = params.sigma * params.sigma;
2204        assert!(
2205            (variance - expected_variance).abs() < expected_variance * 0.5,
2206            "Variance {} should be close to sigma^2 = {}",
2207            variance,
2208            expected_variance
2209        );
2210    }
2211
2212    #[test]
2213    fn test_rlwe_ternary_distribution() {
2214        // Verify ternary noise is in {-1, 0, 1}
2215        let mut rng = StdRng::seed_from_u64(98765);
2216        let params = LatticeParams {
2217            n: 256,
2218            q: 257,
2219            p: 3,
2220            sigma: 2.0,
2221        };
2222
2223        let s = RingElement::random_ternary(params.n, params.q, &mut rng);
2224
2225        // Ternary coefficients are in {-1, 0, 1} before reduction
2226        for &coeff in &s.coeffs {
2227            assert!(
2228                coeff == 0 || coeff == 1 || coeff == -1,
2229                "Ternary coefficient should be -1, 0, or 1, got {}",
2230                coeff
2231            );
2232        }
2233
2234        // Check roughly uniform distribution among {-1, 0, 1}
2235        let count_zero = s.coeffs.iter().filter(|&&c| c == 0).count();
2236        let count_one = s.coeffs.iter().filter(|&&c| c == 1).count();
2237        let count_neg = s.coeffs.iter().filter(|&&c| c == -1).count();
2238
2239        // Each should be roughly 1/3 of total
2240        let expected = params.n / 3;
2241        let tolerance = params.n / 4; // Allow 25% deviation
2242        assert!(
2243            (count_zero as isize - expected as isize).unsigned_abs() < tolerance,
2244            "Ternary distribution unbalanced: zeros={}, ones={}, neg={}",
2245            count_zero,
2246            count_one,
2247            count_neg
2248        );
2249    }
2250
2251    #[test]
2252    fn test_key_evolution_forward_secrecy() {
2253        // Test that key evolution provides forward secrecy
2254        let params = LatticeParams {
2255            n: 64,
2256            q: 257,
2257            p: 3,
2258            sigma: 2.0,
2259        };
2260
2261        let mut ke1 = QuantumKeyEvolution::new(params.clone(), 42);
2262        let mut ke2 = QuantumKeyEvolution::new(params, 42);
2263
2264        // Both start with same state
2265        assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
2266
2267        // Evolve ke1 multiple times
2268        for _ in 0..5 {
2269            ke1.evolve();
2270        }
2271
2272        // Keys should now be different
2273        assert_ne!(ke1.get_key_hash(), ke2.get_key_hash());
2274
2275        // Counter should track evolutions
2276        assert_eq!(ke1.get_evolution_counter(), 5);
2277        assert_eq!(ke2.get_evolution_counter(), 0);
2278
2279        // Sync ke2 to same state
2280        for _ in 0..5 {
2281            ke2.evolve();
2282        }
2283
2284        // Now should match again (deterministic evolution)
2285        assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
2286    }
2287
2288    #[test]
2289    fn test_key_evolution_history_integrity() {
2290        let params = LatticeParams {
2291            n: 32,
2292            q: 257,
2293            p: 3,
2294            sigma: 2.0,
2295        };
2296
2297        let mut ke = QuantumKeyEvolution::new(params, 42);
2298
2299        // Collect hashes during evolution
2300        let mut hashes = Vec::new();
2301        for _ in 0..10 {
2302            let hash = ke.evolve();
2303            hashes.push(hash);
2304        }
2305
2306        // All hashes should be unique (no cycles)
2307        let unique_count = hashes
2308            .iter()
2309            .collect::<std::collections::HashSet<_>>()
2310            .len();
2311        assert_eq!(unique_count, 10, "All evolution hashes should be unique");
2312
2313        // History verification should work for recent keys
2314        for hash in &hashes {
2315            assert!(
2316                ke.verify_evolution(hash),
2317                "Recent evolution should be verifiable"
2318            );
2319        }
2320    }
2321
2322    #[test]
2323    fn test_quantum_protocol_message_integrity() {
2324        // Test that messages are correctly encrypted and decrypted
2325        let config = TitansConfig {
2326            embed_dim: 16,
2327            num_heads: 2,
2328            num_layers: 1,
2329            ff_dim: 32,
2330            max_seq_len: 32,
2331            memory_size: 8,
2332            seed: 42,
2333        };
2334        let params = LatticeParams {
2335            n: 32,
2336            q: 257,
2337            p: 3,
2338            sigma: 2.0,
2339        };
2340
2341        let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
2342        let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
2343
2344        // Test multiple different message sizes
2345        let test_messages = [
2346            b"A".to_vec(),
2347            b"Short".to_vec(),
2348            b"Medium length message".to_vec(),
2349            b"This is a longer message to test variable length handling properly".to_vec(),
2350        ];
2351
2352        for msg in &test_messages {
2353            let quantum_msg = alice.send(msg);
2354            let received = bob.receive(&quantum_msg);
2355            assert!(received.is_some(), "Should receive message");
2356            assert_eq!(
2357                &received.unwrap(),
2358                msg,
2359                "Received message should match original"
2360            );
2361        }
2362    }
2363
2364    #[test]
2365    fn test_tampered_ciphertext_detection() {
2366        let params = LatticeParams {
2367            n: 32,
2368            q: 257,
2369            p: 3,
2370            sigma: 2.0,
2371        };
2372
2373        let mut ke = QuantumKeyEvolution::new(params, 42);
2374
2375        let (mut ciphertext, original_secret) = ke.encapsulate();
2376
2377        // Tamper with ciphertext
2378        if !ciphertext.is_empty() {
2379            ciphertext[0] ^= 0xFF;
2380        }
2381
2382        // Decapsulation with tampered ciphertext should produce different result
2383        let tampered_secret = ke.decapsulate(&ciphertext);
2384
2385        if let Some(tampered) = tampered_secret {
2386            // Due to noise characteristics, tampered ciphertext produces different secret
2387            // In production, we'd add MAC for integrity checking
2388            assert_ne!(
2389                tampered, original_secret,
2390                "Tampered ciphertext should produce different secret"
2391            );
2392        }
2393    }
2394
2395    #[test]
2396    fn test_lattice_params_security_levels() {
2397        // Test different security parameter sets
2398        let toy_params = LatticeParams {
2399            n: 16,
2400            q: 97,
2401            p: 3,
2402            sigma: 2.0,
2403        };
2404        let medium_params = LatticeParams {
2405            n: 256,
2406            q: 7681,
2407            p: 3,
2408            sigma: 3.19,
2409        };
2410        let _high_params = LatticeParams {
2411            n: 1024,
2412            q: 12289,
2413            p: 3,
2414            sigma: 3.19,
2415        };
2416
2417        // Verify params are valid (n is power of 2, q is prime)
2418        assert!(
2419            toy_params.n.is_power_of_two(),
2420            "n should be power of 2 for NTT"
2421        );
2422        assert!(
2423            medium_params.n.is_power_of_two(),
2424            "n should be power of 2 for NTT"
2425        );
2426
2427        // Key generation should work for all param sets
2428        let mut ke_toy = QuantumKeyEvolution::new(toy_params, 1);
2429        let mut ke_med = QuantumKeyEvolution::new(medium_params, 1);
2430
2431        // Both should be able to encapsulate/decapsulate
2432        let (ct_toy, _) = ke_toy.encapsulate();
2433        let (ct_med, _) = ke_med.encapsulate();
2434
2435        assert!(!ct_toy.is_empty());
2436        assert!(!ct_med.is_empty());
2437
2438        // Medium params should produce larger ciphertext
2439        assert!(
2440            ct_med.len() > ct_toy.len(),
2441            "Higher security params should produce larger ciphertext"
2442        );
2443    }
2444
2445    #[test]
2446    fn test_titans_predictor_statistical_properties() {
2447        let config = TitansConfig {
2448            embed_dim: 32,
2449            num_heads: 2,
2450            num_layers: 1,
2451            ff_dim: 64,
2452            max_seq_len: 64,
2453            memory_size: 16,
2454            seed: 42,
2455        };
2456        let mut predictor = TitansPredictor::new(config);
2457
2458        // Train on repetitive pattern
2459        let pattern = b"ABCABC";
2460        for _ in 0..20 {
2461            predictor.observe(pattern);
2462        }
2463
2464        // Predictions should have reasonable confidence
2465        let (_, confidence) = predictor.predict_next();
2466        assert!(
2467            (0.0..=1.0).contains(&confidence),
2468            "Confidence should be normalized"
2469        );
2470
2471        // Surprise should be tracked
2472        let surprise = predictor.get_surprise();
2473        assert!(surprise >= 0.0, "Surprise should be non-negative");
2474    }
2475
2476    #[test]
2477    fn test_kem_shared_secret_match() {
2478        // Verify that encapsulate and decapsulate produce matching shared secrets
2479        let params = LatticeParams {
2480            n: 64,
2481            q: 257,
2482            p: 3,
2483            sigma: 1.5,
2484        };
2485        let mut ke = QuantumKeyEvolution::new(params, 12345);
2486
2487        let (ciphertext, shared_secret_enc) = ke.encapsulate();
2488        let shared_secret_dec = ke.decapsulate(&ciphertext).unwrap();
2489
2490        assert_eq!(
2491            shared_secret_enc, shared_secret_dec,
2492            "KEM shared secrets must match between encapsulate and decapsulate"
2493        );
2494    }
2495
2496    #[test]
2497    fn test_aead_tampered_ciphertext_rejected() {
2498        // Verify that AES-256-GCM rejects tampered ciphertext
2499        let config = TransformerConfig::default();
2500        let params = LatticeParams {
2501            n: 32,
2502            q: 257,
2503            p: 3,
2504            sigma: 2.0,
2505        };
2506
2507        let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
2508        let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
2509
2510        let msg = b"Secret message";
2511        let mut quantum_msg = alice.send(msg);
2512
2513        // Tamper with the encrypted message
2514        if let MessagePayload::Full {
2515            ref mut encrypted_message,
2516            ..
2517        } = quantum_msg.payload
2518        {
2519            if let Some(byte) = encrypted_message.last_mut() {
2520                *byte ^= 0xFF; // Flip bits
2521            }
2522        }
2523
2524        // Bob should reject tampered message (AES-GCM authentication failure)
2525        let received = bob.receive(&quantum_msg);
2526        assert!(
2527            received.is_none(),
2528            "Tampered ciphertext must be rejected by AEAD"
2529        );
2530    }
2531
2532    #[test]
2533    fn test_key_evolution_maintains_kem_invariant() {
2534        // Verify that key evolution produces valid keypairs (b = a*s + e)
2535        let params = LatticeParams {
2536            n: 32,
2537            q: 257,
2538            p: 3,
2539            sigma: 2.0,
2540        };
2541        let mut ke = QuantumKeyEvolution::new(params, 99);
2542
2543        for _ in 0..5 {
2544            ke.evolve();
2545            // After evolution, encaps/decaps should still work
2546            let (ct, ss_enc) = ke.encapsulate();
2547            let ss_dec = ke.decapsulate(&ct).unwrap();
2548            assert_eq!(ss_enc, ss_dec, "KEM must work after key evolution");
2549        }
2550    }
2551
2552    #[test]
2553    fn test_key_evolution_deterministic_hkdf() {
2554        // Verify that two instances with the same seed evolve identically
2555        let params = LatticeParams::default();
2556        let mut ke1 = QuantumKeyEvolution::new(params.clone(), 7777);
2557        let mut ke2 = QuantumKeyEvolution::new(params, 7777);
2558
2559        for _ in 0..5 {
2560            let h1 = ke1.evolve();
2561            let h2 = ke2.evolve();
2562            assert_eq!(
2563                h1, h2,
2564                "Deterministic evolution must produce identical hashes"
2565            );
2566        }
2567        assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
2568    }
2569
2570    #[test]
2571    fn test_aes_gcm_round_trip() {
2572        // Full send/receive round-trip with AES-256-GCM
2573        let config = TransformerConfig::default();
2574        let params = LatticeParams {
2575            n: 64,
2576            q: 257,
2577            p: 3,
2578            sigma: 1.5,
2579        };
2580
2581        let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 100);
2582        let mut bob = QuantumSpeculativeProtocol::new(config, params, 100);
2583
2584        // Send multiple messages
2585        for i in 0..5 {
2586            let msg = format!("Message number {}", i);
2587            let quantum_msg = alice.send(msg.as_bytes());
2588            let received = bob.receive(&quantum_msg);
2589            assert!(received.is_some(), "Message {} should decrypt", i);
2590            assert_eq!(
2591                received.unwrap(),
2592                msg.as_bytes(),
2593                "Message {} content mismatch",
2594                i
2595            );
2596        }
2597    }
2598
2599    // ── Phase 24: ML-KEM (FIPS 203) tests ──────────────────────────────
2600
2601    #[test]
2602    fn test_mlkem_512_round_trip() {
2603        let mut rng = StdRng::seed_from_u64(1);
2604        let kp = mlkem_ops::generate_512(&mut rng);
2605        assert_eq!(kp.algorithm, KemAlgorithm::MlKem512);
2606
2607        let (ct, ss_enc) = mlkem_ops::encapsulate_512(&kp.ek_bytes, &mut rng).unwrap();
2608        let ss_dec = mlkem_ops::decapsulate_512(&kp.dk_bytes, &ct).unwrap();
2609        assert_eq!(ss_enc.len(), 32);
2610        assert_eq!(ss_enc, ss_dec, "ML-KEM-512 shared secret mismatch");
2611    }
2612
2613    #[test]
2614    fn test_mlkem_768_round_trip() {
2615        let mut rng = StdRng::seed_from_u64(2);
2616        let kp = mlkem_ops::generate_768(&mut rng);
2617        assert_eq!(kp.algorithm, KemAlgorithm::MlKem768);
2618
2619        let (ct, ss_enc) = mlkem_ops::encapsulate_768(&kp.ek_bytes, &mut rng).unwrap();
2620        let ss_dec = mlkem_ops::decapsulate_768(&kp.dk_bytes, &ct).unwrap();
2621        assert_eq!(ss_enc.len(), 32);
2622        assert_eq!(ss_enc, ss_dec, "ML-KEM-768 shared secret mismatch");
2623    }
2624
2625    #[test]
2626    fn test_mlkem_1024_round_trip() {
2627        let mut rng = StdRng::seed_from_u64(3);
2628        let kp = mlkem_ops::generate_1024(&mut rng);
2629        assert_eq!(kp.algorithm, KemAlgorithm::MlKem1024);
2630
2631        let (ct, ss_enc) = mlkem_ops::encapsulate_1024(&kp.ek_bytes, &mut rng).unwrap();
2632        let ss_dec = mlkem_ops::decapsulate_1024(&kp.dk_bytes, &ct).unwrap();
2633        assert_eq!(ss_enc.len(), 32);
2634        assert_eq!(ss_enc, ss_dec, "ML-KEM-1024 shared secret mismatch");
2635    }
2636
2637    #[test]
2638    fn test_mlkem_different_keypairs_produce_different_secrets() {
2639        let mut rng = StdRng::seed_from_u64(4);
2640        let kp1 = mlkem_ops::generate_768(&mut rng);
2641        let kp2 = mlkem_ops::generate_768(&mut rng);
2642
2643        let (_, ss1) = mlkem_ops::encapsulate_768(&kp1.ek_bytes, &mut rng).unwrap();
2644        let (_, ss2) = mlkem_ops::encapsulate_768(&kp2.ek_bytes, &mut rng).unwrap();
2645
2646        // Overwhelmingly likely to differ (2^-256 collision probability)
2647        assert_ne!(ss1, ss2, "Different keypairs should yield different secrets");
2648    }
2649
2650    #[test]
2651    fn test_mlkem_wrong_key_decapsulation_fails() {
2652        let mut rng = StdRng::seed_from_u64(5);
2653        let kp1 = mlkem_ops::generate_768(&mut rng);
2654        let kp2 = mlkem_ops::generate_768(&mut rng);
2655
2656        let (ct, ss_enc) = mlkem_ops::encapsulate_768(&kp1.ek_bytes, &mut rng).unwrap();
2657        // Decapsulating with wrong key should yield a different (implicit reject) secret
2658        let ss_wrong = mlkem_ops::decapsulate_768(&kp2.dk_bytes, &ct).unwrap();
2659        assert_ne!(
2660            ss_enc, ss_wrong,
2661            "Wrong DK must produce different shared secret (implicit reject)"
2662        );
2663    }
2664
2665    #[test]
2666    fn test_kem_algorithm_default() {
2667        assert_eq!(KemAlgorithm::default(), KemAlgorithm::MlKem768);
2668    }
2669
2670    #[test]
2671    fn test_quantum_key_evolution_with_mlkem() {
2672        let params = LatticeParams {
2673            n: 32,
2674            q: 257,
2675            p: 3,
2676            sigma: 2.0,
2677        };
2678        let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 42, KemAlgorithm::MlKem768);
2679
2680        // Encapsulate/decapsulate should work with ML-KEM
2681        let (ct, ss_enc) = ke.encapsulate();
2682        let ss_dec = ke.decapsulate(&ct).unwrap();
2683        assert_eq!(ss_enc, ss_dec, "ML-KEM encaps/decaps via QuantumKeyEvolution");
2684        assert!(!ct.is_empty());
2685    }
2686
2687    #[test]
2688    fn test_quantum_key_evolution_hybrid_kem() {
2689        let params = LatticeParams {
2690            n: 32,
2691            q: 257,
2692            p: 3,
2693            sigma: 2.0,
2694        };
2695        let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 42, KemAlgorithm::Hybrid);
2696
2697        let (ct, ss_enc) = ke.encapsulate();
2698        let ss_dec = ke.decapsulate(&ct).unwrap();
2699        assert_eq!(ss_enc, ss_dec, "Hybrid RLWE+ML-KEM shared secret mismatch");
2700        assert_eq!(ss_enc.len(), 32, "Hybrid shared secret should be 32 bytes");
2701        // Hybrid ciphertext is larger (RLWE + ML-KEM-768 concatenated)
2702        assert!(ct.len() > 100, "Hybrid ciphertext should be large");
2703    }
2704
2705    #[test]
2706    fn test_mlkem_key_evolution_maintains_invariant() {
2707        let params = LatticeParams {
2708            n: 32,
2709            q: 257,
2710            p: 3,
2711            sigma: 2.0,
2712        };
2713        let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 55, KemAlgorithm::MlKem768);
2714
2715        for i in 0..5 {
2716            ke.evolve();
2717            let (ct, ss_enc) = ke.encapsulate();
2718            let ss_dec = ke.decapsulate(&ct).unwrap();
2719            assert_eq!(ss_enc, ss_dec, "ML-KEM must work after evolution step {}", i);
2720        }
2721    }
2722
2723    #[test]
2724    fn test_quantum_speculative_protocol_with_mlkem() {
2725        let config = TransformerConfig::default();
2726        let params = LatticeParams {
2727            n: 32,
2728            q: 257,
2729            p: 3,
2730            sigma: 2.0,
2731        };
2732
2733        let mut alice = QuantumSpeculativeProtocol::new_with_algorithm(
2734            config.clone(),
2735            params.clone(),
2736            42,
2737            KemAlgorithm::MlKem768,
2738        );
2739        let mut bob = QuantumSpeculativeProtocol::new_with_algorithm(
2740            config,
2741            params,
2742            42,
2743            KemAlgorithm::MlKem768,
2744        );
2745
2746        assert_eq!(alice.algorithm(), KemAlgorithm::MlKem768);
2747        assert_eq!(bob.algorithm(), KemAlgorithm::MlKem768);
2748
2749        let msg = b"ML-KEM secured message";
2750        let quantum_msg = alice.send(msg);
2751        let received = bob.receive(&quantum_msg);
2752        assert!(received.is_some());
2753        assert_eq!(received.unwrap(), msg);
2754    }
2755
2756    #[test]
2757    fn test_mlkem_ciphertext_sizes() {
2758        let mut rng = StdRng::seed_from_u64(6);
2759        let kp512 = mlkem_ops::generate_512(&mut rng);
2760        let kp768 = mlkem_ops::generate_768(&mut rng);
2761        let kp1024 = mlkem_ops::generate_1024(&mut rng);
2762
2763        let (ct512, _) = mlkem_ops::encapsulate_512(&kp512.ek_bytes, &mut rng).unwrap();
2764        let (ct768, _) = mlkem_ops::encapsulate_768(&kp768.ek_bytes, &mut rng).unwrap();
2765        let (ct1024, _) = mlkem_ops::encapsulate_1024(&kp1024.ek_bytes, &mut rng).unwrap();
2766
2767        assert_eq!(ct512.len(), 768, "ML-KEM-512 ciphertext should be 768 bytes");
2768        assert_eq!(ct768.len(), 1088, "ML-KEM-768 ciphertext should be 1088 bytes");
2769        assert_eq!(ct1024.len(), 1568, "ML-KEM-1024 ciphertext should be 1568 bytes");
2770
2771        // Monotonically increasing with security level
2772        assert!(ct512.len() < ct768.len());
2773        assert!(ct768.len() < ct1024.len());
2774    }
2775}