Skip to main content

egui_cha/
cmd.rs

1//! Command type for side effects (async tasks, HTTP, timers, etc.)
2
3use std::future::Future;
4use std::pin::Pin;
5
6/// A command representing a side effect to be executed
7///
8/// Commands are declarative descriptions of side effects. The runtime
9/// executes them and feeds resulting messages back into the update loop.
10pub enum Cmd<Msg> {
11    /// No side effect
12    None,
13
14    /// Multiple commands to execute
15    Batch(Vec<Cmd<Msg>>),
16
17    /// An async task that produces a message
18    Task(Pin<Box<dyn Future<Output = Msg> + Send + 'static>>),
19
20    /// Emit a message immediately (next frame)
21    Msg(Msg),
22}
23
24impl<Msg> Cmd<Msg> {
25    /// Create an empty command (no side effect)
26    #[inline]
27    pub fn none() -> Self {
28        Cmd::None
29    }
30
31    /// Create a batch of commands
32    pub fn batch(cmds: impl IntoIterator<Item = Cmd<Msg>>) -> Self {
33        let cmds: Vec<_> = cmds.into_iter().collect();
34        if cmds.is_empty() {
35            Cmd::None
36        } else if cmds.len() == 1 {
37            cmds.into_iter().next().unwrap()
38        } else {
39            Cmd::Batch(cmds)
40        }
41    }
42
43    /// Create a command from an async task
44    pub fn task<F>(future: F) -> Self
45    where
46        F: Future<Output = Msg> + Send + 'static,
47    {
48        Cmd::Task(Box::pin(future))
49    }
50
51    /// Create a command that emits a message immediately
52    pub fn msg(msg: Msg) -> Self {
53        Cmd::Msg(msg)
54    }
55
56    /// Map the message type
57    pub fn map<F, NewMsg>(self, f: F) -> Cmd<NewMsg>
58    where
59        F: Fn(Msg) -> NewMsg + Send + Sync + Clone + 'static,
60        Msg: Send + 'static,
61        NewMsg: Send + 'static,
62    {
63        match self {
64            Cmd::None => Cmd::None,
65            Cmd::Batch(cmds) => Cmd::Batch(cmds.into_iter().map(|c| c.map(f.clone())).collect()),
66            Cmd::Task(fut) => {
67                let f = f.clone();
68                Cmd::Task(Box::pin(async move { f(fut.await) }))
69            }
70            Cmd::Msg(msg) => Cmd::Msg(f(msg)),
71        }
72    }
73}
74
75impl<Msg> Default for Cmd<Msg> {
76    fn default() -> Self {
77        Cmd::None
78    }
79}
80
81// ============================================================
82// Test helpers
83// ============================================================
84
85impl<Msg> Cmd<Msg> {
86    /// Check if this is Cmd::None
87    #[inline]
88    pub fn is_none(&self) -> bool {
89        matches!(self, Cmd::None)
90    }
91
92    /// Check if this is Cmd::Task
93    #[inline]
94    pub fn is_task(&self) -> bool {
95        matches!(self, Cmd::Task(_))
96    }
97
98    /// Check if this is Cmd::Msg
99    #[inline]
100    pub fn is_msg(&self) -> bool {
101        matches!(self, Cmd::Msg(_))
102    }
103
104    /// Check if this is Cmd::Batch
105    #[inline]
106    pub fn is_batch(&self) -> bool {
107        matches!(self, Cmd::Batch(_))
108    }
109
110    /// Get the message if this is Cmd::Msg, panics otherwise
111    ///
112    /// # Panics
113    /// Panics if the command is not Cmd::Msg
114    pub fn unwrap_msg(self) -> Msg {
115        match self {
116            Cmd::Msg(msg) => msg,
117            Cmd::None => panic!("called unwrap_msg on Cmd::None"),
118            Cmd::Task(_) => panic!("called unwrap_msg on Cmd::Task"),
119            Cmd::Batch(_) => panic!("called unwrap_msg on Cmd::Batch"),
120        }
121    }
122
123    /// Get the message if this is Cmd::Msg
124    pub fn as_msg(&self) -> Option<&Msg> {
125        match self {
126            Cmd::Msg(msg) => Some(msg),
127            _ => None,
128        }
129    }
130
131    /// Get the batch if this is Cmd::Batch
132    pub fn as_batch(&self) -> Option<&[Cmd<Msg>]> {
133        match self {
134            Cmd::Batch(cmds) => Some(cmds),
135            _ => None,
136        }
137    }
138
139    /// Get the number of commands (1 for non-batch, n for batch, 0 for none)
140    pub fn len(&self) -> usize {
141        match self {
142            Cmd::None => 0,
143            Cmd::Batch(cmds) => cmds.len(),
144            _ => 1,
145        }
146    }
147
148    /// Check if empty (only true for Cmd::None)
149    pub fn is_empty(&self) -> bool {
150        matches!(self, Cmd::None)
151    }
152}
153
154impl<Msg: PartialEq> Cmd<Msg> {
155    /// Check if this is Cmd::Msg with a specific message
156    pub fn is_msg_eq(&self, expected: &Msg) -> bool {
157        match self {
158            Cmd::Msg(msg) => msg == expected,
159            _ => false,
160        }
161    }
162}
163
164impl<Msg: std::fmt::Debug> Cmd<Msg> {
165    /// Assert this is Cmd::None, panics with debug info otherwise
166    #[track_caller]
167    pub fn assert_none(&self) {
168        assert!(
169            self.is_none(),
170            "expected Cmd::None, got {:?}",
171            self.variant_name()
172        );
173    }
174
175    /// Assert this is Cmd::Task, panics with debug info otherwise
176    #[track_caller]
177    pub fn assert_task(&self) {
178        assert!(
179            self.is_task(),
180            "expected Cmd::Task, got {:?}",
181            self.variant_name()
182        );
183    }
184
185    /// Assert this is Cmd::Msg, panics with debug info otherwise
186    #[track_caller]
187    pub fn assert_msg(&self) {
188        assert!(
189            self.is_msg(),
190            "expected Cmd::Msg, got {:?}",
191            self.variant_name()
192        );
193    }
194
195    fn variant_name(&self) -> &'static str {
196        match self {
197            Cmd::None => "Cmd::None",
198            Cmd::Task(_) => "Cmd::Task",
199            Cmd::Msg(_) => "Cmd::Msg",
200            Cmd::Batch(_) => "Cmd::Batch",
201        }
202    }
203}
204
205// Utility functions for common patterns
206impl<Msg: Send + 'static> Cmd<Msg> {
207    /// Create a command from an async task that returns Result
208    ///
209    /// # Example
210    /// ```ignore
211    /// Cmd::try_task(
212    ///     async { fetch_user(id).await },
213    ///     |user| Msg::UserLoaded(user),
214    ///     |err| Msg::Error(err.to_string()),
215    /// )
216    /// ```
217    pub fn try_task<F, T, E, FnOk, FnErr>(future: F, on_ok: FnOk, on_err: FnErr) -> Self
218    where
219        F: Future<Output = Result<T, E>> + Send + 'static,
220        FnOk: FnOnce(T) -> Msg + Send + 'static,
221        FnErr: FnOnce(E) -> Msg + Send + 'static,
222    {
223        Cmd::task(async move {
224            match future.await {
225                Ok(value) => on_ok(value),
226                Err(err) => on_err(err),
227            }
228        })
229    }
230
231    /// Create a command from a Result, converting to Msg immediately
232    ///
233    /// # Example
234    /// ```ignore
235    /// Cmd::from_result(
236    ///     parse_config(),
237    ///     |config| Msg::ConfigLoaded(config),
238    ///     |err| Msg::Error(err.to_string()),
239    /// )
240    /// ```
241    pub fn from_result<T, E, FnOk, FnErr>(result: Result<T, E>, on_ok: FnOk, on_err: FnErr) -> Self
242    where
243        FnOk: FnOnce(T) -> Msg,
244        FnErr: FnOnce(E) -> Msg,
245    {
246        match result {
247            Ok(value) => Cmd::Msg(on_ok(value)),
248            Err(err) => Cmd::Msg(on_err(err)),
249        }
250    }
251}
252
253// Tokio-dependent utility functions
254#[cfg(feature = "tokio")]
255impl<Msg: Send + 'static> Cmd<Msg> {
256    /// Create a delayed message (timer)
257    ///
258    /// Requires the `tokio` feature.
259    pub fn delay(duration: std::time::Duration, msg: Msg) -> Self
260    where
261        Msg: Clone,
262    {
263        Cmd::task(async move {
264            tokio::time::sleep(duration).await;
265            msg
266        })
267    }
268
269    /// Create an async task with a timeout
270    ///
271    /// If the task completes before the timeout, `on_ok` is called with the result.
272    /// If the timeout expires first, `on_timeout` is returned.
273    ///
274    /// # Example
275    /// ```ignore
276    /// Cmd::with_timeout(
277    ///     Duration::from_secs(5),
278    ///     fetch_user(user_id),
279    ///     |user| Msg::UserLoaded(user),
280    ///     Msg::FetchTimeout,
281    /// )
282    /// ```
283    pub fn with_timeout<F, T>(
284        timeout: std::time::Duration,
285        future: F,
286        on_ok: impl FnOnce(T) -> Msg + Send + 'static,
287        on_timeout: Msg,
288    ) -> Self
289    where
290        F: Future<Output = T> + Send + 'static,
291    {
292        Cmd::task(async move {
293            match tokio::time::timeout(timeout, future).await {
294                Ok(value) => on_ok(value),
295                Err(_elapsed) => on_timeout,
296            }
297        })
298    }
299
300    /// Create an async task with timeout that returns Result
301    ///
302    /// Combines `try_task` with timeout handling.
303    ///
304    /// # Example
305    /// ```ignore
306    /// Cmd::try_with_timeout(
307    ///     Duration::from_secs(5),
308    ///     api::fetch_data(),
309    ///     |data| Msg::DataLoaded(data),
310    ///     |err| Msg::FetchError(err.to_string()),
311    ///     Msg::FetchTimeout,
312    /// )
313    /// ```
314    pub fn try_with_timeout<F, T, E, FnOk, FnErr>(
315        timeout: std::time::Duration,
316        future: F,
317        on_ok: FnOk,
318        on_err: FnErr,
319        on_timeout: Msg,
320    ) -> Self
321    where
322        F: Future<Output = Result<T, E>> + Send + 'static,
323        FnOk: FnOnce(T) -> Msg + Send + 'static,
324        FnErr: FnOnce(E) -> Msg + Send + 'static,
325    {
326        Cmd::task(async move {
327            match tokio::time::timeout(timeout, future).await {
328                Ok(Ok(value)) => on_ok(value),
329                Ok(Err(err)) => on_err(err),
330                Err(_elapsed) => on_timeout,
331            }
332        })
333    }
334
335    /// Retry an async task with exponential backoff
336    ///
337    /// Attempts the task up to `max_attempts` times. On failure, waits with
338    /// exponential backoff (doubling the delay each time) before retrying.
339    ///
340    /// # Arguments
341    /// - `max_attempts`: Maximum number of attempts (must be >= 1)
342    /// - `initial_delay`: Delay before first retry (doubles each retry)
343    /// - `make_future`: Factory function that creates the future for each attempt
344    /// - `on_ok`: Called with the successful result
345    /// - `on_fail`: Called with the last error and total attempt count
346    ///
347    /// # Example
348    /// ```ignore
349    /// Cmd::retry(
350    ///     3, // max 3 attempts
351    ///     Duration::from_millis(100), // start with 100ms delay
352    ///     || api::fetch_data(),
353    ///     |data| Msg::DataLoaded(data),
354    ///     |err, attempts| Msg::FetchFailed(err.to_string(), attempts),
355    /// )
356    /// ```
357    ///
358    /// With 3 attempts and 100ms initial delay:
359    /// - Attempt 1: immediate
360    /// - Attempt 2: after 100ms (if attempt 1 failed)
361    /// - Attempt 3: after 200ms (if attempt 2 failed)
362    pub fn retry<F, Fut, T, E, FnOk, FnFail>(
363        max_attempts: u32,
364        initial_delay: std::time::Duration,
365        make_future: F,
366        on_ok: FnOk,
367        on_fail: FnFail,
368    ) -> Self
369    where
370        F: Fn() -> Fut + Send + 'static,
371        Fut: Future<Output = Result<T, E>> + Send,
372        T: Send + 'static,
373        E: Send + 'static,
374        FnOk: FnOnce(T) -> Msg + Send + 'static,
375        FnFail: FnOnce(E, u32) -> Msg + Send + 'static,
376    {
377        assert!(max_attempts >= 1, "max_attempts must be at least 1");
378
379        Cmd::task(async move {
380            let mut delay = initial_delay;
381            let mut last_err = None;
382
383            for attempt in 1..=max_attempts {
384                match make_future().await {
385                    Ok(value) => return on_ok(value),
386                    Err(err) => {
387                        last_err = Some(err);
388                        if attempt < max_attempts {
389                            tokio::time::sleep(delay).await;
390                            delay *= 2; // exponential backoff
391                        }
392                    }
393                }
394            }
395
396            // All attempts failed
397            on_fail(last_err.unwrap(), max_attempts)
398        })
399    }
400
401    /// Retry with custom backoff strategy
402    ///
403    /// Like `retry`, but allows custom delay calculation.
404    ///
405    /// # Example
406    /// ```ignore
407    /// // Linear backoff: 100ms, 200ms, 300ms, ...
408    /// Cmd::retry_with_backoff(
409    ///     5,
410    ///     |attempt| Duration::from_millis(100 * attempt as u64),
411    ///     || api::fetch_data(),
412    ///     |data| Msg::DataLoaded(data),
413    ///     |err, attempts| Msg::FetchFailed(err.to_string(), attempts),
414    /// )
415    /// ```
416    pub fn retry_with_backoff<F, Fut, T, E, B, FnOk, FnFail>(
417        max_attempts: u32,
418        backoff: B,
419        make_future: F,
420        on_ok: FnOk,
421        on_fail: FnFail,
422    ) -> Self
423    where
424        F: Fn() -> Fut + Send + 'static,
425        Fut: Future<Output = Result<T, E>> + Send,
426        T: Send + 'static,
427        E: Send + 'static,
428        B: Fn(u32) -> std::time::Duration + Send + 'static,
429        FnOk: FnOnce(T) -> Msg + Send + 'static,
430        FnFail: FnOnce(E, u32) -> Msg + Send + 'static,
431    {
432        assert!(max_attempts >= 1, "max_attempts must be at least 1");
433
434        Cmd::task(async move {
435            let mut last_err = None;
436
437            for attempt in 1..=max_attempts {
438                match make_future().await {
439                    Ok(value) => return on_ok(value),
440                    Err(err) => {
441                        last_err = Some(err);
442                        if attempt < max_attempts {
443                            let delay = backoff(attempt);
444                            tokio::time::sleep(delay).await;
445                        }
446                    }
447                }
448            }
449
450            on_fail(last_err.unwrap(), max_attempts)
451        })
452    }
453}