1use super::error_handler::{DefaultErrorHandler, ErrorHandler};
4use crate::panic_utils;
5use crate::resource_limits::{ResourceLimits, ResourceMonitor};
6use crate::shared_runtime::shared_runtime;
7use hojicha_core::core::Cmd;
8use hojicha_core::event::Event;
9use log::error;
10use std::panic::{self, AssertUnwindSafe};
11use std::sync::atomic::AtomicUsize;
12use std::sync::{mpsc, Arc};
13use tokio::runtime::Runtime;
14
15#[derive(Clone)]
17pub struct CommandExecutor<M = ()> {
18 runtime: Arc<Runtime>,
19 error_handler: Arc<dyn ErrorHandler<M> + Send + Sync>,
20 resource_monitor: Arc<ResourceMonitor>,
21 _recursion_depth: Arc<AtomicUsize>,
22}
23
24impl<M> CommandExecutor<M>
25where
26 M: Clone + Send + 'static,
27{
28 pub fn new() -> std::io::Result<Self> {
30 Ok(Self {
31 runtime: shared_runtime(),
32 error_handler: Arc::new(DefaultErrorHandler),
33 resource_monitor: Arc::new(ResourceMonitor::new()),
34 _recursion_depth: Arc::new(AtomicUsize::new(0)),
35 })
36 }
37
38 pub fn with_error_handler<H>(error_handler: H) -> std::io::Result<Self>
40 where
41 H: ErrorHandler<M> + Send + Sync + 'static,
42 {
43 Ok(Self {
44 runtime: shared_runtime(),
45 error_handler: Arc::new(error_handler),
46 resource_monitor: Arc::new(ResourceMonitor::new()),
47 _recursion_depth: Arc::new(AtomicUsize::new(0)),
48 })
49 }
50
51 pub fn with_resource_limits(limits: ResourceLimits) -> std::io::Result<Self> {
53 Ok(Self {
54 runtime: shared_runtime(),
55 error_handler: Arc::new(DefaultErrorHandler),
56 resource_monitor: Arc::new(ResourceMonitor::with_limits(limits)),
57 _recursion_depth: Arc::new(AtomicUsize::new(0)),
58 })
59 }
60
61 pub fn resource_stats(&self) -> crate::resource_limits::ResourceStats {
63 self.resource_monitor.stats()
64 }
65
66 fn spawn_with_limits<F>(&self, f: F)
68 where
69 F: std::future::Future<Output = ()> + Send + 'static,
70 {
71 let monitor = self.resource_monitor.clone();
72 let runtime = self.runtime.clone();
73
74 runtime.spawn(async move {
76 match monitor.try_acquire_task_permit().await {
77 Ok(_permit) => {
78 f.await;
80 }
81 Err(e) => {
82 error!("Failed to spawn task: {}", e);
83 }
84 }
85 });
86 }
87
88 pub fn execute(&self, cmd: Cmd<M>, tx: &mpsc::SyncSender<Event<M>>) {
90 if cmd.is_noop() {
91 } else if cmd.is_quit() {
93 let _ = tx.send(Event::Quit);
95 } else if cmd.is_exec_process() {
96 if let Some((_program, _args, _callback)) = cmd.take_exec_process() {
98 let _ = tx.send(Event::ExecProcess);
101 }
102 } else if cmd.is_batch() {
103 if let Some(cmds) = cmd.take_batch() {
105 self.execute_batch(cmds, tx);
106 }
107 } else if cmd.is_sequence() {
108 if let Some(cmds) = cmd.take_sequence() {
110 self.execute_sequence(cmds, tx);
111 }
112 } else if cmd.is_tick() {
113 if let Some((duration, callback)) = cmd.take_tick() {
115 let tx_clone = tx.clone();
116 self.spawn_with_limits(async move {
117 tokio::time::sleep(duration).await;
118 let result = panic::catch_unwind(AssertUnwindSafe(callback));
120 match result {
121 Ok(msg) => {
122 let _ = tx_clone.send(Event::User(msg));
123 }
124 Err(panic) => {
125 let panic_msg =
126 panic_utils::format_panic_message(panic, "Tick callback panicked");
127 eprintln!("{}", panic_msg);
128 }
130 }
131 });
132 }
133 } else if cmd.is_every() {
134 if let Some((duration, callback)) = cmd.take_every() {
136 let tx_clone = tx.clone();
137 self.spawn_with_limits(async move {
138 tokio::time::sleep(duration).await;
141 let result = panic::catch_unwind(AssertUnwindSafe(|| {
143 callback(std::time::Instant::now())
144 }));
145 match result {
146 Ok(msg) => {
147 let _ = tx_clone.send(Event::User(msg));
148 }
149 Err(panic) => {
150 let panic_msg =
151 panic_utils::format_panic_message(panic, "Every callback panicked");
152 eprintln!("{}", panic_msg);
153 }
155 }
156 });
157 }
158 } else if cmd.is_async() {
159 if let Some(future) = cmd.take_async() {
161 let tx_clone = tx.clone();
162 self.spawn_with_limits(async move {
163 use std::pin::Pin;
167 let mut future = future;
168 let future = unsafe { Pin::new_unchecked(&mut *future) };
169
170 if let Some(msg) = future.await {
172 let _ = tx_clone.send(Event::User(msg));
173 }
174 });
175 }
176 } else {
177 let tx_clone = tx.clone();
179 let error_handler = self.error_handler.clone();
180 self.spawn_with_limits(async move {
181 let result = panic::catch_unwind(AssertUnwindSafe(|| cmd.execute()));
183
184 match result {
185 Ok(Ok(Some(msg))) => {
186 let _ = tx_clone.send(Event::User(msg));
187 }
188 Ok(Ok(None)) => {
189 }
191 Ok(Err(error)) => {
192 error_handler.handle_error(error, &tx_clone);
194 }
195 Err(panic) => {
196 let panic_msg =
198 panic_utils::format_panic_message(panic, "Command execution panicked");
199 eprintln!("{}", panic_msg);
200 }
202 }
203 });
204 }
205 }
206
207 pub fn execute_batch(&self, commands: Vec<Cmd<M>>, tx: &mpsc::SyncSender<Event<M>>) {
209 for cmd in commands {
211 self.execute(cmd, tx);
213 }
214 }
215
216 pub fn execute_sequence(&self, commands: Vec<Cmd<M>>, tx: &mpsc::SyncSender<Event<M>>) {
218 let tx_clone = tx.clone();
220 let error_handler = self.error_handler.clone();
221 self.spawn_with_limits(async move {
222 for cmd in commands {
223 let tx_inner = tx_clone.clone();
224
225 if cmd.is_tick() {
227 if let Some((duration, callback)) = cmd.take_tick() {
228 tokio::time::sleep(duration).await;
229 let result = panic::catch_unwind(AssertUnwindSafe(callback));
231 match result {
232 Ok(msg) => {
233 let _ = tx_inner.send(Event::User(msg));
234 }
235 Err(panic) => {
236 let panic_msg = if let Some(s) = panic.downcast_ref::<String>() {
237 s.clone()
238 } else if let Some(s) = panic.downcast_ref::<&str>() {
239 s.to_string()
240 } else {
241 "Unknown panic in tick callback".to_string()
242 };
243 eprintln!("Tick callback panicked: {}", panic_msg);
244 }
246 }
247 }
248 } else if cmd.is_every() {
249 if let Some((duration, callback)) = cmd.take_every() {
250 tokio::time::sleep(duration).await;
251 let result = panic::catch_unwind(AssertUnwindSafe(|| {
253 callback(std::time::Instant::now())
254 }));
255 match result {
256 Ok(msg) => {
257 let _ = tx_inner.send(Event::User(msg));
258 }
259 Err(panic) => {
260 let panic_msg = if let Some(s) = panic.downcast_ref::<String>() {
261 s.clone()
262 } else if let Some(s) = panic.downcast_ref::<&str>() {
263 s.to_string()
264 } else {
265 "Unknown panic in every callback".to_string()
266 };
267 eprintln!("Every callback panicked: {}", panic_msg);
268 }
270 }
271 }
272 } else {
273 let result = panic::catch_unwind(AssertUnwindSafe(|| cmd.execute()));
275 match result {
276 Ok(Ok(Some(msg))) => {
277 let _ = tx_inner.send(Event::User(msg));
278 }
279 Ok(Ok(None)) => {}
280 Ok(Err(error)) => {
281 error_handler.handle_error(error, &tx_inner);
283 }
284 Err(panic) => {
285 let panic_msg = panic_utils::format_panic_message(
286 panic,
287 "Sequence command panicked",
288 );
289 eprintln!("{}", panic_msg);
290 }
292 }
293 }
294 }
295 });
296 }
297
298 pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
300 where
301 F: std::future::Future + Send + 'static,
302 F::Output: Send + 'static,
303 {
304 let monitor = self.resource_monitor.clone();
305 let runtime = self.runtime.clone();
306
307 runtime.spawn(async move {
309 match monitor.try_acquire_task_permit().await {
310 Ok(_permit) => {
311 future.await
313 }
314 Err(e) => {
315 error!("Failed to spawn task due to resource limits: {}", e);
316 std::future::pending().await
325 }
326 }
327 })
328 }
329
330 pub fn block_on<F: std::future::Future>(&self, future: F) -> F::Output {
332 self.runtime.block_on(future)
333 }
334}
335
336impl<M> Default for CommandExecutor<M>
337where
338 M: Clone + Send + 'static,
339{
340 fn default() -> Self {
341 Self::new().expect("Failed to create runtime")
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::testing::AsyncTestHarness;
349 use hojicha_core::commands;
350 use std::time::Duration;
351
352 #[derive(Debug, Clone, PartialEq)]
353 enum TestMsg {
354 Inc,
355 Dec,
356 Text(String),
357 }
358
359 #[test]
360 fn test_execute_custom_command() {
361 let harness = AsyncTestHarness::new();
363 let cmd = commands::custom(|| Some(TestMsg::Inc));
364
365 let messages = harness.execute_command(cmd);
366 assert_eq!(messages, vec![TestMsg::Inc]);
367 }
368
369 #[test]
370 fn test_execute_custom_command_raw() {
371 let executor = CommandExecutor::<TestMsg>::new().unwrap();
373 let (tx, rx) = mpsc::sync_channel(10);
374
375 let cmd = commands::custom(|| Some(TestMsg::Inc));
376 executor.execute(cmd, &tx);
377
378 std::thread::sleep(Duration::from_millis(10));
380
381 let event = rx.try_recv().unwrap();
382 assert_eq!(event, Event::User(TestMsg::Inc));
383 }
384
385 #[test]
386 fn test_execute_quit_command() {
387 let executor = CommandExecutor::<TestMsg>::new().unwrap();
388 let (tx, rx) = mpsc::sync_channel(10);
389
390 let cmd: Cmd<TestMsg> = commands::quit();
391 executor.execute(cmd, &tx);
392
393 let event = rx.recv_timeout(Duration::from_millis(100)).unwrap();
394 assert_eq!(event, Event::Quit);
395 }
396
397 #[test]
398 fn test_execute_batch_commands() {
399 let harness = AsyncTestHarness::new();
401
402 let batch = commands::batch(vec![
403 commands::custom(|| Some(TestMsg::Inc)),
404 commands::custom(|| Some(TestMsg::Dec)),
405 commands::custom(|| Some(TestMsg::Text("test".to_string()))),
406 ]);
407
408 let messages = harness.execute_command(batch);
409
410 assert_eq!(messages.len(), 3);
411 assert!(messages.contains(&TestMsg::Inc));
412 assert!(messages.contains(&TestMsg::Dec));
413 assert!(messages.contains(&TestMsg::Text("test".to_string())));
414 }
415
416 #[test]
417 fn test_execute_none_command() {
418 let executor = CommandExecutor::<TestMsg>::new().unwrap();
419 let (tx, rx) = mpsc::sync_channel(10);
420
421 let cmd: Cmd<TestMsg> = commands::custom(|| None);
424 executor.execute(cmd, &tx);
425
426 std::thread::sleep(Duration::from_millis(10));
428
429 assert!(rx.try_recv().is_err());
431 }
432
433 #[test]
434 fn test_execute_tick_command() {
435 let harness = AsyncTestHarness::new();
437 let cmd = commands::tick(Duration::from_millis(10), || TestMsg::Inc);
438
439 let messages = harness.execute_command(cmd);
440 assert_eq!(messages, vec![TestMsg::Inc]);
441 }
442
443 #[test]
444 fn test_execute_tick_command_raw() {
445 let executor = CommandExecutor::<TestMsg>::new().unwrap();
447 let (tx, rx) = mpsc::sync_channel(10);
448
449 let cmd = commands::tick(Duration::from_millis(10), || TestMsg::Inc);
450 executor.execute(cmd, &tx);
451
452 let event = rx.recv_timeout(Duration::from_millis(50)).unwrap();
454 if let Event::User(msg) = event {
455 assert_eq!(msg, TestMsg::Inc);
456 } else {
457 panic!("Expected User event");
458 }
459 }
460
461 #[test]
462 fn test_execute_sequence() {
463 let harness = AsyncTestHarness::new();
465
466 let seq = commands::sequence(vec![
467 commands::custom(|| Some(TestMsg::Inc)),
468 commands::custom(|| Some(TestMsg::Dec)),
469 ]);
470
471 let messages = harness.execute_and_wait(seq, Duration::from_millis(50));
472
473 assert_eq!(messages.len(), 2);
475 assert_eq!(messages[0], TestMsg::Inc);
476 assert_eq!(messages[1], TestMsg::Dec);
477 }
478
479 #[test]
480 fn test_multiple_executors() {
481 let executor1 = CommandExecutor::<TestMsg>::new().unwrap();
482 let executor2 = CommandExecutor::<TestMsg>::new().unwrap();
483 let (tx, rx) = mpsc::sync_channel(10);
484
485 executor1.execute(commands::custom(|| Some(TestMsg::Inc)), &tx);
486 executor2.execute(commands::custom(|| Some(TestMsg::Dec)), &tx);
487
488 std::thread::sleep(Duration::from_millis(50));
490
491 let mut events = Vec::new();
492 while let Ok(Event::User(msg)) = rx.try_recv() {
493 events.push(msg);
494 }
495
496 assert_eq!(events.len(), 2);
497 assert!(events.contains(&TestMsg::Inc));
498 assert!(events.contains(&TestMsg::Dec));
499 }
500}