use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use crate::types::*;
const SNAPSHOT_MAGIC: &[u8; 4] = b"BG\x01\x00";
#[async_trait]
pub trait StateStore: Send + Sync + 'static {
async fn load(&self, chat_id: ChatId) -> Result<Option<ChatState>, String>;
async fn save(&self, state: &ChatState) -> Result<(), String>;
async fn delete(&self, chat_id: ChatId) -> Result<(), String>;
async fn all_chat_ids(&self) -> Result<Vec<ChatId>, String>;
}
pub struct InMemoryStore {
states: DashMap<ChatId, ChatState>,
}
impl InMemoryStore {
pub fn new() -> Self {
Self {
states: DashMap::new(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.states.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
}
impl Default for InMemoryStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateStore for InMemoryStore {
async fn load(&self, chat_id: ChatId) -> Result<Option<ChatState>, String> {
Ok(self.states.get(&chat_id).map(|r| r.value().clone()))
}
async fn save(&self, state: &ChatState) -> Result<(), String> {
self.states.insert(state.chat_id, state.clone());
Ok(())
}
async fn delete(&self, chat_id: ChatId) -> Result<(), String> {
self.states.remove(&chat_id);
Ok(())
}
async fn all_chat_ids(&self) -> Result<Vec<ChatId>, String> {
Ok(self.states.iter().map(|r| *r.key()).collect())
}
}
impl InMemoryStore {
pub async fn snapshot(
&self,
path: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let states: Vec<ChatState> = self.states.iter().map(|r| r.value().clone()).collect();
let payload = postcard::to_allocvec(&states)?;
let mut buf = Vec::with_capacity(SNAPSHOT_MAGIC.len() + payload.len());
buf.extend_from_slice(SNAPSHOT_MAGIC);
buf.extend_from_slice(&payload);
let tmp = format!("{path}.tmp");
tokio::fs::write(&tmp, buf).await?;
tokio::fs::rename(&tmp, path).await?;
Ok(())
}
pub async fn restore(
&self,
path: &str,
) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
let bytes = match tokio::fs::read(path).await {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
Err(e) => return Err(e.into()),
};
let payload = if bytes.starts_with(SNAPSHOT_MAGIC) {
&bytes[SNAPSHOT_MAGIC.len()..]
} else {
tracing::warn!("snapshot has no version header — attempting legacy JSON migration");
return self.restore_legacy_json(&bytes);
};
let states: Vec<ChatState> = postcard::from_bytes(payload)?;
let count = states.len();
for state in states {
self.states.insert(state.chat_id, state);
}
Ok(count)
}
fn restore_legacy_json(
&self,
bytes: &[u8],
) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
match serde_json::from_slice::<Vec<ChatState>>(bytes) {
Ok(states) => {
let count = states.len();
for state in states {
self.states.insert(state.chat_id, state);
}
tracing::info!(count, "migrated from legacy JSON snapshot");
Ok(count)
}
Err(_) => Err("unrecognized snapshot format — delete the file and restart".into()),
}
}
pub fn start_snapshot_task(
self: &Arc<Self>,
path: String,
interval: std::time::Duration,
) -> tokio::task::JoinHandle<()> {
let store = Arc::clone(self);
tokio::spawn(async move {
let mut tick = tokio::time::interval(interval);
tick.tick().await; loop {
tick.tick().await;
if let Err(e) = store.snapshot(&path).await {
tracing::error!(error = %e, "snapshot failed");
}
}
})
}
}