1use alloc::vec::Vec;
14use soroban_sdk::{contracttype, BytesN, Env};
15
16use super::error::ZKError;
17use super::merkle::tree::MerkleTree;
18use super::traits::{bytes32_to_scalar, i32_to_scalar, u32_to_scalar, u64_to_scalar, GameCircuit};
19use super::types::{Groth16Proof, Scalar, VerificationKey};
20
21#[contracttype]
23#[derive(Clone, Debug)]
24pub struct FogOfWarSnapshot {
25 pub map_root: BytesN<32>,
27 pub explored_root: BytesN<32>,
29 pub origin_x: i32,
31 pub origin_y: i32,
32 pub visibility_radius: u32,
34}
35
36impl FogOfWarSnapshot {
37 pub fn can_reveal(&self, tile_x: i32, tile_y: i32) -> bool {
39 let dx = i64::from(tile_x) - i64::from(self.origin_x);
40 let dy = i64::from(tile_y) - i64::from(self.origin_y);
41 let radius = i64::from(self.visibility_radius);
42
43 dx * dx + dy * dy <= radius * radius
44 }
45}
46
47#[contracttype]
49#[derive(Clone, Debug)]
50pub struct FogOfWarTransition {
51 pub prior_explored_root: BytesN<32>,
52 pub next_explored_root: BytesN<32>,
53 pub tile_x: i32,
54 pub tile_y: i32,
55}
56
57pub fn apply_fog_of_war_transition(
59 snapshot: &FogOfWarSnapshot,
60 transition: &FogOfWarTransition,
61) -> Result<FogOfWarSnapshot, ZKError> {
62 if snapshot.explored_root != transition.prior_explored_root {
63 return Err(ZKError::InvalidStateTransition);
64 }
65
66 if !snapshot.can_reveal(transition.tile_x, transition.tile_y) {
67 return Err(ZKError::InvalidVisibility);
68 }
69
70 let mut updated = snapshot.clone();
71 updated.explored_root = transition.next_explored_root.clone();
72 Ok(updated)
73}
74
75pub struct FogOfWarCircuit {
77 pub vk: VerificationKey,
78 pub max_visibility_radius: u32,
79}
80
81impl GameCircuit for FogOfWarCircuit {
82 fn verification_key(&self) -> &VerificationKey {
83 &self.vk
84 }
85}
86
87impl FogOfWarCircuit {
88 pub fn new(vk: VerificationKey, max_visibility_radius: u32) -> Self {
89 Self {
90 vk,
91 max_visibility_radius,
92 }
93 }
94
95 pub fn verify_exploration(
100 &self,
101 env: &Env,
102 proof: &Groth16Proof,
103 snapshot: &FogOfWarSnapshot,
104 transition: &FogOfWarTransition,
105 ) -> Result<bool, ZKError> {
106 if snapshot.visibility_radius > self.max_visibility_radius {
107 return Err(ZKError::InvalidVisibility);
108 }
109
110 let _ = apply_fog_of_war_transition(snapshot, transition)?;
111
112 let public_inputs = Vec::from([
113 bytes32_to_scalar(&snapshot.map_root),
114 bytes32_to_scalar(&transition.prior_explored_root),
115 bytes32_to_scalar(&transition.next_explored_root),
116 i32_to_scalar(env, snapshot.origin_x),
117 i32_to_scalar(env, snapshot.origin_y),
118 i32_to_scalar(env, transition.tile_x),
119 i32_to_scalar(env, transition.tile_y),
120 u32_to_scalar(env, snapshot.visibility_radius),
121 ]);
122
123 self.verify_with_inputs(env, proof, &public_inputs)
124 }
125}
126
127#[contracttype]
129#[derive(Clone, Debug)]
130pub struct ZkStateChannel {
131 pub channel_id: BytesN<32>,
132 pub participants_root: BytesN<32>,
133 pub state_root: BytesN<32>,
134 pub round: u64,
135 pub dispute_deadline: u64,
136 pub closed: bool,
137}
138
139#[contracttype]
141#[derive(Clone, Debug)]
142pub struct StateChannelTransition {
143 pub prior_state_root: BytesN<32>,
144 pub next_state_root: BytesN<32>,
145 pub round: u64,
146 pub submitted_at: u64,
147}
148
149pub fn open_state_channel(
151 channel_id: BytesN<32>,
152 participants_root: BytesN<32>,
153 initial_state_root: BytesN<32>,
154 dispute_deadline: u64,
155) -> Result<ZkStateChannel, ZKError> {
156 if dispute_deadline == 0 {
157 return Err(ZKError::InvalidInput);
158 }
159
160 Ok(ZkStateChannel {
161 channel_id,
162 participants_root,
163 state_root: initial_state_root,
164 round: 0,
165 dispute_deadline,
166 closed: false,
167 })
168}
169
170pub fn apply_state_channel_transition(
172 channel: &ZkStateChannel,
173 transition: &StateChannelTransition,
174) -> Result<ZkStateChannel, ZKError> {
175 if channel.closed {
176 return Err(ZKError::ChannelClosed);
177 }
178
179 if transition.prior_state_root != channel.state_root {
180 return Err(ZKError::InvalidStateTransition);
181 }
182
183 let expected_round = channel
184 .round
185 .checked_add(1)
186 .ok_or(ZKError::InvalidStateTransition)?;
187 if transition.round != expected_round {
188 return Err(ZKError::InvalidStateTransition);
189 }
190
191 if transition.submitted_at > channel.dispute_deadline {
192 return Err(ZKError::DeadlineExpired);
193 }
194
195 let mut updated = channel.clone();
196 updated.state_root = transition.next_state_root.clone();
197 updated.round = transition.round;
198 Ok(updated)
199}
200
201pub fn close_state_channel(
203 channel: &ZkStateChannel,
204 final_state_root: &BytesN<32>,
205 final_round: u64,
206 closed_at: u64,
207) -> Result<ZkStateChannel, ZKError> {
208 if channel.closed {
209 return Err(ZKError::ChannelClosed);
210 }
211
212 if final_round < channel.round {
213 return Err(ZKError::InvalidStateTransition);
214 }
215
216 let mut closed = channel.clone();
217 closed.state_root = final_state_root.clone();
218 closed.round = final_round;
219 closed.dispute_deadline = closed_at;
220 closed.closed = true;
221 Ok(closed)
222}
223
224pub struct StateChannelCircuit {
226 pub vk: VerificationKey,
227}
228
229impl GameCircuit for StateChannelCircuit {
230 fn verification_key(&self) -> &VerificationKey {
231 &self.vk
232 }
233}
234
235impl StateChannelCircuit {
236 pub fn new(vk: VerificationKey) -> Self {
237 Self { vk }
238 }
239
240 pub fn verify_transition(
245 &self,
246 env: &Env,
247 proof: &Groth16Proof,
248 channel: &ZkStateChannel,
249 transition: &StateChannelTransition,
250 ) -> Result<bool, ZKError> {
251 if channel.closed {
252 return Err(ZKError::ChannelClosed);
253 }
254
255 let _ = apply_state_channel_transition(channel, transition)?;
256 let public_inputs = Vec::from([
257 bytes32_to_scalar(&channel.channel_id),
258 bytes32_to_scalar(&channel.participants_root),
259 bytes32_to_scalar(&transition.prior_state_root),
260 bytes32_to_scalar(&transition.next_state_root),
261 u64_to_scalar(env, transition.round),
262 u64_to_scalar(env, transition.submitted_at),
263 ]);
264
265 self.verify_with_inputs(env, proof, &public_inputs)
266 }
267}
268
269#[contracttype]
271#[derive(Clone, Debug)]
272pub struct RecursiveProofLayout {
273 pub initial_state_root: BytesN<32>,
274 pub final_state_root: BytesN<32>,
275 pub accumulator_root: BytesN<32>,
276 pub proof_count: u32,
277}
278
279impl RecursiveProofLayout {
280 pub fn from_step_roots(
282 env: &Env,
283 initial_state_root: BytesN<32>,
284 final_state_root: BytesN<32>,
285 step_roots: &[BytesN<32>],
286 ) -> Result<Self, ZKError> {
287 let accumulator_root = compose_statement_roots(env, step_roots)?;
288 Ok(Self {
289 initial_state_root,
290 final_state_root,
291 accumulator_root,
292 proof_count: step_roots.len() as u32,
293 })
294 }
295}
296
297pub fn compose_statement_roots(
299 env: &Env,
300 step_roots: &[BytesN<32>],
301) -> Result<BytesN<32>, ZKError> {
302 if step_roots.is_empty() {
303 return Err(ZKError::InvalidProofComposition);
304 }
305
306 let mut leaves = Vec::with_capacity(step_roots.len());
307 for root in step_roots {
308 leaves.push(root.to_array());
309 }
310
311 let tree = MerkleTree::from_leaves(env, &leaves)?;
312 Ok(tree.root_bytes(env))
313}
314
315pub struct RecursiveProofCircuit {
317 pub vk: VerificationKey,
318 pub max_proof_count: u32,
319}
320
321impl GameCircuit for RecursiveProofCircuit {
322 fn verification_key(&self) -> &VerificationKey {
323 &self.vk
324 }
325}
326
327impl RecursiveProofCircuit {
328 pub fn new(vk: VerificationKey, max_proof_count: u32) -> Self {
329 Self {
330 vk,
331 max_proof_count,
332 }
333 }
334
335 pub fn verify_composition(
340 &self,
341 env: &Env,
342 proof: &Groth16Proof,
343 layout: &RecursiveProofLayout,
344 ) -> Result<bool, ZKError> {
345 if layout.proof_count == 0 || layout.proof_count > self.max_proof_count {
346 return Err(ZKError::InvalidProofComposition);
347 }
348
349 let public_inputs: [Scalar; 4] = [
350 bytes32_to_scalar(&layout.initial_state_root),
351 bytes32_to_scalar(&layout.final_state_root),
352 bytes32_to_scalar(&layout.accumulator_root),
353 u32_to_scalar(env, layout.proof_count),
354 ];
355
356 self.verify_with_inputs(env, proof, &public_inputs)
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::zk::types::{G1Point, G2Point};
364
365 fn make_vk(env: &Env, ic_count: u32) -> VerificationKey {
366 let g1 = G1Point {
367 bytes: BytesN::from_array(env, &[0u8; 64]),
368 };
369 let g2 = G2Point {
370 bytes: BytesN::from_array(env, &[0u8; 128]),
371 };
372 let mut ic = soroban_sdk::Vec::new(env);
373 for _ in 0..ic_count {
374 ic.push_back(g1.clone());
375 }
376
377 VerificationKey {
378 alpha: g1.clone(),
379 beta: g2.clone(),
380 gamma: g2.clone(),
381 delta: g2,
382 ic,
383 }
384 }
385
386 fn make_proof(env: &Env) -> Groth16Proof {
387 let g1 = G1Point {
388 bytes: BytesN::from_array(env, &[0u8; 64]),
389 };
390 let g2 = G2Point {
391 bytes: BytesN::from_array(env, &[0u8; 128]),
392 };
393
394 Groth16Proof {
395 a: g1.clone(),
396 b: g2,
397 c: g1,
398 }
399 }
400
401 #[test]
402 fn test_apply_fog_of_war_transition_rejects_hidden_tile() {
403 let env = Env::default();
404 let snapshot = FogOfWarSnapshot {
405 map_root: BytesN::from_array(&env, &[1u8; 32]),
406 explored_root: BytesN::from_array(&env, &[2u8; 32]),
407 origin_x: 0,
408 origin_y: 0,
409 visibility_radius: 2,
410 };
411 let transition = FogOfWarTransition {
412 prior_explored_root: snapshot.explored_root.clone(),
413 next_explored_root: BytesN::from_array(&env, &[3u8; 32]),
414 tile_x: 3,
415 tile_y: 0,
416 };
417
418 let result = apply_fog_of_war_transition(&snapshot, &transition);
419 assert!(matches!(result, Err(ZKError::InvalidVisibility)));
420 }
421
422 #[test]
423 fn test_fog_of_war_circuit_rejects_snapshot_above_max_radius() {
424 let env = Env::default();
425 let circuit = FogOfWarCircuit::new(make_vk(&env, 9), 3);
426 let snapshot = FogOfWarSnapshot {
427 map_root: BytesN::from_array(&env, &[1u8; 32]),
428 explored_root: BytesN::from_array(&env, &[2u8; 32]),
429 origin_x: 0,
430 origin_y: 0,
431 visibility_radius: 4,
432 };
433 let transition = FogOfWarTransition {
434 prior_explored_root: snapshot.explored_root.clone(),
435 next_explored_root: BytesN::from_array(&env, &[3u8; 32]),
436 tile_x: 1,
437 tile_y: 1,
438 };
439
440 let result = circuit.verify_exploration(&env, &make_proof(&env), &snapshot, &transition);
441 assert_eq!(result, Err(ZKError::InvalidVisibility));
442 }
443
444 #[test]
445 fn test_open_state_channel_requires_deadline() {
446 let env = Env::default();
447 let result = open_state_channel(
448 BytesN::from_array(&env, &[1u8; 32]),
449 BytesN::from_array(&env, &[2u8; 32]),
450 BytesN::from_array(&env, &[3u8; 32]),
451 0,
452 );
453
454 assert!(matches!(result, Err(ZKError::InvalidInput)));
455 }
456
457 #[test]
458 fn test_apply_state_channel_transition_rejects_wrong_round() {
459 let env = Env::default();
460 let channel = open_state_channel(
461 BytesN::from_array(&env, &[1u8; 32]),
462 BytesN::from_array(&env, &[2u8; 32]),
463 BytesN::from_array(&env, &[3u8; 32]),
464 10,
465 )
466 .unwrap();
467 let transition = StateChannelTransition {
468 prior_state_root: channel.state_root.clone(),
469 next_state_root: BytesN::from_array(&env, &[4u8; 32]),
470 round: 2,
471 submitted_at: 5,
472 };
473
474 let result = apply_state_channel_transition(&channel, &transition);
475 assert!(matches!(result, Err(ZKError::InvalidStateTransition)));
476 }
477
478 #[test]
479 fn test_state_channel_circuit_rejects_closed_channel() {
480 let env = Env::default();
481 let channel = ZkStateChannel {
482 channel_id: BytesN::from_array(&env, &[1u8; 32]),
483 participants_root: BytesN::from_array(&env, &[2u8; 32]),
484 state_root: BytesN::from_array(&env, &[3u8; 32]),
485 round: 1,
486 dispute_deadline: 5,
487 closed: true,
488 };
489 let transition = StateChannelTransition {
490 prior_state_root: channel.state_root.clone(),
491 next_state_root: BytesN::from_array(&env, &[4u8; 32]),
492 round: 2,
493 submitted_at: 5,
494 };
495 let circuit = StateChannelCircuit::new(make_vk(&env, 7));
496
497 let result = circuit.verify_transition(&env, &make_proof(&env), &channel, &transition);
498 assert_eq!(result, Err(ZKError::ChannelClosed));
499 }
500
501 #[test]
502 fn test_compose_statement_roots_is_deterministic() {
503 let env = Env::default();
504 let steps = [
505 BytesN::from_array(&env, &[1u8; 32]),
506 BytesN::from_array(&env, &[2u8; 32]),
507 BytesN::from_array(&env, &[3u8; 32]),
508 ];
509
510 let root_a = compose_statement_roots(&env, &steps).unwrap();
511 let root_b = compose_statement_roots(&env, &steps).unwrap();
512 assert_eq!(root_a, root_b);
513 }
514
515 #[test]
516 fn test_recursive_proof_layout_requires_non_empty_steps() {
517 let env = Env::default();
518 let result = RecursiveProofLayout::from_step_roots(
519 &env,
520 BytesN::from_array(&env, &[1u8; 32]),
521 BytesN::from_array(&env, &[2u8; 32]),
522 &[],
523 );
524
525 assert!(matches!(result, Err(ZKError::InvalidProofComposition)));
526 }
527
528 #[test]
529 fn test_recursive_proof_circuit_rejects_out_of_bounds_proof_count() {
530 let env = Env::default();
531 let circuit = RecursiveProofCircuit::new(make_vk(&env, 5), 2);
532 let layout = RecursiveProofLayout {
533 initial_state_root: BytesN::from_array(&env, &[1u8; 32]),
534 final_state_root: BytesN::from_array(&env, &[2u8; 32]),
535 accumulator_root: BytesN::from_array(&env, &[3u8; 32]),
536 proof_count: 3,
537 };
538
539 let result = circuit.verify_composition(&env, &make_proof(&env), &layout);
540 assert_eq!(result, Err(ZKError::InvalidProofComposition));
541 }
542}