1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
use futures::future::join_all;
use futures::future::try_join_all;
use tokio::sync::RwLock;
use tokio::sync::Barrier;
use anyhow::Result;
use tokio::task::JoinHandle;
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};

use std::{
    sync::Arc,
    collections::HashMap
};
use crate::{
    ActionObject, Id,
    ActionOutput,
    context::Context
};


/// The runtime for a workflow.
/// 
/// This struct is used to run a workflow. It contains
/// all of the actions that need to be run, and it
/// ensures that all dependencies are run before the
/// actions that depend on them.
/// 
/// # Example
/// 
/// ```
/// use barley_runtime::prelude::*;
/// 
/// let runtime = RuntimeBuilder::new().build();
/// ```
#[derive(Clone)]
pub struct Runtime {
    ctx: Arc<RwLock<Context>>,
    barriers: HashMap<Id, Arc<Barrier>>,
    outputs: Arc<RwLock<HashMap<Id, ActionOutput>>>,
    progress: Arc<RwLock<MultiProgress>>
}

impl Runtime {
    /// Run the workflow.
    pub async fn run(mut self) -> Result<()> {
        let actions = self.ctx.read().await.actions.clone();
        let mut dependents: HashMap<Id, usize> = HashMap::new();

        // Get the dependents for each action. For
        // example, if action A depends on action B,
        // then 1 action is dependent on B (A) and 0
        // actions are dependent on A.
        for action in actions.iter() {
            dependents.insert(action.id, 0);

            action.deps()
                .iter()
                .map(|dep| dep.id())
                .for_each(|id| {
                    let count = dependents.entry(id).or_insert(0);
                    *count += 1;
                });
        }

        // Create a barrier for each action that has
        // any dependents. The barrier will be used
        // to wait for the dependent actions to finish.
        for (id, dependents) in dependents.clone() {
            if dependents == 0 {
                continue;
            }

            let barrier = Arc::new(Barrier::new(dependents + 1));
            self.barriers.insert(id, barrier);
        }

        let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
        let bars = Arc::new(RwLock::new(Vec::new()));
        let bars_clone = bars.clone();

        let tick_loop = tokio::spawn(async move {
            loop {
                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
                bars_clone.write().await.iter().for_each(|bar: &ProgressBar| bar.tick());
            }
        });

        for action in actions {
            let runtime_clone = self.clone();
            let bars = bars.clone();

            let action = action.clone();

            let deps = action.deps();

            let barriers = deps
                .iter()
                .map(|dep| dep.id());

            let barriers = barriers
                .map(|id| self.barriers.get(&id).unwrap().clone())
                .collect::<Vec<_>>();

            let self_barriers = self.barriers.clone();

            handles.push(tokio::spawn(async move {
                let self_barrier = self_barriers.get(&action.id).cloned();

                for barrier in barriers {
                    barrier.wait().await;
                }

                if action.check(runtime_clone.clone()).await? {
                    return Ok(())
                }

                let display_name = action.display_name();

                let progress = runtime_clone.progress.write().await.add(ProgressBar::new_spinner());
                progress.set_style(ProgressStyle::default_spinner().template(" {spinner} [{elapsed_precise}] {wide_msg}")?);
                progress.set_message(display_name.clone());
                bars.write().await.push(progress.clone());

                let output = action.perform(runtime_clone.clone()).await;

                if output.is_err() {
                    progress.finish_with_message(format!("Error: {}", display_name));
                    return Err(output.err().unwrap())
                }

                progress.finish_and_clear();

                let output = output.unwrap();

                if let Some(barrier) = self_barrier {
                    barrier.wait().await;
                }

                if let Some(output) = output {
                    runtime_clone.outputs.write().await.insert(action.id, output);
                }

                Ok(())
            }));
        }

        let results = try_join_all(handles).await;
        tick_loop.abort();
        bars.write().await.iter().for_each(|bar: &ProgressBar| bar.finish());

        results?;

        Ok(())
    }

    /// Get the output of an action.
    pub async fn get_output(&self, obj: ActionObject) -> Option<ActionOutput> {
        self.outputs.read().await.get(&obj.id()).cloned()
    }
}

/// A builder for a runtime.
pub struct RuntimeBuilder {
    ctx: Context
}

impl RuntimeBuilder {
    /// Create a new runtime builder.
    pub fn new() -> Self {
        Self {
            ctx: Context::new()
        }
    }

    /// Add an action to the runtime.
    pub fn add_action(mut self, action: ActionObject) -> Self {
        self.ctx.add_action(action);
        self
    }

    /// Build the runtime.
    pub fn build(self) -> Runtime {
        Runtime {
            ctx: Arc::new(RwLock::new(self.ctx)),
            barriers: HashMap::new(),
            outputs: Arc::new(RwLock::new(HashMap::new())),
            progress: Arc::new(RwLock::new(MultiProgress::new()))
        }
    }
}

impl Default for RuntimeBuilder {
    fn default() -> Self {
        Self::new()
    }
}