barley_runtime/
runtime.rs

1use tokio::sync::RwLock;
2use tokio::sync::Barrier;
3use tokio::task::JoinSet;
4
5use std::any::{Any, TypeId};
6use tracing::{debug, info, error};
7use std::{
8    sync::Arc,
9    collections::HashMap
10};
11
12use crate::Operation;
13use crate::{
14    ActionObject, Id,
15    ActionOutput,
16    ActionError,
17    context::Context
18};
19
20
21/// The runtime for a workflow.
22/// 
23/// This struct is used to run a workflow. It contains
24/// all of the actions that need to be run, and it
25/// ensures that all dependencies are run before the
26/// actions that depend on them.
27/// 
28/// # Example
29/// 
30/// ```
31/// use barley_runtime::prelude::*;
32/// 
33/// let runtime = RuntimeBuilder::new().build();
34/// ```
35#[derive(Clone)]
36pub struct Runtime {
37    ctx: Context,
38    barriers: HashMap<Id, Arc<Barrier>>,
39    outputs: Arc<RwLock<HashMap<Id, ActionOutput>>>,
40    state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>
41}
42
43impl Runtime {
44    /// Run the workflow.
45    pub async fn perform(mut self) -> Result<(), ActionError> {
46        let actions = self.ctx.actions.clone();
47        let mut dependents: HashMap<Id, usize> = HashMap::new();
48
49        // Get the dependents for each action. For
50        // example, if action A depends on action B,
51        // then 1 action is dependent on B (A) and 0
52        // actions are dependent on A.
53        for action in actions.iter() {
54            dependents.insert(action.id, 0);
55
56            action.deps()
57                .iter()
58                .map(|dep| dep.id())
59                .for_each(|id| {
60                    let count = dependents.entry(id).or_insert(0);
61                    *count += 1;
62                });
63        }
64
65        // Create a barrier for each action that has
66        // any dependents. The barrier will be used
67        // to wait for the dependent actions to finish.
68        for (id, dependents) in dependents.clone() {
69            if dependents == 0 {
70                continue;
71            }
72
73            let barrier = Arc::new(Barrier::new(dependents + 1));
74            self.barriers.insert(id, barrier);
75        }
76
77        let mut join_set: JoinSet<Result<(), ActionError>> = JoinSet::new();
78
79        debug!("Starting actions");
80        for action in actions {
81            let runtime_clone = self.clone();
82
83            let action = action.clone();
84
85            let deps = action.deps();
86
87            let barriers = deps
88                .iter()
89                .map(|dep| dep.id());
90
91            let barriers = barriers
92                .map(|id| self.barriers.get(&id).unwrap().clone())
93                .collect::<Vec<_>>();
94
95            let self_barriers = self.barriers.clone();
96
97            join_set.spawn(async move {
98                let self_barrier = self_barriers.get(&action.id).cloned();
99
100                for barrier in barriers {
101                    barrier.wait().await;
102                }
103
104                let probe = action.probe(runtime_clone.clone()).await?;
105                if !probe.needs_run {
106                    return Ok(())
107                }
108
109                let display_name = action.display_name();
110                info!("Starting action: {}", display_name);
111
112                let output = action.run(runtime_clone.clone(), Operation::Perform).await;
113
114                if let Err(err) = &output {
115                    error!("Action failed: {}", display_name);
116                    error!("Error: {}", err);
117
118                    return Err(err.clone())
119                } else {
120                    info!("Action finished: {}", display_name);
121                }
122
123                let output = output.unwrap();
124
125                if let Some(barrier) = self_barrier {
126                    barrier.wait().await;
127                }
128
129                if let Some(output) = output {
130                    runtime_clone.outputs.write().await.insert(action.id, output);
131                }
132
133                Ok(())
134            });
135        }
136
137        while let Some(result) = join_set.join_next().await {
138            match result {
139                Ok(Ok(())) => {},
140                Ok(Err(err)) => {
141                    join_set.abort_all();
142
143                    if let ActionError::ActionFailed(_, long) = err.clone() {
144                        println!("{}", long);
145                    }
146
147                    return Err(err)
148                },
149                Err(_) => {
150                    join_set.abort_all();
151
152                    return Err(ActionError::InternalError("JOIN_SET_ERROR"))
153                }
154            }
155        }
156
157        Ok(())
158    }
159
160    /// Rollback the workflow.
161    /// 
162    /// This will undo all of the actions that have
163    /// been performed, if possible.
164    pub async fn rollback(self) -> Result<(), ActionError> {
165        let actions = self.ctx.actions.clone();
166        let mut dependencies: HashMap<Id, Vec<Id>> = HashMap::new();
167
168        // Check if all of the actions have a rollback
169        // function. If not, then the rollback cannot
170        // be performed.
171        for action in actions.iter() {
172            if !action.probe(self.clone()).await?.can_rollback {
173                return Err(ActionError::InternalError("NO_ROLLBACK"))
174            }
175        }
176
177        // Get the dependencies for each action. For
178        // example, if action A depends on action B,
179        // then B is a dependency of A.
180        for action in actions.iter() {
181            dependencies.insert(action.id, Vec::new());
182
183            action.deps()
184                .iter()
185                .map(|dep| dep.id())
186                .for_each(|id| {
187                    let deps = dependencies.entry(id).or_insert(Vec::new());
188                    deps.push(action.id);
189                });
190        }
191
192        // Sort the actions by their dependencies.
193        let mut actions = actions;
194        actions.sort_by(|a, b| {
195            let a_deps = dependencies.get(&a.id).unwrap();
196            let b_deps = dependencies.get(&b.id).unwrap();
197
198            if a_deps.contains(&b.id) {
199                return std::cmp::Ordering::Greater
200            }
201
202            if b_deps.contains(&a.id) {
203                return std::cmp::Ordering::Less
204            }
205
206            std::cmp::Ordering::Equal
207        });
208
209        // Create spawns
210        let mut join_set: JoinSet<Result<(), ActionError>> = JoinSet::new();
211
212        for action in actions {
213            let runtime_clone = self.clone();
214
215            join_set.spawn(async move {
216                action.run(runtime_clone.clone(), Operation::Rollback).await?;
217
218                Ok(())
219            });
220        }
221
222        while let Some(result) = join_set.join_next().await {
223            match result {
224                Ok(Ok(())) => {},
225                Ok(Err(err)) => {
226                    join_set.abort_all();
227
228                    if let ActionError::ActionFailed(_, long) = err.clone() {
229                        println!("{}", long);
230                    }
231
232                    return Err(err)
233                },
234                Err(_) => {
235                    join_set.abort_all();
236
237                    return Err(ActionError::InternalError("JOIN_SET_ERROR"))
238                }
239            }
240        }
241
242        Ok(())
243    }
244
245    /// Get the output of an action.
246    pub async fn get_output(&self, obj: ActionObject) -> Option<ActionOutput> {
247        self.outputs.read().await.get(&obj.id()).cloned()
248    }
249
250    /// Get the state object of a type.
251    pub fn get_state<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
252        self.state.get(&TypeId::of::<T>()).cloned().map(|state| {
253            state.downcast::<T>().unwrap()
254        })
255    }
256}
257
258/// A builder for a runtime.
259pub struct RuntimeBuilder {
260    ctx: Context,
261    state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>
262}
263
264impl RuntimeBuilder {
265    /// Create a new runtime builder.
266    pub fn new() -> Self {
267        Self {
268            ctx: Context::new(),
269            state: HashMap::new()
270        }
271    }
272
273    /// Add an action to the runtime.
274    pub async fn add_action(mut self, action: ActionObject) -> Self {
275        action.load_state(&mut self).await;
276        self.ctx.add_action(action);
277        self
278    }
279
280    /// Build the runtime.
281    pub fn build(self) -> Runtime {
282        Runtime {
283            ctx: self.ctx,
284            barriers: HashMap::new(),
285            outputs: Arc::new(RwLock::new(HashMap::new())),
286            state: self.state
287        }
288    }
289
290    /// Add a state object to the runtime.
291    pub fn add_state<T: Send + Sync + 'static>(&mut self, state: T) -> &mut Self {
292        self.state.insert(TypeId::of::<T>(), Arc::new(state));
293        self
294    }
295}
296
297impl Default for RuntimeBuilder {
298    fn default() -> Self {
299        Self::new()
300    }
301}