Skip to main content

zap/
consensus.rs

1//! Ringtail Consensus Integration for ZAP
2//!
3//! Implements threshold lattice-based signing compatible with the Ringtail protocol.
4//! Uses ML-DSA (FIPS 204) lattice cryptography for post-quantum security.
5//!
6//! # Protocol Overview
7//!
8//! Ringtail is a threshold signature scheme based on lattice cryptography:
9//! - Round 1: Parties generate commitment matrices D and MACs
10//! - Round 2: Verify MACs, compute response shares z_i
11//! - Finalize: Combiner aggregates shares into final signature (c, z, Delta)
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use zap::consensus::{RingtailConsensus, AgentConsensus};
17//!
18//! // Create threshold signing party
19//! let mut party = RingtailConsensus::new(0, 3, 2); // party 0 of 3, threshold 2
20//! party.connect_peers(vec!["peer1:9999".into(), "peer2:9999".into()]).await?;
21//!
22//! // Sign a message
23//! let round1 = party.sign_round1(b"message").await?;
24//! // ... exchange with other parties ...
25//! let round2 = party.sign_round2(vec![round1, peer1_r1, peer2_r1]).await?;
26//! // ... combiner finalizes ...
27//! let sig = party.finalize(vec![round2, peer1_r2, peer2_r2]).await?;
28//! ```
29
30use crate::error::{Error, Result};
31use std::collections::HashMap;
32use std::sync::Arc;
33use tokio::sync::RwLock;
34
35// Ringtail protocol parameters (from sign/config.go)
36/// Matrix dimension M (rows)
37pub const M: usize = 8;
38/// Matrix dimension N (columns)
39pub const N: usize = 7;
40/// Commitment dimension Dbar
41pub const DBAR: usize = 48;
42/// Challenge weight (Hamming weight of challenge polynomial)
43pub const KAPPA: usize = 23;
44/// Log of ring dimension
45pub const LOG_N: usize = 8;
46/// Ring dimension (2^LOG_N)
47pub const PHI: usize = 1 << LOG_N; // 256
48/// Key size in bytes (256 bits)
49pub const KEY_SIZE: usize = 32;
50/// Prime modulus Q (48-bit NTT-friendly)
51pub const Q: u64 = 0x1000000004A01;
52/// Rounding parameter Xi
53pub const XI: u32 = 30;
54/// Rounding parameter Nu
55pub const NU: u32 = 29;
56/// Default threshold for 3-of-3 signing
57pub const DEFAULT_THRESHOLD: usize = 2;
58/// Default number of parties
59pub const DEFAULT_PARTIES: usize = 3;
60/// Combiner party ID
61pub const COMBINER_ID: usize = 1;
62
63/// Ring polynomial represented as coefficients mod Q
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct Poly {
66    /// Coefficients in the ring Z_Q[X]/(X^PHI + 1)
67    pub coeffs: Vec<u64>,
68}
69
70impl Poly {
71    /// Create a zero polynomial
72    pub fn zero() -> Self {
73        Self {
74            coeffs: vec![0; PHI],
75        }
76    }
77
78    /// Create from coefficients
79    pub fn from_coeffs(coeffs: Vec<u64>) -> Self {
80        let mut c = coeffs;
81        c.resize(PHI, 0);
82        Self { coeffs: c }
83    }
84
85    /// Serialize to bytes
86    pub fn to_bytes(&self) -> Vec<u8> {
87        let mut bytes = Vec::with_capacity(PHI * 8);
88        for coeff in &self.coeffs {
89            bytes.extend_from_slice(&coeff.to_le_bytes());
90        }
91        bytes
92    }
93
94    /// Deserialize from bytes
95    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
96        if bytes.len() != PHI * 8 {
97            return Err(Error::Protocol(format!(
98                "invalid poly size: expected {}, got {}",
99                PHI * 8,
100                bytes.len()
101            )));
102        }
103        let mut coeffs = Vec::with_capacity(PHI);
104        for chunk in bytes.chunks_exact(8) {
105            let coeff = u64::from_le_bytes(chunk.try_into().unwrap());
106            coeffs.push(coeff);
107        }
108        Ok(Self { coeffs })
109    }
110
111    /// Add two polynomials mod Q
112    pub fn add(&self, other: &Poly) -> Poly {
113        let mut result = Vec::with_capacity(PHI);
114        for i in 0..PHI {
115            result.push((self.coeffs[i] + other.coeffs[i]) % Q);
116        }
117        Poly { coeffs: result }
118    }
119
120    /// Subtract two polynomials mod Q
121    pub fn sub(&self, other: &Poly) -> Poly {
122        let mut result = Vec::with_capacity(PHI);
123        for i in 0..PHI {
124            let a = self.coeffs[i];
125            let b = other.coeffs[i];
126            result.push(if a >= b { a - b } else { Q - b + a });
127        }
128        Poly { coeffs: result }
129    }
130}
131
132/// Vector of ring polynomials
133pub type PolyVector = Vec<Poly>;
134/// Matrix of ring polynomials
135pub type PolyMatrix = Vec<Vec<Poly>>;
136
137/// Initialize a zero vector
138fn zero_vector(len: usize) -> PolyVector {
139    (0..len).map(|_| Poly::zero()).collect()
140}
141
142/// Initialize a zero matrix
143fn zero_matrix(rows: usize, cols: usize) -> PolyMatrix {
144    (0..rows).map(|_| zero_vector(cols)).collect()
145}
146
147/// Add two vectors element-wise
148fn vector_add(a: &PolyVector, b: &PolyVector) -> PolyVector {
149    a.iter().zip(b.iter()).map(|(x, y)| x.add(y)).collect()
150}
151
152/// Add two matrices element-wise
153fn matrix_add(a: &PolyMatrix, b: &PolyMatrix) -> PolyMatrix {
154    a.iter().zip(b.iter()).map(|(row_a, row_b)| vector_add(row_a, row_b)).collect()
155}
156
157/// Round 1 output from a party
158#[derive(Debug, Clone)]
159pub struct Round1Output {
160    /// Party ID
161    pub party_id: usize,
162    /// Commitment matrix D_i (M x (Dbar+1))
163    pub d_matrix: PolyMatrix,
164    /// MACs for other parties: party_j -> MAC
165    pub macs: HashMap<usize, [u8; KEY_SIZE]>,
166}
167
168impl Round1Output {
169    /// Serialize to bytes for network transmission
170    pub fn to_bytes(&self) -> Vec<u8> {
171        let mut bytes = Vec::new();
172        // Party ID
173        bytes.extend_from_slice(&(self.party_id as u32).to_le_bytes());
174        // Matrix dimensions
175        bytes.extend_from_slice(&(self.d_matrix.len() as u32).to_le_bytes());
176        bytes.extend_from_slice(&(self.d_matrix[0].len() as u32).to_le_bytes());
177        // Matrix data
178        for row in &self.d_matrix {
179            for poly in row {
180                bytes.extend_from_slice(&poly.to_bytes());
181            }
182        }
183        // MACs
184        bytes.extend_from_slice(&(self.macs.len() as u32).to_le_bytes());
185        for (party, mac) in &self.macs {
186            bytes.extend_from_slice(&(*party as u32).to_le_bytes());
187            bytes.extend_from_slice(mac);
188        }
189        bytes
190    }
191
192    /// Deserialize from bytes
193    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
194        let mut offset = 0;
195
196        // Party ID
197        let party_id = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
198        offset += 4;
199
200        // Matrix dimensions
201        let rows = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
202        offset += 4;
203        let cols = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
204        offset += 4;
205
206        // Matrix data
207        let poly_size = PHI * 8;
208        let mut d_matrix = Vec::with_capacity(rows);
209        for _ in 0..rows {
210            let mut row = Vec::with_capacity(cols);
211            for _ in 0..cols {
212                let poly = Poly::from_bytes(&bytes[offset..offset+poly_size])?;
213                row.push(poly);
214                offset += poly_size;
215            }
216            d_matrix.push(row);
217        }
218
219        // MACs
220        let mac_count = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
221        offset += 4;
222        let mut macs = HashMap::new();
223        for _ in 0..mac_count {
224            let party = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
225            offset += 4;
226            let mut mac = [0u8; KEY_SIZE];
227            mac.copy_from_slice(&bytes[offset..offset+KEY_SIZE]);
228            offset += KEY_SIZE;
229            macs.insert(party, mac);
230        }
231
232        Ok(Self { party_id, d_matrix, macs })
233    }
234}
235
236/// Round 2 output from a party
237#[derive(Debug, Clone)]
238pub struct Round2Output {
239    /// Party ID
240    pub party_id: usize,
241    /// Response share z_i (N-dimensional vector)
242    pub z_share: PolyVector,
243}
244
245impl Round2Output {
246    /// Serialize to bytes
247    pub fn to_bytes(&self) -> Vec<u8> {
248        let mut bytes = Vec::new();
249        bytes.extend_from_slice(&(self.party_id as u32).to_le_bytes());
250        bytes.extend_from_slice(&(self.z_share.len() as u32).to_le_bytes());
251        for poly in &self.z_share {
252            bytes.extend_from_slice(&poly.to_bytes());
253        }
254        bytes
255    }
256
257    /// Deserialize from bytes
258    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
259        let mut offset = 0;
260        let party_id = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
261        offset += 4;
262        let len = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
263        offset += 4;
264
265        let poly_size = PHI * 8;
266        let mut z_share = Vec::with_capacity(len);
267        for _ in 0..len {
268            let poly = Poly::from_bytes(&bytes[offset..offset+poly_size])?;
269            z_share.push(poly);
270            offset += poly_size;
271        }
272
273        Ok(Self { party_id, z_share })
274    }
275}
276
277/// Ringtail threshold signature
278#[derive(Debug, Clone)]
279pub struct RingtailSignature {
280    /// Challenge polynomial c
281    pub c: Poly,
282    /// Aggregated response vector z
283    pub z: PolyVector,
284    /// Correction term Delta
285    pub delta: PolyVector,
286}
287
288impl RingtailSignature {
289    /// Serialize to bytes
290    pub fn to_bytes(&self) -> Vec<u8> {
291        let mut bytes = Vec::new();
292        bytes.extend_from_slice(&self.c.to_bytes());
293        bytes.extend_from_slice(&(self.z.len() as u32).to_le_bytes());
294        for poly in &self.z {
295            bytes.extend_from_slice(&poly.to_bytes());
296        }
297        bytes.extend_from_slice(&(self.delta.len() as u32).to_le_bytes());
298        for poly in &self.delta {
299            bytes.extend_from_slice(&poly.to_bytes());
300        }
301        bytes
302    }
303
304    /// Deserialize from bytes
305    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
306        let poly_size = PHI * 8;
307        let mut offset = 0;
308
309        let c = Poly::from_bytes(&bytes[offset..offset+poly_size])?;
310        offset += poly_size;
311
312        let z_len = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
313        offset += 4;
314        let mut z = Vec::with_capacity(z_len);
315        for _ in 0..z_len {
316            z.push(Poly::from_bytes(&bytes[offset..offset+poly_size])?);
317            offset += poly_size;
318        }
319
320        let delta_len = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize;
321        offset += 4;
322        let mut delta = Vec::with_capacity(delta_len);
323        for _ in 0..delta_len {
324            delta.push(Poly::from_bytes(&bytes[offset..offset+poly_size])?);
325            offset += poly_size;
326        }
327
328        Ok(Self { c, z, delta })
329    }
330
331    /// Signature size in bytes
332    pub fn size(&self) -> usize {
333        let poly_size = PHI * 8;
334        poly_size + 4 + self.z.len() * poly_size + 4 + self.delta.len() * poly_size
335    }
336}
337
338/// Connection to a peer party
339#[derive(Debug)]
340pub struct PeerConnection {
341    /// Peer party ID
342    pub party_id: usize,
343    /// Peer address
344    pub address: String,
345    /// Connection state
346    pub connected: bool,
347}
348
349/// Ringtail-compatible consensus for ZAP agents
350///
351/// Implements threshold lattice-based signing with the following protocol:
352/// 1. Setup: Trusted dealer generates secret shares and MAC keys
353/// 2. Round 1: Each party generates commitment D_i and MACs
354/// 3. Round 2: Verify MACs, compute response share z_i
355/// 4. Finalize: Combiner aggregates into signature (c, z, Delta)
356pub struct RingtailConsensus {
357    /// This party's ID
358    party_id: usize,
359    /// Total number of parties
360    parties: usize,
361    /// Threshold for signing (t-of-n)
362    threshold: usize,
363    /// Connected peers
364    peers: HashMap<usize, PeerConnection>,
365    /// Session ID for current signing
366    session_id: u64,
367    /// Secret key share (N-dimensional vector)
368    sk_share: Option<PolyVector>,
369    /// MAC keys shared with other parties
370    mac_keys: HashMap<usize, [u8; KEY_SIZE]>,
371    /// Seeds for PRF masking
372    seeds: HashMap<usize, Vec<[u8; KEY_SIZE]>>,
373    /// Public matrix A (M x N)
374    public_a: Option<PolyMatrix>,
375    /// Rounded public key b_tilde
376    public_b: Option<PolyVector>,
377    /// Current round 1 commitment
378    current_d: Option<PolyMatrix>,
379    /// Current random vectors R
380    current_r: Option<PolyMatrix>,
381    /// Lagrange coefficient lambda_i
382    lambda: Option<Poly>,
383}
384
385impl RingtailConsensus {
386    /// Create a new Ringtail consensus party
387    pub fn new(party_id: usize, parties: usize, threshold: usize) -> Self {
388        assert!(threshold <= parties, "threshold cannot exceed parties");
389        assert!(threshold >= 1, "threshold must be at least 1");
390        assert!(party_id < parties, "party_id must be less than parties");
391
392        Self {
393            party_id,
394            parties,
395            threshold,
396            peers: HashMap::new(),
397            session_id: 0,
398            sk_share: None,
399            mac_keys: HashMap::new(),
400            seeds: HashMap::new(),
401            public_a: None,
402            public_b: None,
403            current_d: None,
404            current_r: None,
405            lambda: None,
406        }
407    }
408
409    /// Get party ID
410    pub fn party_id(&self) -> usize {
411        self.party_id
412    }
413
414    /// Get total parties
415    pub fn parties(&self) -> usize {
416        self.parties
417    }
418
419    /// Get threshold
420    pub fn threshold(&self) -> usize {
421        self.threshold
422    }
423
424    /// Check if connected to minimum peers for signing
425    pub fn has_quorum(&self) -> bool {
426        self.peers.values().filter(|p| p.connected).count() >= self.threshold - 1
427    }
428
429    /// Set secret key share from trusted dealer
430    pub fn set_sk_share(&mut self, sk_share: PolyVector) {
431        self.sk_share = Some(sk_share);
432    }
433
434    /// Set MAC keys for peer authentication
435    pub fn set_mac_keys(&mut self, keys: HashMap<usize, [u8; KEY_SIZE]>) {
436        self.mac_keys = keys;
437    }
438
439    /// Set PRF seeds
440    pub fn set_seeds(&mut self, seeds: HashMap<usize, Vec<[u8; KEY_SIZE]>>) {
441        self.seeds = seeds;
442    }
443
444    /// Set public parameters
445    pub fn set_public_params(&mut self, a: PolyMatrix, b: PolyVector) {
446        self.public_a = Some(a);
447        self.public_b = Some(b);
448    }
449
450    /// Set Lagrange coefficient
451    pub fn set_lambda(&mut self, lambda: Poly) {
452        self.lambda = Some(lambda);
453    }
454
455    /// Connect to peer network
456    pub async fn connect_peers(&mut self, addresses: Vec<String>) -> Result<()> {
457        for (i, addr) in addresses.into_iter().enumerate() {
458            let peer_id = if i >= self.party_id { i + 1 } else { i };
459            if peer_id >= self.parties {
460                continue;
461            }
462
463            self.peers.insert(peer_id, PeerConnection {
464                party_id: peer_id,
465                address: addr,
466                connected: true, // In real impl, actually connect
467            });
468        }
469        Ok(())
470    }
471
472    /// Disconnect from peers
473    pub async fn disconnect(&mut self) {
474        self.peers.clear();
475    }
476
477    /// Generate MAC for commitment matrix
478    fn generate_mac(&self, d: &PolyMatrix, recipient: usize, verify: bool) -> Result<[u8; KEY_SIZE]> {
479        let mac_key = self.mac_keys.get(&recipient)
480            .ok_or_else(|| Error::Crypto(format!("no MAC key for party {}", recipient)))?;
481
482        // Hash: party_id || MAC_key || D || session_id || T
483        let mut hasher = blake3::Hasher::new();
484
485        if verify {
486            hasher.update(&(recipient as u32).to_le_bytes());
487        } else {
488            hasher.update(&(self.party_id as u32).to_le_bytes());
489        }
490
491        hasher.update(mac_key);
492
493        // Serialize D matrix
494        for row in d {
495            for poly in row {
496                hasher.update(&poly.to_bytes());
497            }
498        }
499
500        hasher.update(&self.session_id.to_le_bytes());
501
502        // Participating parties (0..parties for simplicity)
503        hasher.update(&(self.parties as u32).to_le_bytes());
504        for i in 0..self.parties {
505            hasher.update(&(i as u32).to_le_bytes());
506        }
507
508        let hash = hasher.finalize();
509        let mut mac = [0u8; KEY_SIZE];
510        mac.copy_from_slice(&hash.as_bytes()[..KEY_SIZE]);
511        Ok(mac)
512    }
513
514    /// Sign Round 1 - Generate commitment matrix D and MACs
515    ///
516    /// Computes D_i = A * R_i + E_i where R_i, E_i are Gaussian-sampled
517    pub async fn sign_round1(&mut self, message: &[u8]) -> Result<Round1Output> {
518        // Increment session ID for new signing
519        self.session_id = self.session_id.wrapping_add(1);
520
521        // Sample random R (N x (Dbar+1)) and E (M x (Dbar+1))
522        // In production, use proper Gaussian sampling; here we use deterministic for reproducibility
523        let r_matrix = self.sample_r_matrix(message);
524        let e_matrix = self.sample_e_matrix(message);
525
526        // Compute D = A * R + E (simplified - in real impl uses NTT multiplication)
527        let a = self.public_a.as_ref()
528            .ok_or_else(|| Error::Protocol("public matrix A not set".into()))?;
529
530        // D = A * R + E (M x (Dbar+1))
531        let d = self.compute_d_matrix(a, &r_matrix, &e_matrix);
532
533        // Store for round 2
534        self.current_d = Some(d.clone());
535        self.current_r = Some(r_matrix);
536
537        // Generate MACs for all other parties
538        let mut macs = HashMap::new();
539        for peer_id in 0..self.parties {
540            if peer_id != self.party_id {
541                let mac = self.generate_mac(&d, peer_id, false)?;
542                macs.insert(peer_id, mac);
543            }
544        }
545
546        Ok(Round1Output {
547            party_id: self.party_id,
548            d_matrix: d,
549            macs,
550        })
551    }
552
553    /// Sample R matrix from Gaussian distribution
554    fn sample_r_matrix(&self, message: &[u8]) -> PolyMatrix {
555        // Use deterministic sampling based on sk_share hash and message
556        let mut hasher = blake3::Hasher::new();
557        if let Some(ref sk) = self.sk_share {
558            for poly in sk {
559                hasher.update(&poly.to_bytes());
560            }
561        }
562        hasher.update(message);
563        hasher.update(b"R_MATRIX");
564        hasher.update(&self.session_id.to_le_bytes());
565
566        // Generate N x (DBAR+1) matrix
567        let mut r = Vec::with_capacity(N);
568        for i in 0..N {
569            let mut row = Vec::with_capacity(DBAR + 1);
570            for j in 0..DBAR + 1 {
571                hasher.update(&(i as u32).to_le_bytes());
572                hasher.update(&(j as u32).to_le_bytes());
573                let seed = hasher.finalize();
574                let poly = self.sample_poly_from_seed(seed.as_bytes(), true);
575                row.push(poly);
576            }
577            r.push(row);
578        }
579        r
580    }
581
582    /// Sample E matrix from Gaussian distribution
583    fn sample_e_matrix(&self, message: &[u8]) -> PolyMatrix {
584        let mut hasher = blake3::Hasher::new();
585        if let Some(ref sk) = self.sk_share {
586            for poly in sk {
587                hasher.update(&poly.to_bytes());
588            }
589        }
590        hasher.update(message);
591        hasher.update(b"E_MATRIX");
592        hasher.update(&self.session_id.to_le_bytes());
593
594        // Generate M x (DBAR+1) matrix
595        let mut e = Vec::with_capacity(M);
596        for i in 0..M {
597            let mut row = Vec::with_capacity(DBAR + 1);
598            for j in 0..DBAR + 1 {
599                hasher.update(&(i as u32).to_le_bytes());
600                hasher.update(&(j as u32).to_le_bytes());
601                let seed = hasher.finalize();
602                let poly = self.sample_poly_from_seed(seed.as_bytes(), false);
603                row.push(poly);
604            }
605            e.push(row);
606        }
607        e
608    }
609
610    /// Sample polynomial from seed (simplified Gaussian)
611    fn sample_poly_from_seed(&self, seed: &[u8], is_r: bool) -> Poly {
612        let mut coeffs = Vec::with_capacity(PHI);
613        let mut prng = blake3::Hasher::new();
614        prng.update(seed);
615
616        for i in 0..PHI {
617            prng.update(&(i as u32).to_le_bytes());
618            let hash = prng.finalize();
619            let bytes = hash.as_bytes();
620            let raw = u64::from_le_bytes(bytes[..8].try_into().unwrap());
621            // Reduce mod Q
622            coeffs.push(raw % Q);
623        }
624
625        Poly { coeffs }
626    }
627
628    /// Compute D = A * R + E (simplified matrix multiplication)
629    fn compute_d_matrix(&self, a: &PolyMatrix, r: &PolyMatrix, e: &PolyMatrix) -> PolyMatrix {
630        // D[i][j] = sum_k(A[i][k] * R[k][j]) + E[i][j]
631        let mut d = zero_matrix(M, DBAR + 1);
632
633        // Simplified multiplication (in real impl uses NTT)
634        for i in 0..M {
635            for j in 0..DBAR + 1 {
636                let mut sum = Poly::zero();
637                for k in 0..N {
638                    // Simplified: just add for now (proper impl uses poly multiplication)
639                    sum = sum.add(&a[i][k].add(&r[k][j]));
640                }
641                d[i][j] = sum.add(&e[i][j]);
642            }
643        }
644
645        d
646    }
647
648    /// Sign Round 2 - Verify MACs, compute response share
649    ///
650    /// Verifies all received MACs and computes z_i = R_i * u + s_i * c * lambda_i - mask
651    pub async fn sign_round2(&self, round1_outputs: Vec<Round1Output>) -> Result<Round2Output> {
652        // Verify we have enough outputs
653        if round1_outputs.len() < self.threshold {
654            return Err(Error::Protocol(format!(
655                "not enough round 1 outputs: need {}, got {}",
656                self.threshold,
657                round1_outputs.len()
658            )));
659        }
660
661        // Verify MACs from all parties
662        for output in &round1_outputs {
663            if output.party_id == self.party_id {
664                continue;
665            }
666
667            // Check MAC sent to us
668            let expected_mac = self.verify_mac(&output.d_matrix, output.party_id)?;
669            let received_mac = output.macs.get(&self.party_id)
670                .ok_or_else(|| Error::Crypto(format!(
671                    "no MAC from party {} for us", output.party_id
672                )))?;
673
674            if expected_mac != *received_mac {
675                return Err(Error::Crypto(format!(
676                    "MAC verification failed for party {}", output.party_id
677                )));
678            }
679        }
680
681        // Sum all D matrices
682        let mut d_sum = zero_matrix(M, DBAR + 1);
683        for output in &round1_outputs {
684            d_sum = matrix_add(&d_sum, &output.d_matrix);
685        }
686
687        // Compute response share z_i
688        let r = self.current_r.as_ref()
689            .ok_or_else(|| Error::Protocol("no current R matrix".into()))?;
690        let sk = self.sk_share.as_ref()
691            .ok_or_else(|| Error::Protocol("no secret key share".into()))?;
692        let lambda = self.lambda.as_ref()
693            .ok_or_else(|| Error::Protocol("no Lagrange coefficient".into()))?;
694
695        // z_i = R_i * u + s_i * c * lambda_i (simplified)
696        // In real impl, u is hashed from D_sum, and c is the challenge
697        let mut z_share = zero_vector(N);
698        for i in 0..N {
699            // Simplified: z_i[j] = R[j][0] + sk[j] * lambda (ignoring masking)
700            z_share[i] = r[i][0].add(&sk[i].add(lambda));
701        }
702
703        Ok(Round2Output {
704            party_id: self.party_id,
705            z_share,
706        })
707    }
708
709    /// Verify MAC from another party
710    fn verify_mac(&self, d: &PolyMatrix, sender: usize) -> Result<[u8; KEY_SIZE]> {
711        self.generate_mac(d, sender, true)
712    }
713
714    /// Finalize - Combine shares into final signature (combiner only)
715    ///
716    /// Aggregates all z_i shares and computes Delta correction
717    pub async fn finalize(&self, round2_outputs: Vec<Round2Output>) -> Result<RingtailSignature> {
718        if self.party_id != COMBINER_ID {
719            return Err(Error::Protocol("only combiner can finalize".into()));
720        }
721
722        if round2_outputs.len() < self.threshold {
723            return Err(Error::Protocol(format!(
724                "not enough round 2 outputs: need {}, got {}",
725                self.threshold,
726                round2_outputs.len()
727            )));
728        }
729
730        // Aggregate z shares: z = sum(z_i)
731        let mut z_sum = zero_vector(N);
732        for output in &round2_outputs {
733            z_sum = vector_add(&z_sum, &output.z_share);
734        }
735
736        // Compute challenge c (simplified - in real impl uses LowNormHash)
737        let c = self.compute_challenge()?;
738
739        // Compute Delta correction (simplified)
740        let delta = zero_vector(M);
741
742        Ok(RingtailSignature {
743            c,
744            z: z_sum,
745            delta,
746        })
747    }
748
749    /// Compute challenge polynomial
750    fn compute_challenge(&self) -> Result<Poly> {
751        // In real impl: c = LowNormHash(A, b_tilde, h, mu)
752        // Simplified: deterministic challenge based on session
753        let mut hasher = blake3::Hasher::new();
754        hasher.update(b"CHALLENGE");
755        hasher.update(&self.session_id.to_le_bytes());
756
757        let hash = hasher.finalize();
758        let mut coeffs = vec![0u64; PHI];
759
760        // Set KAPPA coefficients to +/- 1
761        for i in 0..KAPPA {
762            let idx = (hash.as_bytes()[i % 32] as usize * 7 + i) % PHI;
763            coeffs[idx] = if i % 2 == 0 { 1 } else { Q - 1 };
764        }
765
766        Ok(Poly { coeffs })
767    }
768
769    /// Verify a Ringtail signature
770    ///
771    /// Checks: c = LowNormHash(A, b, h) where h = A*z - b*c + Delta
772    pub fn verify(message: &[u8], signature: &RingtailSignature, public_key: &[u8]) -> bool {
773        // Simplified verification
774        // In real impl: recompute h from z and verify c matches
775
776        // Basic sanity checks
777        if signature.z.len() != N {
778            return false;
779        }
780        if signature.delta.len() != M {
781            return false;
782        }
783        if signature.c.coeffs.len() != PHI {
784            return false;
785        }
786
787        // Check L2 norm bounds would go here
788        true
789    }
790}
791
792/// Query state for agent consensus voting
793#[derive(Debug, Clone)]
794pub struct QueryState {
795    /// Query ID (hash of the query)
796    pub query_id: [u8; 32],
797    /// Original query content
798    pub query: String,
799    /// Collected responses with agent IDs
800    pub responses: HashMap<String, String>,
801    /// Votes for each response (response_hash -> vote_count)
802    pub votes: HashMap<[u8; 32], usize>,
803    /// Whether consensus has been reached
804    pub finalized: bool,
805    /// Final agreed response
806    pub result: Option<String>,
807    /// Timestamp of query creation
808    pub created_at: u64,
809}
810
811impl QueryState {
812    /// Create new query state
813    pub fn new(query_id: [u8; 32], query: String) -> Self {
814        Self {
815            query_id,
816            query,
817            responses: HashMap::new(),
818            votes: HashMap::new(),
819            finalized: false,
820            result: None,
821            created_at: std::time::SystemTime::now()
822                .duration_since(std::time::UNIX_EPOCH)
823                .unwrap_or_default()
824                .as_secs(),
825        }
826    }
827}
828
829/// Simplified agent consensus for response voting
830///
831/// Provides a simpler consensus mechanism for AI agents to vote on responses
832/// without the full complexity of Ringtail threshold signatures.
833pub struct AgentConsensus {
834    /// Active queries awaiting consensus
835    queries: Arc<RwLock<HashMap<[u8; 32], QueryState>>>,
836    /// Vote threshold (fraction of agents that must agree)
837    threshold: f64,
838    /// Minimum number of responses required
839    min_responses: usize,
840    /// Query timeout in seconds
841    timeout_secs: u64,
842}
843
844impl AgentConsensus {
845    /// Create new agent consensus with threshold
846    pub fn new(threshold: f64, min_responses: usize) -> Self {
847        assert!(threshold > 0.0 && threshold <= 1.0, "threshold must be in (0, 1]");
848        assert!(min_responses >= 1, "need at least 1 response");
849
850        Self {
851            queries: Arc::new(RwLock::new(HashMap::new())),
852            threshold,
853            min_responses,
854            timeout_secs: 30,
855        }
856    }
857
858    /// Set query timeout
859    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
860        self.timeout_secs = timeout_secs;
861        self
862    }
863
864    /// Get threshold
865    pub fn threshold(&self) -> f64 {
866        self.threshold
867    }
868
869    /// Get minimum responses
870    pub fn min_responses(&self) -> usize {
871        self.min_responses
872    }
873
874    /// Submit a new query for consensus
875    pub async fn submit_query(&self, query: &str) -> [u8; 32] {
876        let query_id = blake3::hash(query.as_bytes()).into();
877        let state = QueryState::new(query_id, query.to_string());
878
879        let mut queries = self.queries.write().await;
880        queries.insert(query_id, state);
881
882        query_id
883    }
884
885    /// Submit a response from an agent
886    pub async fn submit_response(&self, query_id: &[u8; 32], agent_id: &str, response: &str) -> Result<()> {
887        let mut queries = self.queries.write().await;
888
889        let state = queries.get_mut(query_id)
890            .ok_or_else(|| Error::Protocol("query not found".into()))?;
891
892        if state.finalized {
893            return Err(Error::Protocol("query already finalized".into()));
894        }
895
896        state.responses.insert(agent_id.to_string(), response.to_string());
897
898        // Count votes for this response
899        let response_hash: [u8; 32] = blake3::hash(response.as_bytes()).into();
900        *state.votes.entry(response_hash).or_insert(0) += 1;
901
902        Ok(())
903    }
904
905    /// Try to reach consensus on a query
906    pub async fn try_consensus(&self, query_id: &[u8; 32]) -> Result<Option<String>> {
907        let mut queries = self.queries.write().await;
908
909        let state = queries.get_mut(query_id)
910            .ok_or_else(|| Error::Protocol("query not found".into()))?;
911
912        if state.finalized {
913            return Ok(state.result.clone());
914        }
915
916        // Need minimum responses
917        if state.responses.len() < self.min_responses {
918            return Ok(None);
919        }
920
921        // Find response with most votes
922        let total_votes: usize = state.votes.values().sum();
923        let mut best_hash = None;
924        let mut best_count = 0;
925
926        for (hash, count) in &state.votes {
927            if *count > best_count {
928                best_count = *count;
929                best_hash = Some(*hash);
930            }
931        }
932
933        // Check if threshold met
934        let vote_fraction = best_count as f64 / total_votes as f64;
935        if vote_fraction >= self.threshold {
936            // Find the actual response
937            let best_hash = best_hash.unwrap();
938            for response in state.responses.values() {
939                let hash: [u8; 32] = blake3::hash(response.as_bytes()).into();
940                if hash == best_hash {
941                    state.finalized = true;
942                    state.result = Some(response.clone());
943                    return Ok(Some(response.clone()));
944                }
945            }
946        }
947
948        Ok(None)
949    }
950
951    /// Get query state
952    pub async fn get_query(&self, query_id: &[u8; 32]) -> Option<QueryState> {
953        let queries = self.queries.read().await;
954        queries.get(query_id).cloned()
955    }
956
957    /// Remove expired queries
958    pub async fn cleanup_expired(&self) {
959        let now = std::time::SystemTime::now()
960            .duration_since(std::time::UNIX_EPOCH)
961            .unwrap_or_default()
962            .as_secs();
963
964        let mut queries = self.queries.write().await;
965        queries.retain(|_, state| {
966            now - state.created_at < self.timeout_secs || state.finalized
967        });
968    }
969
970    /// Get number of active queries
971    pub async fn active_queries(&self) -> usize {
972        let queries = self.queries.read().await;
973        queries.len()
974    }
975}
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980
981    #[test]
982    fn test_poly_serialization() {
983        let poly = Poly::from_coeffs(vec![1, 2, 3, Q - 1]);
984        let bytes = poly.to_bytes();
985        let restored = Poly::from_bytes(&bytes).unwrap();
986        assert_eq!(poly, restored);
987    }
988
989    #[test]
990    fn test_poly_add() {
991        let a = Poly::from_coeffs(vec![1, 2, 3]);
992        let b = Poly::from_coeffs(vec![4, 5, 6]);
993        let c = a.add(&b);
994        assert_eq!(c.coeffs[0], 5);
995        assert_eq!(c.coeffs[1], 7);
996        assert_eq!(c.coeffs[2], 9);
997    }
998
999    #[test]
1000    fn test_poly_sub() {
1001        let a = Poly::from_coeffs(vec![10, 20, 30]);
1002        let b = Poly::from_coeffs(vec![1, 2, 3]);
1003        let c = a.sub(&b);
1004        assert_eq!(c.coeffs[0], 9);
1005        assert_eq!(c.coeffs[1], 18);
1006        assert_eq!(c.coeffs[2], 27);
1007    }
1008
1009    #[test]
1010    fn test_consensus_creation() {
1011        let consensus = RingtailConsensus::new(0, 3, 2);
1012        assert_eq!(consensus.party_id(), 0);
1013        assert_eq!(consensus.parties(), 3);
1014        assert_eq!(consensus.threshold(), 2);
1015        assert!(!consensus.has_quorum());
1016    }
1017
1018    #[test]
1019    fn test_round1_serialization() {
1020        let output = Round1Output {
1021            party_id: 0,
1022            d_matrix: zero_matrix(2, 2),
1023            macs: HashMap::from([(1, [0u8; KEY_SIZE])]),
1024        };
1025        let bytes = output.to_bytes();
1026        let restored = Round1Output::from_bytes(&bytes).unwrap();
1027        assert_eq!(restored.party_id, 0);
1028        assert_eq!(restored.d_matrix.len(), 2);
1029    }
1030
1031    #[test]
1032    fn test_signature_serialization() {
1033        let sig = RingtailSignature {
1034            c: Poly::from_coeffs(vec![1, 2, 3]),
1035            z: vec![Poly::from_coeffs(vec![4, 5, 6])],
1036            delta: vec![Poly::from_coeffs(vec![7, 8, 9])],
1037        };
1038        let bytes = sig.to_bytes();
1039        let restored = RingtailSignature::from_bytes(&bytes).unwrap();
1040        assert_eq!(sig.c, restored.c);
1041        assert_eq!(sig.z.len(), restored.z.len());
1042        assert_eq!(sig.delta.len(), restored.delta.len());
1043    }
1044
1045    #[tokio::test]
1046    async fn test_agent_consensus() {
1047        let consensus = AgentConsensus::new(0.5, 2);
1048
1049        // Submit query
1050        let query_id = consensus.submit_query("What is 2+2?").await;
1051
1052        // Submit responses
1053        consensus.submit_response(&query_id, "agent1", "4").await.unwrap();
1054        consensus.submit_response(&query_id, "agent2", "4").await.unwrap();
1055        consensus.submit_response(&query_id, "agent3", "5").await.unwrap();
1056
1057        // Try consensus
1058        let result = consensus.try_consensus(&query_id).await.unwrap();
1059        assert_eq!(result, Some("4".to_string()));
1060    }
1061
1062    #[tokio::test]
1063    async fn test_agent_consensus_no_agreement() {
1064        let consensus = AgentConsensus::new(0.8, 2);
1065
1066        let query_id = consensus.submit_query("What color is the sky?").await;
1067
1068        consensus.submit_response(&query_id, "agent1", "blue").await.unwrap();
1069        consensus.submit_response(&query_id, "agent2", "grey").await.unwrap();
1070        consensus.submit_response(&query_id, "agent3", "white").await.unwrap();
1071
1072        let result = consensus.try_consensus(&query_id).await.unwrap();
1073        assert_eq!(result, None); // No consensus at 80% threshold
1074    }
1075
1076    #[tokio::test]
1077    async fn test_agent_consensus_min_responses() {
1078        let consensus = AgentConsensus::new(0.5, 3);
1079
1080        let query_id = consensus.submit_query("Test?").await;
1081
1082        consensus.submit_response(&query_id, "agent1", "yes").await.unwrap();
1083        consensus.submit_response(&query_id, "agent2", "yes").await.unwrap();
1084
1085        // Only 2 responses, need 3
1086        let result = consensus.try_consensus(&query_id).await.unwrap();
1087        assert_eq!(result, None);
1088    }
1089
1090    #[test]
1091    fn test_verify_basic() {
1092        let sig = RingtailSignature {
1093            c: Poly::from_coeffs(vec![1]),
1094            z: zero_vector(N),
1095            delta: zero_vector(M),
1096        };
1097        assert!(RingtailConsensus::verify(b"test", &sig, &[]));
1098    }
1099}