hojicha_runtime/program/
command_executor.rs1use super::error_handler::{DefaultErrorHandler, ErrorHandler};
4use hojicha_core::core::Cmd;
5use hojicha_core::event::Event;
6use std::panic::{self, AssertUnwindSafe};
7use crate::panic_utils;
8use std::sync::{Arc, mpsc};
9use tokio::runtime::Runtime;
10
11#[derive(Clone)]
13pub struct CommandExecutor<M = ()> {
14 runtime: Arc<Runtime>,
15 error_handler: Arc<dyn ErrorHandler<M> + Send + Sync>,
16}
17
18impl<M> CommandExecutor<M>
19where
20 M: Clone + Send + 'static,
21{
22 pub fn new() -> std::io::Result<Self> {
24 Ok(Self {
25 runtime: Arc::new(Runtime::new()?),
26 error_handler: Arc::new(DefaultErrorHandler),
27 })
28 }
29
30 pub fn with_error_handler<H>(error_handler: H) -> std::io::Result<Self>
32 where
33 H: ErrorHandler<M> + Send + Sync + 'static,
34 {
35 Ok(Self {
36 runtime: Arc::new(Runtime::new()?),
37 error_handler: Arc::new(error_handler),
38 })
39 }
40
41 pub fn execute(&self, cmd: Cmd<M>, tx: mpsc::SyncSender<Event<M>>) {
43 if cmd.is_noop() {
44 } else if cmd.is_quit() {
46 let _ = tx.send(Event::Quit);
48 } else if cmd.is_exec_process() {
49 if let Some((_program, _args, _callback)) = cmd.take_exec_process() {
51 let _ = tx.send(Event::ExecProcess);
54 }
55 } else if cmd.is_batch() {
56 if let Some(cmds) = cmd.take_batch() {
58 self.execute_batch(cmds, tx);
59 }
60 } else if cmd.is_sequence() {
61 if let Some(cmds) = cmd.take_sequence() {
63 self.execute_sequence(cmds, tx);
64 }
65 } else if cmd.is_tick() {
66 if let Some((duration, callback)) = cmd.take_tick() {
68 let tx_clone = tx.clone();
69 self.runtime.spawn(async move {
70 tokio::time::sleep(duration).await;
71 let result = panic::catch_unwind(AssertUnwindSafe(|| callback()));
73 match result {
74 Ok(msg) => {
75 let _ = tx_clone.send(Event::User(msg));
76 }
77 Err(panic) => {
78 let panic_msg = panic_utils::format_panic_message(panic, "Tick callback panicked");
79 eprintln!("{}", panic_msg);
80 }
82 }
83 });
84 }
85 } else if cmd.is_every() {
86 if let Some((duration, callback)) = cmd.take_every() {
88 let tx_clone = tx.clone();
89 self.runtime.spawn(async move {
90 tokio::time::sleep(duration).await;
93 let result = panic::catch_unwind(AssertUnwindSafe(|| callback(std::time::Instant::now())));
95 match result {
96 Ok(msg) => {
97 let _ = tx_clone.send(Event::User(msg));
98 }
99 Err(panic) => {
100 let panic_msg = panic_utils::format_panic_message(panic, "Every callback panicked");
101 eprintln!("{}", panic_msg);
102 }
104 }
105 });
106 }
107 } else if cmd.is_async() {
108 if let Some(future) = cmd.take_async() {
110 let tx_clone = tx.clone();
111 self.runtime.spawn(async move {
112 use std::pin::Pin;
114 let mut future = future;
115 let future = unsafe { Pin::new_unchecked(&mut *future) };
116
117 if let Some(msg) = future.await {
119 let _ = tx_clone.send(Event::User(msg));
120 }
121 });
122 }
123 } else {
124 let tx_clone = tx.clone();
126 let error_handler = self.error_handler.clone();
127 self.runtime.spawn(async move {
128 let result = panic::catch_unwind(AssertUnwindSafe(|| cmd.execute()));
130
131 match result {
132 Ok(Ok(Some(msg))) => {
133 let _ = tx_clone.send(Event::User(msg));
134 }
135 Ok(Ok(None)) => {
136 }
138 Ok(Err(error)) => {
139 error_handler.handle_error(error, &tx_clone);
141 }
142 Err(panic) => {
143 let panic_msg = panic_utils::format_panic_message(panic, "Command execution panicked");
145 eprintln!("{}", panic_msg);
146 }
148 }
149 });
150 }
151 }
152
153 pub fn execute_batch(&self, commands: Vec<Cmd<M>>, tx: mpsc::SyncSender<Event<M>>) {
155 for cmd in commands {
157 self.execute(cmd, tx.clone());
159 }
160 }
161
162 pub fn execute_sequence(&self, commands: Vec<Cmd<M>>, tx: mpsc::SyncSender<Event<M>>) {
164 let tx_clone = tx.clone();
166 let error_handler = self.error_handler.clone();
167 self.runtime.spawn(async move {
168 for cmd in commands {
169 let tx_inner = tx_clone.clone();
170
171 if cmd.is_tick() {
173 if let Some((duration, callback)) = cmd.take_tick() {
174 tokio::time::sleep(duration).await;
175 let result = panic::catch_unwind(AssertUnwindSafe(|| callback()));
177 match result {
178 Ok(msg) => {
179 let _ = tx_inner.send(Event::User(msg));
180 }
181 Err(panic) => {
182 let panic_msg = if let Some(s) = panic.downcast_ref::<String>() {
183 s.clone()
184 } else if let Some(s) = panic.downcast_ref::<&str>() {
185 s.to_string()
186 } else {
187 "Unknown panic in tick callback".to_string()
188 };
189 eprintln!("Tick callback panicked: {}", panic_msg);
190 }
192 }
193 }
194 } else if cmd.is_every() {
195 if let Some((duration, callback)) = cmd.take_every() {
196 tokio::time::sleep(duration).await;
197 let result = panic::catch_unwind(AssertUnwindSafe(|| callback(std::time::Instant::now())));
199 match result {
200 Ok(msg) => {
201 let _ = tx_inner.send(Event::User(msg));
202 }
203 Err(panic) => {
204 let panic_msg = if let Some(s) = panic.downcast_ref::<String>() {
205 s.clone()
206 } else if let Some(s) = panic.downcast_ref::<&str>() {
207 s.to_string()
208 } else {
209 "Unknown panic in every callback".to_string()
210 };
211 eprintln!("Every callback panicked: {}", panic_msg);
212 }
214 }
215 }
216 } else {
217 let result = panic::catch_unwind(AssertUnwindSafe(|| cmd.execute()));
219 match result {
220 Ok(Ok(Some(msg))) => {
221 let _ = tx_inner.send(Event::User(msg));
222 }
223 Ok(Ok(None)) => {}
224 Ok(Err(error)) => {
225 error_handler.handle_error(error, &tx_inner);
227 }
228 Err(panic) => {
229 let panic_msg = panic_utils::format_panic_message(panic, "Sequence command panicked");
230 eprintln!("{}", panic_msg);
231 }
233 }
234 }
235 }
236 });
237 }
238
239 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
241 where
242 F: std::future::Future + Send + 'static,
243 F::Output: Send + 'static,
244 {
245 self.runtime.spawn(future)
246 }
247
248 pub fn block_on<F: std::future::Future>(&self, future: F) -> F::Output {
250 self.runtime.block_on(future)
251 }
252}
253
254impl<M> Default for CommandExecutor<M>
255where
256 M: Clone + Send + 'static,
257{
258 fn default() -> Self {
259 Self::new().expect("Failed to create runtime")
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::testing::{AsyncTestHarness, CmdTestExt};
267 use hojicha_core::commands;
268 use std::time::Duration;
269
270 #[derive(Debug, Clone, PartialEq)]
271 enum TestMsg {
272 Inc,
273 Dec,
274 Text(String),
275 }
276
277 #[test]
278 fn test_execute_custom_command() {
279 let harness = AsyncTestHarness::new();
281 let cmd = commands::custom(|| Some(TestMsg::Inc));
282
283 let messages = harness.execute_command(cmd);
284 assert_eq!(messages, vec![TestMsg::Inc]);
285 }
286
287 #[test]
288 fn test_execute_custom_command_raw() {
289 let executor = CommandExecutor::<TestMsg>::new().unwrap();
291 let (tx, rx) = mpsc::sync_channel(10);
292
293 let cmd = commands::custom(|| Some(TestMsg::Inc));
294 executor.execute(cmd, tx);
295
296 std::thread::sleep(Duration::from_millis(10));
298
299 let event = rx.try_recv().unwrap();
300 assert_eq!(event, Event::User(TestMsg::Inc));
301 }
302
303 #[test]
304 fn test_execute_quit_command() {
305 let executor = CommandExecutor::<TestMsg>::new().unwrap();
306 let (tx, rx) = mpsc::sync_channel(10);
307
308 let cmd: Cmd<TestMsg> = commands::quit();
309 executor.execute(cmd, tx);
310
311 let event = rx.recv_timeout(Duration::from_millis(100)).unwrap();
312 assert_eq!(event, Event::Quit);
313 }
314
315 #[test]
316 fn test_execute_batch_commands() {
317 let harness = AsyncTestHarness::new();
319
320 let batch = commands::batch(vec![
321 commands::custom(|| Some(TestMsg::Inc)),
322 commands::custom(|| Some(TestMsg::Dec)),
323 commands::custom(|| Some(TestMsg::Text("test".to_string()))),
324 ]);
325
326 let messages = harness.execute_command(batch);
327
328 assert_eq!(messages.len(), 3);
329 assert!(messages.contains(&TestMsg::Inc));
330 assert!(messages.contains(&TestMsg::Dec));
331 assert!(messages.contains(&TestMsg::Text("test".to_string())));
332 }
333
334 #[test]
335 fn test_execute_none_command() {
336 let executor = CommandExecutor::<TestMsg>::new().unwrap();
337 let (tx, rx) = mpsc::sync_channel(10);
338
339 let cmd: Cmd<TestMsg> = commands::custom(|| None);
342 executor.execute(cmd, tx);
343
344 std::thread::sleep(Duration::from_millis(10));
346
347 assert!(rx.try_recv().is_err());
349 }
350
351 #[test]
352 fn test_execute_tick_command() {
353 let harness = AsyncTestHarness::new();
355 let cmd = commands::tick(Duration::from_millis(10), || TestMsg::Inc);
356
357 let messages = harness.execute_command(cmd);
358 assert_eq!(messages, vec![TestMsg::Inc]);
359 }
360
361 #[test]
362 fn test_execute_tick_command_raw() {
363 let executor = CommandExecutor::<TestMsg>::new().unwrap();
365 let (tx, rx) = mpsc::sync_channel(10);
366
367 let cmd = commands::tick(Duration::from_millis(10), || TestMsg::Inc);
368 executor.execute(cmd, tx);
369
370 let event = rx.recv_timeout(Duration::from_millis(50)).unwrap();
372 if let Event::User(msg) = event {
373 assert_eq!(msg, TestMsg::Inc);
374 } else {
375 panic!("Expected User event");
376 }
377 }
378
379 #[test]
380 fn test_execute_sequence() {
381 let harness = AsyncTestHarness::new();
383
384 let seq = commands::sequence(vec![
385 commands::custom(|| Some(TestMsg::Inc)),
386 commands::custom(|| Some(TestMsg::Dec)),
387 ]);
388
389 let messages = harness.execute_and_wait(seq, Duration::from_millis(50));
390
391 assert_eq!(messages.len(), 2);
393 assert_eq!(messages[0], TestMsg::Inc);
394 assert_eq!(messages[1], TestMsg::Dec);
395 }
396
397 #[test]
398 fn test_multiple_executors() {
399 let executor1 = CommandExecutor::<TestMsg>::new().unwrap();
400 let executor2 = CommandExecutor::<TestMsg>::new().unwrap();
401 let (tx, rx) = mpsc::sync_channel(10);
402
403 executor1.execute(commands::custom(|| Some(TestMsg::Inc)), tx.clone());
404 executor2.execute(commands::custom(|| Some(TestMsg::Dec)), tx.clone());
405
406 std::thread::sleep(Duration::from_millis(50));
408
409 let mut events = Vec::new();
410 while let Ok(Event::User(msg)) = rx.try_recv() {
411 events.push(msg);
412 }
413
414 assert_eq!(events.len(), 2);
415 assert!(events.contains(&TestMsg::Inc));
416 assert!(events.contains(&TestMsg::Dec));
417 }
418}
419
420