Skip to main content

blueprint_tangle_aggregation_svc/
service.rs

1//! Main aggregation service logic
2
3use crate::state::{AggregationState, TaskConfig, ThresholdType};
4use crate::types::*;
5use alloy_primitives::U256;
6use ark_serialize::CanonicalDeserialize;
7use blueprint_crypto_bn254::{ArkBlsBn254, ArkBlsBn254Public, ArkBlsBn254Signature};
8use blueprint_crypto_core::{aggregation::AggregatableSignature, KeyType};
9use std::sync::Arc;
10use std::time::Duration;
11use thiserror::Error;
12use tokio::sync::watch;
13use tracing::{debug, info, warn};
14
15/// Errors from the aggregation service
16#[derive(Debug, Error)]
17pub enum ServiceError {
18    #[error("Task not found")]
19    TaskNotFound,
20    #[error("Task already exists")]
21    TaskAlreadyExists,
22    #[error("Task has expired")]
23    TaskExpired,
24    #[error("Invalid signature format")]
25    InvalidSignature,
26    #[error("Invalid public key format")]
27    InvalidPublicKey,
28    #[error("Signature verification failed")]
29    VerificationFailed,
30    #[error("Output mismatch: submitted output does not match task output")]
31    OutputMismatch,
32    #[error("Aggregation failed: {0}")]
33    AggregationFailed(String),
34    #[error("{0}")]
35    Other(String),
36}
37
38/// Configuration for the aggregation service
39#[derive(Debug, Clone)]
40pub struct ServiceConfig {
41    /// Whether to verify signatures on submission
42    pub verify_on_submit: bool,
43    /// Whether to validate that submitted output matches task output
44    pub validate_output: bool,
45    /// Default TTL for tasks (None = no expiry)
46    pub default_task_ttl: Option<Duration>,
47    /// Cleanup interval for expired/submitted tasks
48    pub cleanup_interval: Option<Duration>,
49    /// Whether to auto-cleanup submitted tasks
50    pub auto_cleanup_submitted: bool,
51}
52
53impl Default for ServiceConfig {
54    fn default() -> Self {
55        Self {
56            verify_on_submit: true,
57            validate_output: true,
58            default_task_ttl: Some(Duration::from_secs(3600)), // 1 hour default
59            cleanup_interval: Some(Duration::from_secs(60)),   // Cleanup every minute
60            auto_cleanup_submitted: true,
61        }
62    }
63}
64
65impl ServiceConfig {
66    /// Create a minimal config (no verification, no cleanup)
67    pub fn minimal() -> Self {
68        Self {
69            verify_on_submit: false,
70            validate_output: false,
71            default_task_ttl: None,
72            cleanup_interval: None,
73            auto_cleanup_submitted: false,
74        }
75    }
76}
77
78/// The main aggregation service
79#[derive(Debug)]
80pub struct AggregationService {
81    state: AggregationState,
82    config: ServiceConfig,
83}
84
85impl AggregationService {
86    /// Create a new aggregation service
87    pub fn new(config: ServiceConfig) -> Self {
88        Self {
89            state: AggregationState::new(),
90            config,
91        }
92    }
93
94    /// Create a new aggregation service wrapped in Arc
95    pub fn new_shared(config: ServiceConfig) -> Arc<Self> {
96        Arc::new(Self::new(config))
97    }
98
99    /// Start the background cleanup worker
100    /// Returns a handle that can be used to stop the worker
101    pub fn start_cleanup_worker(self: &Arc<Self>) -> Option<CleanupWorkerHandle> {
102        let interval = self.config.cleanup_interval?;
103        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
104
105        let service = Arc::clone(self);
106
107        let handle = tokio::spawn(async move {
108            let mut interval_timer = tokio::time::interval(interval);
109
110            loop {
111                tokio::select! {
112                    _ = interval_timer.tick() => {
113                        let removed = if service.config.auto_cleanup_submitted {
114                            service.state.cleanup()
115                        } else {
116                            service.state.cleanup_expired()
117                        };
118                        if removed > 0 {
119                            debug!(removed, "Cleaned up tasks");
120                        }
121                    }
122                    _ = shutdown_rx.changed() => {
123                        if *shutdown_rx.borrow() {
124                            info!("Cleanup worker shutting down");
125                            break;
126                        }
127                    }
128                }
129            }
130        });
131
132        Some(CleanupWorkerHandle {
133            shutdown_tx,
134            handle,
135        })
136    }
137
138    /// Initialize a new aggregation task
139    pub fn init_task(
140        &self,
141        service_id: u64,
142        call_id: u64,
143        output: Vec<u8>,
144        operator_count: u32,
145        threshold: u32,
146    ) -> Result<(), ServiceError> {
147        self.init_task_with_config(
148            service_id,
149            call_id,
150            output,
151            operator_count,
152            TaskConfig {
153                threshold_type: ThresholdType::Count(threshold),
154                ttl: self.config.default_task_ttl,
155                ..Default::default()
156            },
157        )
158    }
159
160    /// Initialize a new aggregation task with full configuration
161    pub fn init_task_with_config(
162        &self,
163        service_id: u64,
164        call_id: u64,
165        output: Vec<u8>,
166        operator_count: u32,
167        config: TaskConfig,
168    ) -> Result<(), ServiceError> {
169        info!(
170            service_id,
171            call_id,
172            operator_count,
173            ?config.threshold_type,
174            "Initializing aggregation task"
175        );
176
177        self.state
178            .init_task_with_config(service_id, call_id, output, operator_count, config)
179            .map_err(|e| ServiceError::Other(e.to_string()))
180    }
181
182    /// Submit a signature for aggregation
183    pub fn submit_signature(
184        &self,
185        req: SubmitSignatureRequest,
186    ) -> Result<SubmitSignatureResponse, ServiceError> {
187        debug!(
188            service_id = req.service_id,
189            call_id = req.call_id,
190            operator_index = req.operator_index,
191            "Received signature submission"
192        );
193
194        // Validate output matches task output (if enabled)
195        if self.config.validate_output {
196            let expected_output = self
197                .state
198                .get_task_output(req.service_id, req.call_id)
199                .ok_or(ServiceError::TaskNotFound)?;
200
201            if req.output != expected_output {
202                warn!(
203                    service_id = req.service_id,
204                    call_id = req.call_id,
205                    operator_index = req.operator_index,
206                    "Output mismatch"
207                );
208                return Err(ServiceError::OutputMismatch);
209            }
210        }
211
212        // Parse signature (G1 point)
213        let signature: ArkBlsBn254Signature = ArkBlsBn254Signature(
214            ark_bn254::G1Affine::deserialize_compressed(&req.signature[..])
215                .map_err(|_| ServiceError::InvalidSignature)?,
216        );
217
218        // Parse public key (G2 point)
219        let public_key: ArkBlsBn254Public = ArkBlsBn254Public(
220            ark_bn254::G2Affine::deserialize_compressed(&req.public_key[..])
221                .map_err(|_| ServiceError::InvalidPublicKey)?,
222        );
223
224        // Optionally verify the signature
225        if self.config.verify_on_submit {
226            // Create the message that should have been signed
227            let message = create_signing_message(req.service_id, req.call_id, &req.output);
228
229            if !ArkBlsBn254::verify(&public_key, &message, &signature) {
230                warn!(
231                    service_id = req.service_id,
232                    call_id = req.call_id,
233                    operator_index = req.operator_index,
234                    "Signature verification failed"
235                );
236                return Err(ServiceError::VerificationFailed);
237            }
238        }
239
240        // Get task status for response
241        let status = self
242            .state
243            .get_status(req.service_id, req.call_id)
244            .ok_or(ServiceError::TaskNotFound)?;
245
246        if status.is_expired {
247            return Err(ServiceError::TaskExpired);
248        }
249
250        // Submit to state
251        let (count, threshold_met) = self
252            .state
253            .submit_signature(
254                req.service_id,
255                req.call_id,
256                req.operator_index,
257                signature,
258                public_key,
259            )
260            .map_err(|e| ServiceError::Other(e.to_string()))?;
261
262        info!(
263            service_id = req.service_id,
264            call_id = req.call_id,
265            operator_index = req.operator_index,
266            signatures_collected = count,
267            threshold_met,
268            "Signature accepted"
269        );
270
271        Ok(SubmitSignatureResponse {
272            accepted: true,
273            signatures_collected: count,
274            threshold_required: status.threshold_required,
275            threshold_met,
276            error: None,
277        })
278    }
279
280    /// Get status of an aggregation task
281    pub fn get_status(&self, service_id: u64, call_id: u64) -> GetStatusResponse {
282        match self.state.get_status(service_id, call_id) {
283            Some(status) => GetStatusResponse {
284                exists: true,
285                signatures_collected: status.signatures_collected,
286                threshold_required: status.threshold_required,
287                threshold_met: status.threshold_met,
288                signer_bitmap: status.signer_bitmap,
289                signed_stake_bps: Some(status.signed_stake_bps),
290                submitted: status.submitted,
291                is_expired: Some(status.is_expired),
292                time_remaining_secs: status.time_remaining_secs,
293            },
294            None => GetStatusResponse {
295                exists: false,
296                signatures_collected: 0,
297                threshold_required: 0,
298                threshold_met: false,
299                signer_bitmap: U256::ZERO,
300                signed_stake_bps: None,
301                submitted: false,
302                is_expired: None,
303                time_remaining_secs: None,
304            },
305        }
306    }
307
308    /// Get aggregated result if threshold is met
309    pub fn get_aggregated_result(
310        &self,
311        service_id: u64,
312        call_id: u64,
313    ) -> Option<AggregatedResultResponse> {
314        let task = self.state.get_for_aggregation(service_id, call_id)?;
315
316        // Aggregate signatures and public keys
317        let (agg_sig, agg_pk) = ArkBlsBn254::aggregate(&task.signatures, &task.public_keys)
318            .map_err(|e| {
319                warn!(
320                    service_id,
321                    call_id,
322                    error = %e,
323                    "Aggregation failed"
324                );
325                e
326            })
327            .ok()?;
328
329        // Serialize for response
330        let mut sig_bytes = Vec::new();
331        ark_serialize::CanonicalSerialize::serialize_compressed(&agg_sig.0, &mut sig_bytes).ok()?;
332
333        let mut pk_bytes = Vec::new();
334        ark_serialize::CanonicalSerialize::serialize_compressed(&agg_pk.0, &mut pk_bytes).ok()?;
335
336        info!(
337            service_id,
338            call_id,
339            signers = task.signatures.len(),
340            non_signers = task.non_signer_indices.len(),
341            "Returning aggregated result"
342        );
343
344        Some(AggregatedResultResponse {
345            service_id: task.service_id,
346            call_id: task.call_id,
347            output: task.output,
348            signer_bitmap: task.signer_bitmap,
349            non_signer_indices: task.non_signer_indices,
350            aggregated_signature: sig_bytes,
351            aggregated_pubkey: pk_bytes,
352        })
353    }
354
355    /// Mark a task as submitted to chain
356    pub fn mark_submitted(&self, service_id: u64, call_id: u64) -> Result<(), ServiceError> {
357        self.state
358            .mark_submitted(service_id, call_id)
359            .map_err(|e| ServiceError::Other(e.to_string()))
360    }
361
362    /// Remove a task
363    pub fn remove_task(&self, service_id: u64, call_id: u64) -> bool {
364        self.state.remove_task(service_id, call_id)
365    }
366
367    /// Get task statistics
368    pub fn get_stats(&self) -> ServiceStats {
369        let counts = self.state.task_counts();
370        ServiceStats {
371            total_tasks: counts.total,
372            pending_tasks: counts.pending,
373            ready_tasks: counts.ready,
374            submitted_tasks: counts.submitted,
375            expired_tasks: counts.expired,
376        }
377    }
378
379    /// Manually trigger cleanup
380    pub fn cleanup(&self) -> usize {
381        self.state.cleanup()
382    }
383
384    /// Cleanup only expired tasks
385    pub fn cleanup_expired(&self) -> usize {
386        self.state.cleanup_expired()
387    }
388
389    /// Cleanup only submitted tasks
390    pub fn cleanup_submitted(&self) -> usize {
391        self.state.cleanup_submitted()
392    }
393}
394
395/// Create the message that operators sign
396///
397/// Format: serviceId (8 bytes BE) || callId (8 bytes BE) || keccak256(output)
398pub fn create_signing_message(service_id: u64, call_id: u64, output: &[u8]) -> Vec<u8> {
399    use alloy_primitives::keccak256;
400
401    let output_hash = keccak256(output);
402    let mut message = Vec::with_capacity(8 + 8 + 32);
403    message.extend_from_slice(&service_id.to_be_bytes());
404    message.extend_from_slice(&call_id.to_be_bytes());
405    message.extend_from_slice(output_hash.as_slice());
406    message
407}
408
409impl Default for AggregationService {
410    fn default() -> Self {
411        Self::new(ServiceConfig::default())
412    }
413}
414
415/// Handle for the cleanup worker
416pub struct CleanupWorkerHandle {
417    shutdown_tx: watch::Sender<bool>,
418    handle: tokio::task::JoinHandle<()>,
419}
420
421impl CleanupWorkerHandle {
422    /// Stop the cleanup worker
423    pub async fn stop(self) {
424        let _ = self.shutdown_tx.send(true);
425        let _ = self.handle.await;
426    }
427}
428
429/// Service statistics
430#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
431pub struct ServiceStats {
432    pub total_tasks: usize,
433    pub pending_tasks: usize,
434    pub ready_tasks: usize,
435    pub submitted_tasks: usize,
436    pub expired_tasks: usize,
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use alloy_primitives::keccak256;
443
444    #[test]
445    fn test_create_signing_message() {
446        let service_id = 1u64;
447        let call_id = 100u64;
448        let output = vec![1, 2, 3, 4];
449
450        let message = create_signing_message(service_id, call_id, &output);
451
452        // Message should be 48 bytes: 8 + 8 + 32
453        assert_eq!(message.len(), 48);
454
455        // Check service_id encoding (big-endian)
456        assert_eq!(&message[0..8], &service_id.to_be_bytes());
457
458        // Check call_id encoding (big-endian)
459        assert_eq!(&message[8..16], &call_id.to_be_bytes());
460
461        // Check output hash
462        let expected_hash = keccak256(&output);
463        assert_eq!(&message[16..48], expected_hash.as_slice());
464    }
465
466    #[test]
467    fn test_create_signing_message_deterministic() {
468        let msg1 = create_signing_message(1, 100, &[1, 2, 3]);
469        let msg2 = create_signing_message(1, 100, &[1, 2, 3]);
470        assert_eq!(msg1, msg2);
471    }
472
473    #[test]
474    fn test_create_signing_message_different_inputs() {
475        let msg1 = create_signing_message(1, 100, &[1, 2, 3]);
476        let msg2 = create_signing_message(2, 100, &[1, 2, 3]);
477        let msg3 = create_signing_message(1, 101, &[1, 2, 3]);
478        let msg4 = create_signing_message(1, 100, &[1, 2, 4]);
479
480        assert_ne!(msg1, msg2);
481        assert_ne!(msg1, msg3);
482        assert_ne!(msg1, msg4);
483    }
484
485    #[test]
486    fn test_service_config_default() {
487        let config = ServiceConfig::default();
488        assert!(config.verify_on_submit);
489        assert!(config.validate_output);
490        assert!(config.default_task_ttl.is_some());
491        assert!(config.cleanup_interval.is_some());
492        assert!(config.auto_cleanup_submitted);
493    }
494
495    #[test]
496    fn test_service_config_minimal() {
497        let config = ServiceConfig::minimal();
498        assert!(!config.verify_on_submit);
499        assert!(!config.validate_output);
500        assert!(config.default_task_ttl.is_none());
501        assert!(config.cleanup_interval.is_none());
502        assert!(!config.auto_cleanup_submitted);
503    }
504
505    #[test]
506    fn test_aggregation_service_init_task() {
507        let service = AggregationService::new(ServiceConfig::minimal());
508
509        assert!(service.init_task(1, 100, vec![1, 2, 3], 5, 3).is_ok());
510
511        // Duplicate should fail
512        let result = service.init_task(1, 100, vec![1, 2, 3], 5, 3);
513        assert!(result.is_err());
514    }
515
516    #[test]
517    fn test_aggregation_service_get_status_nonexistent() {
518        let service = AggregationService::default();
519
520        let status = service.get_status(1, 100);
521        assert!(!status.exists);
522        assert_eq!(status.signatures_collected, 0);
523    }
524
525    #[test]
526    fn test_aggregation_service_get_status_exists() {
527        let service = AggregationService::new(ServiceConfig::minimal());
528        service.init_task(1, 100, vec![], 5, 3).unwrap();
529
530        let status = service.get_status(1, 100);
531        assert!(status.exists);
532        assert_eq!(status.signatures_collected, 0);
533        assert_eq!(status.threshold_required, 3);
534        assert!(!status.threshold_met);
535        assert!(!status.submitted);
536    }
537
538    #[test]
539    fn test_aggregation_service_mark_submitted() {
540        let service = AggregationService::new(ServiceConfig::minimal());
541        service.init_task(1, 100, vec![], 5, 3).unwrap();
542
543        assert!(service.mark_submitted(1, 100).is_ok());
544
545        let status = service.get_status(1, 100);
546        assert!(status.submitted);
547    }
548
549    #[test]
550    fn test_aggregation_service_mark_submitted_nonexistent() {
551        let service = AggregationService::default();
552
553        let result = service.mark_submitted(1, 100);
554        assert!(result.is_err());
555    }
556
557    #[test]
558    fn test_aggregation_service_get_aggregated_result_nonexistent() {
559        let service = AggregationService::default();
560
561        let result = service.get_aggregated_result(1, 100);
562        assert!(result.is_none());
563    }
564
565    #[test]
566    fn test_aggregation_service_get_aggregated_result_threshold_not_met() {
567        let service = AggregationService::new(ServiceConfig::minimal());
568        service.init_task(1, 100, vec![], 5, 3).unwrap();
569
570        // No signatures submitted
571        let result = service.get_aggregated_result(1, 100);
572        assert!(result.is_none());
573    }
574
575    #[test]
576    fn test_aggregation_service_stats() {
577        let service = AggregationService::new(ServiceConfig::minimal());
578
579        service.init_task(1, 100, vec![], 5, 3).unwrap();
580        service.init_task(1, 101, vec![], 5, 3).unwrap();
581
582        let stats = service.get_stats();
583        assert_eq!(stats.total_tasks, 2);
584        assert_eq!(stats.pending_tasks, 2);
585        assert_eq!(stats.ready_tasks, 0);
586        assert_eq!(stats.submitted_tasks, 0);
587        assert_eq!(stats.expired_tasks, 0);
588    }
589
590    #[test]
591    fn test_aggregation_service_remove_task() {
592        let service = AggregationService::new(ServiceConfig::minimal());
593        service.init_task(1, 100, vec![], 5, 3).unwrap();
594
595        assert!(service.get_status(1, 100).exists);
596        assert!(service.remove_task(1, 100));
597        assert!(!service.get_status(1, 100).exists);
598    }
599}