1use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19
20use crate::FsmState;
21use crate::error::StorageError;
22use crate::key::StateKey;
23use crate::storage::StateStorage;
24
25#[derive(Clone)]
30pub struct StateContext {
31 storage: Arc<dyn StateStorage>,
32 key: StateKey,
33 pub current_state: String,
35}
36
37impl StateContext {
38 pub fn new(storage: Arc<dyn StateStorage>, key: StateKey, current_state: String) -> Self {
40 Self {
41 storage,
42 key,
43 current_state,
44 }
45 }
46
47 pub async fn transition(&self, new_state: impl FsmState) -> Result<(), StorageError> {
49 self.storage
50 .set_state(self.key.clone(), new_state.as_key())
51 .await
52 }
53
54 pub async fn clear_state(&self) -> Result<(), StorageError> {
56 self.storage.clear_state(self.key.clone()).await
57 }
58
59 pub async fn set_data<T: Serialize>(&self, field: &str, value: T) -> Result<(), StorageError> {
61 let json = serde_json::to_value(value).map_err(|e| {
62 StorageError::with_source(format!("failed to serialize field `{field}`"), e)
63 })?;
64 self.storage.set_data(self.key.clone(), field, json).await
65 }
66
67 pub async fn get_data<T: DeserializeOwned>(
69 &self,
70 field: &str,
71 ) -> Result<Option<T>, StorageError> {
72 let raw = self.storage.get_data(self.key.clone(), field).await?;
73 match raw {
74 None => Ok(None),
75 Some(val) => {
76 let typed = serde_json::from_value(val).map_err(|e| {
77 StorageError::with_source(format!("failed to deserialize field `{field}`"), e)
78 })?;
79 Ok(Some(typed))
80 }
81 }
82 }
83
84 pub async fn get_all_data(&self) -> Result<HashMap<String, serde_json::Value>, StorageError> {
86 self.storage.get_all_data(self.key.clone()).await
87 }
88
89 pub async fn clear_data(&self) -> Result<(), StorageError> {
91 self.storage.clear_data(self.key.clone()).await
92 }
93
94 pub async fn clear_all(&self) -> Result<(), StorageError> {
96 self.storage.clear_all(self.key.clone()).await
97 }
98
99 pub fn key(&self) -> &StateKey {
101 &self.key
102 }
103}
104
105impl fmt::Debug for StateContext {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 f.debug_struct("StateContext")
108 .field("key", &self.key)
109 .field("current_state", &self.current_state)
110 .finish()
111 }
112}