Skip to main content

flashkraft_core/
flash_helper.rs

1//! Flash Pipeline
2//!
3//! Implements the entire privileged flash pipeline in-process.
4//!
5//! ## Privilege model
6//!
7//! The installed binary carries the **setuid-root** bit
8//! (`sudo chmod u+s /usr/bin/flashkraft`).  At process startup `main.rs`
9//! calls [`set_real_uid`] to record the unprivileged user's UID.
10//!
11//! When the pipeline needs to open a raw block device it temporarily
12//! escalates to root via `nix::unistd::seteuid(0)`, opens the file
13//! descriptor, then immediately drops back to the real UID.  Root is held
14//! for less than one millisecond.
15//!
16//! ## Progress reporting
17//!
18//! The pipeline runs on a dedicated blocking thread spawned by the flash
19//! subscription.  Progress is reported by sending [`FlashEvent`] values
20//! through a [`std::sync::mpsc::Sender`] — no child process, no stdout
21//! parsing, no IPC protocol.
22//!
23//! ## Pipeline stages
24//!
25//! 1. Validate inputs (image exists and is non-empty, device exists, not a partition node)
26//! 2. Unmount all partitions of the target device (lazy / `MNT_DETACH`)
27//! 3. Write the image in 4 MiB blocks, reporting progress every 400 ms
28//! 4. `fsync` the device fd (hard error on failure)
29//! 5. `fdatasync` + global `sync()` (belt-and-suspenders)
30//! 6. `BLKRRPART` ioctl — ask the kernel to re-read the partition table
31//! 7. SHA-256 verify: hash the source image, hash the first N bytes of the device, compare
32
33#[cfg(unix)]
34use nix::libc;
35use std::io::{self, Read, Write};
36use std::path::Path;
37use std::sync::{
38    atomic::{AtomicBool, Ordering},
39    mpsc, Arc, OnceLock,
40};
41use std::time::{Duration, Instant};
42
43// ---------------------------------------------------------------------------
44// Constants
45// ---------------------------------------------------------------------------
46
47/// Write / read-back buffer: 4 MiB is a sweet spot for USB throughput.
48const BLOCK_SIZE: usize = 4 * 1024 * 1024;
49
50/// Minimum interval between `FlashEvent::Progress` emissions.
51const PROGRESS_INTERVAL: Duration = Duration::from_millis(400);
52
53// ---------------------------------------------------------------------------
54// Real-UID registry
55// ---------------------------------------------------------------------------
56
57/// The unprivileged UID of the user who launched the process.
58///
59/// Captured in `main.rs` via `nix::unistd::getuid()` before any `seteuid`
60/// call and stored here via [`set_real_uid`].
61static REAL_UID: OnceLock<u32> = OnceLock::new();
62
63/// Store the real (unprivileged) UID of the process owner.
64///
65/// Must be called once from `main()` before any flash operation.
66/// On non-Unix platforms this is a no-op.
67pub fn set_real_uid(uid: u32) {
68    let _ = REAL_UID.set(uid);
69}
70
71/// Return `true` when the process is currently running with effective root
72/// privileges (i.e. `geteuid() == 0`).
73///
74/// On non-Unix platforms this always returns `false` — callers should use
75/// the Windows Administrator check instead.
76pub fn is_privileged() -> bool {
77    #[cfg(unix)]
78    {
79        nix::unistd::geteuid().is_root()
80    }
81    #[cfg(not(unix))]
82    {
83        false
84    }
85}
86
87/// Attempt to re-exec the current binary with root privileges via `pkexec`
88/// or `sudo -E`, whichever is found first on `PATH`.
89///
90/// This is called **on demand** — e.g. when the user clicks Flash and we
91/// detect that `is_privileged()` is `false` — rather than unconditionally
92/// at startup.  Because `execvp` replaces the current process image on
93/// success, this function only returns when neither escalation helper is
94/// available or the user declined (cancelled the polkit dialog / Ctrl-C'd
95/// the sudo prompt).
96///
97/// `FLASHKRAFT_ESCALATED=1` is injected into the child environment so that
98/// the re-exec'd process skips this call and does not loop.
99///
100/// # Safety
101///
102/// Safe to call from any thread, but must be called before the Iced event
103/// loop has started spawning threads that hold OS resources (file
104/// descriptors, mutexes) that `execvp` would implicitly close/reset.
105/// Calling it from the `update` handler (on the Iced main thread, before
106/// the flash subscription starts) satisfies this requirement.
107#[cfg(unix)]
108pub fn reexec_as_root() {
109    // Never attempt privilege escalation during `cargo test` — sudo/pkexec
110    // would block the test runner waiting for a password prompt.
111    //
112    // IMPORTANT: `#[cfg(test)]` is only set on the *root* crate being tested.
113    // When `flashkraft-core` is compiled as a *dependency* of another crate's
114    // test binary (e.g. `flashkraft-gui`'s tests), it is compiled in normal
115    // (non-test) mode, so `#[cfg(test)]` does NOT fire here.
116    //
117    // We therefore use a runtime heuristic: cargo test binary paths contain
118    // a hash-suffixed name under `target/debug/deps/`, e.g.:
119    //   …/target/debug/deps/flashkraft_gui-a76f74e119b55607
120    // We also check for the FLASHKRAFT_NO_REEXEC env var as an explicit opt-out,
121    // and for NEXTEST_TEST_FILTER which nextest sets.
122    if is_running_under_test_harness() {
123        return;
124    }
125
126    // Compile-time guard for crate-local unit tests (when core IS the root
127    // test crate and #[cfg(test)] IS honoured).
128    #[cfg(test)]
129    return;
130
131    #[cfg(not(test))]
132    reexec_as_root_inner();
133}
134
135/// Returns `true` when the current process appears to be a `cargo test` (or
136/// nextest) test-runner binary, based on runtime evidence.
137///
138/// This is needed because `#[cfg(test)]` is **not** propagated to dependency
139/// crates — only the root crate being tested gets the flag.
140#[cfg(unix)]
141fn is_running_under_test_harness() -> bool {
142    // Explicit opt-out env var — tests can set this if needed.
143    if std::env::var("FLASHKRAFT_NO_REEXEC").is_ok() {
144        return true;
145    }
146
147    // nextest sets this in every test process.
148    if std::env::var("NEXTEST_TEST_FILTER").is_ok() {
149        return true;
150    }
151
152    // cargo test passes `--test-threads` (or related flags) on argv.
153    // More importantly, the test binary itself is passed the test filter as
154    // a positional argv — but the most reliable signal is the executable path:
155    // cargo always places test binaries under `target/debug/deps/<name>-<hash>`
156    // or `target/<profile>/deps/<name>-<hash>`.
157    //
158    // We look for `/deps/` in the executable path as a strong indicator.
159    if let Ok(exe) = std::env::current_exe() {
160        let path_str = exe.to_string_lossy();
161        // All cargo test binaries live under a `deps` directory.
162        if path_str.contains("/deps/") {
163            return true;
164        }
165        // Also catch `target\deps\` on Windows.
166        if path_str.contains("\\deps\\") {
167            return true;
168        }
169    }
170
171    false
172}
173
174#[cfg(all(unix, not(test)))]
175fn reexec_as_root_inner() {
176    use std::ffi::CString;
177
178    // Guard: the re-exec'd copy sets this so we don't loop forever.
179    if std::env::var("FLASHKRAFT_ESCALATED").as_deref() == Ok("1") {
180        return;
181    }
182
183    let self_exe = match std::fs::read_link("/proc/self/exe").or_else(|_| std::env::current_exe()) {
184        Ok(p) => p,
185        Err(_) => return,
186    };
187    let self_exe_str = match self_exe.to_str() {
188        Some(s) => s.to_owned(),
189        None => return,
190    };
191
192    let extra_args: Vec<String> = std::env::args().skip(1).collect();
193
194    // Tell the child it was already escalated so it won't recurse.
195    std::env::set_var("FLASHKRAFT_ESCALATED", "1");
196
197    // ── Try pkexec first (graphical polkit dialog) ────────────────────────────
198    if unix_which_exists("pkexec") {
199        let mut argv: Vec<CString> = Vec::new();
200        argv.push(unix_c_str("pkexec"));
201        argv.push(unix_c_str(&self_exe_str));
202        for a in &extra_args {
203            argv.push(unix_c_str(a));
204        }
205        let _ = nix::unistd::execvp(&unix_c_str("pkexec"), &argv);
206    }
207
208    // ── Try sudo -E (terminal fallback) ───────────────────────────────────────
209    if unix_which_exists("sudo") {
210        let mut argv: Vec<CString> = Vec::new();
211        argv.push(unix_c_str("sudo"));
212        argv.push(unix_c_str("-E")); // preserve DISPLAY / WAYLAND_DISPLAY
213        argv.push(unix_c_str(&self_exe_str));
214        for a in &extra_args {
215            argv.push(unix_c_str(a));
216        }
217        let _ = nix::unistd::execvp(&unix_c_str("sudo"), &argv);
218    }
219
220    // Neither helper available — remove the guard and fall through unprivileged.
221    std::env::remove_var("FLASHKRAFT_ESCALATED");
222}
223
224/// Stub for non-Unix targets so call sites compile without `#[cfg]` guards.
225#[cfg(not(unix))]
226pub fn reexec_as_root() {}
227
228/// Return `true` if `name` is an executable file reachable via `PATH`.
229#[cfg(all(unix, not(test)))]
230fn unix_which_exists(name: &str) -> bool {
231    use std::os::unix::fs::PermissionsExt;
232    if let Ok(path_var) = std::env::var("PATH") {
233        for dir in path_var.split(':') {
234            let candidate = std::path::Path::new(dir).join(name);
235            if let Ok(meta) = std::fs::metadata(&candidate) {
236                if meta.is_file() && meta.permissions().mode() & 0o111 != 0 {
237                    return true;
238                }
239            }
240        }
241    }
242    false
243}
244
245/// Build a `CString`, replacing embedded NUL bytes with `?`.
246#[cfg(all(unix, not(test)))]
247fn unix_c_str(s: &str) -> std::ffi::CString {
248    let sanitised: Vec<u8> = s.bytes().map(|b| if b == 0 { b'?' } else { b }).collect();
249    std::ffi::CString::new(sanitised).unwrap_or_else(|_| std::ffi::CString::new("?").unwrap())
250}
251
252/// Retrieve the stored real UID, falling back to the current effective UID.
253#[cfg(unix)]
254fn real_uid() -> nix::unistd::Uid {
255    let raw = REAL_UID
256        .get()
257        .copied()
258        .unwrap_or_else(|| nix::unistd::getuid().as_raw());
259    nix::unistd::Uid::from_raw(raw)
260}
261
262// ---------------------------------------------------------------------------
263// Public types
264// ---------------------------------------------------------------------------
265
266/// A stage in the five-step flash pipeline.
267#[derive(Debug, Clone, PartialEq, Eq)]
268pub enum FlashStage {
269    /// Initial state before the pipeline starts.
270    Starting,
271    /// All partitions of the target device are being lazily unmounted.
272    Unmounting,
273    /// The image is being written to the block device in 4 MiB chunks.
274    Writing,
275    /// Kernel write-back caches are being flushed (`fsync` / `sync`).
276    Syncing,
277    /// The kernel is asked to re-read the partition table (`BLKRRPART`).
278    Rereading,
279    /// SHA-256 of the source image is compared against a read-back of the device.
280    Verifying,
281    /// The entire pipeline completed successfully.
282    Done,
283    /// The pipeline terminated with an error.
284    Failed(String),
285}
286
287impl std::fmt::Display for FlashStage {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        match self {
290            FlashStage::Starting => write!(f, "Starting…"),
291            FlashStage::Unmounting => write!(f, "Unmounting partitions…"),
292            FlashStage::Writing => write!(f, "Writing image to device…"),
293            FlashStage::Syncing => write!(f, "Flushing write buffers…"),
294            FlashStage::Rereading => write!(f, "Refreshing partition table…"),
295            FlashStage::Verifying => write!(f, "Verifying written data…"),
296            FlashStage::Done => write!(f, "Flash complete!"),
297            FlashStage::Failed(m) => write!(f, "Failed: {m}"),
298        }
299    }
300}
301
302impl FlashStage {
303    /// Map this pipeline stage to the minimum overall-progress-bar floor
304    /// it should hold when it starts.
305    ///
306    /// The write phase occupies 0–80 % via [`FlashEvent::Progress`] events;
307    /// post-write stages advance the floor so the bar keeps moving:
308    ///
309    /// | Stage        | Floor |
310    /// |--------------|-------|
311    /// | Syncing      | 80 %  |
312    /// | Rereading    | 88 %  |
313    /// | Verifying    | 92 %  |
314    /// | everything else | 0 % |
315    pub fn progress_floor(&self) -> f32 {
316        match self {
317            FlashStage::Syncing => 0.80,
318            FlashStage::Rereading => 0.88,
319            FlashStage::Verifying => 0.92,
320            _ => 0.0,
321        }
322    }
323}
324
325/// Compute the overall verification progress (0.0–1.0) from a single-pass
326/// fraction.
327///
328/// The verify pipeline runs two passes:
329/// - `"image"` pass  → hashes the source image file    (contributes 0.0–0.5)
330/// - `"device"` pass → reads back the written device   (contributes 0.5–1.0)
331///
332/// `pass_fraction` must already be clamped to `[0.0, 1.0]`.
333///
334/// # Example
335/// ```
336/// use flashkraft_core::flash_helper::verify_overall_progress;
337/// assert_eq!(verify_overall_progress("image",  0.0), 0.0);
338/// assert_eq!(verify_overall_progress("image",  1.0), 0.5);
339/// assert_eq!(verify_overall_progress("device", 0.0), 0.5);
340/// assert_eq!(verify_overall_progress("device", 1.0), 1.0);
341/// ```
342pub fn verify_overall_progress(phase: &str, pass_fraction: f32) -> f32 {
343    if phase == "image" {
344        pass_fraction * 0.5
345    } else {
346        0.5 + pass_fraction * 0.5
347    }
348}
349
350/// A typed event emitted by the flash pipeline.
351///
352/// Sent over [`std::sync::mpsc`] to the async Iced subscription — no
353/// serialisation, no text parsing.
354#[derive(Debug, Clone)]
355pub enum FlashEvent {
356    /// A pipeline stage transition.
357    Stage(FlashStage),
358    /// Write-progress update.
359    Progress {
360        bytes_written: u64,
361        total_bytes: u64,
362        speed_mb_s: f32,
363    },
364    /// Verification read-back progress update.
365    ///
366    /// Emitted during both the image-hash pass and the device read-back pass.
367    /// `phase` is `"image"` for the source-hash pass and `"device"` for the
368    /// read-back pass.  `bytes_read` and `total_bytes` are the counts for the
369    /// current pass only; the overall verify progress should be computed as:
370    ///
371    /// ```text
372    /// if phase == "image"  { bytes_read / total_bytes * 0.5 }
373    /// if phase == "device" { 0.5 + bytes_read / total_bytes * 0.5 }
374    /// ```
375    VerifyProgress {
376        phase: &'static str,
377        bytes_read: u64,
378        total_bytes: u64,
379        speed_mb_s: f32,
380    },
381    /// Informational log message (not an error).
382    Log(String),
383    /// The pipeline finished successfully.
384    Done,
385    /// The pipeline failed; the string is a human-readable error.
386    Error(String),
387}
388
389// ---------------------------------------------------------------------------
390// Unified frontend event type
391// ---------------------------------------------------------------------------
392
393/// A normalised progress event suitable for consumption by any frontend
394/// (Iced GUI, Ratatui TUI, or future integrations).
395///
396/// Both the GUI's `FlashProgress` and the TUI's `FlashEvent` wrapper types
397/// were independently duplicating this shape.  By defining it once in core
398/// and converting from [`FlashEvent`] via [`From`], each frontend only needs
399/// to bridge this into its own message/command type.
400///
401/// Key differences from the raw [`FlashEvent`]:
402/// - `Progress` carries a normalised `0.0–1.0` write fraction (the raw event
403///   only carries raw byte counts; the fraction is computed here).
404/// - `VerifyProgress` carries a pre-computed `overall` spanning both passes
405///   (via [`verify_overall_progress`]) so frontends never need to duplicate
406///   that formula.
407/// - `Stage` and `Log` are both collapsed into `Message(String)` since both
408///   frontends treat them as human-readable status text.
409/// - `Done` / `Error` become `Completed` / `Failed` to match conventional
410///   naming in UI code.
411#[derive(Debug, Clone)]
412pub enum FlashUpdate {
413    /// Write progress.
414    ///
415    /// `progress` is `bytes_written / total_bytes` clamped to `[0.0, 1.0]`.
416    Progress {
417        progress: f32,
418        bytes_written: u64,
419        speed_mb_s: f32,
420    },
421    /// Verification read-back progress spanning both passes.
422    ///
423    /// `overall` is in `[0.0, 1.0]`:
424    ///   - image pass  → `[0.0, 0.5]`
425    ///   - device pass → `[0.5, 1.0]`
426    VerifyProgress {
427        phase: &'static str,
428        overall: f32,
429        bytes_read: u64,
430        total_bytes: u64,
431        speed_mb_s: f32,
432    },
433    /// Human-readable status text (stage label or log line).
434    Message(String),
435    /// The flash pipeline finished successfully.
436    Completed,
437    /// The flash pipeline failed; the string is a human-readable error.
438    Failed(String),
439}
440
441impl From<FlashEvent> for FlashUpdate {
442    /// Convert a raw pipeline [`FlashEvent`] into a [`FlashUpdate`].
443    ///
444    /// - `Progress` raw byte counts → normalised `0.0–1.0` fraction.
445    /// - `VerifyProgress` per-pass fraction → overall `0.0–1.0` via
446    ///   [`verify_overall_progress`].
447    /// - `Stage` display string and `Log` string → `Message`.
448    /// - `Done` → `Completed`, `Error` → `Failed`.
449    fn from(event: FlashEvent) -> Self {
450        match event {
451            FlashEvent::Progress {
452                bytes_written,
453                total_bytes,
454                speed_mb_s,
455            } => {
456                let progress = if total_bytes > 0 {
457                    (bytes_written as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
458                } else {
459                    0.0
460                };
461                FlashUpdate::Progress {
462                    progress,
463                    bytes_written,
464                    speed_mb_s,
465                }
466            }
467
468            FlashEvent::VerifyProgress {
469                phase,
470                bytes_read,
471                total_bytes,
472                speed_mb_s,
473            } => {
474                let pass_fraction = if total_bytes > 0 {
475                    (bytes_read as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
476                } else {
477                    0.0
478                };
479                let overall = verify_overall_progress(phase, pass_fraction);
480                FlashUpdate::VerifyProgress {
481                    phase,
482                    overall,
483                    bytes_read,
484                    total_bytes,
485                    speed_mb_s,
486                }
487            }
488
489            FlashEvent::Stage(stage) => FlashUpdate::Message(stage.to_string()),
490            FlashEvent::Log(msg) => FlashUpdate::Message(msg),
491            FlashEvent::Done => FlashUpdate::Completed,
492            FlashEvent::Error(e) => FlashUpdate::Failed(e),
493        }
494    }
495}
496
497// ---------------------------------------------------------------------------
498// Public entry point
499// ---------------------------------------------------------------------------
500
501/// Run the full flash pipeline in the **calling thread**.
502///
503/// This function is blocking and must be called from a dedicated
504/// `std::thread::spawn` thread, not from an async executor.
505///
506/// # Arguments
507///
508/// * `image_path`  – path to the source image file
509/// * `device_path` – path to the target block device (e.g. `/dev/sdb`)
510/// * `tx`          – channel to send [`FlashEvent`] progress updates
511/// * `cancel`      – set to `true` to abort the pipeline between blocks
512pub fn run_pipeline(
513    image_path: &str,
514    device_path: &str,
515    tx: mpsc::Sender<FlashEvent>,
516    cancel: Arc<AtomicBool>,
517) {
518    if let Err(e) = flash_pipeline(image_path, device_path, &tx, cancel) {
519        let _ = tx.send(FlashEvent::Error(e));
520    }
521}
522
523// ---------------------------------------------------------------------------
524// Top-level pipeline
525// ---------------------------------------------------------------------------
526
527fn send(tx: &mpsc::Sender<FlashEvent>, event: FlashEvent) {
528    // If the receiver is gone the GUI has been closed — ignore silently.
529    let _ = tx.send(event);
530}
531
532fn flash_pipeline(
533    image_path: &str,
534    device_path: &str,
535    tx: &mpsc::Sender<FlashEvent>,
536    cancel: Arc<AtomicBool>,
537) -> Result<(), String> {
538    // ── Validate inputs ──────────────────────────────────────────────────────
539    if !Path::new(image_path).is_file() {
540        return Err(format!("Image file not found: {image_path}"));
541    }
542
543    if !Path::new(device_path).exists() {
544        return Err(format!("Target device not found: {device_path}"));
545    }
546
547    // Guard against partition nodes (e.g. /dev/sdb1 instead of /dev/sdb).
548    #[cfg(target_os = "linux")]
549    reject_partition_node(device_path)?;
550
551    let image_size = std::fs::metadata(image_path)
552        .map_err(|e| format!("Cannot stat image: {e}"))?
553        .len();
554
555    if image_size == 0 {
556        return Err("Image file is empty".to_string());
557    }
558
559    // ── Step 1: Unmount ──────────────────────────────────────────────────────
560    send(tx, FlashEvent::Stage(FlashStage::Unmounting));
561    unmount_device(device_path, tx);
562
563    // ── Check device is not already in use ───────────────────────────────────
564    // Open the device O_RDONLY | O_EXCL *after* unmounting. If a partition was
565    // merely mounted beforehand the unmount above will have cleared it. If
566    // this still returns EBUSY it means a genuinely foreign process (e.g. a
567    // second flashkraft instance) has the device open for writing.
568    #[cfg(target_os = "linux")]
569    check_device_not_busy(device_path)?;
570
571    // ── Step 2: Write ────────────────────────────────────────────────────────
572    send(tx, FlashEvent::Stage(FlashStage::Writing));
573    send(
574        tx,
575        FlashEvent::Log(format!(
576            "Writing {image_size} bytes from {image_path} → {device_path}"
577        )),
578    );
579    write_image(image_path, device_path, image_size, tx, &cancel)?;
580
581    // ── Step 3: Sync ─────────────────────────────────────────────────────────
582    send(tx, FlashEvent::Stage(FlashStage::Syncing));
583    sync_device(device_path, tx);
584
585    // ── Step 4: Re-read partition table ──────────────────────────────────────
586    send(tx, FlashEvent::Stage(FlashStage::Rereading));
587    reread_partition_table(device_path, tx);
588
589    // ── Step 5: Verify ───────────────────────────────────────────────────────
590    send(tx, FlashEvent::Stage(FlashStage::Verifying));
591    verify(image_path, device_path, image_size, tx)?;
592
593    // ── Done ─────────────────────────────────────────────────────────────────
594    send(tx, FlashEvent::Done);
595    Ok(())
596}
597
598// ---------------------------------------------------------------------------
599// Device-busy guard (Linux only)
600// ---------------------------------------------------------------------------
601
602/// Try to open `device_path` with `O_RDONLY | O_EXCL`. On Linux this fails
603/// with `EBUSY` if a foreign process already holds the device open exclusively
604/// (e.g. a second flashkraft instance). Any other error is ignored here and
605/// will be caught properly when we open the device for writing.
606///
607/// This function is separate so it can be unit-tested by injecting a synthetic
608/// `io::Error`.
609#[cfg(target_os = "linux")]
610fn check_device_not_busy(device_path: &str) -> Result<(), String> {
611    check_device_not_busy_with(device_path, |path| {
612        use std::os::unix::fs::OpenOptionsExt;
613        std::fs::OpenOptions::new()
614            .read(true)
615            .custom_flags(libc::O_EXCL)
616            .open(path)
617            .map(|_| ())
618    })
619}
620
621/// Inner implementation — accepts an injectable opener so tests can supply a
622/// synthetic `EBUSY` without needing a real block device.
623#[cfg(target_os = "linux")]
624fn check_device_not_busy_with<F>(device_path: &str, open_fn: F) -> Result<(), String>
625where
626    F: FnOnce(&str) -> std::io::Result<()>,
627{
628    if let Err(e) = open_fn(device_path) {
629        if e.raw_os_error() == Some(libc::EBUSY) {
630            return Err(format!(
631                "Device '{device_path}' is already in use by another process.\n\
632                 Is another flash operation already running?"
633            ));
634        }
635        // Any other error (EPERM, EACCES) is fine — handled when opening for writing.
636    }
637    Ok(())
638}
639
640// ---------------------------------------------------------------------------
641// Partition-node guard (Linux only)
642// ---------------------------------------------------------------------------
643
644#[cfg(target_os = "linux")]
645fn reject_partition_node(device_path: &str) -> Result<(), String> {
646    let dev_name = Path::new(device_path)
647        .file_name()
648        .map(|n| n.to_string_lossy().to_string())
649        .unwrap_or_default();
650
651    let is_partition = {
652        let bytes = dev_name.as_bytes();
653        !bytes.is_empty() && bytes[bytes.len() - 1].is_ascii_digit() && {
654            let stem = dev_name.trim_end_matches(|c: char| c.is_ascii_digit());
655            stem.ends_with('p')
656                || (!stem.is_empty()
657                    && !stem.ends_with(|c: char| c.is_ascii_digit())
658                    && stem.chars().any(|c| c.is_ascii_alphabetic()))
659        }
660    };
661
662    if is_partition {
663        let whole = dev_name.trim_end_matches(|c: char| c.is_ascii_digit() || c == 'p');
664        return Err(format!(
665            "Refusing to write to partition node '{device_path}'. \
666             Select the whole-disk device (e.g. /dev/{whole}) instead."
667        ));
668    }
669
670    Ok(())
671}
672
673// ---------------------------------------------------------------------------
674// Privilege helpers
675// ---------------------------------------------------------------------------
676
677/// Open `device_path` for raw writing, temporarily escalating to root if the
678/// binary is setuid-root, then immediately dropping back to the real UID.
679fn open_device_for_writing(device_path: &str) -> Result<std::fs::File, String> {
680    #[cfg(unix)]
681    {
682        use nix::unistd::seteuid;
683
684        // Attempt to escalate to root.
685        //
686        // This only succeeds when the binary carries the setuid-root bit
687        // (`chmod u+s`).  If escalation fails we still try to open the file —
688        // it may be a regular writable file (e.g. during tests) or the user
689        // may already have write permission on the device.
690        let escalated = seteuid(nix::unistd::Uid::from_raw(0)).is_ok();
691
692        let result = std::fs::OpenOptions::new()
693            .write(true)
694            .open(device_path)
695            .map_err(|e| {
696                let raw = e.raw_os_error().unwrap_or(0);
697                if raw == libc::EACCES || raw == libc::EPERM {
698                    if escalated {
699                        format!(
700                            "Permission denied opening '{device_path}'.\n\
701                             Even with setuid-root the device refused access — \
702                             check that the device exists and is not in use."
703                        )
704                    } else {
705                        format!(
706                            "Permission denied opening '{device_path}'.\n\
707                             FlashKraft needs root access to write to block devices.\n\
708                             Install setuid-root so it can escalate automatically:\n\
709                             sudo chown root:root /usr/bin/flashkraft\n\
710                             sudo chmod u+s /usr/bin/flashkraft"
711                        )
712                    }
713                } else if raw == libc::EBUSY {
714                    format!(
715                        "Device '{device_path}' is busy. \
716                         Ensure all partitions are unmounted before flashing."
717                    )
718                } else {
719                    format!("Cannot open device '{device_path}' for writing: {e}")
720                }
721            });
722
723        // Drop back to the real (unprivileged) user immediately.
724        if escalated {
725            let _ = seteuid(real_uid());
726        }
727
728        result
729    }
730
731    #[cfg(not(unix))]
732    {
733        std::fs::OpenOptions::new()
734            .write(true)
735            .open(device_path)
736            .map_err(|e| {
737                let raw = e.raw_os_error().unwrap_or(0);
738                // ERROR_ACCESS_DENIED (5) or ERROR_PRIVILEGE_NOT_HELD (1314)
739                if raw == 5 || raw == 1314 {
740                    format!(
741                        "Access denied opening '{device_path}'.\n\
742                         FlashKraft must be run as Administrator on Windows.\n\
743                         Right-click the application and choose \
744                         'Run as administrator'."
745                    )
746                } else if raw == 32 {
747                    // ERROR_SHARING_VIOLATION
748                    format!(
749                        "Device '{device_path}' is in use by another process.\n\
750                         Close any applications using the drive and try again."
751                    )
752                } else {
753                    format!("Cannot open device '{device_path}' for writing: {e}")
754                }
755            })
756    }
757}
758
759// ---------------------------------------------------------------------------
760// Step 1 – Unmount
761// ---------------------------------------------------------------------------
762
763fn unmount_device(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
764    let device_name = Path::new(device_path)
765        .file_name()
766        .map(|n| n.to_string_lossy().to_string())
767        .unwrap_or_default();
768
769    let partitions = find_mounted_partitions(&device_name, device_path);
770
771    if partitions.is_empty() {
772        send(tx, FlashEvent::Log("No mounted partitions found".into()));
773    } else {
774        for partition in &partitions {
775            send(tx, FlashEvent::Log(format!("Unmounting {partition}")));
776            do_unmount(partition, tx);
777        }
778    }
779}
780
781/// Returns the list of mounted partitions/volumes that belong to `device_path`.
782///
783/// On Linux/macOS this parses `/proc/mounts`.
784/// On Windows this enumerates logical drive letters, resolves each to its
785/// underlying physical device via `QueryDosDeviceW`, and returns the volume
786/// paths (e.g. `\\.\C:`) whose physical device number matches `device_path`
787/// (e.g. `\\.\PhysicalDrive1`).
788fn find_mounted_partitions(
789    #[cfg_attr(target_os = "windows", allow(unused_variables))] device_name: &str,
790    device_path: &str,
791) -> Vec<String> {
792    #[cfg(not(target_os = "windows"))]
793    {
794        let mounts = std::fs::read_to_string("/proc/mounts")
795            .or_else(|_| std::fs::read_to_string("/proc/self/mounts"))
796            .unwrap_or_default();
797
798        let mut mount_points = Vec::new();
799        for line in mounts.lines() {
800            let mut fields = line.split_whitespace();
801            let dev = match fields.next() {
802                Some(d) => d,
803                None => continue,
804            };
805            // Second field in /proc/mounts is the mount point directory —
806            // that is what umount2 requires, not the device path.
807            let mount_point = match fields.next() {
808                Some(m) => m,
809                None => continue,
810            };
811            if dev == device_path || is_partition_of(dev, device_name) {
812                mount_points.push(mount_point.to_string());
813            }
814        }
815        mount_points
816    }
817
818    #[cfg(target_os = "windows")]
819    {
820        windows::find_volumes_on_physical_drive(device_path)
821    }
822}
823
824#[cfg(not(target_os = "windows"))]
825fn is_partition_of(dev: &str, device_name: &str) -> bool {
826    // `dev` may be a full path like "/dev/sda1"; compare only the basename.
827    let dev_base = Path::new(dev)
828        .file_name()
829        .map(|n| n.to_string_lossy())
830        .unwrap_or_default();
831
832    if !dev_base.starts_with(device_name) {
833        return false;
834    }
835    let suffix = &dev_base[device_name.len()..];
836    if suffix.is_empty() {
837        return false;
838    }
839    let first = suffix.chars().next().unwrap();
840    first.is_ascii_digit() || (first == 'p' && suffix.len() > 1)
841}
842
843/// Return `true` if `name` resolves to an executable on `PATH`.
844/// Used by `do_unmount` to prefer `udisksctl` when available.
845#[cfg(target_os = "linux")]
846fn which_exists(name: &str) -> bool {
847    use std::os::unix::fs::PermissionsExt;
848    std::env::var("PATH")
849        .unwrap_or_default()
850        .split(':')
851        .any(|dir| {
852            let p = std::path::Path::new(dir).join(name);
853            std::fs::metadata(&p)
854                .map(|m| m.is_file() && m.permissions().mode() & 0o111 != 0)
855                .unwrap_or(false)
856        })
857}
858
859fn do_unmount(partition: &str, tx: &mpsc::Sender<FlashEvent>) {
860    #[cfg(target_os = "linux")]
861    {
862        use nix::unistd::seteuid;
863        use std::ffi::CString;
864
865        // ── Strategy 1: udisksctl ─────────────────────────────────────────────
866        // Prefer udisksctl when available — it signals udisks2/systemd-mount to
867        // properly release the mount and prevents the automounter from
868        // immediately re-mounting the partition after we detach it.
869        // `--no-user-interaction` prevents it from blocking on a password prompt.
870        if which_exists("udisksctl") {
871            // Spawn with a timeout — udisksctl can stall if udisks2 is busy.
872            // We give it 5 seconds before falling through to umount2.
873            let result = std::process::Command::new("udisksctl")
874                .args(["unmount", "--no-user-interaction", "-b", partition])
875                .stdout(std::process::Stdio::null())
876                .stderr(std::process::Stdio::null())
877                .spawn();
878
879            let udisks_ok = match result {
880                Ok(mut child) => {
881                    // Poll for up to 5 s in 100 ms increments.
882                    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
883                    loop {
884                        match child.try_wait() {
885                            Ok(Some(status)) => break status.success(),
886                            Ok(None) if std::time::Instant::now() < deadline => {
887                                std::thread::sleep(std::time::Duration::from_millis(100));
888                            }
889                            _ => {
890                                // Timed out or error — kill and fall through.
891                                let _ = child.kill();
892                                send(
893                                    tx,
894                                    FlashEvent::Log(
895                                        "udisksctl timed out — falling back to umount2".into(),
896                                    ),
897                                );
898                                break false;
899                            }
900                        }
901                    }
902                }
903                Err(_) => false,
904            };
905
906            if udisks_ok {
907                send(
908                    tx,
909                    FlashEvent::Log(format!("Unmounted {partition} via udisksctl")),
910                );
911                return;
912            }
913        }
914
915        // ── Strategy 2: umount2 MNT_DETACH (fallback) ────────────────────────
916        // Lazy unmount: detaches the filesystem immediately even if busy.
917        // umount2 never blocks — MNT_DETACH returns right away.
918        let _ = seteuid(nix::unistd::Uid::from_raw(0));
919
920        if let Ok(c_path) = CString::new(partition) {
921            let ret = unsafe { libc::umount2(c_path.as_ptr(), libc::MNT_DETACH) };
922            if ret != 0 {
923                let raw = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
924                match raw {
925                    // EINVAL — not a mount point or already unmounted, harmless.
926                    libc::EINVAL => {}
927                    // ENOENT — path doesn't exist, also harmless.
928                    libc::ENOENT => {}
929                    _ => {
930                        let err = std::io::Error::from_raw_os_error(raw);
931                        send(
932                            tx,
933                            FlashEvent::Log(format!(
934                                "Warning — could not unmount {partition}: {err}"
935                            )),
936                        );
937                    }
938                }
939            }
940        }
941
942        let _ = seteuid(real_uid());
943    }
944
945    #[cfg(target_os = "macos")]
946    {
947        let out = std::process::Command::new("diskutil")
948            .args(["unmount", partition])
949            .output();
950        if let Ok(o) = out {
951            if !o.status.success() {
952                send(
953                    tx,
954                    FlashEvent::Log(format!("Warning — diskutil unmount {partition} failed")),
955                );
956            }
957        }
958    }
959
960    // Windows: open the volume with exclusive access, lock it, then dismount.
961    // The volume path is expected to be of the form `\\.\C:` (no trailing slash).
962    #[cfg(target_os = "windows")]
963    {
964        match windows::lock_and_dismount_volume(partition) {
965            Ok(()) => send(
966                tx,
967                FlashEvent::Log(format!("Dismounted volume {partition}")),
968            ),
969            Err(e) => send(
970                tx,
971                FlashEvent::Log(format!("Warning — could not dismount {partition}: {e}")),
972            ),
973        }
974    }
975}
976
977// ---------------------------------------------------------------------------
978// Step 2 – Write image
979// ---------------------------------------------------------------------------
980
981fn write_image(
982    image_path: &str,
983    device_path: &str,
984    image_size: u64,
985    tx: &mpsc::Sender<FlashEvent>,
986    cancel: &Arc<AtomicBool>,
987) -> Result<(), String> {
988    let image_file =
989        std::fs::File::open(image_path).map_err(|e| format!("Cannot open image: {e}"))?;
990
991    let device_file = open_device_for_writing(device_path)?;
992
993    let mut reader = io::BufReader::with_capacity(BLOCK_SIZE, image_file);
994    let mut writer = io::BufWriter::with_capacity(BLOCK_SIZE, device_file);
995    let mut buf = vec![0u8; BLOCK_SIZE];
996
997    let mut bytes_written: u64 = 0;
998    let start = Instant::now();
999    let mut last_report = Instant::now();
1000
1001    loop {
1002        // Honour cancellation requests between blocks.
1003        if cancel.load(Ordering::SeqCst) {
1004            return Err("Flash operation cancelled by user".to_string());
1005        }
1006
1007        let n = reader
1008            .read(&mut buf)
1009            .map_err(|e| format!("Read error on image: {e}"))?;
1010
1011        if n == 0 {
1012            break; // EOF
1013        }
1014
1015        writer
1016            .write_all(&buf[..n])
1017            .map_err(|e| format!("Write error on device: {e}"))?;
1018
1019        bytes_written += n as u64;
1020
1021        let now = Instant::now();
1022        if now.duration_since(last_report) >= PROGRESS_INTERVAL || bytes_written >= image_size {
1023            let elapsed_s = now.duration_since(start).as_secs_f32();
1024            let speed_mb_s = if elapsed_s > 0.001 {
1025                (bytes_written as f32 / (1024.0 * 1024.0)) / elapsed_s
1026            } else {
1027                0.0
1028            };
1029
1030            send(
1031                tx,
1032                FlashEvent::Progress {
1033                    bytes_written,
1034                    total_bytes: image_size,
1035                    speed_mb_s,
1036                },
1037            );
1038            last_report = now;
1039        }
1040    }
1041
1042    // Flush BufWriter → kernel page cache.
1043    writer
1044        .flush()
1045        .map_err(|e| format!("Buffer flush error: {e}"))?;
1046
1047    // Retrieve the underlying File for fsync.
1048    #[cfg_attr(not(unix), allow(unused_variables))]
1049    let device_file = writer
1050        .into_inner()
1051        .map_err(|e| format!("BufWriter error: {e}"))?;
1052
1053    // fsync: push all dirty pages to the physical medium.
1054    // Treated as a hard error — a failed fsync means we cannot trust the
1055    // data reached the device.
1056    #[cfg(unix)]
1057    {
1058        use std::os::unix::io::AsRawFd;
1059        let fd = device_file.as_raw_fd();
1060        let ret = unsafe { libc::fsync(fd) };
1061        if ret != 0 {
1062            let err = std::io::Error::last_os_error();
1063            return Err(format!(
1064                "fsync failed on '{device_path}': {err} — \
1065                 data may not have been fully written to the device"
1066            ));
1067        }
1068    }
1069
1070    // Emit a final progress event at 100 %.
1071    let elapsed_s = start.elapsed().as_secs_f32();
1072    let speed_mb_s = if elapsed_s > 0.001 {
1073        (bytes_written as f32 / (1024.0 * 1024.0)) / elapsed_s
1074    } else {
1075        0.0
1076    };
1077    send(
1078        tx,
1079        FlashEvent::Progress {
1080            bytes_written,
1081            total_bytes: image_size,
1082            speed_mb_s,
1083        },
1084    );
1085
1086    send(tx, FlashEvent::Log("Image write complete".into()));
1087    Ok(())
1088}
1089
1090// ---------------------------------------------------------------------------
1091// Step 3 – Sync
1092// ---------------------------------------------------------------------------
1093
1094fn sync_device(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1095    #[cfg(unix)]
1096    if let Ok(f) = std::fs::OpenOptions::new().write(true).open(device_path) {
1097        use std::os::unix::io::AsRawFd;
1098        let fd = f.as_raw_fd();
1099        #[cfg(target_os = "linux")]
1100        unsafe {
1101            libc::fdatasync(fd);
1102        }
1103        #[cfg(not(target_os = "linux"))]
1104        unsafe {
1105            libc::fsync(fd);
1106        }
1107        drop(f);
1108    }
1109
1110    #[cfg(target_os = "linux")]
1111    unsafe {
1112        libc::sync();
1113    }
1114
1115    // Windows: open the physical drive and call FlushFileBuffers.
1116    // This forces the OS to flush all dirty pages for the device to hardware.
1117    #[cfg(target_os = "windows")]
1118    {
1119        match windows::flush_device_buffers(device_path) {
1120            Ok(()) => {}
1121            Err(e) => send(
1122                tx,
1123                FlashEvent::Log(format!(
1124                    "Warning — FlushFileBuffers on '{device_path}' failed: {e}"
1125                )),
1126            ),
1127        }
1128    }
1129
1130    send(tx, FlashEvent::Log("Write-back caches flushed".into()));
1131}
1132
1133// ---------------------------------------------------------------------------
1134// Step 4 – Re-read partition table
1135// ---------------------------------------------------------------------------
1136
1137#[cfg(target_os = "linux")]
1138fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1139    use nix::ioctl_none;
1140    use std::os::unix::io::AsRawFd;
1141
1142    ioctl_none!(blkrrpart, 0x12, 95);
1143
1144    // Brief pause so any pending I/O completes before we poke the kernel.
1145    std::thread::sleep(Duration::from_millis(500));
1146
1147    match std::fs::OpenOptions::new().write(true).open(device_path) {
1148        Ok(f) => {
1149            let result = unsafe { blkrrpart(f.as_raw_fd()) };
1150            match result {
1151                Ok(_) => send(
1152                    tx,
1153                    FlashEvent::Log("Kernel partition table refreshed".into()),
1154                ),
1155                Err(e) => send(
1156                    tx,
1157                    FlashEvent::Log(format!(
1158                        "Warning — BLKRRPART ioctl failed \
1159                         (device may not be partitioned): {e}"
1160                    )),
1161                ),
1162            }
1163        }
1164        Err(e) => send(
1165            tx,
1166            FlashEvent::Log(format!(
1167                "Warning — could not open device for BLKRRPART: {e}"
1168            )),
1169        ),
1170    }
1171}
1172
1173#[cfg(target_os = "macos")]
1174fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1175    let _ = std::process::Command::new("diskutil")
1176        .args(["rereadPartitionTable", device_path])
1177        .output();
1178    send(
1179        tx,
1180        FlashEvent::Log("Partition table refresh requested (macOS)".into()),
1181    );
1182}
1183
1184// Windows: IOCTL_DISK_UPDATE_PROPERTIES asks the partition manager to
1185// re-enumerate the partition table from the on-disk data.
1186#[cfg(target_os = "windows")]
1187fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1188    // Brief pause so the OS flushes before we poke the partition manager.
1189    std::thread::sleep(Duration::from_millis(500));
1190
1191    match windows::update_disk_properties(device_path) {
1192        Ok(()) => send(
1193            tx,
1194            FlashEvent::Log("Partition table refreshed (IOCTL_DISK_UPDATE_PROPERTIES)".into()),
1195        ),
1196        Err(e) => send(
1197            tx,
1198            FlashEvent::Log(format!(
1199                "Warning — IOCTL_DISK_UPDATE_PROPERTIES failed: {e}"
1200            )),
1201        ),
1202    }
1203}
1204
1205#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
1206fn reread_partition_table(_device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1207    send(
1208        tx,
1209        FlashEvent::Log("Partition table refresh not supported on this platform".into()),
1210    );
1211}
1212
1213// ---------------------------------------------------------------------------
1214// Step 5 – Verify
1215// ---------------------------------------------------------------------------
1216
1217fn verify(
1218    image_path: &str,
1219    device_path: &str,
1220    image_size: u64,
1221    tx: &mpsc::Sender<FlashEvent>,
1222) -> Result<(), String> {
1223    send(
1224        tx,
1225        FlashEvent::Log("Computing SHA-256 of source image".into()),
1226    );
1227    let image_hash = sha256_with_progress(image_path, image_size, "image", tx)?;
1228
1229    send(
1230        tx,
1231        FlashEvent::Log(format!(
1232            "Reading back {image_size} bytes from device for verification"
1233        )),
1234    );
1235    let device_hash = sha256_with_progress(device_path, image_size, "device", tx)?;
1236
1237    if image_hash != device_hash {
1238        return Err(format!(
1239            "Verification failed — data mismatch \
1240             (image={image_hash} device={device_hash})"
1241        ));
1242    }
1243
1244    send(
1245        tx,
1246        FlashEvent::Log(format!("Verification passed ({image_hash})")),
1247    );
1248    Ok(())
1249}
1250
1251/// Compute the SHA-256 digest of the first `max_bytes` of `path`, emitting
1252/// [`FlashEvent::VerifyProgress`] events at [`PROGRESS_INTERVAL`] intervals.
1253///
1254/// `phase` is forwarded verbatim into every `VerifyProgress` event so the
1255/// UI can distinguish the image-hash pass (`"image"`) from the device
1256/// read-back pass (`"device"`).
1257fn sha256_with_progress(
1258    path: &str,
1259    max_bytes: u64,
1260    phase: &'static str,
1261    tx: &mpsc::Sender<FlashEvent>,
1262) -> Result<String, String> {
1263    use sha2::{Digest, Sha256};
1264
1265    let file =
1266        std::fs::File::open(path).map_err(|e| format!("Cannot open {path} for hashing: {e}"))?;
1267
1268    let mut hasher = Sha256::new();
1269    let mut reader = io::BufReader::with_capacity(BLOCK_SIZE, file);
1270    let mut buf = vec![0u8; BLOCK_SIZE];
1271    let mut remaining = max_bytes;
1272    let mut bytes_read: u64 = 0;
1273
1274    let start = Instant::now();
1275    let mut last_report = Instant::now();
1276
1277    while remaining > 0 {
1278        let to_read = (remaining as usize).min(buf.len());
1279        let n = reader
1280            .read(&mut buf[..to_read])
1281            .map_err(|e| format!("Read error while hashing {path}: {e}"))?;
1282        if n == 0 {
1283            break;
1284        }
1285        hasher.update(&buf[..n]);
1286        bytes_read += n as u64;
1287        remaining -= n as u64;
1288
1289        let now = Instant::now();
1290        if now.duration_since(last_report) >= PROGRESS_INTERVAL || remaining == 0 {
1291            let elapsed_s = now.duration_since(start).as_secs_f32();
1292            let speed_mb_s = if elapsed_s > 0.001 {
1293                (bytes_read as f32 / (1024.0 * 1024.0)) / elapsed_s
1294            } else {
1295                0.0
1296            };
1297            send(
1298                tx,
1299                FlashEvent::VerifyProgress {
1300                    phase,
1301                    bytes_read,
1302                    total_bytes: max_bytes,
1303                    speed_mb_s,
1304                },
1305            );
1306            last_report = now;
1307        }
1308    }
1309
1310    Ok(format!("{:x}", hasher.finalize()))
1311}
1312
1313/// Legacy non-progress variant kept for unit tests that don't need a channel.
1314#[cfg(test)]
1315fn sha256_first_n_bytes(path: &str, max_bytes: u64) -> Result<String, String> {
1316    let (tx, _rx) = mpsc::channel();
1317    sha256_with_progress(path, max_bytes, "image", &tx)
1318}
1319
1320// ---------------------------------------------------------------------------
1321// Windows implementation helpers
1322// ---------------------------------------------------------------------------
1323
1324/// All Windows-specific raw-device operations are collected here.
1325///
1326/// ## Privilege
1327/// The binary must be run as Administrator (the UAC manifest embedded by
1328/// `build.rs` ensures Windows prompts for elevation on launch).  Raw physical
1329/// drive access (`\\.\PhysicalDriveN`) and volume lock/dismount both require
1330/// the `SeManageVolumePrivilege` that is only present in an elevated token.
1331///
1332/// ## Volume vs physical drive paths
1333/// - **Physical drive**: `\\.\PhysicalDrive0`, `\\.\PhysicalDrive1`, …
1334///   Used for writing the image, flushing, and partition-table refresh.
1335/// - **Volume (drive letter)**: `\\.\C:`, `\\.\D:`, …
1336///   Used for locking and dismounting before we write.
1337#[cfg(target_os = "windows")]
1338mod windows {
1339    // ── Win32 type aliases ────────────────────────────────────────────────────
1340    // windows-sys uses raw C types; give them readable names.
1341    use windows_sys::Win32::{
1342        Foundation::{
1343            CloseHandle, FALSE, GENERIC_READ, GENERIC_WRITE, HANDLE, INVALID_HANDLE_VALUE,
1344        },
1345        Storage::FileSystem::{
1346            CreateFileW, FlushFileBuffers, FILE_FLAG_WRITE_THROUGH, FILE_SHARE_READ,
1347            FILE_SHARE_WRITE, OPEN_EXISTING,
1348        },
1349        System::{
1350            Ioctl::{FSCTL_DISMOUNT_VOLUME, FSCTL_LOCK_VOLUME, IOCTL_DISK_UPDATE_PROPERTIES},
1351            IO::DeviceIoControl,
1352        },
1353    };
1354
1355    // ── Helpers ───────────────────────────────────────────────────────────────
1356
1357    /// Encode a Rust `&str` as a null-terminated UTF-16 `Vec<u16>`.
1358    fn to_wide(s: &str) -> Vec<u16> {
1359        use std::os::windows::ffi::OsStrExt;
1360        std::ffi::OsStr::new(s)
1361            .encode_wide()
1362            .chain(std::iter::once(0))
1363            .collect()
1364    }
1365
1366    /// Open a device path (`\\.\PhysicalDriveN` or `\\.\C:`) and return its
1367    /// Win32 `HANDLE`.  The handle must be closed with `CloseHandle` when done.
1368    ///
1369    /// `access` should be `GENERIC_READ`, `GENERIC_WRITE`, or both OR-ed.
1370    fn open_device_handle(path: &str, access: u32) -> Result<HANDLE, String> {
1371        let wide = to_wide(path);
1372        let handle = unsafe {
1373            CreateFileW(
1374                wide.as_ptr(),
1375                access,
1376                FILE_SHARE_READ | FILE_SHARE_WRITE,
1377                std::ptr::null(),
1378                OPEN_EXISTING,
1379                FILE_FLAG_WRITE_THROUGH,
1380                std::ptr::null_mut(),
1381            )
1382        };
1383        if handle == INVALID_HANDLE_VALUE {
1384            Err(format!(
1385                "Cannot open device '{}': {}",
1386                path,
1387                std::io::Error::last_os_error()
1388            ))
1389        } else {
1390            Ok(handle)
1391        }
1392    }
1393
1394    /// Issue a simple `DeviceIoControl` call with no input or output buffer.
1395    ///
1396    /// Returns `Ok(())` on success, or an `Err` with the Win32 error message.
1397    fn device_ioctl(handle: HANDLE, code: u32) -> Result<(), String> {
1398        let mut bytes_returned: u32 = 0;
1399        let ok = unsafe {
1400            DeviceIoControl(
1401                handle,
1402                code,
1403                std::ptr::null(), // no input buffer
1404                0,
1405                std::ptr::null_mut(), // no output buffer
1406                0,
1407                &mut bytes_returned,
1408                std::ptr::null_mut(), // synchronous (no OVERLAPPED)
1409            )
1410        };
1411        if ok == FALSE {
1412            Err(format!("{}", std::io::Error::last_os_error()))
1413        } else {
1414            Ok(())
1415        }
1416    }
1417
1418    // ── Public helpers called from flash_helper ───────────────────────────────
1419
1420    /// Enumerate all logical drive letters whose underlying physical device
1421    /// path matches `physical_drive` (e.g. `\\.\PhysicalDrive1`).
1422    ///
1423    /// Returns a list of volume paths suitable for passing to
1424    /// `lock_and_dismount_volume`, e.g. `["\\.\C:", "\\.\D:"]`.
1425    ///
1426    /// Algorithm:
1427    /// 1. Obtain the physical drive number from `physical_drive`.
1428    /// 2. Call `GetLogicalDriveStringsW` to list all drive letters.
1429    /// 3. For each letter, open the volume and call `IOCTL_STORAGE_GET_DEVICE_NUMBER`
1430    ///    to get its physical drive number.
1431    /// 4. Collect those whose number matches.
1432    pub fn find_volumes_on_physical_drive(physical_drive: &str) -> Vec<String> {
1433        use windows_sys::Win32::{
1434            Storage::FileSystem::GetLogicalDriveStringsW,
1435            System::Ioctl::{IOCTL_STORAGE_GET_DEVICE_NUMBER, STORAGE_DEVICE_NUMBER},
1436        };
1437
1438        // Extract the drive index from "\\.\PhysicalDriveN".
1439        let target_index: u32 = physical_drive
1440            .to_ascii_lowercase()
1441            .trim_start_matches(r"\\.\physicaldrive")
1442            .parse()
1443            .unwrap_or(u32::MAX);
1444
1445        // Get all logical drive strings ("C:\", "D:\", …).
1446        let mut buf = vec![0u16; 512];
1447        let len = unsafe { GetLogicalDriveStringsW(buf.len() as u32, buf.as_mut_ptr()) };
1448        if len == 0 || len > buf.len() as u32 {
1449            return Vec::new();
1450        }
1451
1452        // Parse the null-separated, double-null-terminated list.
1453        let drive_letters: Vec<String> = buf[..len as usize]
1454            .split(|&c| c == 0)
1455            .filter(|s| !s.is_empty())
1456            .map(|s| {
1457                // "C:\" → "\\.\C:"  (no trailing backslash — required for
1458                // CreateFileW on a volume)
1459                let letter: String = std::char::from_u32(s[0] as u32)
1460                    .map(|c| c.to_string())
1461                    .unwrap_or_default();
1462                format!(r"\\.\{}:", letter)
1463            })
1464            .collect();
1465
1466        let mut matching = Vec::new();
1467
1468        for vol_path in &drive_letters {
1469            let wide = to_wide(vol_path);
1470            let handle = unsafe {
1471                CreateFileW(
1472                    wide.as_ptr(),
1473                    GENERIC_READ,
1474                    FILE_SHARE_READ | FILE_SHARE_WRITE,
1475                    std::ptr::null(),
1476                    OPEN_EXISTING,
1477                    0,
1478                    std::ptr::null_mut(),
1479                )
1480            };
1481            if handle == INVALID_HANDLE_VALUE {
1482                continue;
1483            }
1484
1485            let mut dev_num = STORAGE_DEVICE_NUMBER {
1486                DeviceType: 0,
1487                DeviceNumber: u32::MAX,
1488                PartitionNumber: 0,
1489            };
1490            let mut bytes_returned: u32 = 0;
1491
1492            let ok = unsafe {
1493                DeviceIoControl(
1494                    handle,
1495                    IOCTL_STORAGE_GET_DEVICE_NUMBER,
1496                    std::ptr::null(),
1497                    0,
1498                    &mut dev_num as *mut _ as *mut _,
1499                    std::mem::size_of::<STORAGE_DEVICE_NUMBER>() as u32,
1500                    &mut bytes_returned,
1501                    std::ptr::null_mut(),
1502                )
1503            };
1504
1505            unsafe { CloseHandle(handle) };
1506
1507            if ok != FALSE && dev_num.DeviceNumber == target_index {
1508                matching.push(vol_path.clone());
1509            }
1510        }
1511
1512        matching
1513    }
1514
1515    /// Lock a volume exclusively and dismount it so writes to the underlying
1516    /// physical disk can proceed without the filesystem intercepting I/O.
1517    ///
1518    /// Steps (mirrors what Rufus / dd for Windows do):
1519    /// 1. Open the volume with `GENERIC_READ | GENERIC_WRITE`.
1520    /// 2. `FSCTL_LOCK_VOLUME`   — exclusive lock; fails if files are open.
1521    /// 3. `FSCTL_DISMOUNT_VOLUME` — tell the FS driver to flush and detach.
1522    ///
1523    /// The lock is held for the lifetime of the handle.  Because we close the
1524    /// handle immediately after dismounting, the volume is automatically
1525    /// unlocked (`FSCTL_UNLOCK_VOLUME` is implicit on handle close).
1526    pub fn lock_and_dismount_volume(volume_path: &str) -> Result<(), String> {
1527        let handle = open_device_handle(volume_path, GENERIC_READ | GENERIC_WRITE)?;
1528
1529        // Lock — exclusive; if this fails (files open) we still try to
1530        // dismount because the user may have opened Explorer on the drive.
1531        let lock_result = device_ioctl(handle, FSCTL_LOCK_VOLUME);
1532        if let Err(ref e) = lock_result {
1533            // Non-fatal: log and continue.  Dismount can still succeed.
1534            eprintln!(
1535                "[flash] FSCTL_LOCK_VOLUME on '{volume_path}' failed ({e}); \
1536                 attempting dismount anyway"
1537            );
1538        }
1539
1540        // Dismount — detaches the filesystem; flushes dirty data first.
1541        let dismount_result = device_ioctl(handle, FSCTL_DISMOUNT_VOLUME);
1542
1543        unsafe { CloseHandle(handle) };
1544
1545        lock_result.and(dismount_result)
1546    }
1547
1548    /// Call `FlushFileBuffers` on the physical drive to force the OS to push
1549    /// all dirty write-back pages to the device hardware.
1550    pub fn flush_device_buffers(device_path: &str) -> Result<(), String> {
1551        let handle = open_device_handle(device_path, GENERIC_WRITE)?;
1552        let ok = unsafe { FlushFileBuffers(handle) };
1553        unsafe { CloseHandle(handle) };
1554        if ok == FALSE {
1555            Err(format!("{}", std::io::Error::last_os_error()))
1556        } else {
1557            Ok(())
1558        }
1559    }
1560
1561    /// Send `IOCTL_DISK_UPDATE_PROPERTIES` to the physical drive, asking the
1562    /// Windows partition manager to re-read the partition table from disk.
1563    pub fn update_disk_properties(device_path: &str) -> Result<(), String> {
1564        let handle = open_device_handle(device_path, GENERIC_READ | GENERIC_WRITE)?;
1565        let result = device_ioctl(handle, IOCTL_DISK_UPDATE_PROPERTIES);
1566        unsafe { CloseHandle(handle) };
1567        result
1568    }
1569
1570    // ── Unit tests ────────────────────────────────────────────────────────────
1571
1572    #[cfg(test)]
1573    mod tests {
1574        use super::*;
1575
1576        /// `to_wide` must produce a null-terminated UTF-16 sequence.
1577        #[test]
1578        fn test_to_wide_null_terminated() {
1579            let wide = to_wide("ABC");
1580            assert_eq!(wide.last(), Some(&0u16), "must be null-terminated");
1581            assert_eq!(&wide[..3], &[b'A' as u16, b'B' as u16, b'C' as u16]);
1582        }
1583
1584        /// `to_wide` on an empty string produces exactly one null.
1585        #[test]
1586        fn test_to_wide_empty() {
1587            let wide = to_wide("");
1588            assert_eq!(wide, vec![0u16]);
1589        }
1590
1591        /// `open_device_handle` on a nonexistent path must return an error.
1592        #[test]
1593        fn test_open_device_handle_bad_path_returns_error() {
1594            let result = open_device_handle(r"\\.\NonExistentDevice999", GENERIC_READ);
1595            assert!(result.is_err(), "expected error for nonexistent device");
1596        }
1597
1598        /// `flush_device_buffers` on a nonexistent drive must return an error.
1599        #[test]
1600        fn test_flush_device_buffers_bad_path() {
1601            let result = flush_device_buffers(r"\\.\PhysicalDrive999");
1602            assert!(result.is_err());
1603        }
1604
1605        /// `update_disk_properties` on a nonexistent drive must return an error.
1606        #[test]
1607        fn test_update_disk_properties_bad_path() {
1608            let result = update_disk_properties(r"\\.\PhysicalDrive999");
1609            assert!(result.is_err());
1610        }
1611
1612        /// `lock_and_dismount_volume` on a nonexistent path must return an error.
1613        #[test]
1614        fn test_lock_and_dismount_bad_path() {
1615            let result = lock_and_dismount_volume(r"\\.\Z99:");
1616            assert!(result.is_err());
1617        }
1618
1619        /// `find_volumes_on_physical_drive` with an unparseable path should
1620        /// return an empty Vec (no panic).
1621        #[test]
1622        fn test_find_volumes_bad_path_no_panic() {
1623            let result = find_volumes_on_physical_drive("not-a-valid-path");
1624            // May be empty or contain volumes; must not panic.
1625            let _ = result;
1626        }
1627
1628        /// `find_volumes_on_physical_drive` for a very high drive number
1629        /// (almost certainly nonexistent) should return an empty list.
1630        #[test]
1631        fn test_find_volumes_nonexistent_drive_returns_empty() {
1632            let result = find_volumes_on_physical_drive(r"\\.\PhysicalDrive999");
1633            assert!(
1634                result.is_empty(),
1635                "expected no volumes for PhysicalDrive999"
1636            );
1637        }
1638    }
1639}
1640
1641// ---------------------------------------------------------------------------
1642// Tests
1643// ---------------------------------------------------------------------------
1644
1645#[cfg(test)]
1646mod tests {
1647    use super::*;
1648    use std::io::Write;
1649    use std::sync::mpsc;
1650
1651    fn make_channel() -> (mpsc::Sender<FlashEvent>, mpsc::Receiver<FlashEvent>) {
1652        mpsc::channel()
1653    }
1654
1655    fn drain(rx: &mpsc::Receiver<FlashEvent>) -> Vec<FlashEvent> {
1656        let mut events = Vec::new();
1657        while let Ok(e) = rx.try_recv() {
1658            events.push(e);
1659        }
1660        events
1661    }
1662
1663    fn has_stage(events: &[FlashEvent], stage: &FlashStage) -> bool {
1664        events
1665            .iter()
1666            .any(|e| matches!(e, FlashEvent::Stage(s) if s == stage))
1667    }
1668
1669    fn find_error(events: &[FlashEvent]) -> Option<&str> {
1670        events.iter().find_map(|e| {
1671            if let FlashEvent::Error(msg) = e {
1672                Some(msg.as_str())
1673            } else {
1674                None
1675            }
1676        })
1677    }
1678
1679    // ── set_real_uid ────────────────────────────────────────────────────────
1680
1681    #[test]
1682    fn test_is_privileged_returns_bool() {
1683        // Just verify it doesn't panic and returns a consistent value.
1684        let first = is_privileged();
1685        let second = is_privileged();
1686        assert_eq!(first, second, "is_privileged must be deterministic");
1687    }
1688
1689    #[test]
1690    fn test_reexec_as_root_does_not_panic_when_already_escalated() {
1691        // With the guard env-var set, reexec_as_root must return immediately
1692        // without panicking or actually exec-ing anything.
1693        std::env::set_var("FLASHKRAFT_ESCALATED", "1");
1694        reexec_as_root(); // must not exec — guard fires immediately
1695        std::env::remove_var("FLASHKRAFT_ESCALATED");
1696    }
1697
1698    #[test]
1699    fn test_set_real_uid_stores_value() {
1700        // OnceLock only sets once; in tests the first call wins.
1701        // Just verify it doesn't panic.
1702        set_real_uid(1000);
1703    }
1704
1705    // ── is_partition_of ─────────────────────────────────────────────────────
1706
1707    #[test]
1708    #[cfg(not(target_os = "windows"))]
1709    fn test_is_partition_of_sda() {
1710        assert!(is_partition_of("/dev/sda1", "sda"));
1711        assert!(is_partition_of("/dev/sda2", "sda"));
1712        assert!(!is_partition_of("/dev/sdb1", "sda"));
1713        assert!(!is_partition_of("/dev/sda", "sda"));
1714    }
1715
1716    #[test]
1717    #[cfg(not(target_os = "windows"))]
1718    fn test_is_partition_of_nvme() {
1719        assert!(is_partition_of("/dev/nvme0n1p1", "nvme0n1"));
1720        assert!(is_partition_of("/dev/nvme0n1p2", "nvme0n1"));
1721        assert!(!is_partition_of("/dev/nvme0n1", "nvme0n1"));
1722    }
1723
1724    #[test]
1725    #[cfg(not(target_os = "windows"))]
1726    fn test_is_partition_of_mmcblk() {
1727        assert!(is_partition_of("/dev/mmcblk0p1", "mmcblk0"));
1728        assert!(!is_partition_of("/dev/mmcblk0", "mmcblk0"));
1729    }
1730
1731    #[test]
1732    #[cfg(not(target_os = "windows"))]
1733    fn test_is_partition_of_no_false_prefix_match() {
1734        assert!(!is_partition_of("/dev/sda1", "sd"));
1735    }
1736
1737    // ── reject_partition_node ───────────────────────────────────────────────
1738
1739    #[test]
1740    #[cfg(target_os = "linux")]
1741    fn test_reject_partition_node_sda1() {
1742        let dir = std::env::temp_dir();
1743        let img = dir.join("fk_reject_img.bin");
1744        std::fs::write(&img, vec![0u8; 1024]).unwrap();
1745
1746        let result = reject_partition_node("/dev/sda1");
1747        assert!(result.is_err());
1748        assert!(result.unwrap_err().contains("Refusing"));
1749
1750        let _ = std::fs::remove_file(img);
1751    }
1752
1753    #[test]
1754    #[cfg(target_os = "linux")]
1755    fn test_reject_partition_node_nvme() {
1756        let result = reject_partition_node("/dev/nvme0n1p1");
1757        assert!(result.is_err());
1758        assert!(result.unwrap_err().contains("Refusing"));
1759    }
1760
1761    #[test]
1762    #[cfg(target_os = "linux")]
1763    fn test_reject_partition_node_accepts_whole_disk() {
1764        // /dev/sdb does not exist in CI — the function only checks the name
1765        // pattern, not whether the path exists.
1766        let result = reject_partition_node("/dev/sdb");
1767        assert!(result.is_ok(), "whole-disk node should not be rejected");
1768    }
1769
1770    // ── find_mounted_partitions ──────────────────────────────────────────────
1771
1772    #[test]
1773    fn test_find_mounted_partitions_parses_proc_mounts_format() {
1774        // We cannot mock /proc/mounts in a unit test so we just verify the
1775        // function doesn't panic and returns a Vec.
1776        let result = find_mounted_partitions("sda", "/dev/sda");
1777        let _ = result; // any result is valid
1778    }
1779
1780    // ── sha256_first_n_bytes ─────────────────────────────────────────────────
1781
1782    #[test]
1783    fn test_sha256_full_file() {
1784        use sha2::{Digest, Sha256};
1785
1786        let dir = std::env::temp_dir();
1787        let path = dir.join("fk_sha256_full.bin");
1788        let data: Vec<u8> = (0u8..=255u8).cycle().take(4096).collect();
1789        std::fs::write(&path, &data).unwrap();
1790
1791        let result = sha256_first_n_bytes(path.to_str().unwrap(), data.len() as u64).unwrap();
1792        let expected = format!("{:x}", Sha256::digest(&data));
1793        assert_eq!(result, expected);
1794
1795        let _ = std::fs::remove_file(path);
1796    }
1797
1798    #[test]
1799    fn test_sha256_partial() {
1800        use sha2::{Digest, Sha256};
1801
1802        let dir = std::env::temp_dir();
1803        let path = dir.join("fk_sha256_partial.bin");
1804        let data: Vec<u8> = (0u8..=255u8).cycle().take(8192).collect();
1805        std::fs::write(&path, &data).unwrap();
1806
1807        let n = 4096u64;
1808        let result = sha256_first_n_bytes(path.to_str().unwrap(), n).unwrap();
1809        let expected = format!("{:x}", Sha256::digest(&data[..n as usize]));
1810        assert_eq!(result, expected);
1811
1812        let _ = std::fs::remove_file(path);
1813    }
1814
1815    #[test]
1816    fn test_sha256_nonexistent_returns_error() {
1817        let result = sha256_first_n_bytes("/nonexistent/path.bin", 1024);
1818        assert!(result.is_err());
1819        assert!(result.unwrap_err().contains("Cannot open"));
1820    }
1821
1822    #[test]
1823    fn test_sha256_empty_read_is_hash_of_empty() {
1824        use sha2::{Digest, Sha256};
1825
1826        let dir = std::env::temp_dir();
1827        let path = dir.join("fk_sha256_empty.bin");
1828        std::fs::write(&path, b"hello world extended data").unwrap();
1829
1830        // max_bytes = 0 → nothing is read → hash of empty input
1831        let result = sha256_first_n_bytes(path.to_str().unwrap(), 0).unwrap();
1832        let expected = format!("{:x}", Sha256::digest(b""));
1833        assert_eq!(result, expected);
1834
1835        let _ = std::fs::remove_file(path);
1836    }
1837
1838    // ── write_image (via temp files) ─────────────────────────────────────────
1839
1840    #[test]
1841    fn test_write_image_to_temp_file() {
1842        let dir = std::env::temp_dir();
1843        let img_path = dir.join("fk_write_img.bin");
1844        let dev_path = dir.join("fk_write_dev.bin");
1845
1846        let image_size: u64 = 2 * 1024 * 1024; // 2 MiB
1847        {
1848            let mut f = std::fs::File::create(&img_path).unwrap();
1849            let block: Vec<u8> = (0u8..=255u8).cycle().take(BLOCK_SIZE).collect();
1850            let mut rem = image_size;
1851            while rem > 0 {
1852                let n = rem.min(BLOCK_SIZE as u64) as usize;
1853                f.write_all(&block[..n]).unwrap();
1854                rem -= n as u64;
1855            }
1856        }
1857        std::fs::File::create(&dev_path).unwrap();
1858
1859        let (tx, rx) = make_channel();
1860        let cancel = Arc::new(AtomicBool::new(false));
1861
1862        let result = write_image(
1863            img_path.to_str().unwrap(),
1864            dev_path.to_str().unwrap(),
1865            image_size,
1866            &tx,
1867            &cancel,
1868        );
1869
1870        assert!(result.is_ok(), "write_image failed: {result:?}");
1871
1872        let written = std::fs::read(&dev_path).unwrap();
1873        let original = std::fs::read(&img_path).unwrap();
1874        assert_eq!(written, original, "written data must match image exactly");
1875
1876        let events = drain(&rx);
1877        let has_progress = events
1878            .iter()
1879            .any(|e| matches!(e, FlashEvent::Progress { .. }));
1880        assert!(has_progress, "must emit at least one Progress event");
1881
1882        let _ = std::fs::remove_file(img_path);
1883        let _ = std::fs::remove_file(dev_path);
1884    }
1885
1886    #[test]
1887    fn test_write_image_cancelled_mid_write() {
1888        let dir = std::env::temp_dir();
1889        let img_path = dir.join("fk_cancel_img.bin");
1890        let dev_path = dir.join("fk_cancel_dev.bin");
1891
1892        // Large enough that we definitely hit the cancel check.
1893        let image_size: u64 = 8 * 1024 * 1024; // 8 MiB
1894        {
1895            let mut f = std::fs::File::create(&img_path).unwrap();
1896            let block = vec![0xAAu8; BLOCK_SIZE];
1897            let mut rem = image_size;
1898            while rem > 0 {
1899                let n = rem.min(BLOCK_SIZE as u64) as usize;
1900                f.write_all(&block[..n]).unwrap();
1901                rem -= n as u64;
1902            }
1903        }
1904        std::fs::File::create(&dev_path).unwrap();
1905
1906        let (tx, _rx) = make_channel();
1907        let cancel = Arc::new(AtomicBool::new(true)); // pre-cancelled
1908
1909        let result = write_image(
1910            img_path.to_str().unwrap(),
1911            dev_path.to_str().unwrap(),
1912            image_size,
1913            &tx,
1914            &cancel,
1915        );
1916
1917        assert!(result.is_err());
1918        assert!(
1919            result.unwrap_err().contains("cancelled"),
1920            "error should mention cancellation"
1921        );
1922
1923        let _ = std::fs::remove_file(img_path);
1924        let _ = std::fs::remove_file(dev_path);
1925    }
1926
1927    #[test]
1928    fn test_write_image_missing_image_returns_error() {
1929        let dir = std::env::temp_dir();
1930        let dev_path = dir.join("fk_noimg_dev.bin");
1931        std::fs::File::create(&dev_path).unwrap();
1932
1933        let (tx, _rx) = make_channel();
1934        let cancel = Arc::new(AtomicBool::new(false));
1935
1936        let result = write_image(
1937            "/nonexistent/image.img",
1938            dev_path.to_str().unwrap(),
1939            1024,
1940            &tx,
1941            &cancel,
1942        );
1943
1944        assert!(result.is_err());
1945        assert!(result.unwrap_err().contains("Cannot open image"));
1946
1947        let _ = std::fs::remove_file(dev_path);
1948    }
1949
1950    // ── verify ───────────────────────────────────────────────────────────────
1951
1952    #[test]
1953    fn test_verify_matching_files() {
1954        let dir = std::env::temp_dir();
1955        let img = dir.join("fk_verify_img.bin");
1956        let dev = dir.join("fk_verify_dev.bin");
1957        let data = vec![0xBBu8; 64 * 1024];
1958        std::fs::write(&img, &data).unwrap();
1959        std::fs::write(&dev, &data).unwrap();
1960
1961        let (tx, _rx) = make_channel();
1962        let result = verify(
1963            img.to_str().unwrap(),
1964            dev.to_str().unwrap(),
1965            data.len() as u64,
1966            &tx,
1967        );
1968        assert!(result.is_ok());
1969
1970        let _ = std::fs::remove_file(img);
1971        let _ = std::fs::remove_file(dev);
1972    }
1973
1974    #[test]
1975    fn test_verify_mismatch_returns_error() {
1976        let dir = std::env::temp_dir();
1977        let img = dir.join("fk_mismatch_img.bin");
1978        let dev = dir.join("fk_mismatch_dev.bin");
1979        std::fs::write(&img, vec![0x00u8; 64 * 1024]).unwrap();
1980        std::fs::write(&dev, vec![0xFFu8; 64 * 1024]).unwrap();
1981
1982        let (tx, _rx) = make_channel();
1983        let result = verify(img.to_str().unwrap(), dev.to_str().unwrap(), 64 * 1024, &tx);
1984        assert!(result.is_err());
1985        assert!(result.unwrap_err().contains("Verification failed"));
1986
1987        let _ = std::fs::remove_file(img);
1988        let _ = std::fs::remove_file(dev);
1989    }
1990
1991    #[test]
1992    fn test_verify_only_checks_image_size_bytes() {
1993        let dir = std::env::temp_dir();
1994        let img = dir.join("fk_trunc_img.bin");
1995        let dev = dir.join("fk_trunc_dev.bin");
1996        let image_data = vec![0xCCu8; 32 * 1024];
1997        let mut device_data = image_data.clone();
1998        device_data.extend_from_slice(&[0xDDu8; 32 * 1024]);
1999        std::fs::write(&img, &image_data).unwrap();
2000        std::fs::write(&dev, &device_data).unwrap();
2001
2002        let (tx, _rx) = make_channel();
2003        let result = verify(
2004            img.to_str().unwrap(),
2005            dev.to_str().unwrap(),
2006            image_data.len() as u64,
2007            &tx,
2008        );
2009        assert!(
2010            result.is_ok(),
2011            "should pass when first N bytes match: {result:?}"
2012        );
2013
2014        let _ = std::fs::remove_file(img);
2015        let _ = std::fs::remove_file(dev);
2016    }
2017
2018    // ── flash_pipeline validation ────────────────────────────────────────────
2019
2020    #[test]
2021    fn test_pipeline_rejects_missing_image() {
2022        let (tx, rx) = make_channel();
2023        let cancel = Arc::new(AtomicBool::new(false));
2024        run_pipeline("/nonexistent/image.iso", "/dev/null", tx, cancel);
2025        let events = drain(&rx);
2026        let err = find_error(&events);
2027        assert!(err.is_some(), "must emit an Error event");
2028        assert!(err.unwrap().contains("Image file not found"), "err={err:?}");
2029    }
2030
2031    #[test]
2032    fn test_pipeline_rejects_empty_image() {
2033        let dir = std::env::temp_dir();
2034        let empty = dir.join("fk_empty.img");
2035        std::fs::write(&empty, b"").unwrap();
2036
2037        let (tx, rx) = make_channel();
2038        let cancel = Arc::new(AtomicBool::new(false));
2039        run_pipeline(empty.to_str().unwrap(), "/dev/null", tx, cancel);
2040
2041        let events = drain(&rx);
2042        let err = find_error(&events);
2043        assert!(err.is_some());
2044        assert!(err.unwrap().contains("empty"), "err={err:?}");
2045
2046        let _ = std::fs::remove_file(empty);
2047    }
2048
2049    #[test]
2050    fn test_pipeline_rejects_missing_device() {
2051        let dir = std::env::temp_dir();
2052        let img = dir.join("fk_nodev_img.bin");
2053        std::fs::write(&img, vec![0u8; 1024]).unwrap();
2054
2055        let (tx, rx) = make_channel();
2056        let cancel = Arc::new(AtomicBool::new(false));
2057        run_pipeline(img.to_str().unwrap(), "/nonexistent/device", tx, cancel);
2058
2059        let events = drain(&rx);
2060        let err = find_error(&events);
2061        assert!(err.is_some());
2062        assert!(
2063            err.unwrap().contains("Target device not found"),
2064            "err={err:?}"
2065        );
2066
2067        let _ = std::fs::remove_file(img);
2068    }
2069
2070    /// End-to-end pipeline test using only temp files (no real hardware).
2071    #[test]
2072    fn test_pipeline_end_to_end_temp_files() {
2073        let dir = std::env::temp_dir();
2074        let img = dir.join("fk_e2e_img.bin");
2075        let dev = dir.join("fk_e2e_dev.bin");
2076
2077        let image_data: Vec<u8> = (0u8..=255u8).cycle().take(1024 * 1024).collect();
2078        std::fs::write(&img, &image_data).unwrap();
2079        std::fs::File::create(&dev).unwrap();
2080
2081        let (tx, rx) = make_channel();
2082        let cancel = Arc::new(AtomicBool::new(false));
2083        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2084
2085        let events = drain(&rx);
2086
2087        // Must have seen at least one Progress event.
2088        let has_progress = events
2089            .iter()
2090            .any(|e| matches!(e, FlashEvent::Progress { .. }));
2091        assert!(has_progress, "must emit Progress events");
2092
2093        // Must have passed through the core pipeline stages.
2094        assert!(
2095            has_stage(&events, &FlashStage::Unmounting),
2096            "must emit Unmounting stage"
2097        );
2098        assert!(
2099            has_stage(&events, &FlashStage::Writing),
2100            "must emit Writing stage"
2101        );
2102        assert!(
2103            has_stage(&events, &FlashStage::Syncing),
2104            "must emit Syncing stage"
2105        );
2106
2107        // On temp files the pipeline either completes (Done) or fails after
2108        // the write/verify stage (e.g. BLKRRPART on a regular file).
2109        let has_done = events.iter().any(|e| matches!(e, FlashEvent::Done));
2110        let has_error = events.iter().any(|e| matches!(e, FlashEvent::Error(_)));
2111        assert!(
2112            has_done || has_error,
2113            "pipeline must end with Done or Error"
2114        );
2115
2116        if has_done {
2117            let written = std::fs::read(&dev).unwrap();
2118            assert_eq!(written, image_data, "written data must match image");
2119        } else if let Some(err_msg) = find_error(&events) {
2120            // Error must NOT be from write or verify.
2121            assert!(
2122                !err_msg.contains("Cannot open")
2123                    && !err_msg.contains("Verification failed")
2124                    && !err_msg.contains("Write error"),
2125                "unexpected error: {err_msg}"
2126            );
2127        }
2128
2129        let _ = std::fs::remove_file(img);
2130        let _ = std::fs::remove_file(dev);
2131    }
2132
2133    // ── FlashStage Display ───────────────────────────────────────────────────
2134
2135    #[test]
2136    fn test_flash_stage_display() {
2137        assert!(FlashStage::Writing.to_string().contains("Writing"));
2138        assert!(FlashStage::Syncing.to_string().contains("Flushing"));
2139        assert!(FlashStage::Done.to_string().contains("complete"));
2140        assert!(FlashStage::Failed("oops".into())
2141            .to_string()
2142            .contains("oops"));
2143    }
2144
2145    // ── FlashStage equality ──────────────────────────────────────────────────
2146
2147    #[test]
2148    fn test_flash_stage_eq() {
2149        assert_eq!(FlashStage::Writing, FlashStage::Writing);
2150        assert_ne!(FlashStage::Writing, FlashStage::Syncing);
2151        assert_eq!(
2152            FlashStage::Failed("x".into()),
2153            FlashStage::Failed("x".into())
2154        );
2155        assert_ne!(
2156            FlashStage::Failed("x".into()),
2157            FlashStage::Failed("y".into())
2158        );
2159    }
2160
2161    // ── FlashEvent Clone ─────────────────────────────────────────────────────
2162
2163    #[test]
2164    fn test_flash_event_clone() {
2165        let events = vec![
2166            FlashEvent::Stage(FlashStage::Writing),
2167            FlashEvent::Progress {
2168                bytes_written: 1024,
2169                total_bytes: 4096,
2170                speed_mb_s: 12.5,
2171            },
2172            FlashEvent::Log("hello".into()),
2173            FlashEvent::Done,
2174            FlashEvent::Error("boom".into()),
2175        ];
2176        for e in &events {
2177            let _ = e.clone(); // must not panic
2178        }
2179    }
2180
2181    // ── find_mounted_partitions (platform-neutral contracts) ─────────────────
2182
2183    /// Calling find_mounted_partitions with a device name that almost
2184    /// certainly isn't mounted must return an empty Vec without panicking.
2185    #[test]
2186    fn test_find_mounted_partitions_nonexistent_device_returns_empty() {
2187        // PhysicalDrive999 / sdzzz are both guaranteed not to exist anywhere.
2188        #[cfg(target_os = "windows")]
2189        let result = find_mounted_partitions("PhysicalDrive999", r"\\.\PhysicalDrive999");
2190        #[cfg(not(target_os = "windows"))]
2191        let result = find_mounted_partitions("sdzzz", "/dev/sdzzz");
2192
2193        // Result can be empty or non-empty depending on the OS, but must not panic.
2194        let _ = result;
2195    }
2196
2197    /// find_mounted_partitions must return a Vec (never panic) even when
2198    /// called with an empty device name.
2199    #[test]
2200    fn test_find_mounted_partitions_empty_name_no_panic() {
2201        let result = find_mounted_partitions("", "");
2202        let _ = result;
2203    }
2204
2205    // ── is_partition_of (Windows drive-letter paths are not partitions) ──────
2206
2207    /// On Windows the caller never passes Unix-style paths, so these should
2208    /// all return false (no false positives from the partition-suffix logic).
2209    #[test]
2210    fn test_is_partition_of_windows_style_paths() {
2211        // Windows physical drive paths have no numeric suffix after the name.
2212        assert!(!is_partition_of(r"\\.\PhysicalDrive0", "PhysicalDrive0"));
2213        assert!(!is_partition_of(r"\\.\PhysicalDrive1", "PhysicalDrive0"));
2214    }
2215
2216    // ── sync_device (via pipeline — emits Log event on all platforms) ────────
2217
2218    /// sync_device must emit a "caches flushed" log event regardless of
2219    /// platform.  We test this indirectly via the full pipeline on temp files.
2220    #[test]
2221    fn test_pipeline_emits_syncing_stage() {
2222        let dir = std::env::temp_dir();
2223        let img = dir.join("fk_sync_stage_img.bin");
2224        let dev = dir.join("fk_sync_stage_dev.bin");
2225
2226        let data: Vec<u8> = (0u8..=255).cycle().take(512 * 1024).collect();
2227        std::fs::write(&img, &data).unwrap();
2228        std::fs::File::create(&dev).unwrap();
2229
2230        let (tx, rx) = make_channel();
2231        let cancel = Arc::new(AtomicBool::new(false));
2232        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2233
2234        let events = drain(&rx);
2235        assert!(
2236            has_stage(&events, &FlashStage::Syncing),
2237            "Syncing stage must be emitted on every platform"
2238        );
2239
2240        let _ = std::fs::remove_file(&img);
2241        let _ = std::fs::remove_file(&dev);
2242    }
2243
2244    /// The pipeline must emit the Rereading stage on every platform.
2245    #[test]
2246    fn test_pipeline_emits_rereading_stage() {
2247        let dir = std::env::temp_dir();
2248        let img = dir.join("fk_reread_stage_img.bin");
2249        let dev = dir.join("fk_reread_stage_dev.bin");
2250
2251        let data: Vec<u8> = vec![0xABu8; 256 * 1024];
2252        std::fs::write(&img, &data).unwrap();
2253        std::fs::File::create(&dev).unwrap();
2254
2255        let (tx, rx) = make_channel();
2256        let cancel = Arc::new(AtomicBool::new(false));
2257        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2258
2259        let events = drain(&rx);
2260        assert!(
2261            has_stage(&events, &FlashStage::Rereading),
2262            "Rereading stage must be emitted on every platform"
2263        );
2264
2265        let _ = std::fs::remove_file(&img);
2266        let _ = std::fs::remove_file(&dev);
2267    }
2268
2269    /// The pipeline must emit the Verifying stage on every platform.
2270    #[test]
2271    fn test_pipeline_emits_verifying_stage() {
2272        let dir = std::env::temp_dir();
2273        let img = dir.join("fk_verify_stage_img.bin");
2274        let dev = dir.join("fk_verify_stage_dev.bin");
2275
2276        let data: Vec<u8> = vec![0xCDu8; 256 * 1024];
2277        std::fs::write(&img, &data).unwrap();
2278        std::fs::File::create(&dev).unwrap();
2279
2280        let (tx, rx) = make_channel();
2281        let cancel = Arc::new(AtomicBool::new(false));
2282        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2283
2284        let events = drain(&rx);
2285        assert!(
2286            has_stage(&events, &FlashStage::Verifying),
2287            "Verifying stage must be emitted on every platform"
2288        );
2289
2290        let _ = std::fs::remove_file(&img);
2291        let _ = std::fs::remove_file(&dev);
2292    }
2293
2294    // ── open_device_for_writing error messages ───────────────────────────────
2295
2296    /// Opening a path that does not exist must produce an error that mentions
2297    /// the device path — verified on all platforms.
2298    #[test]
2299    fn test_open_device_for_writing_nonexistent_mentions_path() {
2300        let bad = if cfg!(target_os = "windows") {
2301            r"\\.\PhysicalDrive999".to_string()
2302        } else {
2303            "/nonexistent/fk_bad_device".to_string()
2304        };
2305
2306        // open_device_for_writing is private; exercise it via write_image.
2307        let dir = std::env::temp_dir();
2308        let img = dir.join("fk_open_err_img.bin");
2309        std::fs::write(&img, vec![1u8; 512]).unwrap();
2310
2311        let (tx, _rx) = make_channel();
2312        let cancel = Arc::new(AtomicBool::new(false));
2313        let result = write_image(img.to_str().unwrap(), &bad, 512, &tx, &cancel);
2314
2315        assert!(result.is_err(), "must fail for nonexistent device");
2316        // The error string should mention the device path.
2317        assert!(
2318            result.as_ref().unwrap_err().contains("PhysicalDrive999")
2319                || result.as_ref().unwrap_err().contains("fk_bad_device")
2320                || result.as_ref().unwrap_err().contains("Cannot open"),
2321            "error should reference the bad path: {:?}",
2322            result
2323        );
2324
2325        let _ = std::fs::remove_file(&img);
2326    }
2327
2328    // ── sync_device emits a log message ─────────────────────────────────────
2329
2330    /// sync_device must emit at least one FlashEvent::Log containing the
2331    /// word "flushed" or "flush" on every platform.
2332    #[test]
2333    fn test_sync_device_emits_log() {
2334        let dir = std::env::temp_dir();
2335        let dev = dir.join("fk_sync_log_dev.bin");
2336        std::fs::File::create(&dev).unwrap();
2337
2338        let (tx, rx) = make_channel();
2339        sync_device(dev.to_str().unwrap(), &tx);
2340
2341        let events = drain(&rx);
2342        let has_flush_log = events.iter().any(|e| {
2343            if let FlashEvent::Log(msg) = e {
2344                let lower = msg.to_lowercase();
2345                lower.contains("flush") || lower.contains("cache")
2346            } else {
2347                false
2348            }
2349        });
2350        assert!(
2351            has_flush_log,
2352            "sync_device must emit a flush/cache log event"
2353        );
2354
2355        let _ = std::fs::remove_file(&dev);
2356    }
2357
2358    // ── reread_partition_table emits a log message ───────────────────────────
2359
2360    /// reread_partition_table must emit at least one FlashEvent::Log on every
2361    /// platform — either a success message or a warning.
2362    #[test]
2363    fn test_reread_partition_table_emits_log() {
2364        let dir = std::env::temp_dir();
2365        let dev = dir.join("fk_reread_log_dev.bin");
2366        std::fs::File::create(&dev).unwrap();
2367
2368        let (tx, rx) = make_channel();
2369        reread_partition_table(dev.to_str().unwrap(), &tx);
2370
2371        let events = drain(&rx);
2372        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2373        assert!(
2374            has_log,
2375            "reread_partition_table must emit at least one Log event"
2376        );
2377
2378        let _ = std::fs::remove_file(&dev);
2379    }
2380
2381    // ── unmount_device emits a log message ───────────────────────────────────
2382
2383    /// unmount_device on a temp-file path (which is never mounted) must emit
2384    /// the "no mounted partitions" log without panicking on any platform.
2385    #[test]
2386    fn test_unmount_device_no_partitions_emits_log() {
2387        let dir = std::env::temp_dir();
2388        let dev = dir.join("fk_unmount_log_dev.bin");
2389        std::fs::File::create(&dev).unwrap();
2390
2391        let path_str = dev.to_str().unwrap();
2392        let (tx, rx) = make_channel();
2393        unmount_device(path_str, &tx);
2394
2395        let events = drain(&rx);
2396        // Must emit at least one Log event (either "no partitions" or a warning).
2397        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2398        assert!(has_log, "unmount_device must emit at least one Log event");
2399
2400        let _ = std::fs::remove_file(&dev);
2401    }
2402
2403    // ── Pipeline all-stages ordering ─────────────────────────────────────────
2404
2405    /// The pipeline must emit stages in the documented order:
2406    /// Unmounting → Writing → Syncing → Rereading → Verifying.
2407    #[test]
2408    fn test_pipeline_stage_ordering() {
2409        let dir = std::env::temp_dir();
2410        let img = dir.join("fk_order_img.bin");
2411        let dev = dir.join("fk_order_dev.bin");
2412
2413        let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
2414        std::fs::write(&img, &data).unwrap();
2415        std::fs::File::create(&dev).unwrap();
2416
2417        let (tx, rx) = make_channel();
2418        let cancel = Arc::new(AtomicBool::new(false));
2419        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2420
2421        let events = drain(&rx);
2422
2423        // Collect all Stage events in order.
2424        let stages: Vec<&FlashStage> = events
2425            .iter()
2426            .filter_map(|e| {
2427                if let FlashEvent::Stage(s) = e {
2428                    Some(s)
2429                } else {
2430                    None
2431                }
2432            })
2433            .collect();
2434
2435        // Verify the mandatory stages appear and in correct relative order.
2436        let pos = |target: &FlashStage| {
2437            stages
2438                .iter()
2439                .position(|s| *s == target)
2440                .unwrap_or(usize::MAX)
2441        };
2442
2443        let unmounting = pos(&FlashStage::Unmounting);
2444        let writing = pos(&FlashStage::Writing);
2445        let syncing = pos(&FlashStage::Syncing);
2446        let rereading = pos(&FlashStage::Rereading);
2447        let verifying = pos(&FlashStage::Verifying);
2448
2449        assert!(unmounting < writing, "Unmounting must precede Writing");
2450        assert!(writing < syncing, "Writing must precede Syncing");
2451        assert!(syncing < rereading, "Syncing must precede Rereading");
2452        assert!(rereading < verifying, "Rereading must precede Verifying");
2453
2454        let _ = std::fs::remove_file(&img);
2455        let _ = std::fs::remove_file(&dev);
2456    }
2457
2458    // ── Linux-specific tests ─────────────────────────────────────────────────
2459
2460    /// On Linux, find_mounted_partitions reads /proc/mounts.
2461    /// Verify it returns a Vec without panicking (live test).
2462    #[test]
2463    #[cfg(target_os = "linux")]
2464    fn test_find_mounted_partitions_linux_no_panic() {
2465        // sda is unlikely to be mounted in CI, but the function must not panic.
2466        let result = find_mounted_partitions("sda", "/dev/sda");
2467        let _ = result;
2468    }
2469
2470    /// On Linux, /proc/mounts always contains at least one line (the root
2471    /// filesystem), so reading a clearly-mounted device (e.g. something at /)
2472    /// should find entries.
2473    #[test]
2474    #[cfg(target_os = "linux")]
2475    fn test_find_mounted_partitions_linux_reads_proc_mounts() {
2476        // We can't know exactly which device is at /, but we can verify
2477        // that the function can parse whatever /proc/mounts contains.
2478        let content = std::fs::read_to_string("/proc/mounts").unwrap_or_default();
2479        // If /proc/mounts is non-empty there must be at least one entry parseable.
2480        if !content.is_empty() {
2481            // Parse first real /dev/ device from /proc/mounts and verify
2482            // find_mounted_partitions does not panic on it.
2483            if let Some(line) = content.lines().find(|l| l.starts_with("/dev/")) {
2484                if let Some(dev) = line.split_whitespace().next() {
2485                    let name = std::path::Path::new(dev)
2486                        .file_name()
2487                        .map(|n| n.to_string_lossy().to_string())
2488                        .unwrap_or_default();
2489                    let _ = find_mounted_partitions(&name, dev);
2490                }
2491            }
2492        }
2493    }
2494
2495    /// On Linux, do_unmount on a path that is not mounted must not panic.
2496    /// EINVAL (not a mount point) and ENOENT (path doesn't exist) are both
2497    /// silenced — they are normal/harmless conditions, not warnings to surface
2498    /// to the user.
2499    #[test]
2500    #[cfg(target_os = "linux")]
2501    fn test_do_unmount_not_mounted_does_not_panic() {
2502        let (tx, rx) = make_channel();
2503        do_unmount("/dev/fk_nonexistent_part", &tx);
2504        let events = drain(&rx);
2505        // EINVAL / ENOENT must NOT produce a warning log — they are expected
2506        // silent outcomes when a partition is already detached or never mounted.
2507        let has_warning = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2508        assert!(
2509            !has_warning,
2510            "do_unmount must not emit a warning for EINVAL/ENOENT: {events:?}"
2511        );
2512    }
2513
2514    // ── macOS-specific tests ─────────────────────────────────────────────────
2515
2516    /// On macOS, do_unmount with a bogus partition path must emit a warning
2517    /// log (diskutil will fail) but must not panic.
2518    #[test]
2519    #[cfg(target_os = "macos")]
2520    fn test_do_unmount_macos_bad_path_emits_warning() {
2521        let (tx, rx) = make_channel();
2522        do_unmount("/dev/fk_nonexistent_part", &tx);
2523        let events = drain(&rx);
2524        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2525        assert!(has_log, "do_unmount must emit a Log event on failure");
2526    }
2527
2528    /// On macOS, find_mounted_partitions reads /proc/mounts (which doesn't
2529    /// exist) or falls back gracefully — must not panic.
2530    #[test]
2531    #[cfg(target_os = "macos")]
2532    fn test_find_mounted_partitions_macos_no_panic() {
2533        let result = find_mounted_partitions("disk2", "/dev/disk2");
2534        let _ = result;
2535    }
2536
2537    /// On macOS, reread_partition_table calls diskutil — must emit a log even
2538    /// if the path is a temp file (diskutil will fail gracefully).
2539    #[test]
2540    #[cfg(target_os = "macos")]
2541    fn test_reread_partition_table_macos_emits_log() {
2542        let dir = std::env::temp_dir();
2543        let dev = dir.join("fk_macos_reread_dev.bin");
2544        std::fs::File::create(&dev).unwrap();
2545
2546        let (tx, rx) = make_channel();
2547        reread_partition_table(dev.to_str().unwrap(), &tx);
2548
2549        let events = drain(&rx);
2550        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2551        assert!(has_log, "reread_partition_table must emit a log on macOS");
2552
2553        let _ = std::fs::remove_file(&dev);
2554    }
2555
2556    // ── Windows-specific pipeline tests ─────────────────────────────────────
2557
2558    /// On Windows, find_mounted_partitions delegates to
2559    /// windows::find_volumes_on_physical_drive — verify it does not panic
2560    /// for a well-formed but nonexistent drive.
2561    #[test]
2562    #[cfg(target_os = "windows")]
2563    fn test_find_mounted_partitions_windows_nonexistent() {
2564        let result = find_mounted_partitions("PhysicalDrive999", r"\\.\PhysicalDrive999");
2565        assert!(
2566            result.is_empty(),
2567            "nonexistent physical drive should have no volumes"
2568        );
2569    }
2570
2571    /// On Windows, do_unmount on a bad volume path must emit a warning log
2572    /// and not panic.
2573    #[test]
2574    #[cfg(target_os = "windows")]
2575    fn test_do_unmount_windows_bad_volume_emits_log() {
2576        let (tx, rx) = make_channel();
2577        do_unmount(r"\\.\Z99:", &tx);
2578        let events = drain(&rx);
2579        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2580        assert!(has_log, "do_unmount on bad volume must emit a Log event");
2581    }
2582
2583    /// On Windows, sync_device on a nonexistent physical drive path should
2584    /// emit a warning log (FlushFileBuffers will fail) but not panic.
2585    #[test]
2586    #[cfg(target_os = "windows")]
2587    fn test_sync_device_windows_bad_path_no_panic() {
2588        let (tx, rx) = make_channel();
2589        sync_device(r"\\.\PhysicalDrive999", &tx);
2590        let events = drain(&rx);
2591        // Must emit at least one log event (either flush warning or the
2592        // normal "caches flushed" message).
2593        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2594        assert!(has_log, "sync_device must emit a Log event on Windows");
2595    }
2596
2597    /// On Windows, reread_partition_table on a nonexistent drive must emit
2598    /// a warning log and not panic.
2599    #[test]
2600    #[cfg(target_os = "windows")]
2601    fn test_reread_partition_table_windows_bad_path_no_panic() {
2602        let (tx, rx) = make_channel();
2603        reread_partition_table(r"\\.\PhysicalDrive999", &tx);
2604        let events = drain(&rx);
2605        let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2606        assert!(
2607            has_log,
2608            "reread_partition_table must emit a Log event on Windows"
2609        );
2610    }
2611
2612    /// On Windows, open_device_for_writing on a nonexistent physical drive
2613    /// must return an Err containing a meaningful message.
2614    #[test]
2615    #[cfg(target_os = "windows")]
2616    fn test_open_device_for_writing_windows_access_denied_message() {
2617        let dir = std::env::temp_dir();
2618        let img = dir.join("fk_win_open_img.bin");
2619        std::fs::write(&img, vec![1u8; 512]).unwrap();
2620
2621        let (tx, _rx) = make_channel();
2622        let cancel = Arc::new(AtomicBool::new(false));
2623        let result = write_image(
2624            img.to_str().unwrap(),
2625            r"\\.\PhysicalDrive999",
2626            512,
2627            &tx,
2628            &cancel,
2629        );
2630
2631        assert!(result.is_err());
2632        let msg = result.unwrap_err();
2633        // Must mention either the path, or give a clear error.
2634        assert!(
2635            msg.contains("PhysicalDrive999")
2636                || msg.contains("Access denied")
2637                || msg.contains("Cannot open"),
2638            "error must be descriptive: {msg}"
2639        );
2640
2641        let _ = std::fs::remove_file(&img);
2642    }
2643    // ── FlashStage::progress_floor ────────────────────────────────────────────
2644
2645    #[test]
2646    fn flash_stage_progress_floor_syncing() {
2647        assert!((FlashStage::Syncing.progress_floor() - 0.80).abs() < f32::EPSILON);
2648    }
2649
2650    #[test]
2651    fn flash_stage_progress_floor_rereading() {
2652        assert!((FlashStage::Rereading.progress_floor() - 0.88).abs() < f32::EPSILON);
2653    }
2654
2655    #[test]
2656    fn flash_stage_progress_floor_verifying() {
2657        assert!((FlashStage::Verifying.progress_floor() - 0.92).abs() < f32::EPSILON);
2658    }
2659
2660    #[test]
2661    fn flash_stage_progress_floor_other_stages_are_zero() {
2662        for stage in [
2663            FlashStage::Starting,
2664            FlashStage::Unmounting,
2665            FlashStage::Writing,
2666            FlashStage::Done,
2667        ] {
2668            assert_eq!(
2669                stage.progress_floor(),
2670                0.0,
2671                "{stage:?} should have floor 0.0"
2672            );
2673        }
2674    }
2675
2676    // ── verify_overall_progress ───────────────────────────────────────────────
2677
2678    #[test]
2679    fn verify_overall_image_phase_start() {
2680        assert_eq!(verify_overall_progress("image", 0.0), 0.0);
2681    }
2682
2683    #[test]
2684    fn verify_overall_image_phase_end() {
2685        assert!((verify_overall_progress("image", 1.0) - 0.5).abs() < f32::EPSILON);
2686    }
2687
2688    #[test]
2689    fn verify_overall_image_phase_midpoint() {
2690        assert!((verify_overall_progress("image", 0.5) - 0.25).abs() < f32::EPSILON);
2691    }
2692
2693    #[test]
2694    fn verify_overall_device_phase_start() {
2695        assert!((verify_overall_progress("device", 0.0) - 0.5).abs() < f32::EPSILON);
2696    }
2697
2698    #[test]
2699    fn verify_overall_device_phase_end() {
2700        assert!((verify_overall_progress("device", 1.0) - 1.0).abs() < f32::EPSILON);
2701    }
2702
2703    #[test]
2704    fn verify_overall_device_phase_midpoint() {
2705        assert!((verify_overall_progress("device", 0.5) - 0.75).abs() < f32::EPSILON);
2706    }
2707
2708    #[test]
2709    fn verify_overall_unknown_phase_treated_as_device() {
2710        // Any phase that is not "image" falls into the device branch.
2711        assert!((verify_overall_progress("other", 0.0) - 0.5).abs() < f32::EPSILON);
2712    }
2713
2714    // ── check_device_not_busy ────────────────────────────────────────────────
2715
2716    /// A synthetic EBUSY error must produce the "already in use" message.
2717    #[test]
2718    #[cfg(target_os = "linux")]
2719    fn check_device_not_busy_ebusy_returns_error() {
2720        let err = check_device_not_busy_with("/dev/sdz", |_| {
2721            Err(std::io::Error::from_raw_os_error(libc::EBUSY))
2722        });
2723        assert!(err.is_err(), "EBUSY must be reported as an error");
2724        let msg = err.unwrap_err();
2725        assert!(
2726            msg.contains("already in use"),
2727            "error must mention 'already in use': {msg}"
2728        );
2729        assert!(
2730            msg.contains("/dev/sdz"),
2731            "error must include the device path: {msg}"
2732        );
2733        assert!(
2734            msg.contains("another flash operation"),
2735            "error must hint at another flash operation: {msg}"
2736        );
2737    }
2738
2739    /// A non-EBUSY error (e.g. EPERM) must be silently ignored — it will be
2740    /// handled properly when the device is opened for writing.
2741    #[test]
2742    #[cfg(target_os = "linux")]
2743    fn check_device_not_busy_eperm_is_ignored() {
2744        let result = check_device_not_busy_with("/dev/sdz", |_| {
2745            Err(std::io::Error::from_raw_os_error(libc::EPERM))
2746        });
2747        assert!(
2748            result.is_ok(),
2749            "EPERM must be silently ignored, got: {result:?}"
2750        );
2751    }
2752
2753    /// EACCES must also be silently ignored.
2754    #[test]
2755    #[cfg(target_os = "linux")]
2756    fn check_device_not_busy_eacces_is_ignored() {
2757        let result = check_device_not_busy_with("/dev/sdz", |_| {
2758            Err(std::io::Error::from_raw_os_error(libc::EACCES))
2759        });
2760        assert!(
2761            result.is_ok(),
2762            "EACCES must be silently ignored, got: {result:?}"
2763        );
2764    }
2765
2766    /// When the open succeeds the function must return Ok.
2767    #[test]
2768    #[cfg(target_os = "linux")]
2769    fn check_device_not_busy_success_returns_ok() {
2770        let result = check_device_not_busy_with("/dev/sdz", |_| Ok(()));
2771        assert!(result.is_ok(), "successful open must return Ok");
2772    }
2773
2774    /// On a real regular file (not a block device) O_EXCL never returns EBUSY,
2775    /// so the pipeline must not emit the "already in use" error for temp files.
2776    #[test]
2777    #[cfg(target_os = "linux")]
2778    fn check_device_not_busy_regular_file_never_ebusy() {
2779        let f = tempfile::NamedTempFile::new().expect("tempfile");
2780        let result = check_device_not_busy(f.path().to_str().unwrap());
2781        assert!(
2782            result.is_ok(),
2783            "regular file must never trigger the EBUSY guard: {result:?}"
2784        );
2785    }
2786
2787    /// The pipeline must emit Unmounting *before* it could ever hit the busy
2788    /// check — i.e. the stage order must be Unmounting → Writing, not a
2789    /// premature Error before Unmounting.
2790    #[test]
2791    #[cfg(target_os = "linux")]
2792    fn pipeline_unmounting_precedes_busy_check_in_stage_stream() {
2793        let dir = tempfile::tempdir().expect("tempdir");
2794        let img = dir.path().join("img.bin");
2795        let dev = dir.path().join("dev.bin");
2796
2797        let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
2798        std::fs::write(&img, &data).unwrap();
2799        std::fs::File::create(&dev).unwrap();
2800
2801        let (tx, rx) = make_channel();
2802        let cancel = Arc::new(AtomicBool::new(false));
2803        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2804
2805        let events = drain(&rx);
2806
2807        // Must never see the "already in use" error on a temp file.
2808        if let Some(msg) = find_error(&events) {
2809            assert!(
2810                !msg.contains("already in use"),
2811                "temp file pipeline must not emit a false-positive busy error: {msg}"
2812            );
2813        }
2814
2815        // Unmounting must appear before Writing (busy check sits between them).
2816        let stages: Vec<&FlashStage> = events
2817            .iter()
2818            .filter_map(|e| {
2819                if let FlashEvent::Stage(s) = e {
2820                    Some(s)
2821                } else {
2822                    None
2823                }
2824            })
2825            .collect();
2826
2827        let pos_unmounting = stages.iter().position(|s| **s == FlashStage::Unmounting);
2828        let pos_writing = stages.iter().position(|s| **s == FlashStage::Writing);
2829
2830        assert!(
2831            pos_unmounting.is_some(),
2832            "pipeline must emit Unmounting stage"
2833        );
2834        assert!(pos_writing.is_some(), "pipeline must emit Writing stage");
2835        assert!(
2836            pos_unmounting.unwrap() < pos_writing.unwrap(),
2837            "Unmounting must precede Writing (busy check lives between them)"
2838        );
2839    }
2840
2841    // ── Non-Linux pipeline stage ordering ────────────────────────────────────
2842
2843    /// On macOS, `O_EXCL` on a block device does not produce `EBUSY` for
2844    /// mounted partitions — the kernel uses a different locking model. The
2845    /// busy-device scenario is handled by `open_device_for_writing` returning
2846    /// `EBUSY` at write time. We therefore have no pre-flight guard on macOS,
2847    /// but we still verify the stage ordering and the absence of a spurious
2848    /// busy error on a temp file.
2849    #[test]
2850    #[cfg(target_os = "macos")]
2851    fn pipeline_unmounting_precedes_writing_macos() {
2852        let dir = tempfile::tempdir().expect("tempdir");
2853        let img = dir.path().join("img.bin");
2854        let dev = dir.path().join("dev.bin");
2855
2856        let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
2857        std::fs::write(&img, &data).unwrap();
2858        std::fs::File::create(&dev).unwrap();
2859
2860        let (tx, rx) = make_channel();
2861        let cancel = Arc::new(AtomicBool::new(false));
2862        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2863
2864        let events = drain(&rx);
2865
2866        // Must never see "already in use" on a plain temp file.
2867        if let Some(msg) = find_error(&events) {
2868            assert!(
2869                !msg.contains("already in use"),
2870                "macOS pipeline must not emit a false-positive busy error: {msg}"
2871            );
2872        }
2873
2874        let stages: Vec<&FlashStage> = events
2875            .iter()
2876            .filter_map(|e| {
2877                if let FlashEvent::Stage(s) = e {
2878                    Some(s)
2879                } else {
2880                    None
2881                }
2882            })
2883            .collect();
2884
2885        let pos_unmounting = stages.iter().position(|s| **s == FlashStage::Unmounting);
2886        let pos_writing = stages.iter().position(|s| **s == FlashStage::Writing);
2887
2888        assert!(
2889            pos_unmounting.is_some(),
2890            "pipeline must emit Unmounting stage"
2891        );
2892        assert!(pos_writing.is_some(), "pipeline must emit Writing stage");
2893        assert!(
2894            pos_unmounting.unwrap() < pos_writing.unwrap(),
2895            "Unmounting must precede Writing on macOS"
2896        );
2897    }
2898
2899    /// On Windows, device-busy is caught by `ERROR_SHARING_VIOLATION` (error
2900    /// code 32) inside `open_device_for_writing` — there is no pre-flight
2901    /// `O_EXCL` guard. Verify stage ordering and no false busy error on a
2902    /// temp file.
2903    #[test]
2904    #[cfg(target_os = "windows")]
2905    fn pipeline_unmounting_precedes_writing_windows() {
2906        let dir = tempfile::tempdir().expect("tempdir");
2907        let img = dir.path().join("img.bin");
2908        let dev = dir.path().join("dev.bin");
2909
2910        let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
2911        std::fs::write(&img, &data).unwrap();
2912        std::fs::File::create(&dev).unwrap();
2913
2914        let (tx, rx) = make_channel();
2915        let cancel = Arc::new(AtomicBool::new(false));
2916        run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2917
2918        let events = drain(&rx);
2919
2920        // Must never see "already in use" on a plain temp file.
2921        if let Some(msg) = find_error(&events) {
2922            assert!(
2923                !msg.contains("already in use"),
2924                "Windows pipeline must not emit a false-positive busy error: {msg}"
2925            );
2926        }
2927
2928        let stages: Vec<&FlashStage> = events
2929            .iter()
2930            .filter_map(|e| {
2931                if let FlashEvent::Stage(s) = e {
2932                    Some(s)
2933                } else {
2934                    None
2935                }
2936            })
2937            .collect();
2938
2939        let pos_unmounting = stages.iter().position(|s| **s == FlashStage::Unmounting);
2940        let pos_writing = stages.iter().position(|s| **s == FlashStage::Writing);
2941
2942        assert!(
2943            pos_unmounting.is_some(),
2944            "pipeline must emit Unmounting stage"
2945        );
2946        assert!(pos_writing.is_some(), "pipeline must emit Writing stage");
2947        assert!(
2948            pos_unmounting.unwrap() < pos_writing.unwrap(),
2949            "Unmounting must precede Writing on Windows"
2950        );
2951    }
2952
2953    /// Verify that `open_device_for_writing` on Windows produces a descriptive
2954    /// message for `ERROR_SHARING_VIOLATION` (win32 error 32) — the Windows
2955    /// equivalent of EBUSY.
2956    #[test]
2957    #[cfg(target_os = "windows")]
2958    fn open_device_for_writing_sharing_violation_message() {
2959        // We cannot easily synthesise error 32 without a real locked device,
2960        // but we can confirm the non-existent-path error is descriptive and
2961        // does NOT accidentally claim "already in use".
2962        let dir = tempfile::tempdir().expect("tempdir");
2963        let img = dir.path().join("img.bin");
2964        let nonexistent_dev = dir.path().join("no_such_device");
2965
2966        let data: Vec<u8> = vec![0u8; 512];
2967        std::fs::write(&img, &data).unwrap();
2968
2969        let (tx, rx) = make_channel();
2970        let cancel = Arc::new(AtomicBool::new(false));
2971        run_pipeline(
2972            img.to_str().unwrap(),
2973            nonexistent_dev.to_str().unwrap(),
2974            tx,
2975            cancel,
2976        );
2977
2978        let events = drain(&rx);
2979        // The pipeline must fail (device not found or cannot open).
2980        let has_error = events.iter().any(|e| matches!(e, FlashEvent::Error(_)));
2981        assert!(has_error, "pipeline must fail for a non-existent device");
2982
2983        if let Some(msg) = find_error(&events) {
2984            assert!(
2985                !msg.contains("already in use"),
2986                "non-existent device must not emit a spurious 'already in use' message: {msg}"
2987            );
2988        }
2989    }
2990}