1use crate::error::FloxideError;
4use crate::Merge;
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::future::Future;
9use std::time::Duration;
10use std::{fmt::Debug, sync::Arc};
11use tokio::sync::Mutex;
12use tokio_util::sync::CancellationToken;
13
14pub trait Context: Default + DeserializeOwned + Serialize + Debug + Clone + Send + Sync {}
15impl<T: Default + DeserializeOwned + Serialize + Debug + Clone + Send + Sync> Context for T {}
16
17#[derive(Clone, Debug)]
19pub struct WorkflowCtx<S: Context> {
22 pub store: S,
24 cancel: CancellationToken,
26 timeout: Option<Duration>,
28}
29
30impl<S: Context> WorkflowCtx<S> {
31 pub fn new(store: S) -> Self {
33 Self {
34 store,
35 cancel: CancellationToken::new(),
36 timeout: None,
37 }
38 }
39
40 pub fn with_store<F, R>(&self, f: F) -> R
42 where
43 F: FnOnce(&S) -> R,
44 {
45 f(&self.store)
46 }
47
48 pub fn cancel_token(&self) -> &CancellationToken {
50 &self.cancel
51 }
52
53 pub fn set_timeout(&mut self, d: Duration) {
55 self.timeout = Some(d);
56 }
57
58 pub fn cancel(&self) {
60 self.cancel.cancel();
61 }
62
63 pub fn is_cancelled(&self) -> bool {
65 self.cancel.is_cancelled()
66 }
67
68 pub async fn cancelled(&self) {
70 self.cancel.cancelled().await;
71 }
72
73 pub async fn run_future<R, F>(&self, fut: F) -> Result<R, FloxideError>
75 where
76 F: Future<Output = Result<R, FloxideError>>,
77 {
78 if let Some(duration) = self.timeout {
79 tokio::select! {
80 _ = self.cancel.cancelled() => Err(FloxideError::Cancelled),
81 _ = tokio::time::sleep(duration) => Err(FloxideError::Timeout(duration)),
82 res = fut => res,
83 }
84 } else {
85 tokio::select! {
86 _ = self.cancel.cancelled() => Err(FloxideError::Cancelled),
87 res = fut => res,
88 }
89 }
90 }
91}
92
93#[derive(Clone, Default)]
95pub struct SharedState<T>(Arc<Mutex<T>>);
96
97impl<T> SharedState<T> {
98 pub fn new(value: T) -> Self {
99 SharedState(Arc::new(Mutex::new(value)))
100 }
101
102 pub async fn get(&self) -> tokio::sync::MutexGuard<'_, T> {
103 self.0.lock().await
104 }
105
106 pub async fn set(&self, value: T) {
107 *self.0.lock().await = value;
108 }
109}
110
111impl<T: Serialize + Clone> Serialize for SharedState<T> {
112 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
113 where
114 S: serde::Serializer,
115 {
116 let value = self
117 .0
118 .try_lock()
119 .expect("Failed to lock mutex on SharedState while serializing");
120 T::serialize(&*value, serializer)
122 }
123}
124
125impl<'de, T: Deserialize<'de> + Clone> Deserialize<'de> for SharedState<T> {
126 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
127 where
128 D: serde::Deserializer<'de>,
129 {
130 let value = T::deserialize(deserializer)?;
132 Ok(SharedState(Arc::new(Mutex::new(value))))
133 }
134}
135
136impl<T: Debug> Debug for SharedState<T> {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 match self.0.try_lock() {
139 Ok(value) => write!(f, "{:?}", value),
140 Err(_) => write!(f, "SharedState(Locked)"),
141 }
142 }
143}
144
145impl<T: Merge> Merge for SharedState<T> {
146 fn merge(&mut self, other: Self) {
147 let self_ptr = Arc::as_ptr(&self.0) as usize;
148 let other_ptr = Arc::as_ptr(&other.0) as usize;
149 if self_ptr == other_ptr {
150 return;
152 }
153 let (first, second) = if self_ptr < other_ptr {
155 (&self.0, &other.0)
156 } else {
157 (&other.0, &self.0)
158 };
159 let mut first_guard = first.blocking_lock();
160 let mut second_guard = second.blocking_lock();
161 if self_ptr < other_ptr {
163 let other_val = std::mem::take(&mut *second_guard);
164 first_guard.merge(other_val);
165 } else {
166 let mut temp = std::mem::take(&mut *first_guard);
167 temp.merge(std::mem::take(&mut *second_guard));
168 *first_guard = temp;
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use serde_json;
177
178 #[tokio::test]
179 async fn test_shared_state_serde_direct() {
180 let initial_data = vec![10, 20, 30];
181 let shared_state = SharedState::new(initial_data.clone());
182
183 let serialized = serde_json::to_string(&shared_state).expect("Serialization failed");
185 assert_eq!(serialized, "[10,20,30]");
187
188 let deserialized: SharedState<Vec<i32>> =
190 serde_json::from_str(&serialized).expect("Deserialization failed");
191
192 let final_data = deserialized.get().await;
194 assert_eq!(*final_data, initial_data);
195 }
196
197 #[tokio::test]
198 async fn test_shared_state_serde_within_struct() {
199 #[derive(Serialize, Deserialize, Debug)]
201 struct Container {
202 id: u32,
203 state: SharedState<String>,
204 }
205
206 let initial_string = "hello".to_string();
207 let container = Container {
208 id: 1,
209 state: SharedState::new(initial_string.clone()),
210 };
211
212 let serialized = serde_json::to_string(&container).expect("Serialization failed");
214 assert_eq!(serialized, r#"{"id":1,"state":"hello"}"#);
216
217 let deserialized: Container =
219 serde_json::from_str(&serialized).expect("Deserialization failed");
220
221 assert_eq!(deserialized.id, container.id);
223 let final_string = deserialized.state.get().await;
224 assert_eq!(*final_string, initial_string);
225 }
226}