use crate::state::ThresholdType;
use crate::types::TaskId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, thiserror::Error)]
pub enum PersistenceError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Task not found: {0:?}")]
NotFound(TaskId),
#[error("Backend error: {0}")]
Backend(String),
}
pub type Result<T> = std::result::Result<T, PersistenceError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedTaskState {
pub service_id: u64,
pub call_id: u64,
#[serde(with = "hex_bytes")]
pub output: Vec<u8>,
pub operator_count: u32,
pub threshold_type: PersistedThresholdType,
pub signer_bitmap: String, pub signatures: HashMap<u32, String>,
pub public_keys: HashMap<u32, String>,
pub operator_stakes: HashMap<u32, u64>,
pub total_stake: u64,
pub submitted: bool,
pub created_at_ms: u64,
pub expires_at_ms: Option<u64>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum PersistedThresholdType {
Count(u32),
StakeWeighted(u32),
}
impl From<ThresholdType> for PersistedThresholdType {
fn from(t: ThresholdType) -> Self {
match t {
ThresholdType::Count(n) => PersistedThresholdType::Count(n),
ThresholdType::StakeWeighted(n) => PersistedThresholdType::StakeWeighted(n),
}
}
}
impl From<PersistedThresholdType> for ThresholdType {
fn from(t: PersistedThresholdType) -> Self {
match t {
PersistedThresholdType::Count(n) => ThresholdType::Count(n),
PersistedThresholdType::StakeWeighted(n) => ThresholdType::StakeWeighted(n),
}
}
}
pub trait PersistenceBackend: Send + Sync {
fn save_task(&self, task: &PersistedTaskState) -> Result<()>;
fn load_task(&self, task_id: &TaskId) -> Result<Option<PersistedTaskState>>;
fn delete_task(&self, task_id: &TaskId) -> Result<()>;
fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>>;
fn task_exists(&self, task_id: &TaskId) -> Result<bool> {
Ok(self.load_task(task_id)?.is_some())
}
fn flush(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct NoPersistence;
impl PersistenceBackend for NoPersistence {
fn save_task(&self, _task: &PersistedTaskState) -> Result<()> {
Ok(())
}
fn load_task(&self, _task_id: &TaskId) -> Result<Option<PersistedTaskState>> {
Ok(None)
}
fn delete_task(&self, _task_id: &TaskId) -> Result<()> {
Ok(())
}
fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>> {
Ok(Vec::new())
}
}
#[derive(Debug)]
pub struct FilePersistence {
path: std::path::PathBuf,
lock: parking_lot::RwLock<()>,
}
impl FilePersistence {
pub fn new(path: impl Into<std::path::PathBuf>) -> Self {
Self {
path: path.into(),
lock: parking_lot::RwLock::new(()),
}
}
fn read_all(&self) -> Result<HashMap<String, PersistedTaskState>> {
let _guard = self.lock.read();
if !self.path.exists() {
return Ok(HashMap::new());
}
let contents = std::fs::read_to_string(&self.path)?;
if contents.is_empty() {
return Ok(HashMap::new());
}
serde_json::from_str(&contents).map_err(|e| PersistenceError::Serialization(e.to_string()))
}
fn write_all(&self, tasks: &HashMap<String, PersistedTaskState>) -> Result<()> {
let _guard = self.lock.write();
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent)?;
}
let contents = serde_json::to_string_pretty(tasks)
.map_err(|e| PersistenceError::Serialization(e.to_string()))?;
let temp_path = self.path.with_extension("tmp");
std::fs::write(&temp_path, contents)?;
std::fs::rename(&temp_path, &self.path)?;
Ok(())
}
fn task_key(task_id: &TaskId) -> String {
format!("{}:{}", task_id.service_id, task_id.call_id)
}
}
impl PersistenceBackend for FilePersistence {
fn save_task(&self, task: &PersistedTaskState) -> Result<()> {
let mut tasks = self.read_all()?;
let key = Self::task_key(&TaskId::new(task.service_id, task.call_id));
tasks.insert(key, task.clone());
self.write_all(&tasks)
}
fn load_task(&self, task_id: &TaskId) -> Result<Option<PersistedTaskState>> {
let tasks = self.read_all()?;
let key = Self::task_key(task_id);
Ok(tasks.get(&key).cloned())
}
fn delete_task(&self, task_id: &TaskId) -> Result<()> {
let mut tasks = self.read_all()?;
let key = Self::task_key(task_id);
tasks.remove(&key);
self.write_all(&tasks)
}
fn load_all_tasks(&self) -> Result<Vec<PersistedTaskState>> {
let tasks = self.read_all()?;
Ok(tasks.into_values().collect())
}
fn flush(&self) -> Result<()> {
Ok(())
}
}
pub fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
pub fn remaining_duration(expires_at_ms: Option<u64>) -> Option<Duration> {
expires_at_ms.and_then(|expires| {
let now = now_millis();
if expires > now {
Some(Duration::from_millis(expires - now))
} else {
None
}
})
}
pub fn is_expired(expires_at_ms: Option<u64>) -> bool {
expires_at_ms
.map(|expires| now_millis() > expires)
.unwrap_or(false)
}
mod hex_bytes {
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("0x{}", hex::encode(bytes)))
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let s = s.strip_prefix("0x").unwrap_or(&s);
hex::decode(s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
fn sample_task() -> PersistedTaskState {
PersistedTaskState {
service_id: 1,
call_id: 100,
output: vec![1, 2, 3, 4],
operator_count: 5,
threshold_type: PersistedThresholdType::Count(3),
signer_bitmap: "0x7".to_string(), signatures: HashMap::from([
(0, "0xabc123".to_string()),
(1, "0xdef456".to_string()),
(2, "0x789abc".to_string()),
]),
public_keys: HashMap::from([
(0, "0xpk1".to_string()),
(1, "0xpk2".to_string()),
(2, "0xpk3".to_string()),
]),
operator_stakes: HashMap::from([(0, 100), (1, 100), (2, 100), (3, 100), (4, 100)]),
total_stake: 500,
submitted: false,
created_at_ms: 1700000000000,
expires_at_ms: Some(1700001000000),
}
}
#[test]
fn test_no_persistence() {
let backend = NoPersistence;
let task = sample_task();
let task_id = TaskId::new(task.service_id, task.call_id);
assert!(backend.save_task(&task).is_ok());
assert!(backend.load_task(&task_id).unwrap().is_none());
assert!(backend.delete_task(&task_id).is_ok());
assert!(backend.load_all_tasks().unwrap().is_empty());
}
#[test]
fn test_file_persistence() {
let temp_file = NamedTempFile::new().unwrap();
let backend = FilePersistence::new(temp_file.path());
let task = sample_task();
let task_id = TaskId::new(task.service_id, task.call_id);
backend.save_task(&task).unwrap();
let loaded = backend.load_task(&task_id).unwrap().unwrap();
assert_eq!(loaded.service_id, task.service_id);
assert_eq!(loaded.call_id, task.call_id);
assert_eq!(loaded.output, task.output);
assert_eq!(loaded.operator_count, task.operator_count);
assert_eq!(loaded.signatures.len(), 3);
let all = backend.load_all_tasks().unwrap();
assert_eq!(all.len(), 1);
backend.delete_task(&task_id).unwrap();
assert!(backend.load_task(&task_id).unwrap().is_none());
}
#[test]
fn test_file_persistence_multiple_tasks() {
let temp_file = NamedTempFile::new().unwrap();
let backend = FilePersistence::new(temp_file.path());
for i in 0..5 {
let mut task = sample_task();
task.call_id = 100 + i;
backend.save_task(&task).unwrap();
}
let all = backend.load_all_tasks().unwrap();
assert_eq!(all.len(), 5);
backend.delete_task(&TaskId::new(1, 102)).unwrap();
let all = backend.load_all_tasks().unwrap();
assert_eq!(all.len(), 4);
}
#[test]
fn test_threshold_type_conversion() {
let count = ThresholdType::Count(5);
let persisted: PersistedThresholdType = count.into();
let recovered: ThresholdType = persisted.into();
assert_eq!(count, recovered);
let stake = ThresholdType::StakeWeighted(6700);
let persisted: PersistedThresholdType = stake.into();
let recovered: ThresholdType = persisted.into();
assert_eq!(stake, recovered);
}
#[test]
fn test_time_helpers() {
let now = now_millis();
assert!(now > 0);
let future = Some(now + 10000);
assert!(!is_expired(future));
assert!(remaining_duration(future).is_some());
let past = Some(now - 10000);
assert!(is_expired(past));
assert!(remaining_duration(past).is_none());
assert!(!is_expired(None));
assert!(remaining_duration(None).is_none());
}
#[test]
fn test_serialization_roundtrip() {
let task = sample_task();
let json = serde_json::to_string(&task).unwrap();
let recovered: PersistedTaskState = serde_json::from_str(&json).unwrap();
assert_eq!(task.service_id, recovered.service_id);
assert_eq!(task.call_id, recovered.call_id);
assert_eq!(task.output, recovered.output);
}
}