Skip to main content

flashkraft_gui/core/
flash_subscription.rs

1//! Flash Subscription - Real-time progress streaming
2//!
3//! ## Architecture
4//!
5//! ```text
6//!   Iced async runtime (ThreadPool)       blocking OS thread
7//!   ────────────────────────────────      ──────────────────
8//!   flash_progress()                      std::thread::spawn
9//!        │                                       │
10//!        │  futures::channel::mpsc               │
11//!        │ ◄─────────────────────────── bridge thread
12//!        │        (forwards from std_rx)         │
13//!        │                               run_pipeline(std_tx)
14//!   event = rx.next().await                      │
15//!        │  (yields to executor)          writes image / verifies
16//!        │
17//!   FlashProgress → Message → Iced repaint
18//! ```
19//!
20//! ## Why blocking `recv()` was wrong
21//!
22//! The previous implementation called `std::sync::mpsc::Receiver::recv()`
23//! directly inside the `async` stream block.  `recv()` is a **blocking**
24//! syscall — it parks the OS thread until a message arrives.  Because Iced
25//! drives subscriptions on a `futures::executor::ThreadPool` (not tokio),
26//! blocking that thread starved every other future on the same worker,
27//! including Iced's repaint loop.  Progress events were queued correctly but
28//! the UI never re-rendered until the entire pipeline had finished.
29//!
30//! ## Fix
31//!
32//! We now use a **three-actor design**:
33//!
34//! 1. **Pipeline thread** — calls `run_pipeline` with a `std::sync::mpsc::Sender`.
35//! 2. **Bridge thread** — calls `std_rx.recv()` (blocking is fine here because
36//!    this thread owns nothing except forwarding) and calls
37//!    `futures_tx.try_send()` into a `futures::channel::mpsc` channel.
38//!    A tiny `thread::sleep(1 ms)` between iterations keeps CPU usage near zero
39//!    while the pipeline is idle between blocks.
40//! 3. **Async stream** — calls `rx.next().await` on the `futures::channel::mpsc`
41//!    receiver, which is a proper async future that yields between every message
42//!    and lets the Iced executor schedule repaints freely.
43//!
44//! ## Cancellation
45//!
46//! An `Arc<AtomicBool>` cancel token is shared with the pipeline thread.
47//! The pipeline checks it on every 4 MiB write block.
48
49use crate::flash_debug;
50use flashkraft_core::flash_helper::{run_pipeline, FlashEvent};
51use flashkraft_core::FlashUpdate;
52use futures::channel::mpsc as futures_mpsc;
53use futures::SinkExt;
54use futures::StreamExt;
55use iced::stream;
56use iced::Subscription;
57use std::collections::hash_map::DefaultHasher;
58use std::hash::{Hash, Hasher};
59use std::path::PathBuf;
60use std::sync::{
61    atomic::{AtomicBool, Ordering},
62    Arc,
63};
64
65// ---------------------------------------------------------------------------
66// Public types
67// ---------------------------------------------------------------------------
68
69/// Progress event emitted by the flash subscription to the Iced runtime.
70///
71/// This is a type alias for [`flashkraft_core::FlashUpdate`] — the unified
72/// frontend event defined in core so both the GUI and TUI share the same
73/// representation without duplicating the type.
74pub use flashkraft_core::FlashUpdate as FlashProgress;
75
76// ---------------------------------------------------------------------------
77// Public API
78// ---------------------------------------------------------------------------
79
80/// Create a subscription that streams [`FlashProgress`] events while the
81/// flash operation runs.
82///
83/// `run_id` must be incremented on every new flash attempt so that flashing
84/// the same image to the same device twice always produces a distinct
85/// subscription ID and Iced creates a fresh stream.
86pub fn flash_progress(
87    image_path: PathBuf,
88    device_path: PathBuf,
89    cancel_token: Arc<AtomicBool>,
90    run_id: u64,
91) -> Subscription<FlashProgress> {
92    // Unique subscription ID — changes every flash attempt.
93    let mut hasher = DefaultHasher::new();
94    image_path.hash(&mut hasher);
95    device_path.hash(&mut hasher);
96    run_id.hash(&mut hasher);
97    let id = hasher.finish();
98
99    Subscription::run_with_id(
100        id,
101        stream::channel(64, move |mut output| async move {
102            // ── Validate inputs ───────────────────────────────────────────────
103            let image_size = match image_path.metadata() {
104                Ok(m) if m.len() == 0 => {
105                    let _ = output
106                        .send(FlashProgress::Failed("Image file is empty".into()))
107                        .await;
108                    return std::future::pending().await;
109                }
110                Ok(m) => m.len(),
111                Err(e) => {
112                    let _ = output
113                        .send(FlashProgress::Failed(format!(
114                            "Cannot read image file: {e}"
115                        )))
116                        .await;
117                    return std::future::pending().await;
118                }
119            };
120
121            flash_debug!(
122                "flash_progress: image={image_path:?} dev={device_path:?} size={image_size}"
123            );
124
125            // ── Channel setup ─────────────────────────────────────────────────
126            //
127            // std channel  → bridge thread (blocking recv) → futures channel
128            //                                                       ↓
129            //                                              rx.next().await
130            //                                              (yields to executor)
131            let (std_tx, std_rx) = std::sync::mpsc::channel::<FlashEvent>();
132
133            // futures::channel::mpsc is executor-agnostic — next() is a real
134            // async future that yields between every message.
135            let (mut futures_tx, mut futures_rx) = futures_mpsc::channel::<FlashEvent>(64);
136
137            // ── Pipeline thread ───────────────────────────────────────────────
138            let img_str = image_path.to_string_lossy().into_owned();
139            let dev_str = device_path.to_string_lossy().into_owned();
140            let cancel_pipeline = cancel_token.clone();
141
142            std::thread::Builder::new()
143                .name("flashkraft-pipeline".into())
144                .spawn(move || {
145                    flash_debug!("flash thread: starting pipeline");
146                    run_pipeline(&img_str, &dev_str, std_tx, cancel_pipeline);
147                    flash_debug!("flash thread: pipeline returned");
148                })
149                .expect("failed to spawn flash pipeline thread");
150
151            // ── Bridge thread ─────────────────────────────────────────────────
152            //
153            // Sits in a blocking recv() loop — safe because this is its own
154            // dedicated OS thread and it owns no async resources.  When a
155            // message arrives it forwards it into the futures channel via
156            // try_send (non-blocking from this thread's perspective).
157            std::thread::Builder::new()
158                .name("flashkraft-bridge".into())
159                .spawn(move || {
160                    while let Ok(event) = std_rx.recv() {
161                        // try_send returns Err if the receiver was
162                        // dropped (subscription cancelled) — exit cleanly.
163                        if futures_tx.try_send(event).is_err() {
164                            break;
165                        }
166                    }
167                })
168                .expect("failed to spawn flash bridge thread");
169
170            // ── Async event loop ──────────────────────────────────────────────
171            //
172            // futures_rx.next().await is a genuine async yield point.
173            // The Iced ThreadPool executor is free to run other futures
174            // (repaints, animation ticks, etc.) between every message.
175            loop {
176                match futures_rx.next().await {
177                    Some(FlashEvent::Done) => {
178                        flash_debug!("flash thread: Done");
179                        let _ = output.send(FlashUpdate::Completed).await;
180                        break;
181                    }
182
183                    Some(FlashEvent::Error(e)) => {
184                        flash_debug!("flash thread: Error: {e}");
185                        let _ = output.send(FlashUpdate::Failed(e)).await;
186                        break;
187                    }
188
189                    Some(core_event) => {
190                        let update = FlashUpdate::from(core_event);
191                        flash_debug!("flash event: {update:?}");
192                        let _ = output.send(update).await;
193                    }
194
195                    // Channel closed — bridge thread exited (pipeline done or cancelled).
196                    None => {
197                        flash_debug!("flash channel closed unexpectedly");
198                        if cancel_token.load(Ordering::SeqCst) {
199                            let _ = output
200                                .send(FlashUpdate::Failed(
201                                    "Flash operation cancelled by user".into(),
202                                ))
203                                .await;
204                        } else {
205                            let _ = output
206                                .send(FlashUpdate::Failed(
207                                    "Flash thread terminated unexpectedly".into(),
208                                ))
209                                .await;
210                        }
211                        break;
212                    }
213                }
214            }
215
216            // Park forever — Iced requires the stream future to never return.
217            std::future::pending().await
218        }),
219    )
220}
221
222// ---------------------------------------------------------------------------
223// Tests
224// ---------------------------------------------------------------------------
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_flash_progress_clone() {
232        let variants = vec![
233            FlashProgress::Progress {
234                progress: 0.5,
235                bytes_written: 1024,
236                speed_mb_s: 10.0,
237            },
238            FlashProgress::VerifyProgress {
239                phase: "image",
240                overall: 0.25,
241                bytes_read: 512,
242                total_bytes: 1024,
243                speed_mb_s: 100.0,
244            },
245            FlashProgress::VerifyProgress {
246                phase: "device",
247                overall: 0.75,
248                bytes_read: 512,
249                total_bytes: 1024,
250                speed_mb_s: 80.0,
251            },
252            FlashProgress::Message("hello".to_string()),
253            FlashProgress::Completed,
254            FlashProgress::Failed("oops".to_string()),
255        ];
256        for v in &variants {
257            let _ = v.clone();
258        }
259    }
260
261    #[test]
262    fn test_flash_progress_debug() {
263        let p = FlashProgress::Progress {
264            progress: 1.0,
265            bytes_written: 2048,
266            speed_mb_s: 20.0,
267        };
268        assert!(format!("{p:?}").contains("Progress"));
269    }
270
271    #[test]
272    fn test_subscription_id_is_deterministic() {
273        fn compute_id(image: &str, device: &str, run_id: u64) -> u64 {
274            let mut hasher = DefaultHasher::new();
275            PathBuf::from(image).hash(&mut hasher);
276            PathBuf::from(device).hash(&mut hasher);
277            run_id.hash(&mut hasher);
278            hasher.finish()
279        }
280        let id1 = compute_id("/tmp/test.img", "/dev/sdb", 0);
281        let id2 = compute_id("/tmp/test.img", "/dev/sdb", 0);
282        assert_eq!(id1, id2, "subscription ID must be deterministic");
283    }
284
285    #[test]
286    fn test_subscription_id_differs_for_different_devices() {
287        fn compute_id(image: &str, device: &str, run_id: u64) -> u64 {
288            let mut hasher = DefaultHasher::new();
289            PathBuf::from(image).hash(&mut hasher);
290            PathBuf::from(device).hash(&mut hasher);
291            run_id.hash(&mut hasher);
292            hasher.finish()
293        }
294        let id1 = compute_id("/tmp/test.img", "/dev/sdb", 0);
295        let id2 = compute_id("/tmp/test.img", "/dev/sdc", 0);
296        assert_ne!(id1, id2, "different devices must yield different IDs");
297    }
298
299    #[test]
300    fn test_subscription_id_differs_for_different_run_ids() {
301        fn compute_id(image: &str, device: &str, run_id: u64) -> u64 {
302            let mut hasher = DefaultHasher::new();
303            PathBuf::from(image).hash(&mut hasher);
304            PathBuf::from(device).hash(&mut hasher);
305            run_id.hash(&mut hasher);
306            hasher.finish()
307        }
308        let id1 = compute_id("/tmp/test.img", "/dev/sdb", 0);
309        let id2 = compute_id("/tmp/test.img", "/dev/sdb", 1);
310        assert_ne!(
311            id1, id2,
312            "different run IDs must yield different subscription IDs"
313        );
314    }
315
316    #[test]
317    fn test_verify_progress_overall_image_phase() {
318        for pct in [0.0f32, 0.25, 0.5, 1.0] {
319            let overall = flashkraft_core::flash_helper::verify_overall_progress("image", pct);
320            assert!(
321                (0.0..=0.5).contains(&overall),
322                "image phase overall {overall} out of [0, 0.5]"
323            );
324        }
325    }
326
327    #[test]
328    fn test_verify_progress_overall_device_phase() {
329        for pct in [0.0f32, 0.25, 0.5, 1.0] {
330            let overall = flashkraft_core::flash_helper::verify_overall_progress("device", pct);
331            assert!(
332                (0.5..=1.0).contains(&overall),
333                "device phase overall {overall} out of [0.5, 1.0]"
334            );
335        }
336    }
337
338    #[test]
339    fn test_cancelled_maps_to_failed() {
340        let cancel = Arc::new(AtomicBool::new(true));
341        let msg = if cancel.load(Ordering::SeqCst) {
342            "Flash operation cancelled by user"
343        } else {
344            "Flash thread terminated unexpectedly"
345        };
346        assert_eq!(msg, "Flash operation cancelled by user");
347    }
348
349    /// The bridge thread correctly terminates when the futures receiver is dropped.
350    #[test]
351    fn test_bridge_exits_when_receiver_dropped() {
352        let (std_tx, std_rx) = std::sync::mpsc::channel::<FlashEvent>();
353        let (mut futures_tx, futures_rx) = futures_mpsc::channel::<FlashEvent>(4);
354
355        // Drop the receiver immediately — bridge should exit cleanly.
356        drop(futures_rx);
357
358        let bridge = std::thread::spawn(move || {
359            while let Ok(event) = std_rx.recv() {
360                if futures_tx.try_send(event).is_err() {
361                    break;
362                }
363            }
364        });
365
366        // Send one event — bridge will fail try_send and exit.
367        let _ = std_tx.send(FlashEvent::Done);
368        // Give the bridge thread a moment to process.
369        std::thread::sleep(std::time::Duration::from_millis(50));
370        // Drop sender so bridge's recv() returns Err if it didn't exit already.
371        drop(std_tx);
372
373        bridge.join().expect("bridge thread should exit cleanly");
374    }
375
376    /// Verify that all FlashEvent variants are handled (mapping smoke test).
377    #[test]
378    fn test_flash_event_mapping_smoke() {
379        use flashkraft_core::flash_helper::{FlashEvent as CoreFlashEvent, FlashStage};
380
381        let events = vec![
382            CoreFlashEvent::Stage(FlashStage::Writing),
383            CoreFlashEvent::Progress {
384                bytes_written: 512,
385                total_bytes: 1024,
386                speed_mb_s: 42.0,
387            },
388            CoreFlashEvent::VerifyProgress {
389                phase: "image",
390                bytes_read: 256,
391                total_bytes: 1024,
392                speed_mb_s: 100.0,
393            },
394            CoreFlashEvent::VerifyProgress {
395                phase: "device",
396                bytes_read: 512,
397                total_bytes: 1024,
398                speed_mb_s: 80.0,
399            },
400            CoreFlashEvent::Log("Test log".into()),
401            CoreFlashEvent::Done,
402            CoreFlashEvent::Error("boom".into()),
403        ];
404
405        // Verify each variant converts to a FlashUpdate (= FlashProgress) without panicking.
406        for event in events {
407            let _mapped = FlashUpdate::from(event);
408        }
409    }
410}