bobr/
multiplexer.rs

1use std::{
2    collections::{BTreeMap, VecDeque},
3    io::{stderr, BufWriter, Write},
4    sync::Arc,
5};
6
7use anyhow::Result;
8use chrono::{DateTime, Utc};
9use crossterm::{
10    cursor::MoveTo,
11    style::{Print, Stylize},
12    terminal::{Clear, ClearType, EnterAlternateScreen, LeaveAlternateScreen},
13};
14use flume::Receiver;
15use parking_lot::RwLock;
16use signal_hook::{
17    consts::{SIGINT, SIGTERM},
18    iterator::Signals,
19};
20use tokio::{
21    io::{AsyncBufReadExt, AsyncReadExt, BufReader},
22    process::Command,
23    sync::Semaphore,
24    task::JoinSet,
25};
26
27#[derive(serde::Serialize, serde::Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub struct MultiplexerResult {
30    pub metadata: MultiplexerResultMetadata,
31    pub tasks: BTreeMap<usize, MultiplexerResultDataTask>,
32}
33
34#[derive(serde::Serialize, serde::Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub struct MultiplexerResultMetadata {
37    pub started: DateTime<Utc>,
38    pub ended: DateTime<Utc>,
39}
40
41#[derive(serde::Serialize, serde::Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub struct MultiplexerResultDataTask {
44    pub stdout: String,
45}
46
47#[derive(Debug, Eq, PartialEq)]
48enum TaskStatusCompleted {
49    Success,
50    Failed(Option<i32>),
51}
52
53#[derive(Debug, Eq, PartialEq)]
54enum TaskStatus {
55    Pending,
56    Running,
57    Completed(TaskStatusCompleted),
58}
59
60enum TaskEvent {
61    Update { id: usize, status: TaskStatus },
62    Stderr { id: usize, line: String },
63    Stdout { id: usize, content: String },
64}
65
66struct Task {
67    command: String,
68    status: TaskStatus,
69    stderr: VecDeque<String>,
70    stdout: String,
71}
72
73pub struct Multiplexer {
74    program: Vec<String>,
75    stderr: usize,
76    tasks: BTreeMap<usize, RwLock<Task>>,
77    parallelism: usize,
78}
79
80impl Multiplexer {
81    pub fn new(program: Vec<String>, stderr: usize, tasks: Vec<String>, processes: usize) -> Self {
82        let mut task_map = BTreeMap::<usize, RwLock<Task>>::new();
83        for i in 0..tasks.len() {
84            task_map.insert(
85                i,
86                RwLock::new(Task {
87                    command: tasks[i].clone(),
88                    status: TaskStatus::Pending,
89                    stderr: VecDeque::<_>::new(),
90                    stdout: String::new(),
91                }),
92            );
93        }
94
95        Self {
96            program,
97            stderr,
98            tasks: task_map,
99            parallelism: processes,
100        }
101    }
102
103    pub async fn run(self) -> Result<MultiplexerResult> {
104        let time_start = Utc::now();
105        let (task_event_tx, task_event_rx) = flume::unbounded::<TaskEvent>();
106
107        let mut joins = JoinSet::new();
108        let budget = Arc::new(Semaphore::new(self.parallelism));
109        for command in self.tasks.iter() {
110            let report_channel = task_event_tx.clone();
111            // first item is shell to execute commands in (like "/bin/sh")
112            let mut cmd_proc = Command::new(&self.program[0]);
113            // remaining items are arguments to shell (like "-c")
114            for arg in &self.program[1..] {
115                cmd_proc.arg(arg);
116            }
117            // final argument is the command itself
118            cmd_proc.arg(&command.1.read().command);
119
120            cmd_proc.stdin(std::process::Stdio::null());
121            cmd_proc.stdout(std::process::Stdio::piped());
122            cmd_proc.stderr(std::process::Stdio::piped());
123
124            // spawn child process as member of JoinSet
125            let task_id = command.0.clone();
126            let task_budget = budget.clone();
127            joins.spawn(async move {
128                let _seq_lock = task_budget.acquire().await;
129                let mut child_proc = cmd_proc.spawn().unwrap();
130                // ignore error
131                let _ = report_channel.send(TaskEvent::Update {
132                    id: task_id.clone(),
133                    status: TaskStatus::Running,
134                });
135
136                let stderr = child_proc.stderr.take().unwrap();
137                let mut stderr_reader = BufReader::new(stderr).lines();
138                while let Ok(Some(line)) = stderr_reader.next_line().await {
139                    let _ = report_channel.send(TaskEvent::Stderr {
140                        id: task_id.clone(),
141                        line,
142                    });
143                }
144
145                let stdout = child_proc.stdout.take().unwrap();
146                let mut stdout_out = String::new();
147                let mut stdout_reader = BufReader::new(stdout);
148                stdout_reader.read_to_string(&mut stdout_out).await.unwrap();
149                let _ = report_channel.send(TaskEvent::Stdout {
150                    id: task_id.clone(),
151                    content: stdout_out,
152                });
153
154                let exit_code = child_proc.wait().await.unwrap();
155                let status = if exit_code.success() {
156                    TaskStatusCompleted::Success
157                } else {
158                    TaskStatusCompleted::Failed(exit_code.code())
159                };
160                // ignore error
161                let _ = report_channel.send(TaskEvent::Update {
162                    id: task_id.clone(),
163                    status: TaskStatus::Completed(status),
164                });
165            });
166        }
167        drop(task_event_tx);
168
169        let mut signals = Signals::new([SIGINT, SIGTERM]).unwrap();
170        let signals_handle = signals.handle();
171
172        // task handling abort signals
173        let abort_fut = tokio::spawn(async move { signals.wait() });
174        // task handling command execution
175        let command_fut = tokio::spawn(async move { while let Some(_) = joins.join_next().await {} });
176
177        let event_handler = TaskEventReporter {
178            rx: task_event_rx,
179            stderr: self.stderr,
180            tasks: &self.tasks,
181        };
182
183        tokio::select! {
184            _ = abort_fut => {
185                return Err(anyhow::anyhow!("user interrupt"));
186            }, // abort signal was received
187            _ = command_fut => {}, // all tasks were executed
188            _ = event_handler.run() => {}, // reporting task failed
189        }
190        signals_handle.close();
191        let time_end = Utc::now();
192
193        let mut data = MultiplexerResult {
194            metadata: MultiplexerResultMetadata {
195                started: time_start,
196                ended: time_end,
197            },
198            tasks: BTreeMap::<_, _>::new(),
199        };
200        for t in self.tasks.into_iter() {
201            let task = t.1.into_inner();
202            data.tasks
203                .insert(t.0.clone(), MultiplexerResultDataTask { stdout: task.stdout });
204        }
205
206        Ok(data)
207    }
208}
209
210struct TaskEventReporter<'a> {
211    rx: Receiver<TaskEvent>,
212    stderr: usize,
213    tasks: &'a BTreeMap<usize, RwLock<Task>>,
214}
215
216impl<'a> TaskEventReporter<'a> {
217    pub async fn run(self) {
218        let mut remaining = self.tasks.len();
219        crossterm::execute!(std::io::stderr(), EnterAlternateScreen).unwrap();
220        for event in self.rx {
221            match event {
222                | TaskEvent::Update { id, status } => {
223                    match &status {
224                        | TaskStatus::Completed(_) => remaining -= 1,
225                        | _ => {},
226                    }
227                    self.tasks.get(&id).unwrap().write().status = status;
228                },
229                | TaskEvent::Stderr { id, line } => {
230                    let stderr = &mut self.tasks.get(&id).unwrap().write().stderr;
231                    stderr.push_back(line);
232                    if stderr.len() > self.stderr {
233                        stderr.pop_front();
234                    }
235                },
236                | TaskEvent::Stdout { id, content } => {
237                    let task = &mut self.tasks.get(&id).unwrap().write();
238                    task.stdout = content;
239                },
240            }
241
242            // last should be printed to stderr, therefore exit alternate screen before last
243            // draw
244            if remaining == 0 {
245                crossterm::execute!(std::io::stderr(), LeaveAlternateScreen).unwrap();
246            }
247            Self::draw(&self.tasks, remaining == 0);
248        }
249    }
250
251    fn draw(tasks: &BTreeMap<usize, RwLock<Task>>, completed: bool) {
252        let mut writer = BufWriter::new(stderr());
253        if !completed {
254            crossterm::queue!(writer, Clear(ClearType::All)).unwrap();
255            crossterm::queue!(writer, MoveTo(0, 0)).unwrap();
256        }
257
258        for item in tasks.iter() {
259            let task = item.1.read();
260            crossterm::queue!(writer, Print(format!("⇒ ({})\n", item.0))).unwrap();
261            let lines = task.command.lines();
262            crossterm::queue!(writer, Print(" ↳ Script:\n")).unwrap();
263            for l in lines {
264                crossterm::queue!(writer, Print(format!("   |> {}\n", l.trim()))).unwrap();
265            }
266            let status = match &task.status {
267                | TaskStatus::Pending => "PENDING".to_owned().yellow(),
268                | TaskStatus::Running => "RUNNING".to_owned().yellow(),
269                | TaskStatus::Completed(v) => {
270                    match v {
271                        | TaskStatusCompleted::Success => "SUCCESS (0)".to_owned().green(),
272                        | TaskStatusCompleted::Failed(code) => {
273                            format!(
274                                "FAILED ({})",
275                                code.map(|v| v.to_string()).unwrap_or("unknown".to_owned())
276                            )
277                            .red()
278                        },
279                    }
280                },
281            };
282            crossterm::queue!(writer, Print(" ↳ Status: ")).unwrap();
283            crossterm::queue!(writer, Print(status)).unwrap();
284            crossterm::queue!(writer, Print("\n")).unwrap();
285
286            if task.stderr.len() > 0 {
287                crossterm::queue!(writer, Print(" ↳ Stderr: \n")).unwrap();
288                for line in &task.stderr {
289                    crossterm::queue!(writer, Print(format!("   |> {}\n", line))).unwrap();
290                }
291            }
292        }
293
294        crossterm::queue!(writer, Print("\n")).unwrap();
295        crossterm::queue!(writer, Print("Thinking...")).unwrap();
296        if completed {
297            crossterm::queue!(writer, Print(" DONE\n")).unwrap();
298        }
299        writer.flush().unwrap();
300    }
301}