blueprint_tangle_aggregation_svc/
persistence.rs1use crate::state::ThresholdType;
20use crate::types::TaskId;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::time::{Duration, SystemTime, UNIX_EPOCH};
24
25#[derive(Debug, thiserror::Error)]
27pub enum PersistenceError {
28 #[error("IO error: {0}")]
30 Io(#[from] std::io::Error),
31 #[error("Serialization error: {0}")]
33 Serialization(String),
34 #[error("Task not found: {0:?}")]
36 NotFound(TaskId),
37 #[error("Backend error: {0}")]
39 Backend(String),
40}
41
42pub type Result<T> = std::result::Result<T, PersistenceError>;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct PersistedTaskState {
48 pub service_id: u64,
50 pub call_id: u64,
52 #[serde(with = "hex_bytes")]
54 pub output: Vec<u8>,
55 pub operator_count: u32,
57 pub threshold_type: PersistedThresholdType,
59 pub signer_bitmap: String, pub signatures: HashMap<u32, String>,
63 pub public_keys: HashMap<u32, String>,
65 pub operator_stakes: HashMap<u32, u64>,
67 pub total_stake: u64,
69 pub submitted: bool,
71 pub created_at_ms: u64,
73 pub expires_at_ms: Option<u64>,
75}
76
77#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
79pub enum PersistedThresholdType {
80 Count(u32),
81 StakeWeighted(u32),
82}
83
84impl From<ThresholdType> for PersistedThresholdType {
85 fn from(t: ThresholdType) -> Self {
86 match t {
87 ThresholdType::Count(n) => PersistedThresholdType::Count(n),
88 ThresholdType::StakeWeighted(n) => PersistedThresholdType::StakeWeighted(n),
89 }
90 }
91}
92
93impl From<PersistedThresholdType> for ThresholdType {
94 fn from(t: PersistedThresholdType) -> Self {
95 match t {
96 PersistedThresholdType::Count(n) => ThresholdType::Count(n),
97 PersistedThresholdType::StakeWeighted(n) => ThresholdType::StakeWeighted(n),
98 }
99 }
100}
101
102pub trait PersistenceBackend: Send + Sync {
106 fn save_task(&self, task: &PersistedTaskState) -> Result<()>;
108
109 fn load_task(&self, task_id: &TaskId) -> Result<Option<PersistedTaskState>>;
111
112 fn delete_task(&self, task_id: &TaskId) -> Result<()>;
114
115 fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>>;
117
118 fn task_exists(&self, task_id: &TaskId) -> Result<bool> {
120 Ok(self.load_task(task_id)?.is_some())
121 }
122
123 fn flush(&self) -> Result<()> {
125 Ok(())
126 }
127}
128
129#[derive(Debug, Clone, Default)]
134pub struct NoPersistence;
135
136impl PersistenceBackend for NoPersistence {
137 fn save_task(&self, _task: &PersistedTaskState) -> Result<()> {
138 Ok(())
139 }
140
141 fn load_task(&self, _task_id: &TaskId) -> Result<Option<PersistedTaskState>> {
142 Ok(None)
143 }
144
145 fn delete_task(&self, _task_id: &TaskId) -> Result<()> {
146 Ok(())
147 }
148
149 fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>> {
150 Ok(Vec::new())
151 }
152}
153
154#[derive(Debug)]
159pub struct FilePersistence {
160 path: std::path::PathBuf,
161 lock: parking_lot::RwLock<()>,
162}
163
164impl FilePersistence {
165 pub fn new(path: impl Into<std::path::PathBuf>) -> Self {
167 Self {
168 path: path.into(),
169 lock: parking_lot::RwLock::new(()),
170 }
171 }
172
173 fn read_all(&self) -> Result<HashMap<String, PersistedTaskState>> {
174 let _guard = self.lock.read();
175
176 if !self.path.exists() {
177 return Ok(HashMap::new());
178 }
179
180 let contents = std::fs::read_to_string(&self.path)?;
181 if contents.is_empty() {
182 return Ok(HashMap::new());
183 }
184
185 serde_json::from_str(&contents).map_err(|e| PersistenceError::Serialization(e.to_string()))
186 }
187
188 fn write_all(&self, tasks: &HashMap<String, PersistedTaskState>) -> Result<()> {
189 let _guard = self.lock.write();
190
191 if let Some(parent) = self.path.parent() {
193 std::fs::create_dir_all(parent)?;
194 }
195
196 let contents = serde_json::to_string_pretty(tasks)
197 .map_err(|e| PersistenceError::Serialization(e.to_string()))?;
198
199 let temp_path = self.path.with_extension("tmp");
201 std::fs::write(&temp_path, contents)?;
202 std::fs::rename(&temp_path, &self.path)?;
203
204 Ok(())
205 }
206
207 fn task_key(task_id: &TaskId) -> String {
208 format!("{}:{}", task_id.service_id, task_id.call_id)
209 }
210}
211
212impl PersistenceBackend for FilePersistence {
213 fn save_task(&self, task: &PersistedTaskState) -> Result<()> {
214 let mut tasks = self.read_all()?;
215 let key = Self::task_key(&TaskId::new(task.service_id, task.call_id));
216 tasks.insert(key, task.clone());
217 self.write_all(&tasks)
218 }
219
220 fn load_task(&self, task_id: &TaskId) -> Result<Option<PersistedTaskState>> {
221 let tasks = self.read_all()?;
222 let key = Self::task_key(task_id);
223 Ok(tasks.get(&key).cloned())
224 }
225
226 fn delete_task(&self, task_id: &TaskId) -> Result<()> {
227 let mut tasks = self.read_all()?;
228 let key = Self::task_key(task_id);
229 tasks.remove(&key);
230 self.write_all(&tasks)
231 }
232
233 fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>> {
234 let tasks = self.read_all()?;
235 Ok(tasks.into_values().collect())
236 }
237
238 fn flush(&self) -> Result<()> {
239 Ok(())
241 }
242}
243
244pub fn now_millis() -> u64 {
246 SystemTime::now()
247 .duration_since(UNIX_EPOCH)
248 .unwrap_or_default()
249 .as_millis() as u64
250}
251
252pub fn remaining_duration(expires_at_ms: Option<u64>) -> Option<Duration> {
254 expires_at_ms.and_then(|expires| {
255 let now = now_millis();
256 if expires > now {
257 Some(Duration::from_millis(expires - now))
258 } else {
259 None
260 }
261 })
262}
263
264pub fn is_expired(expires_at_ms: Option<u64>) -> bool {
266 expires_at_ms
267 .map(|expires| now_millis() > expires)
268 .unwrap_or(false)
269}
270
271mod hex_bytes {
273 use serde::{Deserialize, Deserializer, Serializer};
274
275 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
276 where
277 S: Serializer,
278 {
279 serializer.serialize_str(&format!("0x{}", hex::encode(bytes)))
280 }
281
282 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
283 where
284 D: Deserializer<'de>,
285 {
286 let s = String::deserialize(deserializer)?;
287 let s = s.strip_prefix("0x").unwrap_or(&s);
288 hex::decode(s).map_err(serde::de::Error::custom)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use tempfile::NamedTempFile;
296
297 fn sample_task() -> PersistedTaskState {
298 PersistedTaskState {
299 service_id: 1,
300 call_id: 100,
301 output: vec![1, 2, 3, 4],
302 operator_count: 5,
303 threshold_type: PersistedThresholdType::Count(3),
304 signer_bitmap: "0x7".to_string(), signatures: HashMap::from([
306 (0, "0xabc123".to_string()),
307 (1, "0xdef456".to_string()),
308 (2, "0x789abc".to_string()),
309 ]),
310 public_keys: HashMap::from([
311 (0, "0xpk1".to_string()),
312 (1, "0xpk2".to_string()),
313 (2, "0xpk3".to_string()),
314 ]),
315 operator_stakes: HashMap::from([(0, 100), (1, 100), (2, 100), (3, 100), (4, 100)]),
316 total_stake: 500,
317 submitted: false,
318 created_at_ms: 1700000000000,
319 expires_at_ms: Some(1700001000000),
320 }
321 }
322
323 #[test]
324 fn test_no_persistence() {
325 let backend = NoPersistence;
326 let task = sample_task();
327 let task_id = TaskId::new(task.service_id, task.call_id);
328
329 assert!(backend.save_task(&task).is_ok());
331
332 assert!(backend.load_task(&task_id).unwrap().is_none());
334
335 assert!(backend.delete_task(&task_id).is_ok());
337
338 assert!(backend.load_all_tasks().unwrap().is_empty());
340 }
341
342 #[test]
343 fn test_file_persistence() {
344 let temp_file = NamedTempFile::new().unwrap();
345 let backend = FilePersistence::new(temp_file.path());
346
347 let task = sample_task();
348 let task_id = TaskId::new(task.service_id, task.call_id);
349
350 backend.save_task(&task).unwrap();
352
353 let loaded = backend.load_task(&task_id).unwrap().unwrap();
355 assert_eq!(loaded.service_id, task.service_id);
356 assert_eq!(loaded.call_id, task.call_id);
357 assert_eq!(loaded.output, task.output);
358 assert_eq!(loaded.operator_count, task.operator_count);
359 assert_eq!(loaded.signatures.len(), 3);
360
361 let all = backend.load_all_tasks().unwrap();
363 assert_eq!(all.len(), 1);
364
365 backend.delete_task(&task_id).unwrap();
367 assert!(backend.load_task(&task_id).unwrap().is_none());
368 }
369
370 #[test]
371 fn test_file_persistence_multiple_tasks() {
372 let temp_file = NamedTempFile::new().unwrap();
373 let backend = FilePersistence::new(temp_file.path());
374
375 for i in 0..5 {
377 let mut task = sample_task();
378 task.call_id = 100 + i;
379 backend.save_task(&task).unwrap();
380 }
381
382 let all = backend.load_all_tasks().unwrap();
383 assert_eq!(all.len(), 5);
384
385 backend.delete_task(&TaskId::new(1, 102)).unwrap();
387 let all = backend.load_all_tasks().unwrap();
388 assert_eq!(all.len(), 4);
389 }
390
391 #[test]
392 fn test_threshold_type_conversion() {
393 let count = ThresholdType::Count(5);
394 let persisted: PersistedThresholdType = count.into();
395 let recovered: ThresholdType = persisted.into();
396 assert_eq!(count, recovered);
397
398 let stake = ThresholdType::StakeWeighted(6700);
399 let persisted: PersistedThresholdType = stake.into();
400 let recovered: ThresholdType = persisted.into();
401 assert_eq!(stake, recovered);
402 }
403
404 #[test]
405 fn test_time_helpers() {
406 let now = now_millis();
407 assert!(now > 0);
408
409 let future = Some(now + 10000);
411 assert!(!is_expired(future));
412 assert!(remaining_duration(future).is_some());
413
414 let past = Some(now - 10000);
416 assert!(is_expired(past));
417 assert!(remaining_duration(past).is_none());
418
419 assert!(!is_expired(None));
421 assert!(remaining_duration(None).is_none());
422 }
423
424 #[test]
425 fn test_serialization_roundtrip() {
426 let task = sample_task();
427 let json = serde_json::to_string(&task).unwrap();
428 let recovered: PersistedTaskState = serde_json::from_str(&json).unwrap();
429
430 assert_eq!(task.service_id, recovered.service_id);
431 assert_eq!(task.call_id, recovered.call_id);
432 assert_eq!(task.output, recovered.output);
433 }
434}