1use std::{
2 collections::{BTreeMap, VecDeque},
3 io::{stderr, BufWriter, Write},
4 sync::Arc,
5};
6
7use anyhow::Result;
8use chrono::{DateTime, Utc};
9use crossterm::{
10 cursor::MoveTo,
11 style::{Print, Stylize},
12 terminal::{Clear, ClearType, EnterAlternateScreen, LeaveAlternateScreen},
13};
14use flume::Receiver;
15use parking_lot::RwLock;
16use signal_hook::{
17 consts::{SIGINT, SIGTERM},
18 iterator::Signals,
19};
20use tokio::{
21 io::{AsyncBufReadExt, AsyncReadExt, BufReader},
22 process::Command,
23 sync::Semaphore,
24 task::JoinSet,
25};
26
27#[derive(serde::Serialize, serde::Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub struct MultiplexerResult {
30 pub metadata: MultiplexerResultMetadata,
31 pub tasks: BTreeMap<usize, MultiplexerResultDataTask>,
32}
33
34#[derive(serde::Serialize, serde::Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub struct MultiplexerResultMetadata {
37 pub started: DateTime<Utc>,
38 pub ended: DateTime<Utc>,
39}
40
41#[derive(serde::Serialize, serde::Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub struct MultiplexerResultDataTask {
44 pub stdout: String,
45}
46
47#[derive(Debug, Eq, PartialEq)]
48enum TaskStatusCompleted {
49 Success,
50 Failed(Option<i32>),
51}
52
53#[derive(Debug, Eq, PartialEq)]
54enum TaskStatus {
55 Pending,
56 Running,
57 Completed(TaskStatusCompleted),
58}
59
60enum TaskEvent {
61 Update { id: usize, status: TaskStatus },
62 Stderr { id: usize, line: String },
63 Stdout { id: usize, content: String },
64}
65
66struct Task {
67 command: String,
68 status: TaskStatus,
69 stderr: VecDeque<String>,
70 stdout: String,
71}
72
73pub struct Multiplexer {
74 program: Vec<String>,
75 stderr: usize,
76 tasks: BTreeMap<usize, RwLock<Task>>,
77 parallelism: usize,
78}
79
80impl Multiplexer {
81 pub fn new(program: Vec<String>, stderr: usize, tasks: Vec<String>, processes: usize) -> Self {
82 let mut task_map = BTreeMap::<usize, RwLock<Task>>::new();
83 for i in 0..tasks.len() {
84 task_map.insert(
85 i,
86 RwLock::new(Task {
87 command: tasks[i].clone(),
88 status: TaskStatus::Pending,
89 stderr: VecDeque::<_>::new(),
90 stdout: String::new(),
91 }),
92 );
93 }
94
95 Self {
96 program,
97 stderr,
98 tasks: task_map,
99 parallelism: processes,
100 }
101 }
102
103 pub async fn run(self) -> Result<MultiplexerResult> {
104 let time_start = Utc::now();
105 let (task_event_tx, task_event_rx) = flume::unbounded::<TaskEvent>();
106
107 let mut joins = JoinSet::new();
108 let budget = Arc::new(Semaphore::new(self.parallelism));
109 for command in self.tasks.iter() {
110 let report_channel = task_event_tx.clone();
111 let mut cmd_proc = Command::new(&self.program[0]);
113 for arg in &self.program[1..] {
115 cmd_proc.arg(arg);
116 }
117 cmd_proc.arg(&command.1.read().command);
119
120 cmd_proc.stdin(std::process::Stdio::null());
121 cmd_proc.stdout(std::process::Stdio::piped());
122 cmd_proc.stderr(std::process::Stdio::piped());
123
124 let task_id = command.0.clone();
126 let task_budget = budget.clone();
127 joins.spawn(async move {
128 let _seq_lock = task_budget.acquire().await;
129 let mut child_proc = cmd_proc.spawn().unwrap();
130 let _ = report_channel.send(TaskEvent::Update {
132 id: task_id.clone(),
133 status: TaskStatus::Running,
134 });
135
136 let stderr = child_proc.stderr.take().unwrap();
137 let mut stderr_reader = BufReader::new(stderr).lines();
138 while let Ok(Some(line)) = stderr_reader.next_line().await {
139 let _ = report_channel.send(TaskEvent::Stderr {
140 id: task_id.clone(),
141 line,
142 });
143 }
144
145 let stdout = child_proc.stdout.take().unwrap();
146 let mut stdout_out = String::new();
147 let mut stdout_reader = BufReader::new(stdout);
148 stdout_reader.read_to_string(&mut stdout_out).await.unwrap();
149 let _ = report_channel.send(TaskEvent::Stdout {
150 id: task_id.clone(),
151 content: stdout_out,
152 });
153
154 let exit_code = child_proc.wait().await.unwrap();
155 let status = if exit_code.success() {
156 TaskStatusCompleted::Success
157 } else {
158 TaskStatusCompleted::Failed(exit_code.code())
159 };
160 let _ = report_channel.send(TaskEvent::Update {
162 id: task_id.clone(),
163 status: TaskStatus::Completed(status),
164 });
165 });
166 }
167 drop(task_event_tx);
168
169 let mut signals = Signals::new([SIGINT, SIGTERM]).unwrap();
170 let signals_handle = signals.handle();
171
172 let abort_fut = tokio::spawn(async move { signals.wait() });
174 let command_fut = tokio::spawn(async move { while let Some(_) = joins.join_next().await {} });
176
177 let event_handler = TaskEventReporter {
178 rx: task_event_rx,
179 stderr: self.stderr,
180 tasks: &self.tasks,
181 };
182
183 tokio::select! {
184 _ = abort_fut => {
185 return Err(anyhow::anyhow!("user interrupt"));
186 }, _ = command_fut => {}, _ = event_handler.run() => {}, }
190 signals_handle.close();
191 let time_end = Utc::now();
192
193 let mut data = MultiplexerResult {
194 metadata: MultiplexerResultMetadata {
195 started: time_start,
196 ended: time_end,
197 },
198 tasks: BTreeMap::<_, _>::new(),
199 };
200 for t in self.tasks.into_iter() {
201 let task = t.1.into_inner();
202 data.tasks
203 .insert(t.0.clone(), MultiplexerResultDataTask { stdout: task.stdout });
204 }
205
206 Ok(data)
207 }
208}
209
210struct TaskEventReporter<'a> {
211 rx: Receiver<TaskEvent>,
212 stderr: usize,
213 tasks: &'a BTreeMap<usize, RwLock<Task>>,
214}
215
216impl<'a> TaskEventReporter<'a> {
217 pub async fn run(self) {
218 let mut remaining = self.tasks.len();
219 crossterm::execute!(std::io::stderr(), EnterAlternateScreen).unwrap();
220 for event in self.rx {
221 match event {
222 | TaskEvent::Update { id, status } => {
223 match &status {
224 | TaskStatus::Completed(_) => remaining -= 1,
225 | _ => {},
226 }
227 self.tasks.get(&id).unwrap().write().status = status;
228 },
229 | TaskEvent::Stderr { id, line } => {
230 let stderr = &mut self.tasks.get(&id).unwrap().write().stderr;
231 stderr.push_back(line);
232 if stderr.len() > self.stderr {
233 stderr.pop_front();
234 }
235 },
236 | TaskEvent::Stdout { id, content } => {
237 let task = &mut self.tasks.get(&id).unwrap().write();
238 task.stdout = content;
239 },
240 }
241
242 if remaining == 0 {
245 crossterm::execute!(std::io::stderr(), LeaveAlternateScreen).unwrap();
246 }
247 Self::draw(&self.tasks, remaining == 0);
248 }
249 }
250
251 fn draw(tasks: &BTreeMap<usize, RwLock<Task>>, completed: bool) {
252 let mut writer = BufWriter::new(stderr());
253 if !completed {
254 crossterm::queue!(writer, Clear(ClearType::All)).unwrap();
255 crossterm::queue!(writer, MoveTo(0, 0)).unwrap();
256 }
257
258 for item in tasks.iter() {
259 let task = item.1.read();
260 crossterm::queue!(writer, Print(format!("⇒ ({})\n", item.0))).unwrap();
261 let lines = task.command.lines();
262 crossterm::queue!(writer, Print(" ↳ Script:\n")).unwrap();
263 for l in lines {
264 crossterm::queue!(writer, Print(format!(" |> {}\n", l.trim()))).unwrap();
265 }
266 let status = match &task.status {
267 | TaskStatus::Pending => "PENDING".to_owned().yellow(),
268 | TaskStatus::Running => "RUNNING".to_owned().yellow(),
269 | TaskStatus::Completed(v) => {
270 match v {
271 | TaskStatusCompleted::Success => "SUCCESS (0)".to_owned().green(),
272 | TaskStatusCompleted::Failed(code) => {
273 format!(
274 "FAILED ({})",
275 code.map(|v| v.to_string()).unwrap_or("unknown".to_owned())
276 )
277 .red()
278 },
279 }
280 },
281 };
282 crossterm::queue!(writer, Print(" ↳ Status: ")).unwrap();
283 crossterm::queue!(writer, Print(status)).unwrap();
284 crossterm::queue!(writer, Print("\n")).unwrap();
285
286 if task.stderr.len() > 0 {
287 crossterm::queue!(writer, Print(" ↳ Stderr: \n")).unwrap();
288 for line in &task.stderr {
289 crossterm::queue!(writer, Print(format!(" |> {}\n", line))).unwrap();
290 }
291 }
292 }
293
294 crossterm::queue!(writer, Print("\n")).unwrap();
295 crossterm::queue!(writer, Print("Thinking...")).unwrap();
296 if completed {
297 crossterm::queue!(writer, Print(" DONE\n")).unwrap();
298 }
299 writer.flush().unwrap();
300 }
301}