Skip to main content

flashkraft_gui/core/
flash_subscription.rs

1//! Flash Subscription - Real-time progress streaming
2//!
3//! This module implements an Iced [`Subscription`] that drives the flash
4//! pipeline and forwards structured progress events to the UI.
5//!
6//! ## Architecture
7//!
8//! The flash operation runs entirely in-process on a dedicated blocking
9//! `std::thread`.  No child process, no pkexec, no sudo, no IPC protocol.
10//!
11//! ```text
12//!   Iced async runtime                 blocking OS thread
13//!   ─────────────────                  ──────────────────
14//!   flash_progress()                   std::thread::spawn
15//!        │                                    │
16//!        │   std::sync::mpsc::channel         │
17//!        │ ◄────────────────────────── run_pipeline(tx, …)
18//!        │                                    │
19//!   FlashEvent → FlashProgress          writes image
20//!        │
21//!   Iced UI update
22//! ```
23//!
24//! ## Privilege model
25//!
26//! The installed binary is **setuid-root** (`chmod u+s /usr/bin/flashkraft`).
27//! `main.rs` captures the real (unprivileged) UID at startup and stores it via
28//! [`flashkraft_core::flash_helper::set_real_uid`].  The pipeline calls
29//! `seteuid(0)` only for the instant needed to open the block device, then
30//! immediately drops back to the real UID.
31//!
32//! ## Cancellation
33//!
34//! An [`AtomicBool`] cancel token is shared between the Iced update loop and
35//! the flash thread.  The pipeline checks the flag on every write block
36//! (~4 MiB) and exits early when it is set.
37//!
38//! ## `FlashProgress` enum (unchanged)
39//!
40//! The variants `Progress`, `Message`, `Completed`, and `Failed` are identical
41//! to the previous implementation — no changes to `update.rs`, `state.rs`,
42//! `message.rs`, or any UI code are required.
43
44use crate::flash_debug;
45use flashkraft_core::flash_helper::{run_pipeline, FlashEvent};
46use futures::SinkExt;
47use iced::stream;
48use iced::Subscription;
49use std::collections::hash_map::DefaultHasher;
50use std::hash::{Hash, Hasher};
51use std::path::PathBuf;
52use std::sync::{
53    atomic::{AtomicBool, Ordering},
54    Arc,
55};
56
57// ---------------------------------------------------------------------------
58// Public types
59// ---------------------------------------------------------------------------
60
61/// Progress event emitted by the flash subscription to the Iced runtime.
62///
63/// Variants are intentionally identical to the previous pkexec-based
64/// implementation so that `update.rs` / `state.rs` / `message.rs` need no
65/// changes.
66#[derive(Debug, Clone)]
67pub enum FlashProgress {
68    /// `(progress 0.0–1.0, bytes_written, speed_mb_s)`
69    Progress(f32, u64, f32),
70    /// Human-readable status message for the UI (stage name, log line, …)
71    Message(String),
72    /// The flash operation finished successfully.
73    Completed,
74    /// The flash operation failed; the string is a human-readable error.
75    Failed(String),
76}
77
78// ---------------------------------------------------------------------------
79// Public API
80// ---------------------------------------------------------------------------
81
82/// Create a subscription that streams [`FlashProgress`] events while the
83/// flash operation runs.
84///
85/// The subscription is uniquely identified by hashing `image_path` and
86/// `device_path` so Iced can deduplicate it across recompositions.
87pub fn flash_progress(
88    image_path: PathBuf,
89    device_path: PathBuf,
90    cancel_token: Arc<AtomicBool>,
91) -> Subscription<FlashProgress> {
92    // ── Stable subscription ID ────────────────────────────────────────────────
93    let mut hasher = DefaultHasher::new();
94    image_path.hash(&mut hasher);
95    device_path.hash(&mut hasher);
96    let id = hasher.finish();
97
98    Subscription::run_with_id(
99        id,
100        stream::channel(64, move |mut output| async move {
101            let img = image_path.clone();
102            let dev = device_path.clone();
103            let cancel = cancel_token.clone();
104
105            // ── Validate inputs before spinning up a thread ───────────────────
106            let image_size = match img.metadata() {
107                Ok(m) if m.len() == 0 => {
108                    let _ = output
109                        .send(FlashProgress::Failed("Image file is empty".into()))
110                        .await;
111                    return std::future::pending().await;
112                }
113                Ok(m) => m.len(),
114                Err(e) => {
115                    let _ = output
116                        .send(FlashProgress::Failed(format!(
117                            "Cannot read image file: {e}"
118                        )))
119                        .await;
120                    return std::future::pending().await;
121                }
122            };
123
124            flash_debug!("flash_progress: image={img:?} dev={dev:?} size={image_size}");
125
126            // ── Bridge: blocking thread → async ───────────────────────────────
127            // std::sync::mpsc is used on the thread side (blocking send);
128            // we convert it to async by polling with try_recv + yield.
129            let (tx, rx) = std::sync::mpsc::channel::<FlashEvent>();
130
131            let img_str = img.to_string_lossy().into_owned();
132            let dev_str = dev.to_string_lossy().into_owned();
133
134            std::thread::spawn(move || {
135                flash_debug!("flash thread: starting pipeline");
136                run_pipeline(&img_str, &dev_str, tx, cancel);
137                flash_debug!("flash thread: pipeline returned");
138            });
139
140            // ── Forward FlashEvents → FlashProgress ───────────────────────────
141            loop {
142                match rx.recv() {
143                    // ── Progress update ───────────────────────────────────────
144                    Ok(FlashEvent::Progress {
145                        bytes_written,
146                        total_bytes,
147                        speed_mb_s,
148                    }) => {
149                        let progress = if total_bytes > 0 {
150                            (bytes_written as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
151                        } else {
152                            0.0
153                        };
154                        flash_debug!(
155                            "progress: {:.1}% ({bytes_written}/{total_bytes}) @ {speed_mb_s:.1} MB/s",
156                            progress * 100.0
157                        );
158                        let _ = output
159                            .send(FlashProgress::Progress(progress, bytes_written, speed_mb_s))
160                            .await;
161                    }
162
163                    // ── Stage transition ──────────────────────────────────────
164                    Ok(FlashEvent::Stage(stage)) => {
165                        let msg = stage.to_string();
166                        flash_debug!("stage: {msg}");
167                        let _ = output.send(FlashProgress::Message(msg)).await;
168                    }
169
170                    // ── Informational log ─────────────────────────────────────
171                    Ok(FlashEvent::Log(msg)) => {
172                        flash_debug!("log: {msg}");
173                        let _ = output.send(FlashProgress::Message(msg)).await;
174                    }
175
176                    // ── Success ───────────────────────────────────────────────
177                    Ok(FlashEvent::Done) => {
178                        flash_debug!("flash thread: Done");
179                        let _ = output.send(FlashProgress::Completed).await;
180                        break;
181                    }
182
183                    // ── Pipeline error ────────────────────────────────────────
184                    Ok(FlashEvent::Error(e)) => {
185                        flash_debug!("flash thread: Error: {e}");
186                        let _ = output.send(FlashProgress::Failed(e)).await;
187                        break;
188                    }
189
190                    // ── Sender dropped (thread panicked or returned early) ─────
191                    Err(_) => {
192                        flash_debug!("flash thread: channel closed unexpectedly");
193
194                        // Only report failure if we haven't already sent a
195                        // terminal event (Done / Error) that would have broken
196                        // the loop above.  The cancel flag covers intentional
197                        // cancellation.
198                        if cancel_token.load(Ordering::SeqCst) {
199                            let _ = output
200                                .send(FlashProgress::Failed(
201                                    "Flash operation cancelled by user".into(),
202                                ))
203                                .await;
204                        } else {
205                            let _ = output
206                                .send(FlashProgress::Failed(
207                                    "Flash thread terminated unexpectedly".into(),
208                                ))
209                                .await;
210                        }
211                        break;
212                    }
213                }
214            }
215
216            // Keep the subscription alive — Iced requires the async block to
217            // never return (it is driven as a Stream).
218            std::future::pending().await
219        }),
220    )
221}
222
223// ---------------------------------------------------------------------------
224// Tests
225// ---------------------------------------------------------------------------
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    /// All `FlashProgress` variants must be `Clone` (Iced requirement).
232    #[test]
233    fn test_flash_progress_clone() {
234        let variants = vec![
235            FlashProgress::Progress(0.5, 1024, 10.0),
236            FlashProgress::Message("hello".to_string()),
237            FlashProgress::Completed,
238            FlashProgress::Failed("oops".to_string()),
239        ];
240        for v in &variants {
241            let _ = v.clone();
242        }
243    }
244
245    #[test]
246    fn test_flash_progress_debug() {
247        let p = FlashProgress::Progress(1.0, 2048, 20.0);
248        assert!(format!("{p:?}").contains("Progress"));
249    }
250
251    /// The subscription ID must be deterministic for a given (image, device) pair.
252    #[test]
253    fn test_subscription_id_is_deterministic() {
254        fn compute_id(image: &str, device: &str) -> u64 {
255            let mut hasher = DefaultHasher::new();
256            PathBuf::from(image).hash(&mut hasher);
257            PathBuf::from(device).hash(&mut hasher);
258            hasher.finish()
259        }
260        let id1 = compute_id("/tmp/test.img", "/dev/sdb");
261        let id2 = compute_id("/tmp/test.img", "/dev/sdb");
262        assert_eq!(id1, id2, "subscription ID must be deterministic");
263    }
264
265    /// Different (image, device) pairs must produce different IDs.
266    #[test]
267    fn test_subscription_id_differs_for_different_devices() {
268        fn compute_id(image: &str, device: &str) -> u64 {
269            let mut hasher = DefaultHasher::new();
270            PathBuf::from(image).hash(&mut hasher);
271            PathBuf::from(device).hash(&mut hasher);
272            hasher.finish()
273        }
274        let id1 = compute_id("/tmp/test.img", "/dev/sdb");
275        let id2 = compute_id("/tmp/test.img", "/dev/sdc");
276        assert_ne!(id1, id2, "different devices must yield different IDs");
277    }
278
279    /// FlashEvent channel bridge: verify that the mpsc bridge correctly maps
280    /// every FlashEvent variant to the expected FlashProgress variant.
281    #[test]
282    fn test_flash_event_mapping() {
283        use flashkraft_core::flash_helper::FlashStage;
284
285        // Simulate what the async loop does — map FlashEvent → FlashProgress.
286        let events = vec![
287            FlashEvent::Stage(FlashStage::Writing),
288            FlashEvent::Progress {
289                bytes_written: 512,
290                total_bytes: 1024,
291                speed_mb_s: 42.0,
292            },
293            FlashEvent::Log("Test log".into()),
294            FlashEvent::Done,
295        ];
296
297        for event in events {
298            let _progress: Option<FlashProgress> = match event {
299                FlashEvent::Progress {
300                    bytes_written,
301                    total_bytes,
302                    speed_mb_s,
303                } => {
304                    let p = if total_bytes > 0 {
305                        (bytes_written as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
306                    } else {
307                        0.0
308                    };
309                    Some(FlashProgress::Progress(p, bytes_written, speed_mb_s))
310                }
311                FlashEvent::Stage(s) => Some(FlashProgress::Message(s.to_string())),
312                FlashEvent::Log(m) => Some(FlashProgress::Message(m)),
313                FlashEvent::Done => Some(FlashProgress::Completed),
314                FlashEvent::Error(e) => Some(FlashProgress::Failed(e)),
315            };
316            // Just verify the mapping doesn't panic.
317        }
318    }
319
320    /// The channel bridge correctly handles the cancelled case.
321    #[test]
322    fn test_cancelled_maps_to_failed() {
323        let cancel = Arc::new(AtomicBool::new(true));
324        assert!(cancel.load(Ordering::SeqCst));
325        // When cancel is true and Err(_) is received, we send Failed.
326        let msg = if cancel.load(Ordering::SeqCst) {
327            "Flash operation cancelled by user"
328        } else {
329            "Flash thread terminated unexpectedly"
330        };
331        assert_eq!(msg, "Flash operation cancelled by user");
332    }
333}