Skip to main content

egui_cha/
cmd.rs

1//! Command type for side effects (async tasks, HTTP, timers, etc.)
2
3use std::future::Future;
4use std::pin::Pin;
5
6/// A command representing a side effect to be executed
7///
8/// Commands are declarative descriptions of side effects. The runtime
9/// executes them and feeds resulting messages back into the update loop.
10#[derive(Default)]
11pub enum Cmd<Msg> {
12    /// No side effect
13    #[default]
14    None,
15
16    /// Multiple commands to execute
17    Batch(Vec<Cmd<Msg>>),
18
19    /// An async task that produces a message
20    Task(Pin<Box<dyn Future<Output = Msg> + Send + 'static>>),
21
22    /// Emit a message immediately (next frame)
23    Msg(Msg),
24}
25
26impl<Msg> Cmd<Msg> {
27    /// Create an empty command (no side effect)
28    #[inline]
29    pub fn none() -> Self {
30        Cmd::None
31    }
32
33    /// Create a batch of commands
34    pub fn batch(cmds: impl IntoIterator<Item = Cmd<Msg>>) -> Self {
35        let cmds: Vec<_> = cmds.into_iter().collect();
36        if cmds.is_empty() {
37            Cmd::None
38        } else if cmds.len() == 1 {
39            cmds.into_iter().next().unwrap()
40        } else {
41            Cmd::Batch(cmds)
42        }
43    }
44
45    /// Create a command from an async task
46    pub fn task<F>(future: F) -> Self
47    where
48        F: Future<Output = Msg> + Send + 'static,
49    {
50        Cmd::Task(Box::pin(future))
51    }
52
53    /// Create a command that emits a message immediately
54    pub fn msg(msg: Msg) -> Self {
55        Cmd::Msg(msg)
56    }
57
58    /// Map the message type
59    pub fn map<F, NewMsg>(self, f: F) -> Cmd<NewMsg>
60    where
61        F: Fn(Msg) -> NewMsg + Send + Sync + Clone + 'static,
62        Msg: Send + 'static,
63        NewMsg: Send + 'static,
64    {
65        match self {
66            Cmd::None => Cmd::None,
67            Cmd::Batch(cmds) => Cmd::Batch(cmds.into_iter().map(|c| c.map(f.clone())).collect()),
68            Cmd::Task(fut) => {
69                let f = f.clone();
70                Cmd::Task(Box::pin(async move { f(fut.await) }))
71            }
72            Cmd::Msg(msg) => Cmd::Msg(f(msg)),
73        }
74    }
75}
76
77// ============================================================
78// Test helpers
79// ============================================================
80
81impl<Msg> Cmd<Msg> {
82    /// Check if this is Cmd::None
83    #[inline]
84    pub fn is_none(&self) -> bool {
85        matches!(self, Cmd::None)
86    }
87
88    /// Check if this is Cmd::Task
89    #[inline]
90    pub fn is_task(&self) -> bool {
91        matches!(self, Cmd::Task(_))
92    }
93
94    /// Check if this is Cmd::Msg
95    #[inline]
96    pub fn is_msg(&self) -> bool {
97        matches!(self, Cmd::Msg(_))
98    }
99
100    /// Check if this is Cmd::Batch
101    #[inline]
102    pub fn is_batch(&self) -> bool {
103        matches!(self, Cmd::Batch(_))
104    }
105
106    /// Get the message if this is Cmd::Msg, panics otherwise
107    ///
108    /// # Panics
109    /// Panics if the command is not Cmd::Msg
110    pub fn unwrap_msg(self) -> Msg {
111        match self {
112            Cmd::Msg(msg) => msg,
113            Cmd::None => panic!("called unwrap_msg on Cmd::None"),
114            Cmd::Task(_) => panic!("called unwrap_msg on Cmd::Task"),
115            Cmd::Batch(_) => panic!("called unwrap_msg on Cmd::Batch"),
116        }
117    }
118
119    /// Get the message if this is Cmd::Msg
120    pub fn as_msg(&self) -> Option<&Msg> {
121        match self {
122            Cmd::Msg(msg) => Some(msg),
123            _ => None,
124        }
125    }
126
127    /// Get the batch if this is Cmd::Batch
128    pub fn as_batch(&self) -> Option<&[Cmd<Msg>]> {
129        match self {
130            Cmd::Batch(cmds) => Some(cmds),
131            _ => None,
132        }
133    }
134
135    /// Get the number of commands (1 for non-batch, n for batch, 0 for none)
136    pub fn len(&self) -> usize {
137        match self {
138            Cmd::None => 0,
139            Cmd::Batch(cmds) => cmds.len(),
140            _ => 1,
141        }
142    }
143
144    /// Check if empty (only true for Cmd::None)
145    pub fn is_empty(&self) -> bool {
146        matches!(self, Cmd::None)
147    }
148}
149
150impl<Msg: PartialEq> Cmd<Msg> {
151    /// Check if this is Cmd::Msg with a specific message
152    pub fn is_msg_eq(&self, expected: &Msg) -> bool {
153        match self {
154            Cmd::Msg(msg) => msg == expected,
155            _ => false,
156        }
157    }
158}
159
160impl<Msg: std::fmt::Debug> Cmd<Msg> {
161    /// Assert this is Cmd::None, panics with debug info otherwise
162    #[track_caller]
163    pub fn assert_none(&self) {
164        assert!(
165            self.is_none(),
166            "expected Cmd::None, got {:?}",
167            self.variant_name()
168        );
169    }
170
171    /// Assert this is Cmd::Task, panics with debug info otherwise
172    #[track_caller]
173    pub fn assert_task(&self) {
174        assert!(
175            self.is_task(),
176            "expected Cmd::Task, got {:?}",
177            self.variant_name()
178        );
179    }
180
181    /// Assert this is Cmd::Msg, panics with debug info otherwise
182    #[track_caller]
183    pub fn assert_msg(&self) {
184        assert!(
185            self.is_msg(),
186            "expected Cmd::Msg, got {:?}",
187            self.variant_name()
188        );
189    }
190
191    fn variant_name(&self) -> &'static str {
192        match self {
193            Cmd::None => "Cmd::None",
194            Cmd::Task(_) => "Cmd::Task",
195            Cmd::Msg(_) => "Cmd::Msg",
196            Cmd::Batch(_) => "Cmd::Batch",
197        }
198    }
199}
200
201// Utility functions for common patterns
202impl<Msg: Send + 'static> Cmd<Msg> {
203    /// Create a command from an async task that returns Result
204    ///
205    /// # Example
206    /// ```ignore
207    /// Cmd::try_task(
208    ///     async { fetch_user(id).await },
209    ///     |user| Msg::UserLoaded(user),
210    ///     |err| Msg::Error(err.to_string()),
211    /// )
212    /// ```
213    pub fn try_task<F, T, E, FnOk, FnErr>(future: F, on_ok: FnOk, on_err: FnErr) -> Self
214    where
215        F: Future<Output = Result<T, E>> + Send + 'static,
216        FnOk: FnOnce(T) -> Msg + Send + 'static,
217        FnErr: FnOnce(E) -> Msg + Send + 'static,
218    {
219        Cmd::task(async move {
220            match future.await {
221                Ok(value) => on_ok(value),
222                Err(err) => on_err(err),
223            }
224        })
225    }
226
227    /// Create a command from a Result, converting to Msg immediately
228    ///
229    /// # Example
230    /// ```ignore
231    /// Cmd::from_result(
232    ///     parse_config(),
233    ///     |config| Msg::ConfigLoaded(config),
234    ///     |err| Msg::Error(err.to_string()),
235    /// )
236    /// ```
237    pub fn from_result<T, E, FnOk, FnErr>(result: Result<T, E>, on_ok: FnOk, on_err: FnErr) -> Self
238    where
239        FnOk: FnOnce(T) -> Msg,
240        FnErr: FnOnce(E) -> Msg,
241    {
242        match result {
243            Ok(value) => Cmd::Msg(on_ok(value)),
244            Err(err) => Cmd::Msg(on_err(err)),
245        }
246    }
247}
248
249// Tokio-dependent utility functions
250#[cfg(feature = "tokio")]
251impl<Msg: Send + 'static> Cmd<Msg> {
252    /// Create a delayed message (timer)
253    ///
254    /// Requires the `tokio` feature.
255    pub fn delay(duration: std::time::Duration, msg: Msg) -> Self
256    where
257        Msg: Clone,
258    {
259        Cmd::task(async move {
260            tokio::time::sleep(duration).await;
261            msg
262        })
263    }
264
265    /// Create an async task with a timeout
266    ///
267    /// If the task completes before the timeout, `on_ok` is called with the result.
268    /// If the timeout expires first, `on_timeout` is returned.
269    ///
270    /// # Example
271    /// ```ignore
272    /// Cmd::with_timeout(
273    ///     Duration::from_secs(5),
274    ///     fetch_user(user_id),
275    ///     |user| Msg::UserLoaded(user),
276    ///     Msg::FetchTimeout,
277    /// )
278    /// ```
279    pub fn with_timeout<F, T>(
280        timeout: std::time::Duration,
281        future: F,
282        on_ok: impl FnOnce(T) -> Msg + Send + 'static,
283        on_timeout: Msg,
284    ) -> Self
285    where
286        F: Future<Output = T> + Send + 'static,
287    {
288        Cmd::task(async move {
289            match tokio::time::timeout(timeout, future).await {
290                Ok(value) => on_ok(value),
291                Err(_elapsed) => on_timeout,
292            }
293        })
294    }
295
296    /// Create an async task with timeout that returns Result
297    ///
298    /// Combines `try_task` with timeout handling.
299    ///
300    /// # Example
301    /// ```ignore
302    /// Cmd::try_with_timeout(
303    ///     Duration::from_secs(5),
304    ///     api::fetch_data(),
305    ///     |data| Msg::DataLoaded(data),
306    ///     |err| Msg::FetchError(err.to_string()),
307    ///     Msg::FetchTimeout,
308    /// )
309    /// ```
310    pub fn try_with_timeout<F, T, E, FnOk, FnErr>(
311        timeout: std::time::Duration,
312        future: F,
313        on_ok: FnOk,
314        on_err: FnErr,
315        on_timeout: Msg,
316    ) -> Self
317    where
318        F: Future<Output = Result<T, E>> + Send + 'static,
319        FnOk: FnOnce(T) -> Msg + Send + 'static,
320        FnErr: FnOnce(E) -> Msg + Send + 'static,
321    {
322        Cmd::task(async move {
323            match tokio::time::timeout(timeout, future).await {
324                Ok(Ok(value)) => on_ok(value),
325                Ok(Err(err)) => on_err(err),
326                Err(_elapsed) => on_timeout,
327            }
328        })
329    }
330
331    /// Retry an async task with exponential backoff
332    ///
333    /// Attempts the task up to `max_attempts` times. On failure, waits with
334    /// exponential backoff (doubling the delay each time) before retrying.
335    ///
336    /// # Arguments
337    /// - `max_attempts`: Maximum number of attempts (must be >= 1)
338    /// - `initial_delay`: Delay before first retry (doubles each retry)
339    /// - `make_future`: Factory function that creates the future for each attempt
340    /// - `on_ok`: Called with the successful result
341    /// - `on_fail`: Called with the last error and total attempt count
342    ///
343    /// # Example
344    /// ```ignore
345    /// Cmd::retry(
346    ///     3, // max 3 attempts
347    ///     Duration::from_millis(100), // start with 100ms delay
348    ///     || api::fetch_data(),
349    ///     |data| Msg::DataLoaded(data),
350    ///     |err, attempts| Msg::FetchFailed(err.to_string(), attempts),
351    /// )
352    /// ```
353    ///
354    /// With 3 attempts and 100ms initial delay:
355    /// - Attempt 1: immediate
356    /// - Attempt 2: after 100ms (if attempt 1 failed)
357    /// - Attempt 3: after 200ms (if attempt 2 failed)
358    pub fn retry<F, Fut, T, E, FnOk, FnFail>(
359        max_attempts: u32,
360        initial_delay: std::time::Duration,
361        make_future: F,
362        on_ok: FnOk,
363        on_fail: FnFail,
364    ) -> Self
365    where
366        F: Fn() -> Fut + Send + 'static,
367        Fut: Future<Output = Result<T, E>> + Send,
368        T: Send + 'static,
369        E: Send + 'static,
370        FnOk: FnOnce(T) -> Msg + Send + 'static,
371        FnFail: FnOnce(E, u32) -> Msg + Send + 'static,
372    {
373        assert!(max_attempts >= 1, "max_attempts must be at least 1");
374
375        Cmd::task(async move {
376            let mut delay = initial_delay;
377            let mut last_err = None;
378
379            for attempt in 1..=max_attempts {
380                match make_future().await {
381                    Ok(value) => return on_ok(value),
382                    Err(err) => {
383                        last_err = Some(err);
384                        if attempt < max_attempts {
385                            tokio::time::sleep(delay).await;
386                            delay *= 2; // exponential backoff
387                        }
388                    }
389                }
390            }
391
392            // All attempts failed
393            on_fail(last_err.unwrap(), max_attempts)
394        })
395    }
396
397    /// Retry with custom backoff strategy
398    ///
399    /// Like `retry`, but allows custom delay calculation.
400    ///
401    /// # Example
402    /// ```ignore
403    /// // Linear backoff: 100ms, 200ms, 300ms, ...
404    /// Cmd::retry_with_backoff(
405    ///     5,
406    ///     |attempt| Duration::from_millis(100 * attempt as u64),
407    ///     || api::fetch_data(),
408    ///     |data| Msg::DataLoaded(data),
409    ///     |err, attempts| Msg::FetchFailed(err.to_string(), attempts),
410    /// )
411    /// ```
412    pub fn retry_with_backoff<F, Fut, T, E, B, FnOk, FnFail>(
413        max_attempts: u32,
414        backoff: B,
415        make_future: F,
416        on_ok: FnOk,
417        on_fail: FnFail,
418    ) -> Self
419    where
420        F: Fn() -> Fut + Send + 'static,
421        Fut: Future<Output = Result<T, E>> + Send,
422        T: Send + 'static,
423        E: Send + 'static,
424        B: Fn(u32) -> std::time::Duration + Send + 'static,
425        FnOk: FnOnce(T) -> Msg + Send + 'static,
426        FnFail: FnOnce(E, u32) -> Msg + Send + 'static,
427    {
428        assert!(max_attempts >= 1, "max_attempts must be at least 1");
429
430        Cmd::task(async move {
431            let mut last_err = None;
432
433            for attempt in 1..=max_attempts {
434                match make_future().await {
435                    Ok(value) => return on_ok(value),
436                    Err(err) => {
437                        last_err = Some(err);
438                        if attempt < max_attempts {
439                            let delay = backoff(attempt);
440                            tokio::time::sleep(delay).await;
441                        }
442                    }
443                }
444            }
445
446            on_fail(last_err.unwrap(), max_attempts)
447        })
448    }
449}