Skip to main content

blueprint_tangle_aggregation_svc/
persistence.rs

1//! Persistence layer for aggregation state
2//!
3//! This module provides traits and implementations for persisting aggregation state
4//! across service restarts. The default in-memory implementation provides no persistence,
5//! while optional backends can be enabled for production use.
6//!
7//! ## Usage
8//!
9//! ```rust,ignore
10//! use blueprint_tangle_aggregation_svc::persistence::{PersistenceBackend, FilePersistence};
11//!
12//! // Create file-based persistence
13//! let persistence = FilePersistence::new("/var/lib/aggregation/state.json");
14//!
15//! // Create service with persistence
16//! let service = AggregationService::with_persistence(config, persistence);
17//! ```
18
19use crate::state::ThresholdType;
20use crate::types::TaskId;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::time::{Duration, SystemTime, UNIX_EPOCH};
24
25/// Error type for persistence operations
26#[derive(Debug, thiserror::Error)]
27pub enum PersistenceError {
28    /// IO error
29    #[error("IO error: {0}")]
30    Io(#[from] std::io::Error),
31    /// Serialization error
32    #[error("Serialization error: {0}")]
33    Serialization(String),
34    /// Task not found
35    #[error("Task not found: {0:?}")]
36    NotFound(TaskId),
37    /// Backend-specific error
38    #[error("Backend error: {0}")]
39    Backend(String),
40}
41
42/// Result type for persistence operations
43pub type Result<T> = std::result::Result<T, PersistenceError>;
44
45/// Serializable task state for persistence
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct PersistedTaskState {
48    /// Service ID
49    pub service_id: u64,
50    /// Call ID
51    pub call_id: u64,
52    /// The output being signed
53    #[serde(with = "hex_bytes")]
54    pub output: Vec<u8>,
55    /// Number of operators in the service
56    pub operator_count: u32,
57    /// Threshold type
58    pub threshold_type: PersistedThresholdType,
59    /// Bitmap of which operators have signed
60    pub signer_bitmap: String, // U256 as hex string
61    /// Collected signatures indexed by operator index (hex encoded)
62    pub signatures: HashMap<u32, String>,
63    /// Collected public keys indexed by operator index (hex encoded)
64    pub public_keys: HashMap<u32, String>,
65    /// Operator stakes for stake-weighted thresholds
66    pub operator_stakes: HashMap<u32, u64>,
67    /// Total stake of all operators
68    pub total_stake: u64,
69    /// Whether this task has been submitted to chain
70    pub submitted: bool,
71    /// When this task was created (unix timestamp millis)
72    pub created_at_ms: u64,
73    /// When this task expires (unix timestamp millis, None = never)
74    pub expires_at_ms: Option<u64>,
75}
76
77/// Serializable threshold type
78#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
79pub enum PersistedThresholdType {
80    Count(u32),
81    StakeWeighted(u32),
82}
83
84impl From<ThresholdType> for PersistedThresholdType {
85    fn from(t: ThresholdType) -> Self {
86        match t {
87            ThresholdType::Count(n) => PersistedThresholdType::Count(n),
88            ThresholdType::StakeWeighted(n) => PersistedThresholdType::StakeWeighted(n),
89        }
90    }
91}
92
93impl From<PersistedThresholdType> for ThresholdType {
94    fn from(t: PersistedThresholdType) -> Self {
95        match t {
96            PersistedThresholdType::Count(n) => ThresholdType::Count(n),
97            PersistedThresholdType::StakeWeighted(n) => ThresholdType::StakeWeighted(n),
98        }
99    }
100}
101
102/// Trait for persistence backends
103///
104/// Implement this trait to provide custom storage for aggregation state.
105pub trait PersistenceBackend: Send + Sync {
106    /// Save a task to persistent storage
107    fn save_task(&self, task: &PersistedTaskState) -> Result<()>;
108
109    /// Load a task from persistent storage
110    fn load_task(&self, task_id: &TaskId) -> Result<Option<PersistedTaskState>>;
111
112    /// Delete a task from persistent storage
113    fn delete_task(&self, task_id: &TaskId) -> Result<()>;
114
115    /// Load all tasks from persistent storage
116    fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>>;
117
118    /// Check if a task exists
119    fn task_exists(&self, task_id: &TaskId) -> Result<bool> {
120        Ok(self.load_task(task_id)?.is_some())
121    }
122
123    /// Flush any buffered writes (optional, default is no-op)
124    fn flush(&self) -> Result<()> {
125        Ok(())
126    }
127}
128
129/// No-op persistence backend (in-memory only)
130///
131/// This is the default backend that provides no persistence.
132/// Tasks are lost on service restart.
133#[derive(Debug, Clone, Default)]
134pub struct NoPersistence;
135
136impl PersistenceBackend for NoPersistence {
137    fn save_task(&self, _task: &PersistedTaskState) -> Result<()> {
138        Ok(())
139    }
140
141    fn load_task(&self, _task_id: &TaskId) -> Result<Option<PersistedTaskState>> {
142        Ok(None)
143    }
144
145    fn delete_task(&self, _task_id: &TaskId) -> Result<()> {
146        Ok(())
147    }
148
149    fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>> {
150        Ok(Vec::new())
151    }
152}
153
154/// File-based persistence backend
155///
156/// Stores all tasks in a single JSON file. Suitable for small deployments.
157/// For high-throughput scenarios, consider using a database backend.
158#[derive(Debug)]
159pub struct FilePersistence {
160    path: std::path::PathBuf,
161    lock: parking_lot::RwLock<()>,
162}
163
164impl FilePersistence {
165    /// Create a new file persistence backend
166    pub fn new(path: impl Into<std::path::PathBuf>) -> Self {
167        Self {
168            path: path.into(),
169            lock: parking_lot::RwLock::new(()),
170        }
171    }
172
173    fn read_all(&self) -> Result<HashMap<String, PersistedTaskState>> {
174        let _guard = self.lock.read();
175
176        if !self.path.exists() {
177            return Ok(HashMap::new());
178        }
179
180        let contents = std::fs::read_to_string(&self.path)?;
181        if contents.is_empty() {
182            return Ok(HashMap::new());
183        }
184
185        serde_json::from_str(&contents).map_err(|e| PersistenceError::Serialization(e.to_string()))
186    }
187
188    fn write_all(&self, tasks: &HashMap<String, PersistedTaskState>) -> Result<()> {
189        let _guard = self.lock.write();
190
191        // Create parent directory if needed
192        if let Some(parent) = self.path.parent() {
193            std::fs::create_dir_all(parent)?;
194        }
195
196        let contents = serde_json::to_string_pretty(tasks)
197            .map_err(|e| PersistenceError::Serialization(e.to_string()))?;
198
199        // Write to temp file first, then rename for atomicity
200        let temp_path = self.path.with_extension("tmp");
201        std::fs::write(&temp_path, contents)?;
202        std::fs::rename(&temp_path, &self.path)?;
203
204        Ok(())
205    }
206
207    fn task_key(task_id: &TaskId) -> String {
208        format!("{}:{}", task_id.service_id, task_id.call_id)
209    }
210}
211
212impl PersistenceBackend for FilePersistence {
213    fn save_task(&self, task: &PersistedTaskState) -> Result<()> {
214        let mut tasks = self.read_all()?;
215        let key = Self::task_key(&TaskId::new(task.service_id, task.call_id));
216        tasks.insert(key, task.clone());
217        self.write_all(&tasks)
218    }
219
220    fn load_task(&self, task_id: &TaskId) -> Result<Option<PersistedTaskState>> {
221        let tasks = self.read_all()?;
222        let key = Self::task_key(task_id);
223        Ok(tasks.get(&key).cloned())
224    }
225
226    fn delete_task(&self, task_id: &TaskId) -> Result<()> {
227        let mut tasks = self.read_all()?;
228        let key = Self::task_key(task_id);
229        tasks.remove(&key);
230        self.write_all(&tasks)
231    }
232
233    fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>> {
234        let tasks = self.read_all()?;
235        Ok(tasks.into_values().collect())
236    }
237
238    fn flush(&self) -> Result<()> {
239        // File writes are already flushed on each operation
240        Ok(())
241    }
242}
243
244/// Helper to get current timestamp in milliseconds
245pub fn now_millis() -> u64 {
246    SystemTime::now()
247        .duration_since(UNIX_EPOCH)
248        .unwrap_or_default()
249        .as_millis() as u64
250}
251
252/// Helper to convert timestamp to remaining duration
253pub fn remaining_duration(expires_at_ms: Option<u64>) -> Option<Duration> {
254    expires_at_ms.and_then(|expires| {
255        let now = now_millis();
256        if expires > now {
257            Some(Duration::from_millis(expires - now))
258        } else {
259            None
260        }
261    })
262}
263
264/// Helper to check if expired
265pub fn is_expired(expires_at_ms: Option<u64>) -> bool {
266    expires_at_ms
267        .map(|expires| now_millis() > expires)
268        .unwrap_or(false)
269}
270
271/// Hex encoding for byte arrays in persistence
272mod hex_bytes {
273    use serde::{Deserialize, Deserializer, Serializer};
274
275    pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
276    where
277        S: Serializer,
278    {
279        serializer.serialize_str(&format!("0x{}", hex::encode(bytes)))
280    }
281
282    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
283    where
284        D: Deserializer<'de>,
285    {
286        let s = String::deserialize(deserializer)?;
287        let s = s.strip_prefix("0x").unwrap_or(&s);
288        hex::decode(s).map_err(serde::de::Error::custom)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use tempfile::NamedTempFile;
296
297    fn sample_task() -> PersistedTaskState {
298        PersistedTaskState {
299            service_id: 1,
300            call_id: 100,
301            output: vec![1, 2, 3, 4],
302            operator_count: 5,
303            threshold_type: PersistedThresholdType::Count(3),
304            signer_bitmap: "0x7".to_string(), // operators 0, 1, 2 signed
305            signatures: HashMap::from([
306                (0, "0xabc123".to_string()),
307                (1, "0xdef456".to_string()),
308                (2, "0x789abc".to_string()),
309            ]),
310            public_keys: HashMap::from([
311                (0, "0xpk1".to_string()),
312                (1, "0xpk2".to_string()),
313                (2, "0xpk3".to_string()),
314            ]),
315            operator_stakes: HashMap::from([(0, 100), (1, 100), (2, 100), (3, 100), (4, 100)]),
316            total_stake: 500,
317            submitted: false,
318            created_at_ms: 1700000000000,
319            expires_at_ms: Some(1700001000000),
320        }
321    }
322
323    #[test]
324    fn test_no_persistence() {
325        let backend = NoPersistence;
326        let task = sample_task();
327        let task_id = TaskId::new(task.service_id, task.call_id);
328
329        // Save should succeed but not persist
330        assert!(backend.save_task(&task).is_ok());
331
332        // Load should return None
333        assert!(backend.load_task(&task_id).unwrap().is_none());
334
335        // Delete should succeed
336        assert!(backend.delete_task(&task_id).is_ok());
337
338        // Load all should return empty
339        assert!(backend.load_all_tasks().unwrap().is_empty());
340    }
341
342    #[test]
343    fn test_file_persistence() {
344        let temp_file = NamedTempFile::new().unwrap();
345        let backend = FilePersistence::new(temp_file.path());
346
347        let task = sample_task();
348        let task_id = TaskId::new(task.service_id, task.call_id);
349
350        // Save task
351        backend.save_task(&task).unwrap();
352
353        // Load task
354        let loaded = backend.load_task(&task_id).unwrap().unwrap();
355        assert_eq!(loaded.service_id, task.service_id);
356        assert_eq!(loaded.call_id, task.call_id);
357        assert_eq!(loaded.output, task.output);
358        assert_eq!(loaded.operator_count, task.operator_count);
359        assert_eq!(loaded.signatures.len(), 3);
360
361        // Load all
362        let all = backend.load_all_tasks().unwrap();
363        assert_eq!(all.len(), 1);
364
365        // Delete task
366        backend.delete_task(&task_id).unwrap();
367        assert!(backend.load_task(&task_id).unwrap().is_none());
368    }
369
370    #[test]
371    fn test_file_persistence_multiple_tasks() {
372        let temp_file = NamedTempFile::new().unwrap();
373        let backend = FilePersistence::new(temp_file.path());
374
375        // Create and save multiple tasks
376        for i in 0..5 {
377            let mut task = sample_task();
378            task.call_id = 100 + i;
379            backend.save_task(&task).unwrap();
380        }
381
382        let all = backend.load_all_tasks().unwrap();
383        assert_eq!(all.len(), 5);
384
385        // Delete one
386        backend.delete_task(&TaskId::new(1, 102)).unwrap();
387        let all = backend.load_all_tasks().unwrap();
388        assert_eq!(all.len(), 4);
389    }
390
391    #[test]
392    fn test_threshold_type_conversion() {
393        let count = ThresholdType::Count(5);
394        let persisted: PersistedThresholdType = count.into();
395        let recovered: ThresholdType = persisted.into();
396        assert_eq!(count, recovered);
397
398        let stake = ThresholdType::StakeWeighted(6700);
399        let persisted: PersistedThresholdType = stake.into();
400        let recovered: ThresholdType = persisted.into();
401        assert_eq!(stake, recovered);
402    }
403
404    #[test]
405    fn test_time_helpers() {
406        let now = now_millis();
407        assert!(now > 0);
408
409        // Not expired
410        let future = Some(now + 10000);
411        assert!(!is_expired(future));
412        assert!(remaining_duration(future).is_some());
413
414        // Expired
415        let past = Some(now - 10000);
416        assert!(is_expired(past));
417        assert!(remaining_duration(past).is_none());
418
419        // Never expires
420        assert!(!is_expired(None));
421        assert!(remaining_duration(None).is_none());
422    }
423
424    #[test]
425    fn test_serialization_roundtrip() {
426        let task = sample_task();
427        let json = serde_json::to_string(&task).unwrap();
428        let recovered: PersistedTaskState = serde_json::from_str(&json).unwrap();
429
430        assert_eq!(task.service_id, recovered.service_id);
431        assert_eq!(task.call_id, recovered.call_id);
432        assert_eq!(task.output, recovered.output);
433    }
434}