1use crate::state::{AggregationState, TaskConfig, ThresholdType};
4use crate::types::*;
5use alloy_primitives::U256;
6use ark_serialize::CanonicalDeserialize;
7use blueprint_crypto_bn254::{ArkBlsBn254, ArkBlsBn254Public, ArkBlsBn254Signature};
8use blueprint_crypto_core::{aggregation::AggregatableSignature, KeyType};
9use std::sync::Arc;
10use std::time::Duration;
11use thiserror::Error;
12use tokio::sync::watch;
13use tracing::{debug, info, warn};
14
15#[derive(Debug, Error)]
17pub enum ServiceError {
18 #[error("Task not found")]
19 TaskNotFound,
20 #[error("Task already exists")]
21 TaskAlreadyExists,
22 #[error("Task has expired")]
23 TaskExpired,
24 #[error("Invalid signature format")]
25 InvalidSignature,
26 #[error("Invalid public key format")]
27 InvalidPublicKey,
28 #[error("Signature verification failed")]
29 VerificationFailed,
30 #[error("Output mismatch: submitted output does not match task output")]
31 OutputMismatch,
32 #[error("Aggregation failed: {0}")]
33 AggregationFailed(String),
34 #[error("{0}")]
35 Other(String),
36}
37
38#[derive(Debug, Clone)]
40pub struct ServiceConfig {
41 pub verify_on_submit: bool,
43 pub validate_output: bool,
45 pub default_task_ttl: Option<Duration>,
47 pub cleanup_interval: Option<Duration>,
49 pub auto_cleanup_submitted: bool,
51}
52
53impl Default for ServiceConfig {
54 fn default() -> Self {
55 Self {
56 verify_on_submit: true,
57 validate_output: true,
58 default_task_ttl: Some(Duration::from_secs(3600)), cleanup_interval: Some(Duration::from_secs(60)), auto_cleanup_submitted: true,
61 }
62 }
63}
64
65impl ServiceConfig {
66 pub fn minimal() -> Self {
68 Self {
69 verify_on_submit: false,
70 validate_output: false,
71 default_task_ttl: None,
72 cleanup_interval: None,
73 auto_cleanup_submitted: false,
74 }
75 }
76}
77
78#[derive(Debug)]
80pub struct AggregationService {
81 state: AggregationState,
82 config: ServiceConfig,
83}
84
85impl AggregationService {
86 pub fn new(config: ServiceConfig) -> Self {
88 Self {
89 state: AggregationState::new(),
90 config,
91 }
92 }
93
94 pub fn new_shared(config: ServiceConfig) -> Arc<Self> {
96 Arc::new(Self::new(config))
97 }
98
99 pub fn start_cleanup_worker(self: &Arc<Self>) -> Option<CleanupWorkerHandle> {
102 let interval = self.config.cleanup_interval?;
103 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
104
105 let service = Arc::clone(self);
106
107 let handle = tokio::spawn(async move {
108 let mut interval_timer = tokio::time::interval(interval);
109
110 loop {
111 tokio::select! {
112 _ = interval_timer.tick() => {
113 let removed = if service.config.auto_cleanup_submitted {
114 service.state.cleanup()
115 } else {
116 service.state.cleanup_expired()
117 };
118 if removed > 0 {
119 debug!(removed, "Cleaned up tasks");
120 }
121 }
122 _ = shutdown_rx.changed() => {
123 if *shutdown_rx.borrow() {
124 info!("Cleanup worker shutting down");
125 break;
126 }
127 }
128 }
129 }
130 });
131
132 Some(CleanupWorkerHandle {
133 shutdown_tx,
134 handle,
135 })
136 }
137
138 pub fn init_task(
140 &self,
141 service_id: u64,
142 call_id: u64,
143 output: Vec<u8>,
144 operator_count: u32,
145 threshold: u32,
146 ) -> Result<(), ServiceError> {
147 self.init_task_with_config(
148 service_id,
149 call_id,
150 output,
151 operator_count,
152 TaskConfig {
153 threshold_type: ThresholdType::Count(threshold),
154 ttl: self.config.default_task_ttl,
155 ..Default::default()
156 },
157 )
158 }
159
160 pub fn init_task_with_config(
162 &self,
163 service_id: u64,
164 call_id: u64,
165 output: Vec<u8>,
166 operator_count: u32,
167 config: TaskConfig,
168 ) -> Result<(), ServiceError> {
169 info!(
170 service_id,
171 call_id,
172 operator_count,
173 ?config.threshold_type,
174 "Initializing aggregation task"
175 );
176
177 self.state
178 .init_task_with_config(service_id, call_id, output, operator_count, config)
179 .map_err(|e| ServiceError::Other(e.to_string()))
180 }
181
182 pub fn submit_signature(
184 &self,
185 req: SubmitSignatureRequest,
186 ) -> Result<SubmitSignatureResponse, ServiceError> {
187 debug!(
188 service_id = req.service_id,
189 call_id = req.call_id,
190 operator_index = req.operator_index,
191 "Received signature submission"
192 );
193
194 if self.config.validate_output {
196 let expected_output = self
197 .state
198 .get_task_output(req.service_id, req.call_id)
199 .ok_or(ServiceError::TaskNotFound)?;
200
201 if req.output != expected_output {
202 warn!(
203 service_id = req.service_id,
204 call_id = req.call_id,
205 operator_index = req.operator_index,
206 "Output mismatch"
207 );
208 return Err(ServiceError::OutputMismatch);
209 }
210 }
211
212 let signature: ArkBlsBn254Signature = ArkBlsBn254Signature(
214 ark_bn254::G1Affine::deserialize_compressed(&req.signature[..])
215 .map_err(|_| ServiceError::InvalidSignature)?,
216 );
217
218 let public_key: ArkBlsBn254Public = ArkBlsBn254Public(
220 ark_bn254::G2Affine::deserialize_compressed(&req.public_key[..])
221 .map_err(|_| ServiceError::InvalidPublicKey)?,
222 );
223
224 if self.config.verify_on_submit {
226 let message = create_signing_message(req.service_id, req.call_id, &req.output);
228
229 if !ArkBlsBn254::verify(&public_key, &message, &signature) {
230 warn!(
231 service_id = req.service_id,
232 call_id = req.call_id,
233 operator_index = req.operator_index,
234 "Signature verification failed"
235 );
236 return Err(ServiceError::VerificationFailed);
237 }
238 }
239
240 let status = self
242 .state
243 .get_status(req.service_id, req.call_id)
244 .ok_or(ServiceError::TaskNotFound)?;
245
246 if status.is_expired {
247 return Err(ServiceError::TaskExpired);
248 }
249
250 let (count, threshold_met) = self
252 .state
253 .submit_signature(
254 req.service_id,
255 req.call_id,
256 req.operator_index,
257 signature,
258 public_key,
259 )
260 .map_err(|e| ServiceError::Other(e.to_string()))?;
261
262 info!(
263 service_id = req.service_id,
264 call_id = req.call_id,
265 operator_index = req.operator_index,
266 signatures_collected = count,
267 threshold_met,
268 "Signature accepted"
269 );
270
271 Ok(SubmitSignatureResponse {
272 accepted: true,
273 signatures_collected: count,
274 threshold_required: status.threshold_required,
275 threshold_met,
276 error: None,
277 })
278 }
279
280 pub fn get_status(&self, service_id: u64, call_id: u64) -> GetStatusResponse {
282 match self.state.get_status(service_id, call_id) {
283 Some(status) => GetStatusResponse {
284 exists: true,
285 signatures_collected: status.signatures_collected,
286 threshold_required: status.threshold_required,
287 threshold_met: status.threshold_met,
288 signer_bitmap: status.signer_bitmap,
289 signed_stake_bps: Some(status.signed_stake_bps),
290 submitted: status.submitted,
291 is_expired: Some(status.is_expired),
292 time_remaining_secs: status.time_remaining_secs,
293 },
294 None => GetStatusResponse {
295 exists: false,
296 signatures_collected: 0,
297 threshold_required: 0,
298 threshold_met: false,
299 signer_bitmap: U256::ZERO,
300 signed_stake_bps: None,
301 submitted: false,
302 is_expired: None,
303 time_remaining_secs: None,
304 },
305 }
306 }
307
308 pub fn get_aggregated_result(
310 &self,
311 service_id: u64,
312 call_id: u64,
313 ) -> Option<AggregatedResultResponse> {
314 let task = self.state.get_for_aggregation(service_id, call_id)?;
315
316 let (agg_sig, agg_pk) = ArkBlsBn254::aggregate(&task.signatures, &task.public_keys)
318 .map_err(|e| {
319 warn!(
320 service_id,
321 call_id,
322 error = %e,
323 "Aggregation failed"
324 );
325 e
326 })
327 .ok()?;
328
329 let mut sig_bytes = Vec::new();
331 ark_serialize::CanonicalSerialize::serialize_compressed(&agg_sig.0, &mut sig_bytes).ok()?;
332
333 let mut pk_bytes = Vec::new();
334 ark_serialize::CanonicalSerialize::serialize_compressed(&agg_pk.0, &mut pk_bytes).ok()?;
335
336 info!(
337 service_id,
338 call_id,
339 signers = task.signatures.len(),
340 non_signers = task.non_signer_indices.len(),
341 "Returning aggregated result"
342 );
343
344 Some(AggregatedResultResponse {
345 service_id: task.service_id,
346 call_id: task.call_id,
347 output: task.output,
348 signer_bitmap: task.signer_bitmap,
349 non_signer_indices: task.non_signer_indices,
350 aggregated_signature: sig_bytes,
351 aggregated_pubkey: pk_bytes,
352 })
353 }
354
355 pub fn mark_submitted(&self, service_id: u64, call_id: u64) -> Result<(), ServiceError> {
357 self.state
358 .mark_submitted(service_id, call_id)
359 .map_err(|e| ServiceError::Other(e.to_string()))
360 }
361
362 pub fn remove_task(&self, service_id: u64, call_id: u64) -> bool {
364 self.state.remove_task(service_id, call_id)
365 }
366
367 pub fn get_stats(&self) -> ServiceStats {
369 let counts = self.state.task_counts();
370 ServiceStats {
371 total_tasks: counts.total,
372 pending_tasks: counts.pending,
373 ready_tasks: counts.ready,
374 submitted_tasks: counts.submitted,
375 expired_tasks: counts.expired,
376 }
377 }
378
379 pub fn cleanup(&self) -> usize {
381 self.state.cleanup()
382 }
383
384 pub fn cleanup_expired(&self) -> usize {
386 self.state.cleanup_expired()
387 }
388
389 pub fn cleanup_submitted(&self) -> usize {
391 self.state.cleanup_submitted()
392 }
393}
394
395pub fn create_signing_message(service_id: u64, call_id: u64, output: &[u8]) -> Vec<u8> {
399 use alloy_primitives::keccak256;
400
401 let output_hash = keccak256(output);
402 let mut message = Vec::with_capacity(8 + 8 + 32);
403 message.extend_from_slice(&service_id.to_be_bytes());
404 message.extend_from_slice(&call_id.to_be_bytes());
405 message.extend_from_slice(output_hash.as_slice());
406 message
407}
408
409impl Default for AggregationService {
410 fn default() -> Self {
411 Self::new(ServiceConfig::default())
412 }
413}
414
415pub struct CleanupWorkerHandle {
417 shutdown_tx: watch::Sender<bool>,
418 handle: tokio::task::JoinHandle<()>,
419}
420
421impl CleanupWorkerHandle {
422 pub async fn stop(self) {
424 let _ = self.shutdown_tx.send(true);
425 let _ = self.handle.await;
426 }
427}
428
429#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
431pub struct ServiceStats {
432 pub total_tasks: usize,
433 pub pending_tasks: usize,
434 pub ready_tasks: usize,
435 pub submitted_tasks: usize,
436 pub expired_tasks: usize,
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use alloy_primitives::keccak256;
443
444 #[test]
445 fn test_create_signing_message() {
446 let service_id = 1u64;
447 let call_id = 100u64;
448 let output = vec![1, 2, 3, 4];
449
450 let message = create_signing_message(service_id, call_id, &output);
451
452 assert_eq!(message.len(), 48);
454
455 assert_eq!(&message[0..8], &service_id.to_be_bytes());
457
458 assert_eq!(&message[8..16], &call_id.to_be_bytes());
460
461 let expected_hash = keccak256(&output);
463 assert_eq!(&message[16..48], expected_hash.as_slice());
464 }
465
466 #[test]
467 fn test_create_signing_message_deterministic() {
468 let msg1 = create_signing_message(1, 100, &[1, 2, 3]);
469 let msg2 = create_signing_message(1, 100, &[1, 2, 3]);
470 assert_eq!(msg1, msg2);
471 }
472
473 #[test]
474 fn test_create_signing_message_different_inputs() {
475 let msg1 = create_signing_message(1, 100, &[1, 2, 3]);
476 let msg2 = create_signing_message(2, 100, &[1, 2, 3]);
477 let msg3 = create_signing_message(1, 101, &[1, 2, 3]);
478 let msg4 = create_signing_message(1, 100, &[1, 2, 4]);
479
480 assert_ne!(msg1, msg2);
481 assert_ne!(msg1, msg3);
482 assert_ne!(msg1, msg4);
483 }
484
485 #[test]
486 fn test_service_config_default() {
487 let config = ServiceConfig::default();
488 assert!(config.verify_on_submit);
489 assert!(config.validate_output);
490 assert!(config.default_task_ttl.is_some());
491 assert!(config.cleanup_interval.is_some());
492 assert!(config.auto_cleanup_submitted);
493 }
494
495 #[test]
496 fn test_service_config_minimal() {
497 let config = ServiceConfig::minimal();
498 assert!(!config.verify_on_submit);
499 assert!(!config.validate_output);
500 assert!(config.default_task_ttl.is_none());
501 assert!(config.cleanup_interval.is_none());
502 assert!(!config.auto_cleanup_submitted);
503 }
504
505 #[test]
506 fn test_aggregation_service_init_task() {
507 let service = AggregationService::new(ServiceConfig::minimal());
508
509 assert!(service.init_task(1, 100, vec![1, 2, 3], 5, 3).is_ok());
510
511 let result = service.init_task(1, 100, vec![1, 2, 3], 5, 3);
513 assert!(result.is_err());
514 }
515
516 #[test]
517 fn test_aggregation_service_get_status_nonexistent() {
518 let service = AggregationService::default();
519
520 let status = service.get_status(1, 100);
521 assert!(!status.exists);
522 assert_eq!(status.signatures_collected, 0);
523 }
524
525 #[test]
526 fn test_aggregation_service_get_status_exists() {
527 let service = AggregationService::new(ServiceConfig::minimal());
528 service.init_task(1, 100, vec![], 5, 3).unwrap();
529
530 let status = service.get_status(1, 100);
531 assert!(status.exists);
532 assert_eq!(status.signatures_collected, 0);
533 assert_eq!(status.threshold_required, 3);
534 assert!(!status.threshold_met);
535 assert!(!status.submitted);
536 }
537
538 #[test]
539 fn test_aggregation_service_mark_submitted() {
540 let service = AggregationService::new(ServiceConfig::minimal());
541 service.init_task(1, 100, vec![], 5, 3).unwrap();
542
543 assert!(service.mark_submitted(1, 100).is_ok());
544
545 let status = service.get_status(1, 100);
546 assert!(status.submitted);
547 }
548
549 #[test]
550 fn test_aggregation_service_mark_submitted_nonexistent() {
551 let service = AggregationService::default();
552
553 let result = service.mark_submitted(1, 100);
554 assert!(result.is_err());
555 }
556
557 #[test]
558 fn test_aggregation_service_get_aggregated_result_nonexistent() {
559 let service = AggregationService::default();
560
561 let result = service.get_aggregated_result(1, 100);
562 assert!(result.is_none());
563 }
564
565 #[test]
566 fn test_aggregation_service_get_aggregated_result_threshold_not_met() {
567 let service = AggregationService::new(ServiceConfig::minimal());
568 service.init_task(1, 100, vec![], 5, 3).unwrap();
569
570 let result = service.get_aggregated_result(1, 100);
572 assert!(result.is_none());
573 }
574
575 #[test]
576 fn test_aggregation_service_stats() {
577 let service = AggregationService::new(ServiceConfig::minimal());
578
579 service.init_task(1, 100, vec![], 5, 3).unwrap();
580 service.init_task(1, 101, vec![], 5, 3).unwrap();
581
582 let stats = service.get_stats();
583 assert_eq!(stats.total_tasks, 2);
584 assert_eq!(stats.pending_tasks, 2);
585 assert_eq!(stats.ready_tasks, 0);
586 assert_eq!(stats.submitted_tasks, 0);
587 assert_eq!(stats.expired_tasks, 0);
588 }
589
590 #[test]
591 fn test_aggregation_service_remove_task() {
592 let service = AggregationService::new(ServiceConfig::minimal());
593 service.init_task(1, 100, vec![], 5, 3).unwrap();
594
595 assert!(service.get_status(1, 100).exists);
596 assert!(service.remove_task(1, 100));
597 assert!(!service.get_status(1, 100).exists);
598 }
599}