barley_runtime/
runtime.rs1use tokio::sync::RwLock;
2use tokio::sync::Barrier;
3use tokio::task::JoinSet;
4
5use std::any::{Any, TypeId};
6use tracing::{debug, info, error};
7use std::{
8 sync::Arc,
9 collections::HashMap
10};
11
12use crate::Operation;
13use crate::{
14 ActionObject, Id,
15 ActionOutput,
16 ActionError,
17 context::Context
18};
19
20
21#[derive(Clone)]
36pub struct Runtime {
37 ctx: Context,
38 barriers: HashMap<Id, Arc<Barrier>>,
39 outputs: Arc<RwLock<HashMap<Id, ActionOutput>>>,
40 state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>
41}
42
43impl Runtime {
44 pub async fn perform(mut self) -> Result<(), ActionError> {
46 let actions = self.ctx.actions.clone();
47 let mut dependents: HashMap<Id, usize> = HashMap::new();
48
49 for action in actions.iter() {
54 dependents.insert(action.id, 0);
55
56 action.deps()
57 .iter()
58 .map(|dep| dep.id())
59 .for_each(|id| {
60 let count = dependents.entry(id).or_insert(0);
61 *count += 1;
62 });
63 }
64
65 for (id, dependents) in dependents.clone() {
69 if dependents == 0 {
70 continue;
71 }
72
73 let barrier = Arc::new(Barrier::new(dependents + 1));
74 self.barriers.insert(id, barrier);
75 }
76
77 let mut join_set: JoinSet<Result<(), ActionError>> = JoinSet::new();
78
79 debug!("Starting actions");
80 for action in actions {
81 let runtime_clone = self.clone();
82
83 let action = action.clone();
84
85 let deps = action.deps();
86
87 let barriers = deps
88 .iter()
89 .map(|dep| dep.id());
90
91 let barriers = barriers
92 .map(|id| self.barriers.get(&id).unwrap().clone())
93 .collect::<Vec<_>>();
94
95 let self_barriers = self.barriers.clone();
96
97 join_set.spawn(async move {
98 let self_barrier = self_barriers.get(&action.id).cloned();
99
100 for barrier in barriers {
101 barrier.wait().await;
102 }
103
104 let probe = action.probe(runtime_clone.clone()).await?;
105 if !probe.needs_run {
106 return Ok(())
107 }
108
109 let display_name = action.display_name();
110 info!("Starting action: {}", display_name);
111
112 let output = action.run(runtime_clone.clone(), Operation::Perform).await;
113
114 if let Err(err) = &output {
115 error!("Action failed: {}", display_name);
116 error!("Error: {}", err);
117
118 return Err(err.clone())
119 } else {
120 info!("Action finished: {}", display_name);
121 }
122
123 let output = output.unwrap();
124
125 if let Some(barrier) = self_barrier {
126 barrier.wait().await;
127 }
128
129 if let Some(output) = output {
130 runtime_clone.outputs.write().await.insert(action.id, output);
131 }
132
133 Ok(())
134 });
135 }
136
137 while let Some(result) = join_set.join_next().await {
138 match result {
139 Ok(Ok(())) => {},
140 Ok(Err(err)) => {
141 join_set.abort_all();
142
143 if let ActionError::ActionFailed(_, long) = err.clone() {
144 println!("{}", long);
145 }
146
147 return Err(err)
148 },
149 Err(_) => {
150 join_set.abort_all();
151
152 return Err(ActionError::InternalError("JOIN_SET_ERROR"))
153 }
154 }
155 }
156
157 Ok(())
158 }
159
160 pub async fn rollback(self) -> Result<(), ActionError> {
165 let actions = self.ctx.actions.clone();
166 let mut dependencies: HashMap<Id, Vec<Id>> = HashMap::new();
167
168 for action in actions.iter() {
172 if !action.probe(self.clone()).await?.can_rollback {
173 return Err(ActionError::InternalError("NO_ROLLBACK"))
174 }
175 }
176
177 for action in actions.iter() {
181 dependencies.insert(action.id, Vec::new());
182
183 action.deps()
184 .iter()
185 .map(|dep| dep.id())
186 .for_each(|id| {
187 let deps = dependencies.entry(id).or_insert(Vec::new());
188 deps.push(action.id);
189 });
190 }
191
192 let mut actions = actions;
194 actions.sort_by(|a, b| {
195 let a_deps = dependencies.get(&a.id).unwrap();
196 let b_deps = dependencies.get(&b.id).unwrap();
197
198 if a_deps.contains(&b.id) {
199 return std::cmp::Ordering::Greater
200 }
201
202 if b_deps.contains(&a.id) {
203 return std::cmp::Ordering::Less
204 }
205
206 std::cmp::Ordering::Equal
207 });
208
209 let mut join_set: JoinSet<Result<(), ActionError>> = JoinSet::new();
211
212 for action in actions {
213 let runtime_clone = self.clone();
214
215 join_set.spawn(async move {
216 action.run(runtime_clone.clone(), Operation::Rollback).await?;
217
218 Ok(())
219 });
220 }
221
222 while let Some(result) = join_set.join_next().await {
223 match result {
224 Ok(Ok(())) => {},
225 Ok(Err(err)) => {
226 join_set.abort_all();
227
228 if let ActionError::ActionFailed(_, long) = err.clone() {
229 println!("{}", long);
230 }
231
232 return Err(err)
233 },
234 Err(_) => {
235 join_set.abort_all();
236
237 return Err(ActionError::InternalError("JOIN_SET_ERROR"))
238 }
239 }
240 }
241
242 Ok(())
243 }
244
245 pub async fn get_output(&self, obj: ActionObject) -> Option<ActionOutput> {
247 self.outputs.read().await.get(&obj.id()).cloned()
248 }
249
250 pub fn get_state<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
252 self.state.get(&TypeId::of::<T>()).cloned().map(|state| {
253 state.downcast::<T>().unwrap()
254 })
255 }
256}
257
258pub struct RuntimeBuilder {
260 ctx: Context,
261 state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>
262}
263
264impl RuntimeBuilder {
265 pub fn new() -> Self {
267 Self {
268 ctx: Context::new(),
269 state: HashMap::new()
270 }
271 }
272
273 pub async fn add_action(mut self, action: ActionObject) -> Self {
275 action.load_state(&mut self).await;
276 self.ctx.add_action(action);
277 self
278 }
279
280 pub fn build(self) -> Runtime {
282 Runtime {
283 ctx: self.ctx,
284 barriers: HashMap::new(),
285 outputs: Arc::new(RwLock::new(HashMap::new())),
286 state: self.state
287 }
288 }
289
290 pub fn add_state<T: Send + Sync + 'static>(&mut self, state: T) -> &mut Self {
292 self.state.insert(TypeId::of::<T>(), Arc::new(state));
293 self
294 }
295}
296
297impl Default for RuntimeBuilder {
298 fn default() -> Self {
299 Self::new()
300 }
301}