1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use crate::context::TransitionContext;
6
7#[derive(Debug, Clone)]
9pub struct ActionError {
10 pub message: String,
11}
12
13impl ActionError {
14 pub fn new(message: impl Into<String>) -> Self {
15 Self {
16 message: message.into(),
17 }
18 }
19}
20
21impl std::fmt::Display for ActionError {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "ActionError: {}", self.message)
24 }
25}
26
27impl std::error::Error for ActionError {}
28
29pub trait TransitionAction: Send + Sync {
35 fn is_sync(&self) -> bool {
37 false
38 }
39
40 fn run_sync(&self, ctx: &mut TransitionContext) -> Result<(), ActionError>;
42
43 fn run_async<'a>(
46 &'a self,
47 mut ctx: TransitionContext,
48 ) -> Pin<Box<dyn Future<Output = Result<TransitionContext, ActionError>> + Send + 'a>> {
49 Box::pin(async move {
50 self.run_sync(&mut ctx)?;
51 Ok(ctx)
52 })
53 }
54}
55
56pub type BoxedAction = Arc<dyn TransitionAction>;
58
59pub fn passthrough() -> BoxedAction {
63 Arc::new(Passthrough)
64}
65
66struct Passthrough;
67
68impl TransitionAction for Passthrough {
69 fn is_sync(&self) -> bool {
70 true
71 }
72
73 fn run_sync(&self, _ctx: &mut TransitionContext) -> Result<(), ActionError> {
74 Ok(())
75 }
76}
77
78pub fn transform<F>(f: F) -> BoxedAction
80where
81 F: Fn(&mut TransitionContext) -> Arc<dyn std::any::Any + Send + Sync> + Send + Sync + 'static,
82{
83 Arc::new(Transform(f))
84}
85
86struct Transform<F>(F);
87
88impl<F> TransitionAction for Transform<F>
89where
90 F: Fn(&mut TransitionContext) -> Arc<dyn std::any::Any + Send + Sync> + Send + Sync + 'static,
91{
92 fn is_sync(&self) -> bool {
93 true
94 }
95
96 fn run_sync(&self, ctx: &mut TransitionContext) -> Result<(), ActionError> {
97 let result = (self.0)(ctx);
98 for place_name in ctx.output_place_names() {
99 ctx.output_raw(&place_name, Arc::clone(&result))?;
100 }
101 Ok(())
102 }
103}
104
105pub fn fork() -> BoxedAction {
109 Arc::new(Fork)
110}
111
112struct Fork;
113
114impl TransitionAction for Fork {
115 fn is_sync(&self) -> bool {
116 true
117 }
118
119 fn run_sync(&self, ctx: &mut TransitionContext) -> Result<(), ActionError> {
120 let input_places = ctx.input_place_names();
121 if input_places.len() != 1 {
122 return Err(ActionError::new(format!(
123 "Fork requires exactly 1 input place, found {}",
124 input_places.len()
125 )));
126 }
127 let place_name = input_places.into_iter().next().unwrap();
128 let value = ctx.input_raw(&place_name)?;
129 for output_name in ctx.output_place_names() {
130 ctx.output_raw(&output_name, Arc::clone(&value))?;
131 }
132 Ok(())
133 }
134}
135
136pub fn produce<T: Send + Sync + 'static>(place_name: Arc<str>, value: T) -> BoxedAction {
138 let value = Arc::new(value) as Arc<dyn std::any::Any + Send + Sync>;
139 Arc::new(Produce { place_name, value })
140}
141
142struct Produce {
143 place_name: Arc<str>,
144 value: Arc<dyn std::any::Any + Send + Sync>,
145}
146
147impl TransitionAction for Produce {
148 fn is_sync(&self) -> bool {
149 true
150 }
151
152 fn run_sync(&self, ctx: &mut TransitionContext) -> Result<(), ActionError> {
153 ctx.output_raw(&self.place_name, Arc::clone(&self.value))?;
154 Ok(())
155 }
156}
157
158pub fn sync_action<F>(f: F) -> BoxedAction
160where
161 F: Fn(&mut TransitionContext) -> Result<(), ActionError> + Send + Sync + 'static,
162{
163 Arc::new(SyncAction(f))
164}
165
166struct SyncAction<F>(F);
167
168impl<F> TransitionAction for SyncAction<F>
169where
170 F: Fn(&mut TransitionContext) -> Result<(), ActionError> + Send + Sync + 'static,
171{
172 fn is_sync(&self) -> bool {
173 true
174 }
175
176 fn run_sync(&self, ctx: &mut TransitionContext) -> Result<(), ActionError> {
177 (self.0)(ctx)
178 }
179}
180
181pub fn async_action<F, Fut>(f: F) -> BoxedAction
183where
184 F: Fn(TransitionContext) -> Fut + Send + Sync + 'static,
185 Fut: Future<Output = Result<TransitionContext, ActionError>> + Send + 'static,
186{
187 Arc::new(AsyncAction(f))
188}
189
190struct AsyncAction<F>(F);
191
192impl<F, Fut> TransitionAction for AsyncAction<F>
193where
194 F: Fn(TransitionContext) -> Fut + Send + Sync + 'static,
195 Fut: Future<Output = Result<TransitionContext, ActionError>> + Send + 'static,
196{
197 fn run_sync(&self, _ctx: &mut TransitionContext) -> Result<(), ActionError> {
198 Err(ActionError::new("Async action cannot run synchronously"))
199 }
200
201 fn run_async<'a>(
202 &'a self,
203 ctx: TransitionContext,
204 ) -> Pin<Box<dyn Future<Output = Result<TransitionContext, ActionError>> + Send + 'a>> {
205 Box::pin((self.0)(ctx))
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn passthrough_is_sync() {
215 let action = passthrough();
216 assert!(action.is_sync());
217 }
218
219 #[test]
220 fn fork_is_sync() {
221 let action = fork();
222 assert!(action.is_sync());
223 }
224}