Skip to main content

cougr_core/zk/
advanced.rs

1//! Experimental phase-3 ZK patterns.
2//!
3//! These APIs make the phase-3 roadmap concrete without overstating maturity:
4//!
5//! - fog-of-war orchestration around Merkle roots
6//! - multiplayer state-channel transition contracts
7//! - recursive proof-composition descriptors
8//!
9//! They remain part of `zk::experimental` because the repository is only
10//! committing to explicit orchestration and public-input contracts here, not to
11//! production-ready confidentiality guarantees.
12
13use alloc::vec::Vec;
14use soroban_sdk::{contracttype, BytesN, Env};
15
16use super::error::ZKError;
17use super::merkle::tree::MerkleTree;
18use super::traits::{bytes32_to_scalar, i32_to_scalar, u32_to_scalar, u64_to_scalar, GameCircuit};
19use super::types::{Groth16Proof, Scalar, VerificationKey};
20
21/// Snapshot of a player's currently visible fog-of-war state.
22#[contracttype]
23#[derive(Clone, Debug)]
24pub struct FogOfWarSnapshot {
25    /// Merkle root of the hidden map or board state.
26    pub map_root: BytesN<32>,
27    /// Merkle root of the tiles the player has already explored.
28    pub explored_root: BytesN<32>,
29    /// Player origin used by the exploration circuit.
30    pub origin_x: i32,
31    pub origin_y: i32,
32    /// Maximum Euclidean distance the player may reveal from the origin.
33    pub visibility_radius: u32,
34}
35
36impl FogOfWarSnapshot {
37    /// Returns `true` when the target tile is within the visible window.
38    pub fn can_reveal(&self, tile_x: i32, tile_y: i32) -> bool {
39        let dx = i64::from(tile_x) - i64::from(self.origin_x);
40        let dy = i64::from(tile_y) - i64::from(self.origin_y);
41        let radius = i64::from(self.visibility_radius);
42
43        dx * dx + dy * dy <= radius * radius
44    }
45}
46
47/// Root transition for a single fog-of-war exploration update.
48#[contracttype]
49#[derive(Clone, Debug)]
50pub struct FogOfWarTransition {
51    pub prior_explored_root: BytesN<32>,
52    pub next_explored_root: BytesN<32>,
53    pub tile_x: i32,
54    pub tile_y: i32,
55}
56
57/// Apply a validated exploration update to a fog-of-war snapshot.
58pub fn apply_fog_of_war_transition(
59    snapshot: &FogOfWarSnapshot,
60    transition: &FogOfWarTransition,
61) -> Result<FogOfWarSnapshot, ZKError> {
62    if snapshot.explored_root != transition.prior_explored_root {
63        return Err(ZKError::InvalidStateTransition);
64    }
65
66    if !snapshot.can_reveal(transition.tile_x, transition.tile_y) {
67        return Err(ZKError::InvalidVisibility);
68    }
69
70    let mut updated = snapshot.clone();
71    updated.explored_root = transition.next_explored_root.clone();
72    Ok(updated)
73}
74
75/// Experimental circuit contract for fog-of-war exploration proofs.
76pub struct FogOfWarCircuit {
77    pub vk: VerificationKey,
78    pub max_visibility_radius: u32,
79}
80
81impl GameCircuit for FogOfWarCircuit {
82    fn verification_key(&self) -> &VerificationKey {
83        &self.vk
84    }
85}
86
87impl FogOfWarCircuit {
88    pub fn new(vk: VerificationKey, max_visibility_radius: u32) -> Self {
89        Self {
90            vk,
91            max_visibility_radius,
92        }
93    }
94
95    /// Verify that a fog-of-war transition is valid for the provided snapshot.
96    ///
97    /// Public inputs:
98    /// `[map_root, prior_explored_root, next_explored_root, origin_x, origin_y, tile_x, tile_y, visibility_radius]`.
99    pub fn verify_exploration(
100        &self,
101        env: &Env,
102        proof: &Groth16Proof,
103        snapshot: &FogOfWarSnapshot,
104        transition: &FogOfWarTransition,
105    ) -> Result<bool, ZKError> {
106        if snapshot.visibility_radius > self.max_visibility_radius {
107            return Err(ZKError::InvalidVisibility);
108        }
109
110        let _ = apply_fog_of_war_transition(snapshot, transition)?;
111
112        let public_inputs = Vec::from([
113            bytes32_to_scalar(&snapshot.map_root),
114            bytes32_to_scalar(&transition.prior_explored_root),
115            bytes32_to_scalar(&transition.next_explored_root),
116            i32_to_scalar(env, snapshot.origin_x),
117            i32_to_scalar(env, snapshot.origin_y),
118            i32_to_scalar(env, transition.tile_x),
119            i32_to_scalar(env, transition.tile_y),
120            u32_to_scalar(env, snapshot.visibility_radius),
121        ]);
122
123        self.verify_with_inputs(env, proof, &public_inputs)
124    }
125}
126
127/// Off-chain state channel tracked by on-chain commitments and dispute metadata.
128#[contracttype]
129#[derive(Clone, Debug)]
130pub struct ZkStateChannel {
131    pub channel_id: BytesN<32>,
132    pub participants_root: BytesN<32>,
133    pub state_root: BytesN<32>,
134    pub round: u64,
135    pub dispute_deadline: u64,
136    pub closed: bool,
137}
138
139/// Proposed state transition for a multiplayer ZK state channel.
140#[contracttype]
141#[derive(Clone, Debug)]
142pub struct StateChannelTransition {
143    pub prior_state_root: BytesN<32>,
144    pub next_state_root: BytesN<32>,
145    pub round: u64,
146    pub submitted_at: u64,
147}
148
149/// Open a new experimental ZK state channel.
150pub fn open_state_channel(
151    channel_id: BytesN<32>,
152    participants_root: BytesN<32>,
153    initial_state_root: BytesN<32>,
154    dispute_deadline: u64,
155) -> Result<ZkStateChannel, ZKError> {
156    if dispute_deadline == 0 {
157        return Err(ZKError::InvalidInput);
158    }
159
160    Ok(ZkStateChannel {
161        channel_id,
162        participants_root,
163        state_root: initial_state_root,
164        round: 0,
165        dispute_deadline,
166        closed: false,
167    })
168}
169
170/// Apply a verified transition to the state channel.
171pub fn apply_state_channel_transition(
172    channel: &ZkStateChannel,
173    transition: &StateChannelTransition,
174) -> Result<ZkStateChannel, ZKError> {
175    if channel.closed {
176        return Err(ZKError::ChannelClosed);
177    }
178
179    if transition.prior_state_root != channel.state_root {
180        return Err(ZKError::InvalidStateTransition);
181    }
182
183    let expected_round = channel
184        .round
185        .checked_add(1)
186        .ok_or(ZKError::InvalidStateTransition)?;
187    if transition.round != expected_round {
188        return Err(ZKError::InvalidStateTransition);
189    }
190
191    if transition.submitted_at > channel.dispute_deadline {
192        return Err(ZKError::DeadlineExpired);
193    }
194
195    let mut updated = channel.clone();
196    updated.state_root = transition.next_state_root.clone();
197    updated.round = transition.round;
198    Ok(updated)
199}
200
201/// Close a channel with the latest accepted state root.
202pub fn close_state_channel(
203    channel: &ZkStateChannel,
204    final_state_root: &BytesN<32>,
205    final_round: u64,
206    closed_at: u64,
207) -> Result<ZkStateChannel, ZKError> {
208    if channel.closed {
209        return Err(ZKError::ChannelClosed);
210    }
211
212    if final_round < channel.round {
213        return Err(ZKError::InvalidStateTransition);
214    }
215
216    let mut closed = channel.clone();
217    closed.state_root = final_state_root.clone();
218    closed.round = final_round;
219    closed.dispute_deadline = closed_at;
220    closed.closed = true;
221    Ok(closed)
222}
223
224/// Experimental circuit contract for channel transition proofs.
225pub struct StateChannelCircuit {
226    pub vk: VerificationKey,
227}
228
229impl GameCircuit for StateChannelCircuit {
230    fn verification_key(&self) -> &VerificationKey {
231        &self.vk
232    }
233}
234
235impl StateChannelCircuit {
236    pub fn new(vk: VerificationKey) -> Self {
237        Self { vk }
238    }
239
240    /// Verify a state transition for a channel.
241    ///
242    /// Public inputs:
243    /// `[channel_id, participants_root, prior_state_root, next_state_root, round, submitted_at]`.
244    pub fn verify_transition(
245        &self,
246        env: &Env,
247        proof: &Groth16Proof,
248        channel: &ZkStateChannel,
249        transition: &StateChannelTransition,
250    ) -> Result<bool, ZKError> {
251        if channel.closed {
252            return Err(ZKError::ChannelClosed);
253        }
254
255        let _ = apply_state_channel_transition(channel, transition)?;
256        let public_inputs = Vec::from([
257            bytes32_to_scalar(&channel.channel_id),
258            bytes32_to_scalar(&channel.participants_root),
259            bytes32_to_scalar(&transition.prior_state_root),
260            bytes32_to_scalar(&transition.next_state_root),
261            u64_to_scalar(env, transition.round),
262            u64_to_scalar(env, transition.submitted_at),
263        ]);
264
265        self.verify_with_inputs(env, proof, &public_inputs)
266    }
267}
268
269/// Experimental descriptor for a recursive proof batch.
270#[contracttype]
271#[derive(Clone, Debug)]
272pub struct RecursiveProofLayout {
273    pub initial_state_root: BytesN<32>,
274    pub final_state_root: BytesN<32>,
275    pub accumulator_root: BytesN<32>,
276    pub proof_count: u32,
277}
278
279impl RecursiveProofLayout {
280    /// Build a layout by folding per-step statement roots into a Merkle accumulator.
281    pub fn from_step_roots(
282        env: &Env,
283        initial_state_root: BytesN<32>,
284        final_state_root: BytesN<32>,
285        step_roots: &[BytesN<32>],
286    ) -> Result<Self, ZKError> {
287        let accumulator_root = compose_statement_roots(env, step_roots)?;
288        Ok(Self {
289            initial_state_root,
290            final_state_root,
291            accumulator_root,
292            proof_count: step_roots.len() as u32,
293        })
294    }
295}
296
297/// Fold per-step statement roots into a deterministic Merkle accumulator root.
298pub fn compose_statement_roots(
299    env: &Env,
300    step_roots: &[BytesN<32>],
301) -> Result<BytesN<32>, ZKError> {
302    if step_roots.is_empty() {
303        return Err(ZKError::InvalidProofComposition);
304    }
305
306    let mut leaves = Vec::with_capacity(step_roots.len());
307    for root in step_roots {
308        leaves.push(root.to_array());
309    }
310
311    let tree = MerkleTree::from_leaves(env, &leaves)?;
312    Ok(tree.root_bytes(env))
313}
314
315/// Experimental circuit contract for recursive proof aggregation.
316pub struct RecursiveProofCircuit {
317    pub vk: VerificationKey,
318    pub max_proof_count: u32,
319}
320
321impl GameCircuit for RecursiveProofCircuit {
322    fn verification_key(&self) -> &VerificationKey {
323        &self.vk
324    }
325}
326
327impl RecursiveProofCircuit {
328    pub fn new(vk: VerificationKey, max_proof_count: u32) -> Self {
329        Self {
330            vk,
331            max_proof_count,
332        }
333    }
334
335    /// Verify an aggregated recursive-proof layout.
336    ///
337    /// Public inputs:
338    /// `[initial_state_root, final_state_root, accumulator_root, proof_count]`.
339    pub fn verify_composition(
340        &self,
341        env: &Env,
342        proof: &Groth16Proof,
343        layout: &RecursiveProofLayout,
344    ) -> Result<bool, ZKError> {
345        if layout.proof_count == 0 || layout.proof_count > self.max_proof_count {
346            return Err(ZKError::InvalidProofComposition);
347        }
348
349        let public_inputs: [Scalar; 4] = [
350            bytes32_to_scalar(&layout.initial_state_root),
351            bytes32_to_scalar(&layout.final_state_root),
352            bytes32_to_scalar(&layout.accumulator_root),
353            u32_to_scalar(env, layout.proof_count),
354        ];
355
356        self.verify_with_inputs(env, proof, &public_inputs)
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use crate::zk::types::{G1Point, G2Point};
364
365    fn make_vk(env: &Env, ic_count: u32) -> VerificationKey {
366        let g1 = G1Point {
367            bytes: BytesN::from_array(env, &[0u8; 64]),
368        };
369        let g2 = G2Point {
370            bytes: BytesN::from_array(env, &[0u8; 128]),
371        };
372        let mut ic = soroban_sdk::Vec::new(env);
373        for _ in 0..ic_count {
374            ic.push_back(g1.clone());
375        }
376
377        VerificationKey {
378            alpha: g1.clone(),
379            beta: g2.clone(),
380            gamma: g2.clone(),
381            delta: g2,
382            ic,
383        }
384    }
385
386    fn make_proof(env: &Env) -> Groth16Proof {
387        let g1 = G1Point {
388            bytes: BytesN::from_array(env, &[0u8; 64]),
389        };
390        let g2 = G2Point {
391            bytes: BytesN::from_array(env, &[0u8; 128]),
392        };
393
394        Groth16Proof {
395            a: g1.clone(),
396            b: g2,
397            c: g1,
398        }
399    }
400
401    #[test]
402    fn test_apply_fog_of_war_transition_rejects_hidden_tile() {
403        let env = Env::default();
404        let snapshot = FogOfWarSnapshot {
405            map_root: BytesN::from_array(&env, &[1u8; 32]),
406            explored_root: BytesN::from_array(&env, &[2u8; 32]),
407            origin_x: 0,
408            origin_y: 0,
409            visibility_radius: 2,
410        };
411        let transition = FogOfWarTransition {
412            prior_explored_root: snapshot.explored_root.clone(),
413            next_explored_root: BytesN::from_array(&env, &[3u8; 32]),
414            tile_x: 3,
415            tile_y: 0,
416        };
417
418        let result = apply_fog_of_war_transition(&snapshot, &transition);
419        assert!(matches!(result, Err(ZKError::InvalidVisibility)));
420    }
421
422    #[test]
423    fn test_fog_of_war_circuit_rejects_snapshot_above_max_radius() {
424        let env = Env::default();
425        let circuit = FogOfWarCircuit::new(make_vk(&env, 9), 3);
426        let snapshot = FogOfWarSnapshot {
427            map_root: BytesN::from_array(&env, &[1u8; 32]),
428            explored_root: BytesN::from_array(&env, &[2u8; 32]),
429            origin_x: 0,
430            origin_y: 0,
431            visibility_radius: 4,
432        };
433        let transition = FogOfWarTransition {
434            prior_explored_root: snapshot.explored_root.clone(),
435            next_explored_root: BytesN::from_array(&env, &[3u8; 32]),
436            tile_x: 1,
437            tile_y: 1,
438        };
439
440        let result = circuit.verify_exploration(&env, &make_proof(&env), &snapshot, &transition);
441        assert_eq!(result, Err(ZKError::InvalidVisibility));
442    }
443
444    #[test]
445    fn test_open_state_channel_requires_deadline() {
446        let env = Env::default();
447        let result = open_state_channel(
448            BytesN::from_array(&env, &[1u8; 32]),
449            BytesN::from_array(&env, &[2u8; 32]),
450            BytesN::from_array(&env, &[3u8; 32]),
451            0,
452        );
453
454        assert!(matches!(result, Err(ZKError::InvalidInput)));
455    }
456
457    #[test]
458    fn test_apply_state_channel_transition_rejects_wrong_round() {
459        let env = Env::default();
460        let channel = open_state_channel(
461            BytesN::from_array(&env, &[1u8; 32]),
462            BytesN::from_array(&env, &[2u8; 32]),
463            BytesN::from_array(&env, &[3u8; 32]),
464            10,
465        )
466        .unwrap();
467        let transition = StateChannelTransition {
468            prior_state_root: channel.state_root.clone(),
469            next_state_root: BytesN::from_array(&env, &[4u8; 32]),
470            round: 2,
471            submitted_at: 5,
472        };
473
474        let result = apply_state_channel_transition(&channel, &transition);
475        assert!(matches!(result, Err(ZKError::InvalidStateTransition)));
476    }
477
478    #[test]
479    fn test_state_channel_circuit_rejects_closed_channel() {
480        let env = Env::default();
481        let channel = ZkStateChannel {
482            channel_id: BytesN::from_array(&env, &[1u8; 32]),
483            participants_root: BytesN::from_array(&env, &[2u8; 32]),
484            state_root: BytesN::from_array(&env, &[3u8; 32]),
485            round: 1,
486            dispute_deadline: 5,
487            closed: true,
488        };
489        let transition = StateChannelTransition {
490            prior_state_root: channel.state_root.clone(),
491            next_state_root: BytesN::from_array(&env, &[4u8; 32]),
492            round: 2,
493            submitted_at: 5,
494        };
495        let circuit = StateChannelCircuit::new(make_vk(&env, 7));
496
497        let result = circuit.verify_transition(&env, &make_proof(&env), &channel, &transition);
498        assert_eq!(result, Err(ZKError::ChannelClosed));
499    }
500
501    #[test]
502    fn test_compose_statement_roots_is_deterministic() {
503        let env = Env::default();
504        let steps = [
505            BytesN::from_array(&env, &[1u8; 32]),
506            BytesN::from_array(&env, &[2u8; 32]),
507            BytesN::from_array(&env, &[3u8; 32]),
508        ];
509
510        let root_a = compose_statement_roots(&env, &steps).unwrap();
511        let root_b = compose_statement_roots(&env, &steps).unwrap();
512        assert_eq!(root_a, root_b);
513    }
514
515    #[test]
516    fn test_recursive_proof_layout_requires_non_empty_steps() {
517        let env = Env::default();
518        let result = RecursiveProofLayout::from_step_roots(
519            &env,
520            BytesN::from_array(&env, &[1u8; 32]),
521            BytesN::from_array(&env, &[2u8; 32]),
522            &[],
523        );
524
525        assert!(matches!(result, Err(ZKError::InvalidProofComposition)));
526    }
527
528    #[test]
529    fn test_recursive_proof_circuit_rejects_out_of_bounds_proof_count() {
530        let env = Env::default();
531        let circuit = RecursiveProofCircuit::new(make_vk(&env, 5), 2);
532        let layout = RecursiveProofLayout {
533            initial_state_root: BytesN::from_array(&env, &[1u8; 32]),
534            final_state_root: BytesN::from_array(&env, &[2u8; 32]),
535            accumulator_root: BytesN::from_array(&env, &[3u8; 32]),
536            proof_count: 3,
537        };
538
539        let result = circuit.verify_composition(&env, &make_proof(&env), &layout);
540        assert_eq!(result, Err(ZKError::InvalidProofComposition));
541    }
542}