Skip to main content

blueprint_tangle_aggregation_svc/
state.rs

1//! In-memory aggregation state management
2
3use crate::types::TaskId;
4use alloy_primitives::U256;
5use blueprint_crypto_bn254::{ArkBlsBn254Public, ArkBlsBn254Signature};
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11/// Threshold type for aggregation
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ThresholdType {
14    /// Count-based: need at least N signatures
15    Count(u32),
16    /// Stake-weighted: need at least N basis points (0-10000) of total stake
17    StakeWeighted(u32),
18}
19
20impl Default for ThresholdType {
21    fn default() -> Self {
22        ThresholdType::Count(1)
23    }
24}
25
26/// Operator information for stake-weighted aggregation
27#[derive(Debug, Clone)]
28pub struct OperatorInfo {
29    /// Operator's stake weight (can be actual stake or relative weight)
30    pub stake: u64,
31    /// Whether this operator has been registered
32    pub registered: bool,
33}
34
35impl Default for OperatorInfo {
36    fn default() -> Self {
37        Self {
38            stake: 1, // Default weight of 1 for count-based
39            registered: true,
40        }
41    }
42}
43
44/// State for a single aggregation task
45#[derive(Debug)]
46pub struct TaskState {
47    /// Service ID
48    pub service_id: u64,
49    /// Call ID
50    pub call_id: u64,
51    /// The output being signed
52    pub output: Vec<u8>,
53    /// Number of operators in the service
54    pub operator_count: u32,
55    /// Threshold type and value
56    pub threshold_type: ThresholdType,
57    /// Bitmap of which operators have signed (bit i = operator i signed)
58    pub signer_bitmap: U256,
59    /// Collected signatures indexed by operator index
60    pub signatures: HashMap<u32, ArkBlsBn254Signature>,
61    /// Collected public keys indexed by operator index
62    pub public_keys: HashMap<u32, ArkBlsBn254Public>,
63    /// Operator stakes for stake-weighted thresholds
64    pub operator_stakes: HashMap<u32, u64>,
65    /// Total stake of all operators
66    pub total_stake: u64,
67    /// Whether this task has been submitted to chain
68    pub submitted: bool,
69    /// When this task was created
70    pub created_at: Instant,
71    /// When this task expires (None = never)
72    pub expires_at: Option<Instant>,
73}
74
75impl TaskState {
76    /// Create a new task state with count-based threshold
77    pub fn new(
78        service_id: u64,
79        call_id: u64,
80        output: Vec<u8>,
81        operator_count: u32,
82        threshold: u32,
83    ) -> Self {
84        Self::with_config(
85            service_id,
86            call_id,
87            output,
88            operator_count,
89            ThresholdType::Count(threshold),
90            None,
91            None,
92        )
93    }
94
95    /// Create a new task state with full configuration
96    pub fn with_config(
97        service_id: u64,
98        call_id: u64,
99        output: Vec<u8>,
100        operator_count: u32,
101        threshold_type: ThresholdType,
102        operator_stakes: Option<HashMap<u32, u64>>,
103        ttl: Option<Duration>,
104    ) -> Self {
105        let now = Instant::now();
106        let expires_at = ttl.map(|d| now + d);
107
108        // Calculate total stake
109        let (stakes, total_stake) = if let Some(stakes) = operator_stakes {
110            let total: u64 = stakes.values().sum();
111            (stakes, total)
112        } else {
113            // Default: each operator has stake of 1
114            let stakes: HashMap<u32, u64> = (0..operator_count).map(|i| (i, 1u64)).collect();
115            let total = operator_count as u64;
116            (stakes, total)
117        };
118
119        Self {
120            service_id,
121            call_id,
122            output,
123            operator_count,
124            threshold_type,
125            signer_bitmap: U256::ZERO,
126            signatures: HashMap::new(),
127            public_keys: HashMap::new(),
128            operator_stakes: stakes,
129            total_stake,
130            submitted: false,
131            created_at: now,
132            expires_at,
133        }
134    }
135
136    /// Check if this task has expired
137    pub fn is_expired(&self) -> bool {
138        self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
139    }
140
141    /// Get remaining time until expiry
142    pub fn time_remaining(&self) -> Option<Duration> {
143        self.expires_at.and_then(|t| {
144            let now = Instant::now();
145            if now < t {
146                Some(t - now)
147            } else {
148                None
149            }
150        })
151    }
152
153    /// Add a signature from an operator
154    pub fn add_signature(
155        &mut self,
156        operator_index: u32,
157        signature: ArkBlsBn254Signature,
158        public_key: ArkBlsBn254Public,
159    ) -> Result<(), &'static str> {
160        if operator_index >= self.operator_count {
161            return Err("Operator index out of bounds");
162        }
163
164        if self.has_signed(operator_index) {
165            return Err("Operator already signed");
166        }
167
168        if self.is_expired() {
169            return Err("Task has expired");
170        }
171
172        // Set bit in bitmap
173        self.signer_bitmap |= U256::from(1u64) << operator_index as usize;
174
175        // Store signature and public key
176        self.signatures.insert(operator_index, signature);
177        self.public_keys.insert(operator_index, public_key);
178
179        Ok(())
180    }
181
182    /// Check if an operator has already signed
183    pub fn has_signed(&self, operator_index: u32) -> bool {
184        (self.signer_bitmap >> operator_index as usize) & U256::from(1u64) == U256::from(1u64)
185    }
186
187    /// Get the number of signatures collected
188    pub fn signature_count(&self) -> usize {
189        self.signatures.len()
190    }
191
192    /// Get the total stake that has signed
193    pub fn signed_stake(&self) -> u64 {
194        self.signatures
195            .keys()
196            .map(|idx| self.operator_stakes.get(idx).copied().unwrap_or(0))
197            .sum()
198    }
199
200    /// Get the signed stake as basis points (0-10000) of total stake
201    pub fn signed_stake_bps(&self) -> u32 {
202        if self.total_stake == 0 {
203            return 0;
204        }
205        ((self.signed_stake() * 10000) / self.total_stake) as u32
206    }
207
208    /// Check if threshold is met
209    pub fn threshold_met(&self) -> bool {
210        match self.threshold_type {
211            ThresholdType::Count(n) => self.signature_count() >= n as usize,
212            ThresholdType::StakeWeighted(bps) => self.signed_stake_bps() >= bps,
213        }
214    }
215
216    /// Get the threshold value (for API responses)
217    pub fn threshold_value(&self) -> usize {
218        match self.threshold_type {
219            ThresholdType::Count(n) => n as usize,
220            ThresholdType::StakeWeighted(bps) => bps as usize,
221        }
222    }
223
224    /// Get list of operators who haven't signed (non-signers)
225    pub fn get_non_signers(&self) -> Vec<u32> {
226        (0..self.operator_count)
227            .filter(|&i| !self.has_signed(i))
228            .collect()
229    }
230
231    /// Get list of operators who have signed
232    pub fn get_signers(&self) -> Vec<u32> {
233        let mut signers: Vec<_> = self.signatures.keys().copied().collect();
234        signers.sort_unstable();
235        signers
236    }
237
238    /// Get all signatures and public keys in order for aggregation
239    pub fn get_signatures_for_aggregation(
240        &self,
241    ) -> (Vec<ArkBlsBn254Signature>, Vec<ArkBlsBn254Public>) {
242        let mut sigs = Vec::with_capacity(self.signatures.len());
243        let mut pks = Vec::with_capacity(self.public_keys.len());
244
245        // Collect in sorted order by operator index
246        let indices = self.get_signers();
247
248        for idx in indices {
249            if let (Some(sig), Some(pk)) = (self.signatures.get(&idx), self.public_keys.get(&idx)) {
250                sigs.push(sig.clone());
251                pks.push(pk.clone());
252            }
253        }
254
255        (sigs, pks)
256    }
257}
258
259/// Configuration for task initialization
260#[derive(Debug, Clone)]
261pub struct TaskConfig {
262    /// Threshold type
263    pub threshold_type: ThresholdType,
264    /// Operator stakes (optional, defaults to equal weight)
265    pub operator_stakes: Option<HashMap<u32, u64>>,
266    /// Time-to-live for the task
267    pub ttl: Option<Duration>,
268}
269
270impl Default for TaskConfig {
271    fn default() -> Self {
272        Self {
273            threshold_type: ThresholdType::Count(1),
274            operator_stakes: None,
275            ttl: None,
276        }
277    }
278}
279
280/// Global aggregation state manager
281#[derive(Debug, Clone)]
282pub struct AggregationState {
283    /// All active tasks
284    tasks: Arc<RwLock<HashMap<TaskId, TaskState>>>,
285}
286
287impl Default for AggregationState {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293impl AggregationState {
294    /// Create a new aggregation state manager
295    pub fn new() -> Self {
296        Self {
297            tasks: Arc::new(RwLock::new(HashMap::new())),
298        }
299    }
300
301    /// Initialize a new aggregation task (simple API)
302    pub fn init_task(
303        &self,
304        service_id: u64,
305        call_id: u64,
306        output: Vec<u8>,
307        operator_count: u32,
308        threshold: u32,
309    ) -> Result<(), &'static str> {
310        self.init_task_with_config(
311            service_id,
312            call_id,
313            output,
314            operator_count,
315            TaskConfig {
316                threshold_type: ThresholdType::Count(threshold),
317                ..Default::default()
318            },
319        )
320    }
321
322    /// Initialize a new aggregation task with full configuration
323    pub fn init_task_with_config(
324        &self,
325        service_id: u64,
326        call_id: u64,
327        output: Vec<u8>,
328        operator_count: u32,
329        config: TaskConfig,
330    ) -> Result<(), &'static str> {
331        let task_id = TaskId::new(service_id, call_id);
332        let mut tasks = self.tasks.write();
333
334        if tasks.contains_key(&task_id) {
335            return Err("Task already exists");
336        }
337
338        let state = TaskState::with_config(
339            service_id,
340            call_id,
341            output,
342            operator_count,
343            config.threshold_type,
344            config.operator_stakes,
345            config.ttl,
346        );
347        tasks.insert(task_id, state);
348        Ok(())
349    }
350
351    /// Get the expected output for a task (for validation)
352    pub fn get_task_output(&self, service_id: u64, call_id: u64) -> Option<Vec<u8>> {
353        let task_id = TaskId::new(service_id, call_id);
354        let tasks = self.tasks.read();
355        tasks.get(&task_id).map(|t| t.output.clone())
356    }
357
358    /// Submit a signature for a task
359    pub fn submit_signature(
360        &self,
361        service_id: u64,
362        call_id: u64,
363        operator_index: u32,
364        signature: ArkBlsBn254Signature,
365        public_key: ArkBlsBn254Public,
366    ) -> Result<(usize, bool), &'static str> {
367        let task_id = TaskId::new(service_id, call_id);
368        let mut tasks = self.tasks.write();
369
370        let task = tasks.get_mut(&task_id).ok_or("Task not found")?;
371
372        if task.submitted {
373            return Err("Task already submitted to chain");
374        }
375
376        if task.is_expired() {
377            return Err("Task has expired");
378        }
379
380        task.add_signature(operator_index, signature, public_key)?;
381
382        Ok((task.signature_count(), task.threshold_met()))
383    }
384
385    /// Get task status
386    pub fn get_status(&self, service_id: u64, call_id: u64) -> Option<TaskStatus> {
387        let task_id = TaskId::new(service_id, call_id);
388        let tasks = self.tasks.read();
389
390        tasks.get(&task_id).map(|task| TaskStatus {
391            signatures_collected: task.signature_count(),
392            threshold_required: task.threshold_value(),
393            threshold_type: task.threshold_type,
394            threshold_met: task.threshold_met(),
395            signer_bitmap: task.signer_bitmap,
396            signed_stake_bps: task.signed_stake_bps(),
397            submitted: task.submitted,
398            is_expired: task.is_expired(),
399            time_remaining_secs: task.time_remaining().map(|d| d.as_secs()),
400        })
401    }
402
403    /// Get task for aggregation (if threshold met)
404    pub fn get_for_aggregation(&self, service_id: u64, call_id: u64) -> Option<TaskForAggregation> {
405        let task_id = TaskId::new(service_id, call_id);
406        let tasks = self.tasks.read();
407
408        let task = tasks.get(&task_id)?;
409
410        if !task.threshold_met() || task.submitted || task.is_expired() {
411            return None;
412        }
413
414        let (signatures, public_keys) = task.get_signatures_for_aggregation();
415
416        Some(TaskForAggregation {
417            service_id: task.service_id,
418            call_id: task.call_id,
419            output: task.output.clone(),
420            signer_bitmap: task.signer_bitmap,
421            non_signer_indices: task.get_non_signers(),
422            signatures,
423            public_keys,
424        })
425    }
426
427    /// Mark task as submitted
428    pub fn mark_submitted(&self, service_id: u64, call_id: u64) -> Result<(), &'static str> {
429        let task_id = TaskId::new(service_id, call_id);
430        let mut tasks = self.tasks.write();
431
432        let task = tasks.get_mut(&task_id).ok_or("Task not found")?;
433        task.submitted = true;
434        Ok(())
435    }
436
437    /// Remove a completed task
438    pub fn remove_task(&self, service_id: u64, call_id: u64) -> bool {
439        let task_id = TaskId::new(service_id, call_id);
440        self.tasks.write().remove(&task_id).is_some()
441    }
442
443    /// Cleanup expired tasks
444    /// Returns the number of tasks removed
445    pub fn cleanup_expired(&self) -> usize {
446        let mut tasks = self.tasks.write();
447        let before = tasks.len();
448        tasks.retain(|_, task| !task.is_expired());
449        before - tasks.len()
450    }
451
452    /// Cleanup submitted tasks
453    /// Returns the number of tasks removed
454    pub fn cleanup_submitted(&self) -> usize {
455        let mut tasks = self.tasks.write();
456        let before = tasks.len();
457        tasks.retain(|_, task| !task.submitted);
458        before - tasks.len()
459    }
460
461    /// Cleanup both expired and submitted tasks
462    /// Returns the number of tasks removed
463    pub fn cleanup(&self) -> usize {
464        let mut tasks = self.tasks.write();
465        let before = tasks.len();
466        tasks.retain(|_, task| !task.is_expired() && !task.submitted);
467        before - tasks.len()
468    }
469
470    /// Get count of active tasks
471    pub fn task_count(&self) -> usize {
472        self.tasks.read().len()
473    }
474
475    /// Get count of tasks by status
476    pub fn task_counts(&self) -> TaskCounts {
477        let tasks = self.tasks.read();
478        let mut counts = TaskCounts::default();
479
480        for task in tasks.values() {
481            counts.total += 1;
482            if task.is_expired() {
483                counts.expired += 1;
484            } else if task.submitted {
485                counts.submitted += 1;
486            } else if task.threshold_met() {
487                counts.ready += 1;
488            } else {
489                counts.pending += 1;
490            }
491        }
492
493        counts
494    }
495}
496
497/// Simplified task status for API responses
498#[derive(Debug, Clone)]
499pub struct TaskStatus {
500    pub signatures_collected: usize,
501    pub threshold_required: usize,
502    pub threshold_type: ThresholdType,
503    pub threshold_met: bool,
504    pub signer_bitmap: U256,
505    pub signed_stake_bps: u32,
506    pub submitted: bool,
507    pub is_expired: bool,
508    pub time_remaining_secs: Option<u64>,
509}
510
511/// Task data ready for aggregation
512#[derive(Debug)]
513pub struct TaskForAggregation {
514    pub service_id: u64,
515    pub call_id: u64,
516    pub output: Vec<u8>,
517    pub signer_bitmap: U256,
518    pub non_signer_indices: Vec<u32>,
519    pub signatures: Vec<ArkBlsBn254Signature>,
520    pub public_keys: Vec<ArkBlsBn254Public>,
521}
522
523/// Task count statistics
524#[derive(Debug, Clone, Default)]
525pub struct TaskCounts {
526    pub total: usize,
527    pub pending: usize,
528    pub ready: usize,
529    pub submitted: usize,
530    pub expired: usize,
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use ark_bn254::{G1Affine, G2Affine};
537    use ark_ec::AffineRepr;
538
539    fn dummy_signature() -> ArkBlsBn254Signature {
540        ArkBlsBn254Signature(G1Affine::generator())
541    }
542
543    fn dummy_public_key() -> ArkBlsBn254Public {
544        ArkBlsBn254Public(G2Affine::generator())
545    }
546
547    #[test]
548    fn test_task_state_new() {
549        let state = TaskState::new(1, 100, vec![1, 2, 3], 5, 3);
550        assert_eq!(state.service_id, 1);
551        assert_eq!(state.call_id, 100);
552        assert_eq!(state.output, vec![1, 2, 3]);
553        assert_eq!(state.operator_count, 5);
554        assert_eq!(state.threshold_type, ThresholdType::Count(3));
555        assert_eq!(state.signer_bitmap, U256::ZERO);
556        assert!(state.signatures.is_empty());
557        assert!(state.public_keys.is_empty());
558        assert!(!state.submitted);
559        assert!(!state.is_expired());
560    }
561
562    #[test]
563    fn test_task_state_add_signature() {
564        let mut state = TaskState::new(1, 100, vec![], 5, 3);
565
566        // Add first signature
567        assert!(state
568            .add_signature(0, dummy_signature(), dummy_public_key())
569            .is_ok());
570        assert!(state.has_signed(0));
571        assert!(!state.has_signed(1));
572        assert_eq!(state.signature_count(), 1);
573
574        // Add second signature
575        assert!(state
576            .add_signature(2, dummy_signature(), dummy_public_key())
577            .is_ok());
578        assert!(state.has_signed(2));
579        assert_eq!(state.signature_count(), 2);
580    }
581
582    #[test]
583    fn test_task_state_duplicate_signature() {
584        let mut state = TaskState::new(1, 100, vec![], 5, 3);
585
586        assert!(state
587            .add_signature(0, dummy_signature(), dummy_public_key())
588            .is_ok());
589        let result = state.add_signature(0, dummy_signature(), dummy_public_key());
590        assert!(result.is_err());
591        assert_eq!(result.unwrap_err(), "Operator already signed");
592    }
593
594    #[test]
595    fn test_task_state_out_of_bounds() {
596        let mut state = TaskState::new(1, 100, vec![], 5, 3);
597
598        let result = state.add_signature(5, dummy_signature(), dummy_public_key());
599        assert!(result.is_err());
600        assert_eq!(result.unwrap_err(), "Operator index out of bounds");
601    }
602
603    #[test]
604    fn test_task_state_threshold() {
605        let mut state = TaskState::new(1, 100, vec![], 5, 3);
606
607        assert!(!state.threshold_met());
608
609        state
610            .add_signature(0, dummy_signature(), dummy_public_key())
611            .unwrap();
612        assert!(!state.threshold_met());
613
614        state
615            .add_signature(1, dummy_signature(), dummy_public_key())
616            .unwrap();
617        assert!(!state.threshold_met());
618
619        state
620            .add_signature(2, dummy_signature(), dummy_public_key())
621            .unwrap();
622        assert!(state.threshold_met());
623    }
624
625    #[test]
626    fn test_task_state_bitmap() {
627        let mut state = TaskState::new(1, 100, vec![], 10, 3);
628
629        state
630            .add_signature(0, dummy_signature(), dummy_public_key())
631            .unwrap();
632        state
633            .add_signature(3, dummy_signature(), dummy_public_key())
634            .unwrap();
635        state
636            .add_signature(7, dummy_signature(), dummy_public_key())
637            .unwrap();
638
639        // Bitmap should be 0b10001001 = 137
640        assert_eq!(state.signer_bitmap, U256::from(137));
641    }
642
643    #[test]
644    fn test_task_state_non_signers() {
645        let mut state = TaskState::new(1, 100, vec![], 5, 3);
646
647        state
648            .add_signature(0, dummy_signature(), dummy_public_key())
649            .unwrap();
650        state
651            .add_signature(2, dummy_signature(), dummy_public_key())
652            .unwrap();
653        state
654            .add_signature(4, dummy_signature(), dummy_public_key())
655            .unwrap();
656
657        let non_signers = state.get_non_signers();
658        assert_eq!(non_signers, vec![1, 3]);
659
660        let signers = state.get_signers();
661        assert_eq!(signers, vec![0, 2, 4]);
662    }
663
664    #[test]
665    fn test_task_state_stake_weighted() {
666        let mut stakes = HashMap::new();
667        stakes.insert(0, 1000); // 10%
668        stakes.insert(1, 2000); // 20%
669        stakes.insert(2, 3000); // 30%
670        stakes.insert(3, 4000); // 40%
671
672        let mut state = TaskState::with_config(
673            1,
674            100,
675            vec![],
676            4,
677            ThresholdType::StakeWeighted(5000), // 50% required
678            Some(stakes),
679            None,
680        );
681
682        assert_eq!(state.total_stake, 10000);
683        assert_eq!(state.signed_stake(), 0);
684        assert_eq!(state.signed_stake_bps(), 0);
685        assert!(!state.threshold_met());
686
687        // Add operator 3 (40%)
688        state
689            .add_signature(3, dummy_signature(), dummy_public_key())
690            .unwrap();
691        assert_eq!(state.signed_stake(), 4000);
692        assert_eq!(state.signed_stake_bps(), 4000);
693        assert!(!state.threshold_met());
694
695        // Add operator 1 (20%) -> now 60%
696        state
697            .add_signature(1, dummy_signature(), dummy_public_key())
698            .unwrap();
699        assert_eq!(state.signed_stake(), 6000);
700        assert_eq!(state.signed_stake_bps(), 6000);
701        assert!(state.threshold_met());
702    }
703
704    #[test]
705    fn test_task_state_expiry() {
706        let state = TaskState::with_config(
707            1,
708            100,
709            vec![],
710            5,
711            ThresholdType::Count(3),
712            None,
713            Some(Duration::from_millis(50)), // Very short TTL
714        );
715
716        assert!(!state.is_expired());
717        assert!(state.time_remaining().is_some());
718
719        // Wait for expiry
720        std::thread::sleep(Duration::from_millis(60));
721
722        assert!(state.is_expired());
723        assert!(state.time_remaining().is_none());
724    }
725
726    #[test]
727    fn test_task_state_expired_signature_rejected() {
728        let mut state = TaskState::with_config(
729            1,
730            100,
731            vec![],
732            5,
733            ThresholdType::Count(3),
734            None,
735            Some(Duration::from_millis(10)),
736        );
737
738        // Wait for expiry
739        std::thread::sleep(Duration::from_millis(20));
740
741        let result = state.add_signature(0, dummy_signature(), dummy_public_key());
742        assert!(result.is_err());
743        assert_eq!(result.unwrap_err(), "Task has expired");
744    }
745
746    #[test]
747    fn test_aggregation_state_init_task() {
748        let state = AggregationState::new();
749
750        assert!(state.init_task(1, 100, vec![1, 2, 3], 5, 3).is_ok());
751
752        // Duplicate should fail
753        let result = state.init_task(1, 100, vec![1, 2, 3], 5, 3);
754        assert!(result.is_err());
755        assert_eq!(result.unwrap_err(), "Task already exists");
756    }
757
758    #[test]
759    fn test_aggregation_state_submit_signature() {
760        let state = AggregationState::new();
761        state.init_task(1, 100, vec![], 5, 3).unwrap();
762
763        let (count, threshold_met) = state
764            .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
765            .unwrap();
766        assert_eq!(count, 1);
767        assert!(!threshold_met);
768
769        let (count, threshold_met) = state
770            .submit_signature(1, 100, 1, dummy_signature(), dummy_public_key())
771            .unwrap();
772        assert_eq!(count, 2);
773        assert!(!threshold_met);
774
775        let (count, threshold_met) = state
776            .submit_signature(1, 100, 2, dummy_signature(), dummy_public_key())
777            .unwrap();
778        assert_eq!(count, 3);
779        assert!(threshold_met);
780    }
781
782    #[test]
783    fn test_aggregation_state_get_status() {
784        let state = AggregationState::new();
785
786        // Non-existent task
787        assert!(state.get_status(1, 100).is_none());
788
789        // Create task
790        state.init_task(1, 100, vec![], 5, 3).unwrap();
791
792        let status = state.get_status(1, 100).unwrap();
793        assert_eq!(status.signatures_collected, 0);
794        assert_eq!(status.threshold_required, 3);
795        assert!(!status.threshold_met);
796        assert!(!status.submitted);
797        assert!(!status.is_expired);
798    }
799
800    #[test]
801    fn test_aggregation_state_mark_submitted() {
802        let state = AggregationState::new();
803        state.init_task(1, 100, vec![], 5, 3).unwrap();
804
805        assert!(state.mark_submitted(1, 100).is_ok());
806
807        let status = state.get_status(1, 100).unwrap();
808        assert!(status.submitted);
809
810        // Can't submit signatures after marked submitted
811        let result = state.submit_signature(1, 100, 0, dummy_signature(), dummy_public_key());
812        assert!(result.is_err());
813        assert_eq!(result.unwrap_err(), "Task already submitted to chain");
814    }
815
816    #[test]
817    fn test_aggregation_state_get_for_aggregation() {
818        let state = AggregationState::new();
819        state.init_task(1, 100, vec![1, 2, 3], 5, 2).unwrap();
820
821        // Not enough signatures
822        assert!(state.get_for_aggregation(1, 100).is_none());
823
824        // Add signatures to meet threshold
825        state
826            .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
827            .unwrap();
828        state
829            .submit_signature(1, 100, 1, dummy_signature(), dummy_public_key())
830            .unwrap();
831
832        let task = state.get_for_aggregation(1, 100).unwrap();
833        assert_eq!(task.service_id, 1);
834        assert_eq!(task.call_id, 100);
835        assert_eq!(task.output, vec![1, 2, 3]);
836        assert_eq!(task.signatures.len(), 2);
837        assert_eq!(task.public_keys.len(), 2);
838        assert_eq!(task.non_signer_indices, vec![2, 3, 4]);
839    }
840
841    #[test]
842    fn test_aggregation_state_get_for_aggregation_submitted() {
843        let state = AggregationState::new();
844        state.init_task(1, 100, vec![], 5, 2).unwrap();
845        state
846            .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
847            .unwrap();
848        state
849            .submit_signature(1, 100, 1, dummy_signature(), dummy_public_key())
850            .unwrap();
851
852        // Should return aggregation data
853        assert!(state.get_for_aggregation(1, 100).is_some());
854
855        // Mark as submitted
856        state.mark_submitted(1, 100).unwrap();
857
858        // Should no longer return data after submission
859        assert!(state.get_for_aggregation(1, 100).is_none());
860    }
861
862    #[test]
863    fn test_aggregation_state_remove_task() {
864        let state = AggregationState::new();
865        state.init_task(1, 100, vec![], 5, 3).unwrap();
866
867        assert!(state.get_status(1, 100).is_some());
868        assert!(state.remove_task(1, 100));
869        assert!(state.get_status(1, 100).is_none());
870
871        // Removing non-existent task returns false
872        assert!(!state.remove_task(1, 100));
873    }
874
875    #[test]
876    fn test_multiple_tasks() {
877        let state = AggregationState::new();
878
879        // Create multiple tasks
880        state.init_task(1, 100, vec![1], 5, 3).unwrap();
881        state.init_task(1, 101, vec![2], 5, 3).unwrap();
882        state.init_task(2, 100, vec![3], 5, 3).unwrap();
883
884        // Each task is independent
885        state
886            .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
887            .unwrap();
888
889        assert_eq!(state.get_status(1, 100).unwrap().signatures_collected, 1);
890        assert_eq!(state.get_status(1, 101).unwrap().signatures_collected, 0);
891        assert_eq!(state.get_status(2, 100).unwrap().signatures_collected, 0);
892    }
893
894    #[test]
895    fn test_cleanup_expired() {
896        let state = AggregationState::new();
897
898        // Create expired task
899        state
900            .init_task_with_config(
901                1,
902                100,
903                vec![],
904                5,
905                TaskConfig {
906                    threshold_type: ThresholdType::Count(3),
907                    ttl: Some(Duration::from_millis(10)),
908                    ..Default::default()
909                },
910            )
911            .unwrap();
912
913        // Create non-expired task
914        state.init_task(1, 101, vec![], 5, 3).unwrap();
915
916        assert_eq!(state.task_count(), 2);
917
918        // Wait for expiry
919        std::thread::sleep(Duration::from_millis(20));
920
921        let removed = state.cleanup_expired();
922        assert_eq!(removed, 1);
923        assert_eq!(state.task_count(), 1);
924        assert!(state.get_status(1, 101).is_some());
925    }
926
927    #[test]
928    fn test_cleanup_submitted() {
929        let state = AggregationState::new();
930
931        state.init_task(1, 100, vec![], 5, 1).unwrap();
932        state.init_task(1, 101, vec![], 5, 1).unwrap();
933
934        state
935            .submit_signature(1, 100, 0, dummy_signature(), dummy_public_key())
936            .unwrap();
937        state.mark_submitted(1, 100).unwrap();
938
939        assert_eq!(state.task_count(), 2);
940
941        let removed = state.cleanup_submitted();
942        assert_eq!(removed, 1);
943        assert_eq!(state.task_count(), 1);
944        assert!(state.get_status(1, 101).is_some());
945    }
946
947    #[test]
948    fn test_task_counts() {
949        let state = AggregationState::new();
950
951        // Pending task
952        state.init_task(1, 100, vec![], 5, 3).unwrap();
953
954        // Ready task (threshold met)
955        state.init_task(1, 101, vec![], 5, 1).unwrap();
956        state
957            .submit_signature(1, 101, 0, dummy_signature(), dummy_public_key())
958            .unwrap();
959
960        // Submitted task
961        state.init_task(1, 102, vec![], 5, 1).unwrap();
962        state
963            .submit_signature(1, 102, 0, dummy_signature(), dummy_public_key())
964            .unwrap();
965        state.mark_submitted(1, 102).unwrap();
966
967        let counts = state.task_counts();
968        assert_eq!(counts.total, 3);
969        assert_eq!(counts.pending, 1);
970        assert_eq!(counts.ready, 1);
971        assert_eq!(counts.submitted, 1);
972        assert_eq!(counts.expired, 0);
973    }
974
975    #[test]
976    fn test_get_task_output() {
977        let state = AggregationState::new();
978        let output = vec![1, 2, 3, 4, 5];
979
980        state.init_task(1, 100, output.clone(), 5, 3).unwrap();
981
982        let retrieved = state.get_task_output(1, 100);
983        assert_eq!(retrieved, Some(output));
984
985        // Non-existent task
986        assert!(state.get_task_output(1, 999).is_none());
987    }
988}