1use crate::types::TaskId;
4use alloy_primitives::U256;
5use blueprint_crypto_bn254::{ArkBlsBn254Public, ArkBlsBn254Signature};
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ThresholdType {
14 Count(u32),
16 StakeWeighted(u32),
18}
19
20impl Default for ThresholdType {
21 fn default() -> Self {
22 ThresholdType::Count(1)
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct OperatorInfo {
29 pub stake: u64,
31 pub registered: bool,
33}
34
35impl Default for OperatorInfo {
36 fn default() -> Self {
37 Self {
38 stake: 1, registered: true,
40 }
41 }
42}
43
44#[derive(Debug)]
46pub struct TaskState {
47 pub service_id: u64,
49 pub call_id: u64,
51 pub output: Vec<u8>,
53 pub operator_count: u32,
55 pub threshold_type: ThresholdType,
57 pub signer_bitmap: U256,
59 pub signatures: HashMap<u32, ArkBlsBn254Signature>,
61 pub public_keys: HashMap<u32, ArkBlsBn254Public>,
63 pub operator_stakes: HashMap<u32, u64>,
65 pub total_stake: u64,
67 pub submitted: bool,
69 pub created_at: Instant,
71 pub expires_at: Option<Instant>,
73}
74
75impl TaskState {
76 pub fn new(
78 service_id: u64,
79 call_id: u64,
80 output: Vec<u8>,
81 operator_count: u32,
82 threshold: u32,
83 ) -> Self {
84 Self::with_config(
85 service_id,
86 call_id,
87 output,
88 operator_count,
89 ThresholdType::Count(threshold),
90 None,
91 None,
92 )
93 }
94
95 pub fn with_config(
97 service_id: u64,
98 call_id: u64,
99 output: Vec<u8>,
100 operator_count: u32,
101 threshold_type: ThresholdType,
102 operator_stakes: Option<HashMap<u32, u64>>,
103 ttl: Option<Duration>,
104 ) -> Self {
105 let now = Instant::now();
106 let expires_at = ttl.map(|d| now + d);
107
108 let (stakes, total_stake) = if let Some(stakes) = operator_stakes {
110 let total: u64 = stakes.values().sum();
111 (stakes, total)
112 } else {
113 let stakes: HashMap<u32, u64> = (0..operator_count).map(|i| (i, 1u64)).collect();
115 let total = operator_count as u64;
116 (stakes, total)
117 };
118
119 Self {
120 service_id,
121 call_id,
122 output,
123 operator_count,
124 threshold_type,
125 signer_bitmap: U256::ZERO,
126 signatures: HashMap::new(),
127 public_keys: HashMap::new(),
128 operator_stakes: stakes,
129 total_stake,
130 submitted: false,
131 created_at: now,
132 expires_at,
133 }
134 }
135
136 pub fn is_expired(&self) -> bool {
138 self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
139 }
140
141 pub fn time_remaining(&self) -> Option<Duration> {
143 self.expires_at.and_then(|t| {
144 let now = Instant::now();
145 if now < t {
146 Some(t - now)
147 } else {
148 None
149 }
150 })
151 }
152
153 pub fn add_signature(
155 &mut self,
156 operator_index: u32,
157 signature: ArkBlsBn254Signature,
158 public_key: ArkBlsBn254Public,
159 ) -> Result<(), &'static str> {
160 if operator_index >= self.operator_count {
161 return Err("Operator index out of bounds");
162 }
163
164 if self.has_signed(operator_index) {
165 return Err("Operator already signed");
166 }
167
168 if self.is_expired() {
169 return Err("Task has expired");
170 }
171
172 self.signer_bitmap |= U256::from(1u64) << operator_index as usize;
174
175 self.signatures.insert(operator_index, signature);
177 self.public_keys.insert(operator_index, public_key);
178
179 Ok(())
180 }
181
182 pub fn has_signed(&self, operator_index: u32) -> bool {
184 (self.signer_bitmap >> operator_index as usize) & U256::from(1u64) == U256::from(1u64)
185 }
186
187 pub fn signature_count(&self) -> usize {
189 self.signatures.len()
190 }
191
192 pub fn signed_stake(&self) -> u64 {
194 self.signatures
195 .keys()
196 .map(|idx| self.operator_stakes.get(idx).copied().unwrap_or(0))
197 .sum()
198 }
199
200 pub fn signed_stake_bps(&self) -> u32 {
202 if self.total_stake == 0 {
203 return 0;
204 }
205 ((self.signed_stake() * 10000) / self.total_stake) as u32
206 }
207
208 pub fn threshold_met(&self) -> bool {
210 match self.threshold_type {
211 ThresholdType::Count(n) => self.signature_count() >= n as usize,
212 ThresholdType::StakeWeighted(bps) => self.signed_stake_bps() >= bps,
213 }
214 }
215
216 pub fn threshold_value(&self) -> usize {
218 match self.threshold_type {
219 ThresholdType::Count(n) => n as usize,
220 ThresholdType::StakeWeighted(bps) => bps as usize,
221 }
222 }
223
224 pub fn get_non_signers(&self) -> Vec<u32> {
226 (0..self.operator_count)
227 .filter(|&i| !self.has_signed(i))
228 .collect()
229 }
230
231 pub fn get_signers(&self) -> Vec<u32> {
233 let mut signers: Vec<_> = self.signatures.keys().copied().collect();
234 signers.sort_unstable();
235 signers
236 }
237
238 pub fn get_signatures_for_aggregation(
240 &self,
241 ) -> (Vec<ArkBlsBn254Signature>, Vec<ArkBlsBn254Public>) {
242 let mut sigs = Vec::with_capacity(self.signatures.len());
243 let mut pks = Vec::with_capacity(self.public_keys.len());
244
245 let indices = self.get_signers();
247
248 for idx in indices {
249 if let (Some(sig), Some(pk)) = (self.signatures.get(&idx), self.public_keys.get(&idx)) {
250 sigs.push(sig.clone());
251 pks.push(pk.clone());
252 }
253 }
254
255 (sigs, pks)
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct TaskConfig {
262 pub threshold_type: ThresholdType,
264 pub operator_stakes: Option<HashMap<u32, u64>>,
266 pub ttl: Option<Duration>,
268}
269
270impl Default for TaskConfig {
271 fn default() -> Self {
272 Self {
273 threshold_type: ThresholdType::Count(1),
274 operator_stakes: None,
275 ttl: None,
276 }
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct AggregationState {
283 tasks: Arc<RwLock<HashMap<TaskId, TaskState>>>,
285}
286
287impl Default for AggregationState {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293impl AggregationState {
294 pub fn new() -> Self {
296 Self {
297 tasks: Arc::new(RwLock::new(HashMap::new())),
298 }
299 }
300
301 pub fn init_task(
303 &self,
304 service_id: u64,
305 call_id: u64,
306 output: Vec<u8>,
307 operator_count: u32,
308 threshold: u32,
309 ) -> Result<(), &'static str> {
310 self.init_task_with_config(
311 service_id,
312 call_id,
313 output,
314 operator_count,
315 TaskConfig {
316 threshold_type: ThresholdType::Count(threshold),
317 ..Default::default()
318 },
319 )
320 }
321
322 pub fn init_task_with_config(
324 &self,
325 service_id: u64,
326 call_id: u64,
327 output: Vec<u8>,
328 operator_count: u32,
329 config: TaskConfig,
330 ) -> Result<(), &'static str> {
331 let task_id = TaskId::new(service_id, call_id);
332 let mut tasks = self.tasks.write();
333
334 if tasks.contains_key(&task_id) {
335 return Err("Task already exists");
336 }
337
338 let state = TaskState::with_config(
339 service_id,
340 call_id,
341 output,
342 operator_count,
343 config.threshold_type,
344 config.operator_stakes,
345 config.ttl,
346 );
347 tasks.insert(task_id, state);
348 Ok(())
349 }
350
351 pub fn get_task_output(&self, service_id: u64, call_id: u64) -> Option<Vec<u8>> {
353 let task_id = TaskId::new(service_id, call_id);
354 let tasks = self.tasks.read();
355 tasks.get(&task_id).map(|t| t.output.clone())
356 }
357
358 pub fn submit_signature(
360 &self,
361 service_id: u64,
362 call_id: u64,
363 operator_index: u32,
364 signature: ArkBlsBn254Signature,
365 public_key: ArkBlsBn254Public,
366 ) -> Result<(usize, bool), &'static str> {
367 let task_id = TaskId::new(service_id, call_id);
368 let mut tasks = self.tasks.write();
369
370 let task = tasks.get_mut(&task_id).ok_or("Task not found")?;
371
372 if task.submitted {
373 return Err("Task already submitted to chain");
374 }
375
376 if task.is_expired() {
377 return Err("Task has expired");
378 }
379
380 task.add_signature(operator_index, signature, public_key)?;
381
382 Ok((task.signature_count(), task.threshold_met()))
383 }
384
385 pub fn get_status(&self, service_id: u64, call_id: u64) -> Option<TaskStatus> {
387 let task_id = TaskId::new(service_id, call_id);
388 let tasks = self.tasks.read();
389
390 tasks.get(&task_id).map(|task| TaskStatus {
391 signatures_collected: task.signature_count(),
392 threshold_required: task.threshold_value(),
393 threshold_type: task.threshold_type,
394 threshold_met: task.threshold_met(),
395 signer_bitmap: task.signer_bitmap,
396 signed_stake_bps: task.signed_stake_bps(),
397 submitted: task.submitted,
398 is_expired: task.is_expired(),
399 time_remaining_secs: task.time_remaining().map(|d| d.as_secs()),
400 })
401 }
402
403 pub fn get_for_aggregation(&self, service_id: u64, call_id: u64) -> Option<TaskForAggregation> {
405 let task_id = TaskId::new(service_id, call_id);
406 let tasks = self.tasks.read();
407
408 let task = tasks.get(&task_id)?;
409
410 if !task.threshold_met() || task.submitted || task.is_expired() {
411 return None;
412 }
413
414 let (signatures, public_keys) = task.get_signatures_for_aggregation();
415
416 Some(TaskForAggregation {
417 service_id: task.service_id,
418 call_id: task.call_id,
419 output: task.output.clone(),
420 signer_bitmap: task.signer_bitmap,
421 non_signer_indices: task.get_non_signers(),
422 signatures,
423 public_keys,
424 })
425 }
426
427 pub fn mark_submitted(&self, service_id: u64, call_id: u64) -> Result<(), &'static str> {
429 let task_id = TaskId::new(service_id, call_id);
430 let mut tasks = self.tasks.write();
431
432 let task = tasks.get_mut(&task_id).ok_or("Task not found")?;
433 task.submitted = true;
434 Ok(())
435 }
436
437 pub fn remove_task(&self, service_id: u64, call_id: u64) -> bool {
439 let task_id = TaskId::new(service_id, call_id);
440 self.tasks.write().remove(&task_id).is_some()
441 }
442
443 pub fn cleanup_expired(&self) -> usize {
446 let mut tasks = self.tasks.write();
447 let before = tasks.len();
448 tasks.retain(|_, task| !task.is_expired());
449 before - tasks.len()
450 }
451
452 pub fn cleanup_submitted(&self) -> usize {
455 let mut tasks = self.tasks.write();
456 let before = tasks.len();
457 tasks.retain(|_, task| !task.submitted);
458 before - tasks.len()
459 }
460
461 pub fn cleanup(&self) -> usize {
464 let mut tasks = self.tasks.write();
465 let before = tasks.len();
466 tasks.retain(|_, task| !task.is_expired() && !task.submitted);
467 before - tasks.len()
468 }
469
470 pub fn task_count(&self) -> usize {
472 self.tasks.read().len()
473 }
474
475 pub fn task_counts(&self) -> TaskCounts {
477 let tasks = self.tasks.read();
478 let mut counts = TaskCounts::default();
479
480 for task in tasks.values() {
481 counts.total += 1;
482 if task.is_expired() {
483 counts.expired += 1;
484 } else if task.submitted {
485 counts.submitted += 1;
486 } else if task.threshold_met() {
487 counts.ready += 1;
488 } else {
489 counts.pending += 1;
490 }
491 }
492
493 counts
494 }
495}
496
497#[derive(Debug, Clone)]
499pub struct TaskStatus {
500 pub signatures_collected: usize,
501 pub threshold_required: usize,
502 pub threshold_type: ThresholdType,
503 pub threshold_met: bool,
504 pub signer_bitmap: U256,
505 pub signed_stake_bps: u32,
506 pub submitted: bool,
507 pub is_expired: bool,
508 pub time_remaining_secs: Option<u64>,
509}
510
511#[derive(Debug)]
513pub struct TaskForAggregation {
514 pub service_id: u64,
515 pub call_id: u64,
516 pub output: Vec<u8>,
517 pub signer_bitmap: U256,
518 pub non_signer_indices: Vec<u32>,
519 pub signatures: Vec<ArkBlsBn254Signature>,
520 pub public_keys: Vec<ArkBlsBn254Public>,
521}
522
523#[derive(Debug, Clone, Default)]
525pub struct TaskCounts {
526 pub total: usize,
527 pub pending: usize,
528 pub ready: usize,
529 pub submitted: usize,
530 pub expired: usize,
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use ark_bn254::{G1Affine, G2Affine};
537 use ark_ec::AffineRepr;
538
539 fn dummy_signature() -> ArkBlsBn254Signature {
540 ArkBlsBn254Signature(G1Affine::generator())
541 }
542
543 fn dummy_public_key() -> ArkBlsBn254Public {
544 ArkBlsBn254Public(G2Affine::generator())
545 }
546
547 #[test]
548 fn test_task_state_new() {
549 let state = TaskState::new(1, 100, vec![1, 2, 3], 5, 3);
550 assert_eq!(state.service_id, 1);
551 assert_eq!(state.call_id, 100);
552 assert_eq!(state.output, vec![1, 2, 3]);
553 assert_eq!(state.operator_count, 5);
554 assert_eq!(state.threshold_type, ThresholdType::Count(3));
555 assert_eq!(state.signer_bitmap, U256::ZERO);
556 assert!(state.signatures.is_empty());
557 assert!(state.public_keys.is_empty());
558 assert!(!state.submitted);
559 assert!(!state.is_expired());
560 }
561
562 #[test]
563 fn test_task_state_add_signature() {
564 let mut state = TaskState::new(1, 100, vec![], 5, 3);
565
566 assert!(state
568 .add_signature(0, dummy_signature(), dummy_public_key())
569 .is_ok());
570 assert!(state.has_signed(0));
571 assert!(!state.has_signed(1));
572 assert_eq!(state.signature_count(), 1);
573
574 assert!(state
576 .add_signature(2, dummy_signature(), dummy_public_key())
577 .is_ok());
578 assert!(state.has_signed(2));
579 assert_eq!(state.signature_count(), 2);
580 }
581
582 #[test]
583 fn test_task_state_duplicate_signature() {
584 let mut state = TaskState::new(1, 100, vec![], 5, 3);
585
586 assert!(state
587 .add_signature(0, dummy_signature(), dummy_public_key())
588 .is_ok());
589 let result = state.add_signature(0, dummy_signature(), dummy_public_key());
590 assert!(result.is_err());
591 assert_eq!(result.unwrap_err(), "Operator already signed");
592 }
593
594 #[test]
595 fn test_task_state_out_of_bounds() {
596 let mut state = TaskState::new(1, 100, vec![], 5, 3);
597
598 let result = state.add_signature(5, dummy_signature(), dummy_public_key());
599 assert!(result.is_err());
600 assert_eq!(result.unwrap_err(), "Operator index out of bounds");
601 }
602
603 #[test]
604 fn test_task_state_threshold() {
605 let mut state = TaskState::new(1, 100, vec![], 5, 3);
606
607 assert!(!state.threshold_met());
608
609 state
610 .add_signature(0, dummy_signature(), dummy_public_key())
611 .unwrap();
612 assert!(!state.threshold_met());
613
614 state
615 .add_signature(1, dummy_signature(), dummy_public_key())
616 .unwrap();
617 assert!(!state.threshold_met());
618
619 state
620 .add_signature(2, dummy_signature(), dummy_public_key())
621 .unwrap();
622 assert!(state.threshold_met());
623 }
624
625 #[test]
626 fn test_task_state_bitmap() {
627 let mut state = TaskState::new(1, 100, vec![], 10, 3);
628
629 state
630 .add_signature(0, dummy_signature(), dummy_public_key())
631 .unwrap();
632 state
633 .add_signature(3, dummy_signature(), dummy_public_key())
634 .unwrap();
635 state
636 .add_signature(7, dummy_signature(), dummy_public_key())
637 .unwrap();
638
639 assert_eq!(state.signer_bitmap, U256::from(137));
641 }
642
643 #[test]
644 fn test_task_state_non_signers() {
645 let mut state = TaskState::new(1, 100, vec![], 5, 3);
646
647 state
648 .add_signature(0, dummy_signature(), dummy_public_key())
649 .unwrap();
650 state
651 .add_signature(2, dummy_signature(), dummy_public_key())
652 .unwrap();
653 state
654 .add_signature(4, dummy_signature(), dummy_public_key())
655 .unwrap();
656
657 let non_signers = state.get_non_signers();
658 assert_eq!(non_signers, vec![1, 3]);
659
660 let signers = state.get_signers();
661 assert_eq!(signers, vec![0, 2, 4]);
662 }
663
664 #[test]
665 fn test_task_state_stake_weighted() {
666 let mut stakes = HashMap::new();
667 stakes.insert(0, 1000); stakes.insert(1, 2000); stakes.insert(2, 3000); stakes.insert(3, 4000); let mut state = TaskState::with_config(
673 1,
674 100,
675 vec![],
676 4,
677 ThresholdType::StakeWeighted(5000), Some(stakes),
679 None,
680 );
681
682 assert_eq!(state.total_stake, 10000);
683 assert_eq!(state.signed_stake(), 0);
684 assert_eq!(state.signed_stake_bps(), 0);
685 assert!(!state.threshold_met());
686
687 state
689 .add_signature(3, dummy_signature(), dummy_public_key())
690 .unwrap();
691 assert_eq!(state.signed_stake(), 4000);
692 assert_eq!(state.signed_stake_bps(), 4000);
693 assert!(!state.threshold_met());
694
695 state
697 .add_signature(1, dummy_signature(), dummy_public_key())
698 .unwrap();
699 assert_eq!(state.signed_stake(), 6000);
700 assert_eq!(state.signed_stake_bps(), 6000);
701 assert!(state.threshold_met());
702 }
703
704 #[test]
705 fn test_task_state_expiry() {
706 let state = TaskState::with_config(
707 1,
708 100,
709 vec![],
710 5,
711 ThresholdType::Count(3),
712 None,
713 Some(Duration::from_millis(50)), );
715
716 assert!(!state.is_expired());
717 assert!(state.time_remaining().is_some());
718
719 std::thread::sleep(Duration::from_millis(60));
721
722 assert!(state.is_expired());
723 assert!(state.time_remaining().is_none());
724 }
725
726 #[test]
727 fn test_task_state_expired_signature_rejected() {
728 let mut state = TaskState::with_config(
729 1,
730 100,
731 vec![],
732 5,
733 ThresholdType::Count(3),
734 None,
735 Some(Duration::from_millis(10)),
736 );
737
738 std::thread::sleep(Duration::from_millis(20));
740
741 let result = state.add_signature(0, dummy_signature(), dummy_public_key());
742 assert!(result.is_err());
743 assert_eq!(result.unwrap_err(), "Task has expired");
744 }
745
746 #[test]
747 fn test_aggregation_state_init_task() {
748 let state = AggregationState::new();
749
750 assert!(state.init_task(1, 100, vec![1, 2, 3], 5, 3).is_ok());
751
752 let result = state.init_task(1, 100, vec![1, 2, 3], 5, 3);
754 assert!(result.is_err());
755 assert_eq!(result.unwrap_err(), "Task already exists");
756 }
757
758 #[test]
759 fn test_aggregation_state_submit_signature() {
760 let state = AggregationState::new();
761 state.init_task(1, 100, vec![], 5, 3).unwrap();
762
763 let (count, threshold_met) = state
764 .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
765 .unwrap();
766 assert_eq!(count, 1);
767 assert!(!threshold_met);
768
769 let (count, threshold_met) = state
770 .submit_signature(1, 100, 1, dummy_signature(), dummy_public_key())
771 .unwrap();
772 assert_eq!(count, 2);
773 assert!(!threshold_met);
774
775 let (count, threshold_met) = state
776 .submit_signature(1, 100, 2, dummy_signature(), dummy_public_key())
777 .unwrap();
778 assert_eq!(count, 3);
779 assert!(threshold_met);
780 }
781
782 #[test]
783 fn test_aggregation_state_get_status() {
784 let state = AggregationState::new();
785
786 assert!(state.get_status(1, 100).is_none());
788
789 state.init_task(1, 100, vec![], 5, 3).unwrap();
791
792 let status = state.get_status(1, 100).unwrap();
793 assert_eq!(status.signatures_collected, 0);
794 assert_eq!(status.threshold_required, 3);
795 assert!(!status.threshold_met);
796 assert!(!status.submitted);
797 assert!(!status.is_expired);
798 }
799
800 #[test]
801 fn test_aggregation_state_mark_submitted() {
802 let state = AggregationState::new();
803 state.init_task(1, 100, vec![], 5, 3).unwrap();
804
805 assert!(state.mark_submitted(1, 100).is_ok());
806
807 let status = state.get_status(1, 100).unwrap();
808 assert!(status.submitted);
809
810 let result = state.submit_signature(1, 100, 0, dummy_signature(), dummy_public_key());
812 assert!(result.is_err());
813 assert_eq!(result.unwrap_err(), "Task already submitted to chain");
814 }
815
816 #[test]
817 fn test_aggregation_state_get_for_aggregation() {
818 let state = AggregationState::new();
819 state.init_task(1, 100, vec![1, 2, 3], 5, 2).unwrap();
820
821 assert!(state.get_for_aggregation(1, 100).is_none());
823
824 state
826 .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
827 .unwrap();
828 state
829 .submit_signature(1, 100, 1, dummy_signature(), dummy_public_key())
830 .unwrap();
831
832 let task = state.get_for_aggregation(1, 100).unwrap();
833 assert_eq!(task.service_id, 1);
834 assert_eq!(task.call_id, 100);
835 assert_eq!(task.output, vec![1, 2, 3]);
836 assert_eq!(task.signatures.len(), 2);
837 assert_eq!(task.public_keys.len(), 2);
838 assert_eq!(task.non_signer_indices, vec![2, 3, 4]);
839 }
840
841 #[test]
842 fn test_aggregation_state_get_for_aggregation_submitted() {
843 let state = AggregationState::new();
844 state.init_task(1, 100, vec![], 5, 2).unwrap();
845 state
846 .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
847 .unwrap();
848 state
849 .submit_signature(1, 100, 1, dummy_signature(), dummy_public_key())
850 .unwrap();
851
852 assert!(state.get_for_aggregation(1, 100).is_some());
854
855 state.mark_submitted(1, 100).unwrap();
857
858 assert!(state.get_for_aggregation(1, 100).is_none());
860 }
861
862 #[test]
863 fn test_aggregation_state_remove_task() {
864 let state = AggregationState::new();
865 state.init_task(1, 100, vec![], 5, 3).unwrap();
866
867 assert!(state.get_status(1, 100).is_some());
868 assert!(state.remove_task(1, 100));
869 assert!(state.get_status(1, 100).is_none());
870
871 assert!(!state.remove_task(1, 100));
873 }
874
875 #[test]
876 fn test_multiple_tasks() {
877 let state = AggregationState::new();
878
879 state.init_task(1, 100, vec![1], 5, 3).unwrap();
881 state.init_task(1, 101, vec![2], 5, 3).unwrap();
882 state.init_task(2, 100, vec![3], 5, 3).unwrap();
883
884 state
886 .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
887 .unwrap();
888
889 assert_eq!(state.get_status(1, 100).unwrap().signatures_collected, 1);
890 assert_eq!(state.get_status(1, 101).unwrap().signatures_collected, 0);
891 assert_eq!(state.get_status(2, 100).unwrap().signatures_collected, 0);
892 }
893
894 #[test]
895 fn test_cleanup_expired() {
896 let state = AggregationState::new();
897
898 state
900 .init_task_with_config(
901 1,
902 100,
903 vec![],
904 5,
905 TaskConfig {
906 threshold_type: ThresholdType::Count(3),
907 ttl: Some(Duration::from_millis(10)),
908 ..Default::default()
909 },
910 )
911 .unwrap();
912
913 state.init_task(1, 101, vec![], 5, 3).unwrap();
915
916 assert_eq!(state.task_count(), 2);
917
918 std::thread::sleep(Duration::from_millis(20));
920
921 let removed = state.cleanup_expired();
922 assert_eq!(removed, 1);
923 assert_eq!(state.task_count(), 1);
924 assert!(state.get_status(1, 101).is_some());
925 }
926
927 #[test]
928 fn test_cleanup_submitted() {
929 let state = AggregationState::new();
930
931 state.init_task(1, 100, vec![], 5, 1).unwrap();
932 state.init_task(1, 101, vec![], 5, 1).unwrap();
933
934 state
935 .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
936 .unwrap();
937 state.mark_submitted(1, 100).unwrap();
938
939 assert_eq!(state.task_count(), 2);
940
941 let removed = state.cleanup_submitted();
942 assert_eq!(removed, 1);
943 assert_eq!(state.task_count(), 1);
944 assert!(state.get_status(1, 101).is_some());
945 }
946
947 #[test]
948 fn test_task_counts() {
949 let state = AggregationState::new();
950
951 state.init_task(1, 100, vec![], 5, 3).unwrap();
953
954 state.init_task(1, 101, vec![], 5, 1).unwrap();
956 state
957 .submit_signature(1, 101, 0, dummy_signature(), dummy_public_key())
958 .unwrap();
959
960 state.init_task(1, 102, vec![], 5, 1).unwrap();
962 state
963 .submit_signature(1, 102, 0, dummy_signature(), dummy_public_key())
964 .unwrap();
965 state.mark_submitted(1, 102).unwrap();
966
967 let counts = state.task_counts();
968 assert_eq!(counts.total, 3);
969 assert_eq!(counts.pending, 1);
970 assert_eq!(counts.ready, 1);
971 assert_eq!(counts.submitted, 1);
972 assert_eq!(counts.expired, 0);
973 }
974
975 #[test]
976 fn test_get_task_output() {
977 let state = AggregationState::new();
978 let output = vec![1, 2, 3, 4, 5];
979
980 state.init_task(1, 100, output.clone(), 5, 3).unwrap();
981
982 let retrieved = state.get_task_output(1, 100);
983 assert_eq!(retrieved, Some(output));
984
985 assert!(state.get_task_output(1, 999).is_none());
987 }
988}