#![allow(dead_code)]
use rand::prelude::*;
use rand::rngs::StdRng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};
use spine_neural::{
Activation, DenseLayer, MirasNeuralEncoder, MirasVariant, MultiHeadAttention,
NeuralEncoderConfig, TitansMemory,
};
use std::collections::VecDeque;
use subtle::ConstantTimeEq;
use ml_kem::kem::{Decapsulate, Encapsulate, EncapsulationKey, DecapsulationKey};
use ml_kem::{MlKem512, MlKem768, MlKem1024, KemCore, EncodedSizeUser, Encoded,
MlKem512Params, MlKem768Params, MlKem1024Params};
use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use hkdf::Hkdf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionalEncoding {
max_len: usize,
embed_dim: usize,
encodings: Vec<Vec<f32>>,
}
impl PositionalEncoding {
pub fn new(max_len: usize, embed_dim: usize) -> Self {
let encodings: Vec<Vec<f32>> = (0..max_len)
.map(|pos| {
(0..embed_dim)
.map(|i| {
let angle = pos as f32
/ (10000.0_f32).powf(2.0 * (i / 2) as f32 / embed_dim as f32);
if i % 2 == 0 {
angle.sin()
} else {
angle.cos()
}
})
.collect()
})
.collect();
Self {
max_len,
embed_dim,
encodings,
}
}
pub fn get(&self, position: usize) -> &[f32] {
if self.encodings.is_empty() {
return &[];
}
&self.encodings[position.min(self.max_len.saturating_sub(1))]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerNorm {
dim: usize,
gamma: Vec<f32>,
beta: Vec<f32>,
eps: f32,
}
impl LayerNorm {
pub fn new(dim: usize) -> Self {
Self {
dim,
gamma: vec![1.0; dim],
beta: vec![0.0; dim],
eps: 1e-5,
}
}
pub fn forward(&self, x: &[f32]) -> Vec<f32> {
if x.is_empty() {
return Vec::new();
}
let n = x.len() as f32;
let mean: f32 = x.iter().sum::<f32>() / n;
let var: f32 = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
let std = (var + self.eps).sqrt();
x.iter()
.enumerate()
.map(|(i, &v)| self.gamma[i % self.dim] * (v - mean) / std + self.beta[i % self.dim])
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedForward {
linear1: DenseLayer,
linear2: DenseLayer,
}
impl FeedForward {
pub fn new(embed_dim: usize, ff_dim: usize, rng: &mut StdRng) -> Self {
Self {
linear1: DenseLayer::new(embed_dim, ff_dim, Activation::GELU, rng),
linear2: DenseLayer::new(ff_dim, embed_dim, Activation::None, rng),
}
}
pub fn forward(&mut self, x: &[f32]) -> Vec<f32> {
let hidden = self.linear1.forward(x);
self.linear2.forward(&hidden)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TitansBlock {
memory: TitansMemory,
attention: MultiHeadAttention,
ff: FeedForward,
norm1: LayerNorm,
norm2: LayerNorm,
norm3: LayerNorm,
embed_dim: usize,
}
impl TitansBlock {
pub fn new(
embed_dim: usize,
num_heads: usize,
ff_dim: usize,
memory_size: usize,
rng: &mut StdRng,
) -> Self {
Self {
memory: TitansMemory::new(embed_dim, embed_dim, memory_size, rng),
attention: MultiHeadAttention::new(embed_dim, num_heads, rng),
ff: FeedForward::new(embed_dim, ff_dim, rng),
norm1: LayerNorm::new(embed_dim),
norm2: LayerNorm::new(embed_dim),
norm3: LayerNorm::new(embed_dim),
embed_dim,
}
}
pub fn forward(&mut self, sequence: &[Vec<f32>]) -> Vec<f32> {
if sequence.is_empty() {
return vec![0.0; self.embed_dim];
}
let last = &sequence[sequence.len() - 1];
let memory_out = self.memory.forward(last);
let residual1: Vec<f32> = memory_out
.iter()
.zip(last.iter())
.map(|(m, l)| m + l)
.collect();
let normed1 = self.norm1.forward(&residual1);
let attended = self.attention.forward(sequence);
let residual2: Vec<f32> = attended
.iter()
.zip(normed1.iter())
.map(|(a, n)| a + n)
.collect();
let normed2 = self.norm2.forward(&residual2);
let ff_out = self.ff.forward(&normed2);
let residual3: Vec<f32> = ff_out
.iter()
.zip(normed2.iter())
.map(|(f, n)| f + n)
.collect();
self.norm3.forward(&residual3)
}
pub fn get_surprise(&self) -> f32 {
self.memory.get_surprise()
}
pub fn reset_memory(&mut self) {
self.memory.reset_state();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ByteTokenizer {
embed_dim: usize,
embeddings: Vec<Vec<f32>>, }
impl ByteTokenizer {
pub fn new(embed_dim: usize, rng: &mut StdRng) -> Self {
let scale = (1.0 / embed_dim as f32).sqrt();
let embeddings: Vec<Vec<f32>> = (0..256)
.map(|_| {
(0..embed_dim)
.map(|_| rng.gen::<f32>() * 2.0 * scale - scale)
.collect()
})
.collect();
Self {
embed_dim,
embeddings,
}
}
pub fn encode(&self, byte: u8) -> &[f32] {
&self.embeddings[byte as usize]
}
pub fn encode_sequence(&self, bytes: &[u8]) -> Vec<Vec<f32>> {
bytes
.iter()
.map(|&b| self.embeddings[b as usize].clone())
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputProjection {
weights: Vec<Vec<f32>>, temperature: f32,
}
impl OutputProjection {
pub fn new(embed_dim: usize, rng: &mut StdRng) -> Self {
let scale = (1.0 / embed_dim as f32).sqrt();
let weights: Vec<Vec<f32>> = (0..256)
.map(|_| {
(0..embed_dim)
.map(|_| rng.gen::<f32>() * 2.0 * scale - scale)
.collect()
})
.collect();
Self {
weights,
temperature: 1.0,
}
}
pub fn set_temperature(&mut self, temp: f32) {
self.temperature = temp.max(0.01);
}
pub fn forward(&self, hidden: &[f32]) -> Vec<f32> {
let mut logits = vec![0.0; 256];
for (i, w) in self.weights.iter().enumerate() {
for (j, &h) in hidden.iter().enumerate() {
logits[i] += w[j] * h;
}
logits[i] /= self.temperature;
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for l in &mut logits {
*l = (*l - max).exp();
sum += *l;
}
for l in &mut logits {
*l /= sum;
}
logits
}
pub fn sample(&self, probs: &[f32], rng: &mut StdRng) -> u8 {
let mut cumsum = 0.0;
let r: f32 = rng.gen();
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return i as u8;
}
}
255
}
pub fn argmax(&self, probs: &[f32]) -> u8 {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less)
})
.map(|(i, _)| i as u8)
.unwrap_or(0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TitansPredictor {
tokenizer: ByteTokenizer,
positional: PositionalEncoding,
blocks: Vec<TitansBlock>,
output: OutputProjection,
embed_dim: usize,
max_seq_len: usize,
memory_size: usize,
context_window: VecDeque<Vec<f32>>,
total_surprise: f32,
#[serde(skip, default = "default_rng")]
rng: StdRng,
}
#[derive(Debug, Clone)]
pub struct MirasTitansPredictor {
base: TitansPredictor,
miras_encoder: Option<MirasNeuralEncoder>,
active_variant: MirasVariant,
surprise_history: VecDeque<f32>,
anomaly_threshold: f32,
message_count: u64,
miras_enhanced_predictions: u64,
latent_dim: usize,
}
impl MirasTitansPredictor {
pub fn new(config: TitansConfig) -> Self {
let base = TitansPredictor::new(config.clone());
let encoder_config = NeuralEncoderConfig {
input_dim: config.embed_dim,
latent_dim: config.embed_dim,
hidden_dims: vec![config.ff_dim, config.embed_dim],
attention_heads: config.num_heads,
seed: config.seed + 1,
miras_variant: MirasVariant::Titans,
memory_tokens: config.memory_size,
};
let miras_encoder = Some(MirasNeuralEncoder::new(&encoder_config));
Self {
base,
miras_encoder,
active_variant: MirasVariant::Titans,
surprise_history: VecDeque::with_capacity(100),
anomaly_threshold: 0.5,
message_count: 0,
miras_enhanced_predictions: 0,
latent_dim: config.embed_dim,
}
}
pub fn new_with_variant(config: TitansConfig, variant: MirasVariant) -> Self {
let base = TitansPredictor::new(config.clone());
let encoder_config = NeuralEncoderConfig {
input_dim: config.embed_dim,
latent_dim: config.embed_dim,
hidden_dims: vec![config.ff_dim, config.embed_dim],
attention_heads: config.num_heads,
seed: config.seed + 1,
miras_variant: variant,
memory_tokens: config.memory_size,
};
Self {
base,
miras_encoder: Some(MirasNeuralEncoder::new(&encoder_config)),
active_variant: variant,
surprise_history: VecDeque::with_capacity(100),
anomaly_threshold: 0.5,
message_count: 0,
miras_enhanced_predictions: 0,
latent_dim: config.embed_dim,
}
}
pub fn set_anomaly_threshold(&mut self, threshold: f32) {
self.anomaly_threshold = threshold;
}
pub fn variant(&self) -> &str {
match self.active_variant {
MirasVariant::Titans => "titans",
MirasVariant::Yaad => "yaad",
MirasVariant::Moneta { .. } => "moneta",
MirasVariant::Memora => "memora",
}
}
pub fn anomaly_level(&self) -> f32 {
if self.surprise_history.is_empty() {
0.0
} else {
self.surprise_history.iter().sum::<f32>() / self.surprise_history.len() as f32
}
}
fn maybe_switch_variant(&mut self) {
let anomaly = self.anomaly_level();
let new_variant = if anomaly > self.anomaly_threshold * 2.0 {
MirasVariant::Yaad
} else if anomaly > self.anomaly_threshold {
MirasVariant::Memora
} else if self.message_count > 10000 {
MirasVariant::Moneta { p: 2.0 }
} else {
MirasVariant::Titans
};
let variant_changed = !matches!(
(&new_variant, &self.active_variant),
(MirasVariant::Titans, MirasVariant::Titans)
| (MirasVariant::Yaad, MirasVariant::Yaad)
| (MirasVariant::Moneta { .. }, MirasVariant::Moneta { .. })
| (MirasVariant::Memora, MirasVariant::Memora)
);
if variant_changed {
self.active_variant = new_variant;
}
}
pub fn observe(&mut self, message: &[u8]) {
self.base.observe(message);
let surprise = self.base.get_surprise();
self.surprise_history.push_back(surprise);
if self.surprise_history.len() > 100 {
self.surprise_history.pop_front();
}
if let Some(ref mut encoder) = self.miras_encoder {
let _latent = encoder.encode(message);
self.miras_enhanced_predictions += 1;
}
self.message_count += 1;
self.maybe_switch_variant();
}
pub fn predict_next(&mut self) -> (u8, f32) {
self.base.predict_next()
}
pub fn predict_sequence(&mut self, length: usize, greedy: bool) -> Vec<u8> {
self.base.predict_sequence(length, greedy)
}
pub fn verify_prediction(&mut self, message: &[u8]) -> (bool, f32) {
self.base.verify_prediction(message)
}
pub fn get_surprise(&self) -> f32 {
self.base.get_surprise()
}
pub fn is_anomalous(&self, threshold: f32) -> bool {
self.base.is_anomalous(threshold)
}
pub fn get_miras_surprise(&self) -> Option<f32> {
self.miras_encoder.as_ref().map(|e| e.get_surprise())
}
pub fn get_combined_surprise(&self) -> f32 {
let titans = self.base.get_surprise();
let miras = self.get_miras_surprise().unwrap_or(0.0);
(titans + miras) / 2.0
}
pub fn reset(&mut self) {
self.base.reset();
}
pub fn reset_all(&mut self) {
self.base.reset_all();
self.surprise_history.clear();
self.message_count = 0;
if let Some(ref mut encoder) = self.miras_encoder {
encoder.reset();
}
}
pub fn stats(&self) -> MirasPredictorStats {
MirasPredictorStats {
message_count: self.message_count,
miras_enhanced_predictions: self.miras_enhanced_predictions,
current_variant: self.variant().to_string(),
anomaly_level: self.anomaly_level(),
titans_surprise: self.base.get_surprise(),
miras_surprise: self.get_miras_surprise(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MirasPredictorStats {
pub message_count: u64,
pub miras_enhanced_predictions: u64,
pub current_variant: String,
pub anomaly_level: f32,
pub titans_surprise: f32,
pub miras_surprise: Option<f32>,
}
fn default_rng() -> StdRng {
StdRng::seed_from_u64(42)
}
impl TitansPredictor {
pub fn new(config: TitansConfig) -> Self {
let mut rng = StdRng::seed_from_u64(config.seed);
let tokenizer = ByteTokenizer::new(config.embed_dim, &mut rng);
let positional = PositionalEncoding::new(config.max_seq_len, config.embed_dim);
let blocks: Vec<TitansBlock> = (0..config.num_layers)
.map(|_| {
TitansBlock::new(
config.embed_dim,
config.num_heads,
config.ff_dim,
config.memory_size,
&mut rng,
)
})
.collect();
let output = OutputProjection::new(config.embed_dim, &mut rng);
Self {
tokenizer,
positional,
blocks,
output,
embed_dim: config.embed_dim,
max_seq_len: config.max_seq_len,
memory_size: config.memory_size,
context_window: VecDeque::with_capacity(config.max_seq_len),
total_surprise: 0.0,
rng,
}
}
pub fn observe(&mut self, message: &[u8]) {
for &byte in message {
let mut embedding = self.tokenizer.encode(byte).to_vec();
let pos = self.context_window.len();
let pos_enc = self.positional.get(pos);
for (e, p) in embedding.iter_mut().zip(pos_enc.iter()) {
*e += *p;
}
self.context_window.push_back(embedding);
if self.context_window.len() > self.max_seq_len {
self.context_window.pop_front();
}
}
self.total_surprise = self.blocks.iter().map(|b| b.get_surprise()).sum::<f32>()
/ self.blocks.len().max(1) as f32;
}
pub fn predict_next(&mut self) -> (u8, f32) {
let sequence: Vec<Vec<f32>> = self.context_window.iter().cloned().collect();
if sequence.is_empty() {
return (0, 1.0 / 256.0);
}
let mut hidden = self.blocks[0].forward(&sequence);
for block in &mut self.blocks[1..] {
let seq_with_hidden = vec![hidden.clone()];
hidden = block.forward(&seq_with_hidden);
}
let probs = self.output.forward(&hidden);
let predicted = self.output.argmax(&probs);
let confidence = probs[predicted as usize];
(predicted, confidence)
}
pub fn predict_sequence(&mut self, length: usize, greedy: bool) -> Vec<u8> {
let mut result = Vec::with_capacity(length);
for _ in 0..length {
let sequence: Vec<Vec<f32>> = self.context_window.iter().cloned().collect();
if sequence.is_empty() {
let byte = if greedy { 0 } else { self.rng.gen() };
result.push(byte);
continue;
}
let mut hidden = self.blocks[0].forward(&sequence);
for block in &mut self.blocks[1..] {
let seq_with_hidden = vec![hidden.clone()];
hidden = block.forward(&seq_with_hidden);
}
let probs = self.output.forward(&hidden);
let byte = if greedy {
self.output.argmax(&probs)
} else {
self.output.sample(&probs, &mut self.rng)
};
result.push(byte);
let mut embedding = self.tokenizer.encode(byte).to_vec();
let pos = self.context_window.len();
let pos_enc = self.positional.get(pos);
for (e, p) in embedding.iter_mut().zip(pos_enc.iter()) {
*e += *p;
}
self.context_window.push_back(embedding);
if self.context_window.len() > self.max_seq_len {
self.context_window.pop_front();
}
}
result
}
pub fn verify_prediction(&mut self, message: &[u8]) -> (bool, f32) {
let predicted = self.predict_sequence(message.len(), true);
let matches = predicted == message;
let similarity = predicted
.iter()
.zip(message.iter())
.filter(|(p, m)| p == m)
.count() as f32
/ message.len().max(1) as f32;
(matches, similarity)
}
pub fn get_surprise(&self) -> f32 {
self.total_surprise
}
pub fn is_anomalous(&self, threshold: f32) -> bool {
self.total_surprise > threshold
}
pub fn reset(&mut self) {
self.context_window.clear();
self.total_surprise = 0.0;
}
pub fn reset_all(&mut self) {
self.context_window.clear();
self.total_surprise = 0.0;
for block in &mut self.blocks {
block.reset_memory();
}
}
pub fn set_temperature(&mut self, temp: f32) {
self.output.set_temperature(temp);
}
}
pub type TransformerPredictor = TitansPredictor;
pub type TransformerConfig = TitansConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TitansConfig {
pub embed_dim: usize,
pub num_heads: usize,
pub num_layers: usize,
pub ff_dim: usize,
pub max_seq_len: usize,
pub memory_size: usize,
pub seed: u64,
}
impl Default for TitansConfig {
fn default() -> Self {
Self {
embed_dim: 64,
num_heads: 4,
num_layers: 2,
ff_dim: 128,
max_seq_len: 256,
memory_size: 64, seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatticeParams {
pub n: usize, pub q: u64, pub p: u64, pub sigma: f64, }
impl Default for LatticeParams {
fn default() -> Self {
Self {
n: 1024, q: 12289, p: 3, sigma: 3.2, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
pub struct RingElement {
coeffs: Vec<i64>,
n: usize,
q: u64,
}
impl RingElement {
pub fn new(n: usize, q: u64) -> Self {
Self {
coeffs: vec![0; n],
n,
q,
}
}
pub fn random(n: usize, q: u64, rng: &mut StdRng) -> Self {
let coeffs: Vec<i64> = (0..n).map(|_| rng.gen_range(0..q as i64)).collect();
Self { coeffs, n, q }
}
pub fn random_ternary(n: usize, q: u64, rng: &mut StdRng) -> Self {
let coeffs: Vec<i64> = (0..n).map(|_| rng.gen_range(-1..=1)).collect();
Self { coeffs, n, q }
}
pub fn random_gaussian(n: usize, q: u64, sigma: f64, rng: &mut StdRng) -> Self {
let coeffs: Vec<i64> = (0..n)
.map(|_| {
let u1: f64 = rng.gen::<f64>().max(1e-10);
let u2: f64 = rng.gen();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
(z * sigma).round() as i64
})
.collect();
Self { coeffs, n, q }
}
pub fn from_bytes(bytes: &[u8], n: usize, q: u64) -> Self {
let mut coeffs = vec![0i64; n];
for (i, chunk) in bytes.chunks(2).enumerate() {
if i >= n {
break;
}
let val = if chunk.len() == 2 {
((chunk[0] as u16) | ((chunk[1] as u16) << 8)) as i64
} else {
chunk[0] as i64
};
coeffs[i] = val % q as i64;
}
Self { coeffs, n, q }
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.n * 2);
for &c in &self.coeffs {
let val = ((c % self.q as i64 + self.q as i64) % self.q as i64) as u16;
bytes.push(val as u8);
bytes.push((val >> 8) as u8);
}
bytes
}
fn reduce(&mut self) {
for c in &mut self.coeffs {
*c = ((*c % self.q as i64) + self.q as i64) % self.q as i64;
}
}
pub fn mul(&self, other: &RingElement) -> RingElement {
assert_eq!(self.n, other.n);
let mut result = vec![0i64; self.n];
for i in 0..self.n {
for j in 0..self.n {
let idx = i + j;
let coeff = self.coeffs[i] * other.coeffs[j];
if idx < self.n {
result[idx] += coeff;
} else {
result[idx - self.n] -= coeff;
}
}
}
let mut elem = RingElement {
coeffs: result,
n: self.n,
q: self.q,
};
elem.reduce();
elem
}
pub fn add(&self, other: &RingElement) -> RingElement {
assert_eq!(self.n, other.n);
let coeffs: Vec<i64> = self
.coeffs
.iter()
.zip(other.coeffs.iter())
.map(|(a, b)| (a + b) % self.q as i64)
.collect();
let mut elem = RingElement {
coeffs,
n: self.n,
q: self.q,
};
elem.reduce();
elem
}
pub fn sub(&self, other: &RingElement) -> RingElement {
assert_eq!(self.n, other.n);
let coeffs: Vec<i64> = self
.coeffs
.iter()
.zip(other.coeffs.iter())
.map(|(a, b)| (a - b) % self.q as i64)
.collect();
let mut elem = RingElement {
coeffs,
n: self.n,
q: self.q,
};
elem.reduce();
elem
}
pub fn scale(&self, scalar: i64) -> RingElement {
let coeffs: Vec<i64> = self
.coeffs
.iter()
.map(|&c| (c * scalar) % self.q as i64)
.collect();
let mut elem = RingElement {
coeffs,
n: self.n,
q: self.q,
};
elem.reduce();
elem
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
pub enum KemAlgorithm {
Rlwe,
MlKem512,
#[default]
MlKem768,
MlKem1024,
Hybrid,
}
#[derive(Debug, Clone, Zeroize, ZeroizeOnDrop)]
struct MlKemKeyPair {
dk_bytes: Vec<u8>, ek_bytes: Vec<u8>, #[zeroize(skip)]
algorithm: KemAlgorithm,
}
mod mlkem_ops {
use super::*;
pub fn generate_512(rng: &mut StdRng) -> MlKemKeyPair {
let (dk, ek) = MlKem512::generate(rng);
MlKemKeyPair {
dk_bytes: dk.as_bytes().to_vec(),
ek_bytes: ek.as_bytes().to_vec(),
algorithm: KemAlgorithm::MlKem512,
}
}
pub fn generate_768(rng: &mut StdRng) -> MlKemKeyPair {
let (dk, ek) = MlKem768::generate(rng);
MlKemKeyPair {
dk_bytes: dk.as_bytes().to_vec(),
ek_bytes: ek.as_bytes().to_vec(),
algorithm: KemAlgorithm::MlKem768,
}
}
pub fn generate_1024(rng: &mut StdRng) -> MlKemKeyPair {
let (dk, ek) = MlKem1024::generate(rng);
MlKemKeyPair {
dk_bytes: dk.as_bytes().to_vec(),
ek_bytes: ek.as_bytes().to_vec(),
algorithm: KemAlgorithm::MlKem1024,
}
}
pub fn encapsulate_512(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
let ek_encoded = <Encoded<EncapsulationKey<MlKem512Params>>>::try_from(ek_bytes).ok()?;
let ek = EncapsulationKey::<MlKem512Params>::from_bytes(&ek_encoded);
let (ct, ss) = ek.encapsulate(rng).ok()?;
let mut shared = [0u8; 32];
shared.copy_from_slice(ss.as_slice());
Some((ct.to_vec(), shared))
}
pub fn encapsulate_768(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
let ek_encoded = <Encoded<EncapsulationKey<MlKem768Params>>>::try_from(ek_bytes).ok()?;
let ek = EncapsulationKey::<MlKem768Params>::from_bytes(&ek_encoded);
let (ct, ss) = ek.encapsulate(rng).ok()?;
let mut shared = [0u8; 32];
shared.copy_from_slice(ss.as_slice());
Some((ct.to_vec(), shared))
}
pub fn encapsulate_1024(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
let ek_encoded = <Encoded<EncapsulationKey<MlKem1024Params>>>::try_from(ek_bytes).ok()?;
let ek = EncapsulationKey::<MlKem1024Params>::from_bytes(&ek_encoded);
let (ct, ss) = ek.encapsulate(rng).ok()?;
let mut shared = [0u8; 32];
shared.copy_from_slice(ss.as_slice());
Some((ct.to_vec(), shared))
}
pub fn decapsulate_512(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
let dk_encoded = <Encoded<DecapsulationKey<MlKem512Params>>>::try_from(dk_bytes).ok()?;
let dk = DecapsulationKey::<MlKem512Params>::from_bytes(&dk_encoded);
let ct = <ml_kem::Ciphertext<MlKem512>>::try_from(ct_bytes).ok()?;
let ss = dk.decapsulate(&ct).ok()?;
let mut shared = [0u8; 32];
shared.copy_from_slice(ss.as_slice());
Some(shared)
}
pub fn decapsulate_768(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
let dk_encoded = <Encoded<DecapsulationKey<MlKem768Params>>>::try_from(dk_bytes).ok()?;
let dk = DecapsulationKey::<MlKem768Params>::from_bytes(&dk_encoded);
let ct = <ml_kem::Ciphertext<MlKem768>>::try_from(ct_bytes).ok()?;
let ss = dk.decapsulate(&ct).ok()?;
let mut shared = [0u8; 32];
shared.copy_from_slice(ss.as_slice());
Some(shared)
}
pub fn decapsulate_1024(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
let dk_encoded = <Encoded<DecapsulationKey<MlKem1024Params>>>::try_from(dk_bytes).ok()?;
let dk = DecapsulationKey::<MlKem1024Params>::from_bytes(&dk_encoded);
let ct = <ml_kem::Ciphertext<MlKem1024>>::try_from(ct_bytes).ok()?;
let ss = dk.decapsulate(&ct).ok()?;
let mut shared = [0u8; 32];
shared.copy_from_slice(ss.as_slice());
Some(shared)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
pub struct QuantumKeyPair {
pub a: RingElement,
pub public_key: RingElement,
secret_key: RingElement,
#[zeroize(skip)]
params: LatticeParams,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantumKeyEvolution {
params: LatticeParams,
current_key: QuantumKeyPair,
evolution_counter: u64,
key_history: VecDeque<[u8; 32]>, max_history: usize,
#[serde(skip, default = "default_rng")]
rng: StdRng,
algorithm: KemAlgorithm,
#[serde(skip)]
mlkem_keypair: Option<MlKemKeyPair>,
}
impl Drop for QuantumKeyEvolution {
fn drop(&mut self) {
for h in self.key_history.iter_mut() {
h.zeroize();
}
self.key_history.clear();
}
}
impl QuantumKeyEvolution {
pub fn new(params: LatticeParams, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let current_key = Self::generate_keypair(¶ms, &mut rng);
Self {
params,
current_key,
evolution_counter: 0,
key_history: VecDeque::new(),
max_history: 100,
rng,
algorithm: KemAlgorithm::Rlwe,
mlkem_keypair: None,
}
}
pub fn new_with_algorithm(params: LatticeParams, seed: u64, algorithm: KemAlgorithm) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let current_key = Self::generate_keypair(¶ms, &mut rng);
let mlkem_keypair = match algorithm {
KemAlgorithm::MlKem512 => Some(mlkem_ops::generate_512(&mut rng)),
KemAlgorithm::MlKem768 => Some(mlkem_ops::generate_768(&mut rng)),
KemAlgorithm::MlKem1024 => Some(mlkem_ops::generate_1024(&mut rng)),
KemAlgorithm::Hybrid => Some(mlkem_ops::generate_768(&mut rng)),
KemAlgorithm::Rlwe => None,
};
Self {
params,
current_key,
evolution_counter: 0,
key_history: VecDeque::new(),
max_history: 100,
rng,
algorithm,
mlkem_keypair,
}
}
fn generate_keypair(params: &LatticeParams, rng: &mut StdRng) -> QuantumKeyPair {
let a = RingElement::random(params.n, params.q, rng);
let s = RingElement::random_ternary(params.n, params.q, rng);
let e = RingElement::random_gaussian(params.n, params.q, params.sigma, rng);
let b = a.mul(&s).add(&e);
QuantumKeyPair {
a, public_key: b,
secret_key: s,
params: params.clone(),
}
}
pub fn evolve(&mut self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.current_key.public_key.to_bytes());
hasher.update(self.current_key.secret_key.to_bytes());
hasher.update(self.evolution_counter.to_le_bytes());
let hash: [u8; 32] = hasher.finalize().into();
self.key_history.push_back(hash);
if self.key_history.len() > self.max_history {
self.key_history.pop_front();
}
let hk = Hkdf::<Sha256>::new(Some(&self.evolution_counter.to_le_bytes()), &hash);
let mut okm = [0u8; 32];
hk.expand(b"spine-key-evolution", &mut okm)
.expect("HKDF expand failed");
let new_seed = u64::from_le_bytes(okm[0..8].try_into().unwrap());
let mut new_rng = StdRng::seed_from_u64(new_seed);
self.current_key = Self::generate_keypair(&self.params, &mut new_rng);
if self.algorithm != KemAlgorithm::Rlwe {
self.mlkem_keypair = match self.algorithm {
KemAlgorithm::MlKem512 => Some(mlkem_ops::generate_512(&mut new_rng)),
KemAlgorithm::MlKem768 | KemAlgorithm::Hybrid => Some(mlkem_ops::generate_768(&mut new_rng)),
KemAlgorithm::MlKem1024 => Some(mlkem_ops::generate_1024(&mut new_rng)),
KemAlgorithm::Rlwe => None,
};
}
self.evolution_counter += 1;
hash
}
pub fn encapsulate(&mut self) -> (Vec<u8>, [u8; 32]) {
match self.algorithm {
KemAlgorithm::Rlwe => self.encapsulate_rlwe(),
KemAlgorithm::MlKem512 => self.encapsulate_mlkem(KemAlgorithm::MlKem512),
KemAlgorithm::MlKem768 => self.encapsulate_mlkem(KemAlgorithm::MlKem768),
KemAlgorithm::MlKem1024 => self.encapsulate_mlkem(KemAlgorithm::MlKem1024),
KemAlgorithm::Hybrid => self.encapsulate_hybrid(),
}
}
fn encapsulate_mlkem(&mut self, alg: KemAlgorithm) -> (Vec<u8>, [u8; 32]) {
let kp = self.mlkem_keypair.as_ref().expect("ML-KEM keypair required");
let result = match alg {
KemAlgorithm::MlKem512 => mlkem_ops::encapsulate_512(&kp.ek_bytes, &mut self.rng),
KemAlgorithm::MlKem768 => mlkem_ops::encapsulate_768(&kp.ek_bytes, &mut self.rng),
KemAlgorithm::MlKem1024 => mlkem_ops::encapsulate_1024(&kp.ek_bytes, &mut self.rng),
_ => unreachable!(),
};
result.unwrap_or_else(|| {
self.encapsulate_rlwe()
})
}
fn encapsulate_hybrid(&mut self) -> (Vec<u8>, [u8; 32]) {
let (rlwe_ct, rlwe_ss) = self.encapsulate_rlwe();
let (mlkem_ct, mlkem_ss) = self.encapsulate_mlkem(KemAlgorithm::MlKem768);
let mut combined_ikm = [0u8; 64];
combined_ikm[..32].copy_from_slice(&rlwe_ss);
combined_ikm[32..].copy_from_slice(&mlkem_ss);
let hk = Hkdf::<Sha256>::new(None, &combined_ikm);
let mut hybrid_ss = [0u8; 32];
hk.expand(b"spine-hybrid-kem", &mut hybrid_ss).expect("HKDF expand");
let rlwe_len = (rlwe_ct.len() as u32).to_le_bytes();
let mut hybrid_ct = Vec::with_capacity(4 + rlwe_ct.len() + mlkem_ct.len());
hybrid_ct.extend_from_slice(&rlwe_len);
hybrid_ct.extend_from_slice(&rlwe_ct);
hybrid_ct.extend_from_slice(&mlkem_ct);
(hybrid_ct, hybrid_ss)
}
fn encapsulate_rlwe(&mut self) -> (Vec<u8>, [u8; 32]) {
let a = &self.current_key.a;
let r = RingElement::random_ternary(self.params.n, self.params.q, &mut self.rng);
let e1 = RingElement::random_gaussian(
self.params.n,
self.params.q,
self.params.sigma,
&mut self.rng,
);
let e2 = RingElement::random_gaussian(
self.params.n,
self.params.q,
self.params.sigma,
&mut self.rng,
);
let m: Vec<i64> = (0..self.params.n)
.map(|_| self.rng.gen_range(0..2i64))
.collect();
let u = a.mul(&r).add(&e1);
let half_q = (self.params.q / 2) as i64;
let encoded_m = RingElement {
coeffs: m.iter().map(|&mi| mi * half_q).collect(),
n: self.params.n,
q: self.params.q,
};
let v = self.current_key.public_key.mul(&r).add(&e2).add(&encoded_m);
let mut ciphertext = u.to_bytes();
ciphertext.extend(v.to_bytes());
let mut hasher = Sha256::new();
for &mi in &m {
hasher.update(mi.to_le_bytes());
}
let shared_secret: [u8; 32] = hasher.finalize().into();
(ciphertext, shared_secret)
}
pub fn decapsulate(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
match self.algorithm {
KemAlgorithm::Rlwe => self.decapsulate_rlwe(ciphertext),
KemAlgorithm::MlKem512 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem512),
KemAlgorithm::MlKem768 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem768),
KemAlgorithm::MlKem1024 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem1024),
KemAlgorithm::Hybrid => self.decapsulate_hybrid(ciphertext),
}
}
fn decapsulate_mlkem(&self, ciphertext: &[u8], alg: KemAlgorithm) -> Option<[u8; 32]> {
let kp = self.mlkem_keypair.as_ref()?;
match alg {
KemAlgorithm::MlKem512 => mlkem_ops::decapsulate_512(&kp.dk_bytes, ciphertext),
KemAlgorithm::MlKem768 => mlkem_ops::decapsulate_768(&kp.dk_bytes, ciphertext),
KemAlgorithm::MlKem1024 => mlkem_ops::decapsulate_1024(&kp.dk_bytes, ciphertext),
_ => None,
}
}
fn decapsulate_hybrid(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
if ciphertext.len() < 4 { return None; }
let rlwe_len = u32::from_le_bytes(ciphertext[..4].try_into().ok()?) as usize;
if ciphertext.len() < 4 + rlwe_len { return None; }
let rlwe_ct = &ciphertext[4..4+rlwe_len];
let mlkem_ct = &ciphertext[4+rlwe_len..];
let rlwe_ss = self.decapsulate_rlwe(rlwe_ct)?;
let mlkem_ss = self.decapsulate_mlkem(mlkem_ct, KemAlgorithm::MlKem768)?;
let mut combined_ikm = [0u8; 64];
combined_ikm[..32].copy_from_slice(&rlwe_ss);
combined_ikm[32..].copy_from_slice(&mlkem_ss);
let hk = Hkdf::<Sha256>::new(None, &combined_ikm);
let mut hybrid_ss = [0u8; 32];
hk.expand(b"spine-hybrid-kem", &mut hybrid_ss).expect("HKDF expand");
Some(hybrid_ss)
}
fn decapsulate_rlwe(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
let half = ciphertext.len() / 2;
if half < self.params.n * 2 {
return None;
}
let u = RingElement::from_bytes(&ciphertext[..half], self.params.n, self.params.q);
let v = RingElement::from_bytes(&ciphertext[half..], self.params.n, self.params.q);
let recovered = v.sub(&u.mul(&self.current_key.secret_key));
let half_q = self.params.q as i64 / 2;
let quarter_q = self.params.q as i64 / 4;
let m: Vec<i64> = recovered
.coeffs
.iter()
.map(|&c| {
let c_pos =
((c % self.params.q as i64) + self.params.q as i64) % self.params.q as i64;
if (c_pos - half_q).abs() < quarter_q {
1i64
} else {
0i64
}
})
.collect();
let mut hasher = Sha256::new();
for &mi in &m {
hasher.update(mi.to_le_bytes());
}
Some(hasher.finalize().into())
}
pub fn get_key_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.current_key.public_key.to_bytes());
hasher.finalize().into()
}
pub fn verify_evolution(&self, expected_hash: &[u8; 32]) -> bool {
self.key_history
.iter()
.any(|h| h.ct_eq(expected_hash).into())
}
pub fn get_evolution_counter(&self) -> u64 {
self.evolution_counter
}
pub fn export_public_key(&self) -> Vec<u8> {
self.current_key.public_key.to_bytes()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantumSpeculativeProtocol {
predictor: TransformerPredictor,
key_evolution: QuantumKeyEvolution,
prediction_threshold: f32,
evolution_interval: u64,
message_count: u64,
}
impl QuantumSpeculativeProtocol {
pub fn new(
transformer_config: TransformerConfig,
lattice_params: LatticeParams,
seed: u64,
) -> Self {
Self {
predictor: TransformerPredictor::new(transformer_config),
key_evolution: QuantumKeyEvolution::new(lattice_params, seed),
prediction_threshold: 0.8,
evolution_interval: 10,
message_count: 0,
}
}
pub fn new_with_algorithm(
transformer_config: TransformerConfig,
lattice_params: LatticeParams,
seed: u64,
algorithm: KemAlgorithm,
) -> Self {
Self {
predictor: TransformerPredictor::new(transformer_config),
key_evolution: QuantumKeyEvolution::new_with_algorithm(lattice_params, seed, algorithm),
prediction_threshold: 0.8,
evolution_interval: 10,
message_count: 0,
}
}
pub fn algorithm(&self) -> KemAlgorithm {
self.key_evolution.algorithm
}
pub fn send(&mut self, message: &[u8]) -> QuantumMessage {
let (matches, similarity) = self.predictor.verify_prediction(message);
let payload = if matches && similarity >= self.prediction_threshold {
MessagePayload::Confirmation {
hash: Self::hash_message(message),
length: message.len(),
}
} else {
let (ciphertext, shared_secret) = self.key_evolution.encapsulate();
let hk = Hkdf::<Sha256>::new(None, &shared_secret);
let mut aes_key = [0u8; 32];
hk.expand(b"spine-aead-key", &mut aes_key)
.expect("HKDF expand failed");
let mut nonce_bytes = [0u8; 12];
nonce_bytes[..8].copy_from_slice(&self.message_count.to_le_bytes());
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(&aes_key).expect("AES key length");
let encrypted = cipher.encrypt(nonce, message).expect("AES-GCM encrypt");
let mut encrypted_message = nonce_bytes.to_vec();
encrypted_message.extend(encrypted);
MessagePayload::Full {
ciphertext,
encrypted_message,
}
};
self.message_count += 1;
let key_evolution = if self.message_count.is_multiple_of(self.evolution_interval) {
Some(self.key_evolution.evolve())
} else {
None
};
QuantumMessage {
payload,
evolution_counter: self.key_evolution.get_evolution_counter(),
key_evolution,
}
}
pub fn get_morph_seed(&self) -> u64 {
let key_hash = self.key_evolution.get_key_hash();
u64::from_le_bytes(key_hash[0..8].try_into().unwrap())
}
pub fn receive(&mut self, quantum_msg: &QuantumMessage) -> Option<Vec<u8>> {
while self.key_evolution.get_evolution_counter() < quantum_msg.evolution_counter {
self.key_evolution.evolve();
}
let message = match &quantum_msg.payload {
MessagePayload::Confirmation { hash, length } => {
let predicted = self.predictor.predict_sequence(*length, true);
let predicted_hash = Self::hash_message(&predicted);
if &predicted_hash == hash {
Some(predicted)
} else {
None }
}
MessagePayload::Full {
ciphertext,
encrypted_message,
} => {
let shared_secret = self.key_evolution.decapsulate(ciphertext)?;
let hk = Hkdf::<Sha256>::new(None, &shared_secret);
let mut aes_key = [0u8; 32];
hk.expand(b"spine-aead-key", &mut aes_key)
.expect("HKDF expand failed");
if encrypted_message.len() < 12 {
return None;
}
let nonce = Nonce::from_slice(&encrypted_message[..12]);
let ciphertext_data = &encrypted_message[12..];
let cipher = Aes256Gcm::new_from_slice(&aes_key).expect("AES key length");
cipher.decrypt(nonce, ciphertext_data).ok()
}
};
if let Some(ref msg) = message {
self.predictor.observe(msg);
}
message
}
fn hash_message(message: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(message);
hasher.finalize().into()
}
pub fn set_threshold(&mut self, threshold: f32) {
self.prediction_threshold = threshold.clamp(0.0, 1.0);
}
pub fn set_evolution_interval(&mut self, interval: u64) {
self.evolution_interval = interval.max(1);
}
pub fn reset(&mut self) {
self.predictor.reset();
self.message_count = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantumMessage {
pub payload: MessagePayload,
pub evolution_counter: u64,
pub key_evolution: Option<[u8; 32]>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessagePayload {
Confirmation { hash: [u8; 32], length: usize },
Full {
ciphertext: Vec<u8>,
encrypted_message: Vec<u8>,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_zeroize_on_drop<T: ZeroizeOnDrop>() {}
#[test]
fn ringelement_implements_zeroize_on_drop() {
assert_zeroize_on_drop::<RingElement>();
}
#[test]
fn mlkemkeypair_implements_zeroize_on_drop() {
assert_zeroize_on_drop::<MlKemKeyPair>();
}
#[test]
fn quantumkeypair_implements_zeroize_on_drop() {
assert_zeroize_on_drop::<QuantumKeyPair>();
}
#[test]
fn ringelement_zeroize_clears_all_coefficients() {
let mut rng = StdRng::seed_from_u64(0xAB_CD);
let mut r = RingElement::random(64, 8_192, &mut rng);
assert!(
r.coeffs.iter().any(|&c| c != 0),
"test precondition: random RingElement should have non-zero coeffs"
);
r.zeroize();
assert!(
r.coeffs.iter().all(|&c| c == 0),
"RingElement::zeroize did not clear every coefficient"
);
}
#[test]
fn mlkemkeypair_zeroize_clears_dk_bytes() {
let mut rng = StdRng::seed_from_u64(0x12_34);
let mut kp = mlkem_ops::generate_768(&mut rng);
assert!(
kp.dk_bytes.iter().any(|&b| b != 0),
"test precondition: fresh ML-KEM dk should be non-zero"
);
kp.zeroize();
assert!(
kp.dk_bytes.iter().all(|&b| b == 0),
"MlKemKeyPair::zeroize left non-zero bytes in dk_bytes"
);
}
#[test]
fn quantumkeyevolution_drop_clears_key_history() {
let mut ev = QuantumKeyEvolution::new(LatticeParams::default(), 0xCAFE);
ev.key_history.push_back([0x11u8; 32]);
ev.key_history.push_back([0x22u8; 32]);
assert_eq!(ev.key_history.len(), 2);
ev.key_history.iter_mut().for_each(|h| h.zeroize());
assert!(
ev.key_history.iter().all(|h| h.iter().all(|&b| b == 0)),
"QuantumKeyEvolution key_history not zeroed"
);
}
#[test]
fn test_positional_encoding() {
let pe = PositionalEncoding::new(100, 64);
let enc0 = pe.get(0);
let enc50 = pe.get(50);
assert_eq!(enc0.len(), 64);
assert_ne!(enc0, enc50);
}
#[test]
fn test_layer_norm() {
let ln = LayerNorm::new(8);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let output = ln.forward(&input);
assert_eq!(output.len(), 8);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
assert!(mean.abs() < 0.01);
}
#[test]
fn test_titans_predictor() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = TitansPredictor::new(config);
predictor.observe(b"Hello ");
predictor.observe(b"World");
let (next, conf) = predictor.predict_next();
assert!(conf > 0.0 && conf <= 1.0);
let _ = next;
let surprise = predictor.get_surprise();
assert!(surprise >= 0.0);
}
#[test]
fn test_titans_anomaly_detection() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = TitansPredictor::new(config);
for _ in 0..10 {
predictor.observe(b"GET /api/status\n");
}
let _normal_surprise = predictor.get_surprise();
predictor.observe(b"MALICIOUS_PAYLOAD_XYZ!!!");
let anomaly_surprise = predictor.get_surprise();
assert!(anomaly_surprise >= 0.0);
}
#[test]
fn test_ring_operations() {
let mut rng = StdRng::seed_from_u64(42);
let params = LatticeParams {
n: 16,
q: 97,
p: 3,
sigma: 2.0,
};
let a = RingElement::random(params.n, params.q, &mut rng);
let b = RingElement::random(params.n, params.q, &mut rng);
let sum = a.add(&b);
let product = a.mul(&b);
assert_eq!(sum.coeffs.len(), params.n);
assert_eq!(product.coeffs.len(), params.n);
for &c in &sum.coeffs {
assert!(c >= 0 && c < params.q as i64);
}
}
#[test]
fn test_key_evolution() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new(params, 42);
let hash1 = ke.get_key_hash();
ke.evolve();
let hash2 = ke.get_key_hash();
assert_ne!(hash1, hash2);
assert_eq!(ke.get_evolution_counter(), 1);
}
#[test]
fn test_encapsulation() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new(params, 42);
let (ciphertext, _shared_secret1) = ke.encapsulate();
assert!(!ciphertext.is_empty());
let shared_secret2 = ke.decapsulate(&ciphertext);
assert!(shared_secret2.is_some());
}
#[test]
fn test_quantum_speculative_protocol() {
let config = TitansConfig {
embed_dim: 16,
num_heads: 2,
num_layers: 1,
ff_dim: 32,
max_seq_len: 32,
memory_size: 8,
seed: 42,
};
let params = LatticeParams {
n: 16,
q: 97,
p: 3,
sigma: 2.0,
};
let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
let msg = b"Hello Bob!";
let quantum_msg = alice.send(msg);
let received = bob.receive(&quantum_msg);
assert!(received.is_some());
assert_eq!(received.unwrap(), msg.to_vec());
}
#[test]
fn test_prediction_efficiency() {
let config = TransformerConfig::default();
let params = LatticeParams::default();
let mut sender = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
let mut receiver = QuantumSpeculativeProtocol::new(config, params, 42);
for _ in 0..5 {
let msg1 = sender.send(b"GET /api/status");
receiver.receive(&msg1);
let msg2 = sender.send(b"200 OK");
receiver.receive(&msg2);
}
let msg = sender.send(b"GET /api/status");
let received = receiver.receive(&msg);
assert!(received.is_some());
}
#[test]
fn test_miras_predictor_basic() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = MirasTitansPredictor::new(config);
predictor.observe(b"Hello World");
assert_eq!(predictor.variant(), "titans");
let (next, conf) = predictor.predict_next();
assert!(conf > 0.0 && conf <= 1.0);
let _ = next;
let stats = predictor.stats();
assert_eq!(stats.message_count, 1);
assert!(stats.miras_enhanced_predictions > 0);
}
#[test]
fn test_miras_predictor_variants() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
for variant in [
MirasVariant::Titans,
MirasVariant::Yaad,
MirasVariant::Moneta { p: 2.0 },
MirasVariant::Memora,
] {
let predictor = MirasTitansPredictor::new_with_variant(config.clone(), variant);
match variant {
MirasVariant::Titans => assert_eq!(predictor.variant(), "titans"),
MirasVariant::Yaad => assert_eq!(predictor.variant(), "yaad"),
MirasVariant::Moneta { .. } => assert_eq!(predictor.variant(), "moneta"),
MirasVariant::Memora => assert_eq!(predictor.variant(), "memora"),
}
}
let mut predictor =
MirasTitansPredictor::new_with_variant(config.clone(), MirasVariant::Yaad);
assert_eq!(predictor.variant(), "yaad");
predictor.observe(b"test");
}
#[test]
fn test_miras_predictor_combined_surprise() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = MirasTitansPredictor::new(config);
for _ in 0..5 {
predictor.observe(b"normal message pattern");
}
let combined = predictor.get_combined_surprise();
assert!(combined >= 0.0);
let titans_surprise = predictor.get_surprise();
let miras_surprise = predictor.get_miras_surprise();
assert!(titans_surprise >= 0.0);
assert!(miras_surprise.is_some());
}
#[test]
fn test_miras_predictor_anomaly_level() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = MirasTitansPredictor::new(config);
assert_eq!(predictor.anomaly_level(), 0.0);
predictor.observe(b"test");
let level = predictor.anomaly_level();
assert!(level >= 0.0); }
#[test]
fn test_miras_predictor_reset() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = MirasTitansPredictor::new(config);
for _ in 0..10 {
predictor.observe(b"data");
}
assert!(predictor.stats().message_count > 0);
predictor.reset_all();
let stats = predictor.stats();
assert_eq!(stats.message_count, 0);
}
#[test]
fn test_rlwe_ring_arithmetic_correctness() {
let mut rng = StdRng::seed_from_u64(12345);
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let a = RingElement::random(params.n, params.q, &mut rng);
let b = RingElement::random(params.n, params.q, &mut rng);
let c = RingElement::random(params.n, params.q, &mut rng);
let ab = a.add(&b);
let ba = b.add(&a);
assert_eq!(ab.coeffs, ba.coeffs, "Addition should be commutative");
let ab_c = a.add(&b).add(&c);
let a_bc = a.add(&b.add(&c));
assert_eq!(ab_c.coeffs, a_bc.coeffs, "Addition should be associative");
let a_times_bplusc = a.mul(&b.add(&c));
let ab_plus_ac = a.mul(&b).add(&a.mul(&c));
assert_eq!(
a_times_bplusc.coeffs, ab_plus_ac.coeffs,
"Multiplication should distribute over addition"
);
}
#[test]
fn test_rlwe_gaussian_distribution() {
let mut rng = StdRng::seed_from_u64(54321);
let params = LatticeParams {
n: 1024,
q: 12289, p: 3,
sigma: 3.2,
};
let e = RingElement::random_gaussian(params.n, params.q, params.sigma, &mut rng);
let mean: f64 = e.coeffs.iter().map(|&c| c as f64).sum::<f64>() / params.n as f64;
let variance: f64 = e
.coeffs
.iter()
.map(|&c| (c as f64 - mean).powi(2))
.sum::<f64>()
/ params.n as f64;
assert!(
mean.abs() < params.sigma,
"Gaussian mean should be near 0, got {}",
mean
);
let expected_variance = params.sigma * params.sigma;
assert!(
(variance - expected_variance).abs() < expected_variance * 0.5,
"Variance {} should be close to sigma^2 = {}",
variance,
expected_variance
);
}
#[test]
fn test_rlwe_ternary_distribution() {
let mut rng = StdRng::seed_from_u64(98765);
let params = LatticeParams {
n: 256,
q: 257,
p: 3,
sigma: 2.0,
};
let s = RingElement::random_ternary(params.n, params.q, &mut rng);
for &coeff in &s.coeffs {
assert!(
coeff == 0 || coeff == 1 || coeff == -1,
"Ternary coefficient should be -1, 0, or 1, got {}",
coeff
);
}
let count_zero = s.coeffs.iter().filter(|&&c| c == 0).count();
let count_one = s.coeffs.iter().filter(|&&c| c == 1).count();
let count_neg = s.coeffs.iter().filter(|&&c| c == -1).count();
let expected = params.n / 3;
let tolerance = params.n / 4; assert!(
(count_zero as isize - expected as isize).unsigned_abs() < tolerance,
"Ternary distribution unbalanced: zeros={}, ones={}, neg={}",
count_zero,
count_one,
count_neg
);
}
#[test]
fn test_key_evolution_forward_secrecy() {
let params = LatticeParams {
n: 64,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke1 = QuantumKeyEvolution::new(params.clone(), 42);
let mut ke2 = QuantumKeyEvolution::new(params, 42);
assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
for _ in 0..5 {
ke1.evolve();
}
assert_ne!(ke1.get_key_hash(), ke2.get_key_hash());
assert_eq!(ke1.get_evolution_counter(), 5);
assert_eq!(ke2.get_evolution_counter(), 0);
for _ in 0..5 {
ke2.evolve();
}
assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
}
#[test]
fn test_key_evolution_history_integrity() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new(params, 42);
let mut hashes = Vec::new();
for _ in 0..10 {
let hash = ke.evolve();
hashes.push(hash);
}
let unique_count = hashes
.iter()
.collect::<std::collections::HashSet<_>>()
.len();
assert_eq!(unique_count, 10, "All evolution hashes should be unique");
for hash in &hashes {
assert!(
ke.verify_evolution(hash),
"Recent evolution should be verifiable"
);
}
}
#[test]
fn test_quantum_protocol_message_integrity() {
let config = TitansConfig {
embed_dim: 16,
num_heads: 2,
num_layers: 1,
ff_dim: 32,
max_seq_len: 32,
memory_size: 8,
seed: 42,
};
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
let test_messages = [
b"A".to_vec(),
b"Short".to_vec(),
b"Medium length message".to_vec(),
b"This is a longer message to test variable length handling properly".to_vec(),
];
for msg in &test_messages {
let quantum_msg = alice.send(msg);
let received = bob.receive(&quantum_msg);
assert!(received.is_some(), "Should receive message");
assert_eq!(
&received.unwrap(),
msg,
"Received message should match original"
);
}
}
#[test]
fn test_tampered_ciphertext_detection() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new(params, 42);
let (mut ciphertext, original_secret) = ke.encapsulate();
if !ciphertext.is_empty() {
ciphertext[0] ^= 0xFF;
}
let tampered_secret = ke.decapsulate(&ciphertext);
if let Some(tampered) = tampered_secret {
assert_ne!(
tampered, original_secret,
"Tampered ciphertext should produce different secret"
);
}
}
#[test]
fn test_lattice_params_security_levels() {
let toy_params = LatticeParams {
n: 16,
q: 97,
p: 3,
sigma: 2.0,
};
let medium_params = LatticeParams {
n: 256,
q: 7681,
p: 3,
sigma: 3.19,
};
let _high_params = LatticeParams {
n: 1024,
q: 12289,
p: 3,
sigma: 3.19,
};
assert!(
toy_params.n.is_power_of_two(),
"n should be power of 2 for NTT"
);
assert!(
medium_params.n.is_power_of_two(),
"n should be power of 2 for NTT"
);
let mut ke_toy = QuantumKeyEvolution::new(toy_params, 1);
let mut ke_med = QuantumKeyEvolution::new(medium_params, 1);
let (ct_toy, _) = ke_toy.encapsulate();
let (ct_med, _) = ke_med.encapsulate();
assert!(!ct_toy.is_empty());
assert!(!ct_med.is_empty());
assert!(
ct_med.len() > ct_toy.len(),
"Higher security params should produce larger ciphertext"
);
}
#[test]
fn test_titans_predictor_statistical_properties() {
let config = TitansConfig {
embed_dim: 32,
num_heads: 2,
num_layers: 1,
ff_dim: 64,
max_seq_len: 64,
memory_size: 16,
seed: 42,
};
let mut predictor = TitansPredictor::new(config);
let pattern = b"ABCABC";
for _ in 0..20 {
predictor.observe(pattern);
}
let (_, confidence) = predictor.predict_next();
assert!(
(0.0..=1.0).contains(&confidence),
"Confidence should be normalized"
);
let surprise = predictor.get_surprise();
assert!(surprise >= 0.0, "Surprise should be non-negative");
}
#[test]
fn test_kem_shared_secret_match() {
let params = LatticeParams {
n: 64,
q: 257,
p: 3,
sigma: 1.5,
};
let mut ke = QuantumKeyEvolution::new(params, 12345);
let (ciphertext, shared_secret_enc) = ke.encapsulate();
let shared_secret_dec = ke.decapsulate(&ciphertext).unwrap();
assert_eq!(
shared_secret_enc, shared_secret_dec,
"KEM shared secrets must match between encapsulate and decapsulate"
);
}
#[test]
fn test_aead_tampered_ciphertext_rejected() {
let config = TransformerConfig::default();
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
let msg = b"Secret message";
let mut quantum_msg = alice.send(msg);
if let MessagePayload::Full {
ref mut encrypted_message,
..
} = quantum_msg.payload
{
if let Some(byte) = encrypted_message.last_mut() {
*byte ^= 0xFF; }
}
let received = bob.receive(&quantum_msg);
assert!(
received.is_none(),
"Tampered ciphertext must be rejected by AEAD"
);
}
#[test]
fn test_key_evolution_maintains_kem_invariant() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new(params, 99);
for _ in 0..5 {
ke.evolve();
let (ct, ss_enc) = ke.encapsulate();
let ss_dec = ke.decapsulate(&ct).unwrap();
assert_eq!(ss_enc, ss_dec, "KEM must work after key evolution");
}
}
#[test]
fn test_key_evolution_deterministic_hkdf() {
let params = LatticeParams::default();
let mut ke1 = QuantumKeyEvolution::new(params.clone(), 7777);
let mut ke2 = QuantumKeyEvolution::new(params, 7777);
for _ in 0..5 {
let h1 = ke1.evolve();
let h2 = ke2.evolve();
assert_eq!(
h1, h2,
"Deterministic evolution must produce identical hashes"
);
}
assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
}
#[test]
fn test_aes_gcm_round_trip() {
let config = TransformerConfig::default();
let params = LatticeParams {
n: 64,
q: 257,
p: 3,
sigma: 1.5,
};
let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 100);
let mut bob = QuantumSpeculativeProtocol::new(config, params, 100);
for i in 0..5 {
let msg = format!("Message number {}", i);
let quantum_msg = alice.send(msg.as_bytes());
let received = bob.receive(&quantum_msg);
assert!(received.is_some(), "Message {} should decrypt", i);
assert_eq!(
received.unwrap(),
msg.as_bytes(),
"Message {} content mismatch",
i
);
}
}
#[test]
fn test_mlkem_512_round_trip() {
let mut rng = StdRng::seed_from_u64(1);
let kp = mlkem_ops::generate_512(&mut rng);
assert_eq!(kp.algorithm, KemAlgorithm::MlKem512);
let (ct, ss_enc) = mlkem_ops::encapsulate_512(&kp.ek_bytes, &mut rng).unwrap();
let ss_dec = mlkem_ops::decapsulate_512(&kp.dk_bytes, &ct).unwrap();
assert_eq!(ss_enc.len(), 32);
assert_eq!(ss_enc, ss_dec, "ML-KEM-512 shared secret mismatch");
}
#[test]
fn test_mlkem_768_round_trip() {
let mut rng = StdRng::seed_from_u64(2);
let kp = mlkem_ops::generate_768(&mut rng);
assert_eq!(kp.algorithm, KemAlgorithm::MlKem768);
let (ct, ss_enc) = mlkem_ops::encapsulate_768(&kp.ek_bytes, &mut rng).unwrap();
let ss_dec = mlkem_ops::decapsulate_768(&kp.dk_bytes, &ct).unwrap();
assert_eq!(ss_enc.len(), 32);
assert_eq!(ss_enc, ss_dec, "ML-KEM-768 shared secret mismatch");
}
#[test]
fn test_mlkem_1024_round_trip() {
let mut rng = StdRng::seed_from_u64(3);
let kp = mlkem_ops::generate_1024(&mut rng);
assert_eq!(kp.algorithm, KemAlgorithm::MlKem1024);
let (ct, ss_enc) = mlkem_ops::encapsulate_1024(&kp.ek_bytes, &mut rng).unwrap();
let ss_dec = mlkem_ops::decapsulate_1024(&kp.dk_bytes, &ct).unwrap();
assert_eq!(ss_enc.len(), 32);
assert_eq!(ss_enc, ss_dec, "ML-KEM-1024 shared secret mismatch");
}
#[test]
fn test_mlkem_different_keypairs_produce_different_secrets() {
let mut rng = StdRng::seed_from_u64(4);
let kp1 = mlkem_ops::generate_768(&mut rng);
let kp2 = mlkem_ops::generate_768(&mut rng);
let (_, ss1) = mlkem_ops::encapsulate_768(&kp1.ek_bytes, &mut rng).unwrap();
let (_, ss2) = mlkem_ops::encapsulate_768(&kp2.ek_bytes, &mut rng).unwrap();
assert_ne!(ss1, ss2, "Different keypairs should yield different secrets");
}
#[test]
fn test_mlkem_wrong_key_decapsulation_fails() {
let mut rng = StdRng::seed_from_u64(5);
let kp1 = mlkem_ops::generate_768(&mut rng);
let kp2 = mlkem_ops::generate_768(&mut rng);
let (ct, ss_enc) = mlkem_ops::encapsulate_768(&kp1.ek_bytes, &mut rng).unwrap();
let ss_wrong = mlkem_ops::decapsulate_768(&kp2.dk_bytes, &ct).unwrap();
assert_ne!(
ss_enc, ss_wrong,
"Wrong DK must produce different shared secret (implicit reject)"
);
}
#[test]
fn test_kem_algorithm_default() {
assert_eq!(KemAlgorithm::default(), KemAlgorithm::MlKem768);
}
#[test]
fn test_quantum_key_evolution_with_mlkem() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 42, KemAlgorithm::MlKem768);
let (ct, ss_enc) = ke.encapsulate();
let ss_dec = ke.decapsulate(&ct).unwrap();
assert_eq!(ss_enc, ss_dec, "ML-KEM encaps/decaps via QuantumKeyEvolution");
assert!(!ct.is_empty());
}
#[test]
fn test_quantum_key_evolution_hybrid_kem() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 42, KemAlgorithm::Hybrid);
let (ct, ss_enc) = ke.encapsulate();
let ss_dec = ke.decapsulate(&ct).unwrap();
assert_eq!(ss_enc, ss_dec, "Hybrid RLWE+ML-KEM shared secret mismatch");
assert_eq!(ss_enc.len(), 32, "Hybrid shared secret should be 32 bytes");
assert!(ct.len() > 100, "Hybrid ciphertext should be large");
}
#[test]
fn test_mlkem_key_evolution_maintains_invariant() {
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 55, KemAlgorithm::MlKem768);
for i in 0..5 {
ke.evolve();
let (ct, ss_enc) = ke.encapsulate();
let ss_dec = ke.decapsulate(&ct).unwrap();
assert_eq!(ss_enc, ss_dec, "ML-KEM must work after evolution step {}", i);
}
}
#[test]
fn test_quantum_speculative_protocol_with_mlkem() {
let config = TransformerConfig::default();
let params = LatticeParams {
n: 32,
q: 257,
p: 3,
sigma: 2.0,
};
let mut alice = QuantumSpeculativeProtocol::new_with_algorithm(
config.clone(),
params.clone(),
42,
KemAlgorithm::MlKem768,
);
let mut bob = QuantumSpeculativeProtocol::new_with_algorithm(
config,
params,
42,
KemAlgorithm::MlKem768,
);
assert_eq!(alice.algorithm(), KemAlgorithm::MlKem768);
assert_eq!(bob.algorithm(), KemAlgorithm::MlKem768);
let msg = b"ML-KEM secured message";
let quantum_msg = alice.send(msg);
let received = bob.receive(&quantum_msg);
assert!(received.is_some());
assert_eq!(received.unwrap(), msg);
}
#[test]
fn test_mlkem_ciphertext_sizes() {
let mut rng = StdRng::seed_from_u64(6);
let kp512 = mlkem_ops::generate_512(&mut rng);
let kp768 = mlkem_ops::generate_768(&mut rng);
let kp1024 = mlkem_ops::generate_1024(&mut rng);
let (ct512, _) = mlkem_ops::encapsulate_512(&kp512.ek_bytes, &mut rng).unwrap();
let (ct768, _) = mlkem_ops::encapsulate_768(&kp768.ek_bytes, &mut rng).unwrap();
let (ct1024, _) = mlkem_ops::encapsulate_1024(&kp1024.ek_bytes, &mut rng).unwrap();
assert_eq!(ct512.len(), 768, "ML-KEM-512 ciphertext should be 768 bytes");
assert_eq!(ct768.len(), 1088, "ML-KEM-768 ciphertext should be 1088 bytes");
assert_eq!(ct1024.len(), 1568, "ML-KEM-1024 ciphertext should be 1568 bytes");
assert!(ct512.len() < ct768.len());
assert!(ct768.len() < ct1024.len());
}
}