use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::FsmState;
use crate::error::StorageError;
use crate::key::StateKey;
use crate::storage::StateStorage;
#[derive(Clone)]
pub struct StateContext {
storage: Arc<dyn StateStorage>,
key: StateKey,
pub current_state: String,
}
impl StateContext {
pub fn new(storage: Arc<dyn StateStorage>, key: StateKey, current_state: String) -> Self {
Self {
storage,
key,
current_state,
}
}
pub async fn transition(&self, new_state: impl FsmState) -> Result<(), StorageError> {
self.storage
.set_state(self.key.clone(), new_state.as_key())
.await
}
pub async fn clear_state(&self) -> Result<(), StorageError> {
self.storage.clear_state(self.key.clone()).await
}
pub async fn set_data<T: Serialize>(&self, field: &str, value: T) -> Result<(), StorageError> {
let json = serde_json::to_value(value).map_err(|e| {
StorageError::with_source(format!("failed to serialize field `{field}`"), e)
})?;
self.storage.set_data(self.key.clone(), field, json).await
}
pub async fn get_data<T: DeserializeOwned>(
&self,
field: &str,
) -> Result<Option<T>, StorageError> {
let raw = self.storage.get_data(self.key.clone(), field).await?;
match raw {
None => Ok(None),
Some(val) => {
let typed = serde_json::from_value(val).map_err(|e| {
StorageError::with_source(format!("failed to deserialize field `{field}`"), e)
})?;
Ok(Some(typed))
}
}
}
pub async fn get_all_data(&self) -> Result<HashMap<String, serde_json::Value>, StorageError> {
self.storage.get_all_data(self.key.clone()).await
}
pub async fn clear_data(&self) -> Result<(), StorageError> {
self.storage.clear_data(self.key.clone()).await
}
pub async fn clear_all(&self) -> Result<(), StorageError> {
self.storage.clear_all(self.key.clone()).await
}
pub fn key(&self) -> &StateKey {
&self.key
}
}
impl fmt::Debug for StateContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StateContext")
.field("key", &self.key)
.field("current_state", &self.current_state)
.finish()
}
}