floxide_core/
context.rs

1//! The context for a workflow execution.
2
3use crate::error::FloxideError;
4use crate::Merge;
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::future::Future;
9use std::time::Duration;
10use std::{fmt::Debug, sync::Arc};
11use tokio::sync::Mutex;
12use tokio_util::sync::CancellationToken;
13
14pub trait Context: Default + DeserializeOwned + Serialize + Debug + Clone + Send + Sync {}
15impl<T: Default + DeserializeOwned + Serialize + Debug + Clone + Send + Sync> Context for T {}
16
17/// The context for a workflow execution.
18#[derive(Clone, Debug)]
19///
20/// The context contains the store, cancellation token, and optional timeout.
21pub struct WorkflowCtx<S: Context> {
22    /// The store for the workflow.
23    pub store: S,
24    /// The cancellation token for the workflow.
25    cancel: CancellationToken,
26    /// The optional timeout for the workflow.
27    timeout: Option<Duration>,
28}
29
30impl<S: Context> WorkflowCtx<S> {
31    /// Creates a new workflow context with the given store.
32    pub fn new(store: S) -> Self {
33        Self {
34            store,
35            cancel: CancellationToken::new(),
36            timeout: None,
37        }
38    }
39
40    /// Runs the provided function with a reference to the store.
41    pub fn with_store<F, R>(&self, f: F) -> R
42    where
43        F: FnOnce(&S) -> R,
44    {
45        f(&self.store)
46    }
47
48    /// Returns a reference to the cancellation token.
49    pub fn cancel_token(&self) -> &CancellationToken {
50        &self.cancel
51    }
52
53    /// Sets a timeout for the workflow.
54    pub fn set_timeout(&mut self, d: Duration) {
55        self.timeout = Some(d);
56    }
57
58    /// Cancel the workflow execution.
59    pub fn cancel(&self) {
60        self.cancel.cancel();
61    }
62
63    /// Returns true if the workflow has been cancelled.
64    pub fn is_cancelled(&self) -> bool {
65        self.cancel.is_cancelled()
66    }
67
68    /// Asynchronously wait until the workflow is cancelled.
69    pub async fn cancelled(&self) {
70        self.cancel.cancelled().await;
71    }
72
73    /// Runs the provided future, respecting cancellation and optional timeout.
74    pub async fn run_future<R, F>(&self, fut: F) -> Result<R, FloxideError>
75    where
76        F: Future<Output = Result<R, FloxideError>>,
77    {
78        if let Some(duration) = self.timeout {
79            tokio::select! {
80                _ = self.cancel.cancelled() => Err(FloxideError::Cancelled),
81                _ = tokio::time::sleep(duration) => Err(FloxideError::Timeout(duration)),
82                res = fut => res,
83            }
84        } else {
85            tokio::select! {
86                _ = self.cancel.cancelled() => Err(FloxideError::Cancelled),
87                res = fut => res,
88            }
89        }
90    }
91}
92
93/// Arc<Mutex<T>> wrapper with custom (de)serialization and debug support
94#[derive(Clone, Default)]
95pub struct SharedState<T>(Arc<Mutex<T>>);
96
97impl<T> SharedState<T> {
98    pub fn new(value: T) -> Self {
99        SharedState(Arc::new(Mutex::new(value)))
100    }
101
102    pub async fn get(&self) -> tokio::sync::MutexGuard<'_, T> {
103        self.0.lock().await
104    }
105
106    pub async fn set(&self, value: T) {
107        *self.0.lock().await = value;
108    }
109}
110
111impl<T: Serialize + Clone> Serialize for SharedState<T> {
112    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
113    where
114        S: serde::Serializer,
115    {
116        let value = self
117            .0
118            .try_lock()
119            .expect("Failed to lock mutex on SharedState while serializing");
120        // Directly serialize the inner value T
121        T::serialize(&*value, serializer)
122    }
123}
124
125impl<'de, T: Deserialize<'de> + Clone> Deserialize<'de> for SharedState<T> {
126    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
127    where
128        D: serde::Deserializer<'de>,
129    {
130        // Directly deserialize into T
131        let value = T::deserialize(deserializer)?;
132        Ok(SharedState(Arc::new(Mutex::new(value))))
133    }
134}
135
136impl<T: Debug> Debug for SharedState<T> {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        match self.0.try_lock() {
139            Ok(value) => write!(f, "{:?}", value),
140            Err(_) => write!(f, "SharedState(Locked)"),
141        }
142    }
143}
144
145impl<T: Merge> Merge for SharedState<T> {
146    fn merge(&mut self, other: Self) {
147        let self_ptr = Arc::as_ptr(&self.0) as usize;
148        let other_ptr = Arc::as_ptr(&other.0) as usize;
149        if self_ptr == other_ptr {
150            // Prevent self-deadlock: merging with itself is a no-op
151            return;
152        }
153        // Lock in address order to prevent lock order inversion deadlocks
154        let (first, second) = if self_ptr < other_ptr {
155            (&self.0, &other.0)
156        } else {
157            (&other.0, &self.0)
158        };
159        let mut first_guard = first.blocking_lock();
160        let mut second_guard = second.blocking_lock();
161        // Always merge into self
162        if self_ptr < other_ptr {
163            let other_val = std::mem::take(&mut *second_guard);
164            first_guard.merge(other_val);
165        } else {
166            let mut temp = std::mem::take(&mut *first_guard);
167            temp.merge(std::mem::take(&mut *second_guard));
168            *first_guard = temp;
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use serde_json;
177
178    #[tokio::test]
179    async fn test_shared_state_serde_direct() {
180        let initial_data = vec![10, 20, 30];
181        let shared_state = SharedState::new(initial_data.clone());
182
183        // Serialize
184        let serialized = serde_json::to_string(&shared_state).expect("Serialization failed");
185        // Should serialize directly as the inner Vec<i32>
186        assert_eq!(serialized, "[10,20,30]");
187
188        // Deserialize
189        let deserialized: SharedState<Vec<i32>> =
190            serde_json::from_str(&serialized).expect("Deserialization failed");
191
192        // Verify data
193        let final_data = deserialized.get().await;
194        assert_eq!(*final_data, initial_data);
195    }
196
197    #[tokio::test]
198    async fn test_shared_state_serde_within_struct() {
199        // Removed PartialEq derive
200        #[derive(Serialize, Deserialize, Debug)]
201        struct Container {
202            id: u32,
203            state: SharedState<String>,
204        }
205
206        let initial_string = "hello".to_string();
207        let container = Container {
208            id: 1,
209            state: SharedState::new(initial_string.clone()),
210        };
211
212        // Serialize
213        let serialized = serde_json::to_string(&container).expect("Serialization failed");
214        // state should be serialized directly as the string
215        assert_eq!(serialized, r#"{"id":1,"state":"hello"}"#);
216
217        // Deserialize
218        let deserialized: Container =
219            serde_json::from_str(&serialized).expect("Deserialization failed");
220
221        // Verify data manually
222        assert_eq!(deserialized.id, container.id);
223        let final_string = deserialized.state.get().await;
224        assert_eq!(*final_string, initial_string);
225    }
226}