1#[cfg(unix)]
34use nix::libc;
35use std::io::{self, Read, Write};
36use std::path::Path;
37use std::sync::{
38 atomic::{AtomicBool, Ordering},
39 mpsc, Arc, OnceLock,
40};
41use std::time::{Duration, Instant};
42
43const BLOCK_SIZE: usize = 4 * 1024 * 1024;
49
50const PROGRESS_INTERVAL: Duration = Duration::from_millis(400);
52
53static REAL_UID: OnceLock<u32> = OnceLock::new();
62
63pub fn set_real_uid(uid: u32) {
68 let _ = REAL_UID.set(uid);
69}
70
71pub fn is_privileged() -> bool {
77 #[cfg(unix)]
78 {
79 nix::unistd::geteuid().is_root()
80 }
81 #[cfg(not(unix))]
82 {
83 false
84 }
85}
86
87#[cfg(unix)]
108pub fn reexec_as_root() {
109 if is_running_under_test_harness() {
123 return;
124 }
125
126 #[cfg(test)]
129 return;
130
131 #[cfg(not(test))]
132 reexec_as_root_inner();
133}
134
135#[cfg(unix)]
141fn is_running_under_test_harness() -> bool {
142 if std::env::var("FLASHKRAFT_NO_REEXEC").is_ok() {
144 return true;
145 }
146
147 if std::env::var("NEXTEST_TEST_FILTER").is_ok() {
149 return true;
150 }
151
152 if let Ok(exe) = std::env::current_exe() {
160 let path_str = exe.to_string_lossy();
161 if path_str.contains("/deps/") {
163 return true;
164 }
165 if path_str.contains("\\deps\\") {
167 return true;
168 }
169 }
170
171 false
172}
173
174#[cfg(all(unix, not(test)))]
175fn reexec_as_root_inner() {
176 use std::ffi::CString;
177
178 if std::env::var("FLASHKRAFT_ESCALATED").as_deref() == Ok("1") {
180 return;
181 }
182
183 let self_exe = match std::fs::read_link("/proc/self/exe").or_else(|_| std::env::current_exe()) {
184 Ok(p) => p,
185 Err(_) => return,
186 };
187 let self_exe_str = match self_exe.to_str() {
188 Some(s) => s.to_owned(),
189 None => return,
190 };
191
192 let extra_args: Vec<String> = std::env::args().skip(1).collect();
193
194 std::env::set_var("FLASHKRAFT_ESCALATED", "1");
196
197 if unix_which_exists("pkexec") {
199 let mut argv: Vec<CString> = Vec::new();
200 argv.push(unix_c_str("pkexec"));
201 argv.push(unix_c_str(&self_exe_str));
202 for a in &extra_args {
203 argv.push(unix_c_str(a));
204 }
205 let _ = nix::unistd::execvp(&unix_c_str("pkexec"), &argv);
206 }
207
208 if unix_which_exists("sudo") {
210 let mut argv: Vec<CString> = Vec::new();
211 argv.push(unix_c_str("sudo"));
212 argv.push(unix_c_str("-E")); argv.push(unix_c_str(&self_exe_str));
214 for a in &extra_args {
215 argv.push(unix_c_str(a));
216 }
217 let _ = nix::unistd::execvp(&unix_c_str("sudo"), &argv);
218 }
219
220 std::env::remove_var("FLASHKRAFT_ESCALATED");
222}
223
224#[cfg(not(unix))]
226pub fn reexec_as_root() {}
227
228#[cfg(all(unix, not(test)))]
230fn unix_which_exists(name: &str) -> bool {
231 use std::os::unix::fs::PermissionsExt;
232 if let Ok(path_var) = std::env::var("PATH") {
233 for dir in path_var.split(':') {
234 let candidate = std::path::Path::new(dir).join(name);
235 if let Ok(meta) = std::fs::metadata(&candidate) {
236 if meta.is_file() && meta.permissions().mode() & 0o111 != 0 {
237 return true;
238 }
239 }
240 }
241 }
242 false
243}
244
245#[cfg(all(unix, not(test)))]
247fn unix_c_str(s: &str) -> std::ffi::CString {
248 let sanitised: Vec<u8> = s.bytes().map(|b| if b == 0 { b'?' } else { b }).collect();
249 std::ffi::CString::new(sanitised).unwrap_or_else(|_| std::ffi::CString::new("?").unwrap())
250}
251
252#[cfg(unix)]
254fn real_uid() -> nix::unistd::Uid {
255 let raw = REAL_UID
256 .get()
257 .copied()
258 .unwrap_or_else(|| nix::unistd::getuid().as_raw());
259 nix::unistd::Uid::from_raw(raw)
260}
261
262#[derive(Debug, Clone, PartialEq, Eq)]
268pub enum FlashStage {
269 Starting,
271 Unmounting,
273 Writing,
275 Syncing,
277 Rereading,
279 Verifying,
281 Done,
283 Failed(String),
285}
286
287impl std::fmt::Display for FlashStage {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self {
290 FlashStage::Starting => write!(f, "Starting…"),
291 FlashStage::Unmounting => write!(f, "Unmounting partitions…"),
292 FlashStage::Writing => write!(f, "Writing image to device…"),
293 FlashStage::Syncing => write!(f, "Flushing write buffers…"),
294 FlashStage::Rereading => write!(f, "Refreshing partition table…"),
295 FlashStage::Verifying => write!(f, "Verifying written data…"),
296 FlashStage::Done => write!(f, "Flash complete!"),
297 FlashStage::Failed(m) => write!(f, "Failed: {m}"),
298 }
299 }
300}
301
302impl FlashStage {
303 pub fn progress_floor(&self) -> f32 {
316 match self {
317 FlashStage::Syncing => 0.80,
318 FlashStage::Rereading => 0.88,
319 FlashStage::Verifying => 0.92,
320 _ => 0.0,
321 }
322 }
323}
324
325pub fn verify_overall_progress(phase: &str, pass_fraction: f32) -> f32 {
343 if phase == "image" {
344 pass_fraction * 0.5
345 } else {
346 0.5 + pass_fraction * 0.5
347 }
348}
349
350#[derive(Debug, Clone)]
355pub enum FlashEvent {
356 Stage(FlashStage),
358 Progress {
360 bytes_written: u64,
361 total_bytes: u64,
362 speed_mb_s: f32,
363 },
364 VerifyProgress {
376 phase: &'static str,
377 bytes_read: u64,
378 total_bytes: u64,
379 speed_mb_s: f32,
380 },
381 Log(String),
383 Done,
385 Error(String),
387}
388
389#[derive(Debug, Clone)]
412pub enum FlashUpdate {
413 Progress {
417 progress: f32,
418 bytes_written: u64,
419 speed_mb_s: f32,
420 },
421 VerifyProgress {
427 phase: &'static str,
428 overall: f32,
429 bytes_read: u64,
430 total_bytes: u64,
431 speed_mb_s: f32,
432 },
433 Message(String),
435 Completed,
437 Failed(String),
439}
440
441impl From<FlashEvent> for FlashUpdate {
442 fn from(event: FlashEvent) -> Self {
450 match event {
451 FlashEvent::Progress {
452 bytes_written,
453 total_bytes,
454 speed_mb_s,
455 } => {
456 let progress = if total_bytes > 0 {
457 (bytes_written as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
458 } else {
459 0.0
460 };
461 FlashUpdate::Progress {
462 progress,
463 bytes_written,
464 speed_mb_s,
465 }
466 }
467
468 FlashEvent::VerifyProgress {
469 phase,
470 bytes_read,
471 total_bytes,
472 speed_mb_s,
473 } => {
474 let pass_fraction = if total_bytes > 0 {
475 (bytes_read as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
476 } else {
477 0.0
478 };
479 let overall = verify_overall_progress(phase, pass_fraction);
480 FlashUpdate::VerifyProgress {
481 phase,
482 overall,
483 bytes_read,
484 total_bytes,
485 speed_mb_s,
486 }
487 }
488
489 FlashEvent::Stage(stage) => FlashUpdate::Message(stage.to_string()),
490 FlashEvent::Log(msg) => FlashUpdate::Message(msg),
491 FlashEvent::Done => FlashUpdate::Completed,
492 FlashEvent::Error(e) => FlashUpdate::Failed(e),
493 }
494 }
495}
496
497pub fn run_pipeline(
513 image_path: &str,
514 device_path: &str,
515 tx: mpsc::Sender<FlashEvent>,
516 cancel: Arc<AtomicBool>,
517) {
518 if let Err(e) = flash_pipeline(image_path, device_path, &tx, cancel) {
519 let _ = tx.send(FlashEvent::Error(e));
520 }
521}
522
523fn send(tx: &mpsc::Sender<FlashEvent>, event: FlashEvent) {
528 let _ = tx.send(event);
530}
531
532fn flash_pipeline(
533 image_path: &str,
534 device_path: &str,
535 tx: &mpsc::Sender<FlashEvent>,
536 cancel: Arc<AtomicBool>,
537) -> Result<(), String> {
538 if !Path::new(image_path).is_file() {
540 return Err(format!("Image file not found: {image_path}"));
541 }
542
543 if !Path::new(device_path).exists() {
544 return Err(format!("Target device not found: {device_path}"));
545 }
546
547 #[cfg(target_os = "linux")]
549 reject_partition_node(device_path)?;
550
551 let image_size = std::fs::metadata(image_path)
552 .map_err(|e| format!("Cannot stat image: {e}"))?
553 .len();
554
555 if image_size == 0 {
556 return Err("Image file is empty".to_string());
557 }
558
559 #[cfg(target_os = "linux")]
565 {
566 use std::os::unix::fs::OpenOptionsExt;
567 let busy = std::fs::OpenOptions::new()
568 .read(true)
569 .custom_flags(libc::O_EXCL)
570 .open(device_path);
571 if let Err(e) = busy {
572 if e.raw_os_error() == Some(libc::EBUSY) {
573 return Err(format!(
574 "Device '{device_path}' is already in use by another process.\n\
575 Is another flash operation already running?"
576 ));
577 }
578 }
581 }
582
583 send(tx, FlashEvent::Stage(FlashStage::Unmounting));
585 unmount_device(device_path, tx);
586
587 send(tx, FlashEvent::Stage(FlashStage::Writing));
589 send(
590 tx,
591 FlashEvent::Log(format!(
592 "Writing {image_size} bytes from {image_path} → {device_path}"
593 )),
594 );
595 write_image(image_path, device_path, image_size, tx, &cancel)?;
596
597 send(tx, FlashEvent::Stage(FlashStage::Syncing));
599 sync_device(device_path, tx);
600
601 send(tx, FlashEvent::Stage(FlashStage::Rereading));
603 reread_partition_table(device_path, tx);
604
605 send(tx, FlashEvent::Stage(FlashStage::Verifying));
607 verify(image_path, device_path, image_size, tx)?;
608
609 send(tx, FlashEvent::Done);
611 Ok(())
612}
613
614#[cfg(target_os = "linux")]
619fn reject_partition_node(device_path: &str) -> Result<(), String> {
620 let dev_name = Path::new(device_path)
621 .file_name()
622 .map(|n| n.to_string_lossy().to_string())
623 .unwrap_or_default();
624
625 let is_partition = {
626 let bytes = dev_name.as_bytes();
627 !bytes.is_empty() && bytes[bytes.len() - 1].is_ascii_digit() && {
628 let stem = dev_name.trim_end_matches(|c: char| c.is_ascii_digit());
629 stem.ends_with('p')
630 || (!stem.is_empty()
631 && !stem.ends_with(|c: char| c.is_ascii_digit())
632 && stem.chars().any(|c| c.is_ascii_alphabetic()))
633 }
634 };
635
636 if is_partition {
637 let whole = dev_name.trim_end_matches(|c: char| c.is_ascii_digit() || c == 'p');
638 return Err(format!(
639 "Refusing to write to partition node '{device_path}'. \
640 Select the whole-disk device (e.g. /dev/{whole}) instead."
641 ));
642 }
643
644 Ok(())
645}
646
647fn open_device_for_writing(device_path: &str) -> Result<std::fs::File, String> {
654 #[cfg(unix)]
655 {
656 use nix::unistd::seteuid;
657
658 let escalated = seteuid(nix::unistd::Uid::from_raw(0)).is_ok();
665
666 let result = std::fs::OpenOptions::new()
667 .write(true)
668 .open(device_path)
669 .map_err(|e| {
670 let raw = e.raw_os_error().unwrap_or(0);
671 if raw == libc::EACCES || raw == libc::EPERM {
672 if escalated {
673 format!(
674 "Permission denied opening '{device_path}'.\n\
675 Even with setuid-root the device refused access — \
676 check that the device exists and is not in use."
677 )
678 } else {
679 format!(
680 "Permission denied opening '{device_path}'.\n\
681 FlashKraft needs root access to write to block devices.\n\
682 Install setuid-root so it can escalate automatically:\n\
683 sudo chown root:root /usr/bin/flashkraft\n\
684 sudo chmod u+s /usr/bin/flashkraft"
685 )
686 }
687 } else if raw == libc::EBUSY {
688 format!(
689 "Device '{device_path}' is busy. \
690 Ensure all partitions are unmounted before flashing."
691 )
692 } else {
693 format!("Cannot open device '{device_path}' for writing: {e}")
694 }
695 });
696
697 if escalated {
699 let _ = seteuid(real_uid());
700 }
701
702 result
703 }
704
705 #[cfg(not(unix))]
706 {
707 std::fs::OpenOptions::new()
708 .write(true)
709 .open(device_path)
710 .map_err(|e| {
711 let raw = e.raw_os_error().unwrap_or(0);
712 if raw == 5 || raw == 1314 {
714 format!(
715 "Access denied opening '{device_path}'.\n\
716 FlashKraft must be run as Administrator on Windows.\n\
717 Right-click the application and choose \
718 'Run as administrator'."
719 )
720 } else if raw == 32 {
721 format!(
723 "Device '{device_path}' is in use by another process.\n\
724 Close any applications using the drive and try again."
725 )
726 } else {
727 format!("Cannot open device '{device_path}' for writing: {e}")
728 }
729 })
730 }
731}
732
733fn unmount_device(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
738 let device_name = Path::new(device_path)
739 .file_name()
740 .map(|n| n.to_string_lossy().to_string())
741 .unwrap_or_default();
742
743 let partitions = find_mounted_partitions(&device_name, device_path);
744
745 if partitions.is_empty() {
746 send(tx, FlashEvent::Log("No mounted partitions found".into()));
747 } else {
748 for partition in &partitions {
749 send(tx, FlashEvent::Log(format!("Unmounting {partition}")));
750 do_unmount(partition, tx);
751 }
752 }
753}
754
755fn find_mounted_partitions(
763 #[cfg_attr(target_os = "windows", allow(unused_variables))] device_name: &str,
764 device_path: &str,
765) -> Vec<String> {
766 #[cfg(not(target_os = "windows"))]
767 {
768 let mounts = std::fs::read_to_string("/proc/mounts")
769 .or_else(|_| std::fs::read_to_string("/proc/self/mounts"))
770 .unwrap_or_default();
771
772 let mut mount_points = Vec::new();
773 for line in mounts.lines() {
774 let mut fields = line.split_whitespace();
775 let dev = match fields.next() {
776 Some(d) => d,
777 None => continue,
778 };
779 let mount_point = match fields.next() {
782 Some(m) => m,
783 None => continue,
784 };
785 if dev == device_path || is_partition_of(dev, device_name) {
786 mount_points.push(mount_point.to_string());
787 }
788 }
789 mount_points
790 }
791
792 #[cfg(target_os = "windows")]
793 {
794 windows::find_volumes_on_physical_drive(device_path)
795 }
796}
797
798#[cfg(not(target_os = "windows"))]
799fn is_partition_of(dev: &str, device_name: &str) -> bool {
800 let dev_base = Path::new(dev)
802 .file_name()
803 .map(|n| n.to_string_lossy())
804 .unwrap_or_default();
805
806 if !dev_base.starts_with(device_name) {
807 return false;
808 }
809 let suffix = &dev_base[device_name.len()..];
810 if suffix.is_empty() {
811 return false;
812 }
813 let first = suffix.chars().next().unwrap();
814 first.is_ascii_digit() || (first == 'p' && suffix.len() > 1)
815}
816
817#[cfg(target_os = "linux")]
820fn which_exists(name: &str) -> bool {
821 use std::os::unix::fs::PermissionsExt;
822 std::env::var("PATH")
823 .unwrap_or_default()
824 .split(':')
825 .any(|dir| {
826 let p = std::path::Path::new(dir).join(name);
827 std::fs::metadata(&p)
828 .map(|m| m.is_file() && m.permissions().mode() & 0o111 != 0)
829 .unwrap_or(false)
830 })
831}
832
833fn do_unmount(partition: &str, tx: &mpsc::Sender<FlashEvent>) {
834 #[cfg(target_os = "linux")]
835 {
836 use nix::unistd::seteuid;
837 use std::ffi::CString;
838
839 if which_exists("udisksctl") {
845 let result = std::process::Command::new("udisksctl")
848 .args(["unmount", "--no-user-interaction", "-b", partition])
849 .stdout(std::process::Stdio::null())
850 .stderr(std::process::Stdio::null())
851 .spawn();
852
853 let udisks_ok = match result {
854 Ok(mut child) => {
855 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
857 loop {
858 match child.try_wait() {
859 Ok(Some(status)) => break status.success(),
860 Ok(None) if std::time::Instant::now() < deadline => {
861 std::thread::sleep(std::time::Duration::from_millis(100));
862 }
863 _ => {
864 let _ = child.kill();
866 send(
867 tx,
868 FlashEvent::Log(
869 "udisksctl timed out — falling back to umount2".into(),
870 ),
871 );
872 break false;
873 }
874 }
875 }
876 }
877 Err(_) => false,
878 };
879
880 if udisks_ok {
881 send(
882 tx,
883 FlashEvent::Log(format!("Unmounted {partition} via udisksctl")),
884 );
885 return;
886 }
887 }
888
889 let _ = seteuid(nix::unistd::Uid::from_raw(0));
893
894 if let Ok(c_path) = CString::new(partition) {
895 let ret = unsafe { libc::umount2(c_path.as_ptr(), libc::MNT_DETACH) };
896 if ret != 0 {
897 let raw = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
898 match raw {
899 libc::EINVAL => {}
901 libc::ENOENT => {}
903 _ => {
904 let err = std::io::Error::from_raw_os_error(raw);
905 send(
906 tx,
907 FlashEvent::Log(format!(
908 "Warning — could not unmount {partition}: {err}"
909 )),
910 );
911 }
912 }
913 }
914 }
915
916 let _ = seteuid(real_uid());
917 }
918
919 #[cfg(target_os = "macos")]
920 {
921 let out = std::process::Command::new("diskutil")
922 .args(["unmount", partition])
923 .output();
924 if let Ok(o) = out {
925 if !o.status.success() {
926 send(
927 tx,
928 FlashEvent::Log(format!("Warning — diskutil unmount {partition} failed")),
929 );
930 }
931 }
932 }
933
934 #[cfg(target_os = "windows")]
937 {
938 match windows::lock_and_dismount_volume(partition) {
939 Ok(()) => send(
940 tx,
941 FlashEvent::Log(format!("Dismounted volume {partition}")),
942 ),
943 Err(e) => send(
944 tx,
945 FlashEvent::Log(format!("Warning — could not dismount {partition}: {e}")),
946 ),
947 }
948 }
949}
950
951fn write_image(
956 image_path: &str,
957 device_path: &str,
958 image_size: u64,
959 tx: &mpsc::Sender<FlashEvent>,
960 cancel: &Arc<AtomicBool>,
961) -> Result<(), String> {
962 let image_file =
963 std::fs::File::open(image_path).map_err(|e| format!("Cannot open image: {e}"))?;
964
965 let device_file = open_device_for_writing(device_path)?;
966
967 let mut reader = io::BufReader::with_capacity(BLOCK_SIZE, image_file);
968 let mut writer = io::BufWriter::with_capacity(BLOCK_SIZE, device_file);
969 let mut buf = vec![0u8; BLOCK_SIZE];
970
971 let mut bytes_written: u64 = 0;
972 let start = Instant::now();
973 let mut last_report = Instant::now();
974
975 loop {
976 if cancel.load(Ordering::SeqCst) {
978 return Err("Flash operation cancelled by user".to_string());
979 }
980
981 let n = reader
982 .read(&mut buf)
983 .map_err(|e| format!("Read error on image: {e}"))?;
984
985 if n == 0 {
986 break; }
988
989 writer
990 .write_all(&buf[..n])
991 .map_err(|e| format!("Write error on device: {e}"))?;
992
993 bytes_written += n as u64;
994
995 let now = Instant::now();
996 if now.duration_since(last_report) >= PROGRESS_INTERVAL || bytes_written >= image_size {
997 let elapsed_s = now.duration_since(start).as_secs_f32();
998 let speed_mb_s = if elapsed_s > 0.001 {
999 (bytes_written as f32 / (1024.0 * 1024.0)) / elapsed_s
1000 } else {
1001 0.0
1002 };
1003
1004 send(
1005 tx,
1006 FlashEvent::Progress {
1007 bytes_written,
1008 total_bytes: image_size,
1009 speed_mb_s,
1010 },
1011 );
1012 last_report = now;
1013 }
1014 }
1015
1016 writer
1018 .flush()
1019 .map_err(|e| format!("Buffer flush error: {e}"))?;
1020
1021 #[cfg_attr(not(unix), allow(unused_variables))]
1023 let device_file = writer
1024 .into_inner()
1025 .map_err(|e| format!("BufWriter error: {e}"))?;
1026
1027 #[cfg(unix)]
1031 {
1032 use std::os::unix::io::AsRawFd;
1033 let fd = device_file.as_raw_fd();
1034 let ret = unsafe { libc::fsync(fd) };
1035 if ret != 0 {
1036 let err = std::io::Error::last_os_error();
1037 return Err(format!(
1038 "fsync failed on '{device_path}': {err} — \
1039 data may not have been fully written to the device"
1040 ));
1041 }
1042 }
1043
1044 let elapsed_s = start.elapsed().as_secs_f32();
1046 let speed_mb_s = if elapsed_s > 0.001 {
1047 (bytes_written as f32 / (1024.0 * 1024.0)) / elapsed_s
1048 } else {
1049 0.0
1050 };
1051 send(
1052 tx,
1053 FlashEvent::Progress {
1054 bytes_written,
1055 total_bytes: image_size,
1056 speed_mb_s,
1057 },
1058 );
1059
1060 send(tx, FlashEvent::Log("Image write complete".into()));
1061 Ok(())
1062}
1063
1064fn sync_device(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1069 #[cfg(unix)]
1070 if let Ok(f) = std::fs::OpenOptions::new().write(true).open(device_path) {
1071 use std::os::unix::io::AsRawFd;
1072 let fd = f.as_raw_fd();
1073 #[cfg(target_os = "linux")]
1074 unsafe {
1075 libc::fdatasync(fd);
1076 }
1077 #[cfg(not(target_os = "linux"))]
1078 unsafe {
1079 libc::fsync(fd);
1080 }
1081 drop(f);
1082 }
1083
1084 #[cfg(target_os = "linux")]
1085 unsafe {
1086 libc::sync();
1087 }
1088
1089 #[cfg(target_os = "windows")]
1092 {
1093 match windows::flush_device_buffers(device_path) {
1094 Ok(()) => {}
1095 Err(e) => send(
1096 tx,
1097 FlashEvent::Log(format!(
1098 "Warning — FlushFileBuffers on '{device_path}' failed: {e}"
1099 )),
1100 ),
1101 }
1102 }
1103
1104 send(tx, FlashEvent::Log("Write-back caches flushed".into()));
1105}
1106
1107#[cfg(target_os = "linux")]
1112fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1113 use nix::ioctl_none;
1114 use std::os::unix::io::AsRawFd;
1115
1116 ioctl_none!(blkrrpart, 0x12, 95);
1117
1118 std::thread::sleep(Duration::from_millis(500));
1120
1121 match std::fs::OpenOptions::new().write(true).open(device_path) {
1122 Ok(f) => {
1123 let result = unsafe { blkrrpart(f.as_raw_fd()) };
1124 match result {
1125 Ok(_) => send(
1126 tx,
1127 FlashEvent::Log("Kernel partition table refreshed".into()),
1128 ),
1129 Err(e) => send(
1130 tx,
1131 FlashEvent::Log(format!(
1132 "Warning — BLKRRPART ioctl failed \
1133 (device may not be partitioned): {e}"
1134 )),
1135 ),
1136 }
1137 }
1138 Err(e) => send(
1139 tx,
1140 FlashEvent::Log(format!(
1141 "Warning — could not open device for BLKRRPART: {e}"
1142 )),
1143 ),
1144 }
1145}
1146
1147#[cfg(target_os = "macos")]
1148fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1149 let _ = std::process::Command::new("diskutil")
1150 .args(["rereadPartitionTable", device_path])
1151 .output();
1152 send(
1153 tx,
1154 FlashEvent::Log("Partition table refresh requested (macOS)".into()),
1155 );
1156}
1157
1158#[cfg(target_os = "windows")]
1161fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1162 std::thread::sleep(Duration::from_millis(500));
1164
1165 match windows::update_disk_properties(device_path) {
1166 Ok(()) => send(
1167 tx,
1168 FlashEvent::Log("Partition table refreshed (IOCTL_DISK_UPDATE_PROPERTIES)".into()),
1169 ),
1170 Err(e) => send(
1171 tx,
1172 FlashEvent::Log(format!(
1173 "Warning — IOCTL_DISK_UPDATE_PROPERTIES failed: {e}"
1174 )),
1175 ),
1176 }
1177}
1178
1179#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
1180fn reread_partition_table(_device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
1181 send(
1182 tx,
1183 FlashEvent::Log("Partition table refresh not supported on this platform".into()),
1184 );
1185}
1186
1187fn verify(
1192 image_path: &str,
1193 device_path: &str,
1194 image_size: u64,
1195 tx: &mpsc::Sender<FlashEvent>,
1196) -> Result<(), String> {
1197 send(
1198 tx,
1199 FlashEvent::Log("Computing SHA-256 of source image".into()),
1200 );
1201 let image_hash = sha256_with_progress(image_path, image_size, "image", tx)?;
1202
1203 send(
1204 tx,
1205 FlashEvent::Log(format!(
1206 "Reading back {image_size} bytes from device for verification"
1207 )),
1208 );
1209 let device_hash = sha256_with_progress(device_path, image_size, "device", tx)?;
1210
1211 if image_hash != device_hash {
1212 return Err(format!(
1213 "Verification failed — data mismatch \
1214 (image={image_hash} device={device_hash})"
1215 ));
1216 }
1217
1218 send(
1219 tx,
1220 FlashEvent::Log(format!("Verification passed ({image_hash})")),
1221 );
1222 Ok(())
1223}
1224
1225fn sha256_with_progress(
1232 path: &str,
1233 max_bytes: u64,
1234 phase: &'static str,
1235 tx: &mpsc::Sender<FlashEvent>,
1236) -> Result<String, String> {
1237 use sha2::{Digest, Sha256};
1238
1239 let file =
1240 std::fs::File::open(path).map_err(|e| format!("Cannot open {path} for hashing: {e}"))?;
1241
1242 let mut hasher = Sha256::new();
1243 let mut reader = io::BufReader::with_capacity(BLOCK_SIZE, file);
1244 let mut buf = vec![0u8; BLOCK_SIZE];
1245 let mut remaining = max_bytes;
1246 let mut bytes_read: u64 = 0;
1247
1248 let start = Instant::now();
1249 let mut last_report = Instant::now();
1250
1251 while remaining > 0 {
1252 let to_read = (remaining as usize).min(buf.len());
1253 let n = reader
1254 .read(&mut buf[..to_read])
1255 .map_err(|e| format!("Read error while hashing {path}: {e}"))?;
1256 if n == 0 {
1257 break;
1258 }
1259 hasher.update(&buf[..n]);
1260 bytes_read += n as u64;
1261 remaining -= n as u64;
1262
1263 let now = Instant::now();
1264 if now.duration_since(last_report) >= PROGRESS_INTERVAL || remaining == 0 {
1265 let elapsed_s = now.duration_since(start).as_secs_f32();
1266 let speed_mb_s = if elapsed_s > 0.001 {
1267 (bytes_read as f32 / (1024.0 * 1024.0)) / elapsed_s
1268 } else {
1269 0.0
1270 };
1271 send(
1272 tx,
1273 FlashEvent::VerifyProgress {
1274 phase,
1275 bytes_read,
1276 total_bytes: max_bytes,
1277 speed_mb_s,
1278 },
1279 );
1280 last_report = now;
1281 }
1282 }
1283
1284 Ok(format!("{:x}", hasher.finalize()))
1285}
1286
1287#[cfg(test)]
1289fn sha256_first_n_bytes(path: &str, max_bytes: u64) -> Result<String, String> {
1290 let (tx, _rx) = mpsc::channel();
1291 sha256_with_progress(path, max_bytes, "image", &tx)
1292}
1293
1294#[cfg(target_os = "windows")]
1312mod windows {
1313 use windows_sys::Win32::{
1316 Foundation::{
1317 CloseHandle, FALSE, GENERIC_READ, GENERIC_WRITE, HANDLE, INVALID_HANDLE_VALUE,
1318 },
1319 Storage::FileSystem::{
1320 CreateFileW, FlushFileBuffers, FILE_FLAG_WRITE_THROUGH, FILE_SHARE_READ,
1321 FILE_SHARE_WRITE, OPEN_EXISTING,
1322 },
1323 System::{
1324 Ioctl::{FSCTL_DISMOUNT_VOLUME, FSCTL_LOCK_VOLUME, IOCTL_DISK_UPDATE_PROPERTIES},
1325 IO::DeviceIoControl,
1326 },
1327 };
1328
1329 fn to_wide(s: &str) -> Vec<u16> {
1333 use std::os::windows::ffi::OsStrExt;
1334 std::ffi::OsStr::new(s)
1335 .encode_wide()
1336 .chain(std::iter::once(0))
1337 .collect()
1338 }
1339
1340 fn open_device_handle(path: &str, access: u32) -> Result<HANDLE, String> {
1345 let wide = to_wide(path);
1346 let handle = unsafe {
1347 CreateFileW(
1348 wide.as_ptr(),
1349 access,
1350 FILE_SHARE_READ | FILE_SHARE_WRITE,
1351 std::ptr::null(),
1352 OPEN_EXISTING,
1353 FILE_FLAG_WRITE_THROUGH,
1354 std::ptr::null_mut(),
1355 )
1356 };
1357 if handle == INVALID_HANDLE_VALUE {
1358 Err(format!(
1359 "Cannot open device '{}': {}",
1360 path,
1361 std::io::Error::last_os_error()
1362 ))
1363 } else {
1364 Ok(handle)
1365 }
1366 }
1367
1368 fn device_ioctl(handle: HANDLE, code: u32) -> Result<(), String> {
1372 let mut bytes_returned: u32 = 0;
1373 let ok = unsafe {
1374 DeviceIoControl(
1375 handle,
1376 code,
1377 std::ptr::null(), 0,
1379 std::ptr::null_mut(), 0,
1381 &mut bytes_returned,
1382 std::ptr::null_mut(), )
1384 };
1385 if ok == FALSE {
1386 Err(format!("{}", std::io::Error::last_os_error()))
1387 } else {
1388 Ok(())
1389 }
1390 }
1391
1392 pub fn find_volumes_on_physical_drive(physical_drive: &str) -> Vec<String> {
1407 use windows_sys::Win32::{
1408 Storage::FileSystem::GetLogicalDriveStringsW,
1409 System::Ioctl::{IOCTL_STORAGE_GET_DEVICE_NUMBER, STORAGE_DEVICE_NUMBER},
1410 };
1411
1412 let target_index: u32 = physical_drive
1414 .to_ascii_lowercase()
1415 .trim_start_matches(r"\\.\physicaldrive")
1416 .parse()
1417 .unwrap_or(u32::MAX);
1418
1419 let mut buf = vec![0u16; 512];
1421 let len = unsafe { GetLogicalDriveStringsW(buf.len() as u32, buf.as_mut_ptr()) };
1422 if len == 0 || len > buf.len() as u32 {
1423 return Vec::new();
1424 }
1425
1426 let drive_letters: Vec<String> = buf[..len as usize]
1428 .split(|&c| c == 0)
1429 .filter(|s| !s.is_empty())
1430 .map(|s| {
1431 let letter: String = std::char::from_u32(s[0] as u32)
1434 .map(|c| c.to_string())
1435 .unwrap_or_default();
1436 format!(r"\\.\{}:", letter)
1437 })
1438 .collect();
1439
1440 let mut matching = Vec::new();
1441
1442 for vol_path in &drive_letters {
1443 let wide = to_wide(vol_path);
1444 let handle = unsafe {
1445 CreateFileW(
1446 wide.as_ptr(),
1447 GENERIC_READ,
1448 FILE_SHARE_READ | FILE_SHARE_WRITE,
1449 std::ptr::null(),
1450 OPEN_EXISTING,
1451 0,
1452 std::ptr::null_mut(),
1453 )
1454 };
1455 if handle == INVALID_HANDLE_VALUE {
1456 continue;
1457 }
1458
1459 let mut dev_num = STORAGE_DEVICE_NUMBER {
1460 DeviceType: 0,
1461 DeviceNumber: u32::MAX,
1462 PartitionNumber: 0,
1463 };
1464 let mut bytes_returned: u32 = 0;
1465
1466 let ok = unsafe {
1467 DeviceIoControl(
1468 handle,
1469 IOCTL_STORAGE_GET_DEVICE_NUMBER,
1470 std::ptr::null(),
1471 0,
1472 &mut dev_num as *mut _ as *mut _,
1473 std::mem::size_of::<STORAGE_DEVICE_NUMBER>() as u32,
1474 &mut bytes_returned,
1475 std::ptr::null_mut(),
1476 )
1477 };
1478
1479 unsafe { CloseHandle(handle) };
1480
1481 if ok != FALSE && dev_num.DeviceNumber == target_index {
1482 matching.push(vol_path.clone());
1483 }
1484 }
1485
1486 matching
1487 }
1488
1489 pub fn lock_and_dismount_volume(volume_path: &str) -> Result<(), String> {
1501 let handle = open_device_handle(volume_path, GENERIC_READ | GENERIC_WRITE)?;
1502
1503 let lock_result = device_ioctl(handle, FSCTL_LOCK_VOLUME);
1506 if let Err(ref e) = lock_result {
1507 eprintln!(
1509 "[flash] FSCTL_LOCK_VOLUME on '{volume_path}' failed ({e}); \
1510 attempting dismount anyway"
1511 );
1512 }
1513
1514 let dismount_result = device_ioctl(handle, FSCTL_DISMOUNT_VOLUME);
1516
1517 unsafe { CloseHandle(handle) };
1518
1519 lock_result.and(dismount_result)
1520 }
1521
1522 pub fn flush_device_buffers(device_path: &str) -> Result<(), String> {
1525 let handle = open_device_handle(device_path, GENERIC_WRITE)?;
1526 let ok = unsafe { FlushFileBuffers(handle) };
1527 unsafe { CloseHandle(handle) };
1528 if ok == FALSE {
1529 Err(format!("{}", std::io::Error::last_os_error()))
1530 } else {
1531 Ok(())
1532 }
1533 }
1534
1535 pub fn update_disk_properties(device_path: &str) -> Result<(), String> {
1538 let handle = open_device_handle(device_path, GENERIC_READ | GENERIC_WRITE)?;
1539 let result = device_ioctl(handle, IOCTL_DISK_UPDATE_PROPERTIES);
1540 unsafe { CloseHandle(handle) };
1541 result
1542 }
1543
1544 #[cfg(test)]
1547 mod tests {
1548 use super::*;
1549
1550 #[test]
1552 fn test_to_wide_null_terminated() {
1553 let wide = to_wide("ABC");
1554 assert_eq!(wide.last(), Some(&0u16), "must be null-terminated");
1555 assert_eq!(&wide[..3], &[b'A' as u16, b'B' as u16, b'C' as u16]);
1556 }
1557
1558 #[test]
1560 fn test_to_wide_empty() {
1561 let wide = to_wide("");
1562 assert_eq!(wide, vec![0u16]);
1563 }
1564
1565 #[test]
1567 fn test_open_device_handle_bad_path_returns_error() {
1568 let result = open_device_handle(r"\\.\NonExistentDevice999", GENERIC_READ);
1569 assert!(result.is_err(), "expected error for nonexistent device");
1570 }
1571
1572 #[test]
1574 fn test_flush_device_buffers_bad_path() {
1575 let result = flush_device_buffers(r"\\.\PhysicalDrive999");
1576 assert!(result.is_err());
1577 }
1578
1579 #[test]
1581 fn test_update_disk_properties_bad_path() {
1582 let result = update_disk_properties(r"\\.\PhysicalDrive999");
1583 assert!(result.is_err());
1584 }
1585
1586 #[test]
1588 fn test_lock_and_dismount_bad_path() {
1589 let result = lock_and_dismount_volume(r"\\.\Z99:");
1590 assert!(result.is_err());
1591 }
1592
1593 #[test]
1596 fn test_find_volumes_bad_path_no_panic() {
1597 let result = find_volumes_on_physical_drive("not-a-valid-path");
1598 let _ = result;
1600 }
1601
1602 #[test]
1605 fn test_find_volumes_nonexistent_drive_returns_empty() {
1606 let result = find_volumes_on_physical_drive(r"\\.\PhysicalDrive999");
1607 assert!(
1608 result.is_empty(),
1609 "expected no volumes for PhysicalDrive999"
1610 );
1611 }
1612 }
1613}
1614
1615#[cfg(test)]
1620mod tests {
1621 use super::*;
1622 use std::io::Write;
1623 use std::sync::mpsc;
1624
1625 fn make_channel() -> (mpsc::Sender<FlashEvent>, mpsc::Receiver<FlashEvent>) {
1626 mpsc::channel()
1627 }
1628
1629 fn drain(rx: &mpsc::Receiver<FlashEvent>) -> Vec<FlashEvent> {
1630 let mut events = Vec::new();
1631 while let Ok(e) = rx.try_recv() {
1632 events.push(e);
1633 }
1634 events
1635 }
1636
1637 fn has_stage(events: &[FlashEvent], stage: &FlashStage) -> bool {
1638 events
1639 .iter()
1640 .any(|e| matches!(e, FlashEvent::Stage(s) if s == stage))
1641 }
1642
1643 fn find_error(events: &[FlashEvent]) -> Option<&str> {
1644 events.iter().find_map(|e| {
1645 if let FlashEvent::Error(msg) = e {
1646 Some(msg.as_str())
1647 } else {
1648 None
1649 }
1650 })
1651 }
1652
1653 #[test]
1656 fn test_is_privileged_returns_bool() {
1657 let first = is_privileged();
1659 let second = is_privileged();
1660 assert_eq!(first, second, "is_privileged must be deterministic");
1661 }
1662
1663 #[test]
1664 fn test_reexec_as_root_does_not_panic_when_already_escalated() {
1665 std::env::set_var("FLASHKRAFT_ESCALATED", "1");
1668 reexec_as_root(); std::env::remove_var("FLASHKRAFT_ESCALATED");
1670 }
1671
1672 #[test]
1673 fn test_set_real_uid_stores_value() {
1674 set_real_uid(1000);
1677 }
1678
1679 #[test]
1682 #[cfg(not(target_os = "windows"))]
1683 fn test_is_partition_of_sda() {
1684 assert!(is_partition_of("/dev/sda1", "sda"));
1685 assert!(is_partition_of("/dev/sda2", "sda"));
1686 assert!(!is_partition_of("/dev/sdb1", "sda"));
1687 assert!(!is_partition_of("/dev/sda", "sda"));
1688 }
1689
1690 #[test]
1691 #[cfg(not(target_os = "windows"))]
1692 fn test_is_partition_of_nvme() {
1693 assert!(is_partition_of("/dev/nvme0n1p1", "nvme0n1"));
1694 assert!(is_partition_of("/dev/nvme0n1p2", "nvme0n1"));
1695 assert!(!is_partition_of("/dev/nvme0n1", "nvme0n1"));
1696 }
1697
1698 #[test]
1699 #[cfg(not(target_os = "windows"))]
1700 fn test_is_partition_of_mmcblk() {
1701 assert!(is_partition_of("/dev/mmcblk0p1", "mmcblk0"));
1702 assert!(!is_partition_of("/dev/mmcblk0", "mmcblk0"));
1703 }
1704
1705 #[test]
1706 #[cfg(not(target_os = "windows"))]
1707 fn test_is_partition_of_no_false_prefix_match() {
1708 assert!(!is_partition_of("/dev/sda1", "sd"));
1709 }
1710
1711 #[test]
1714 #[cfg(target_os = "linux")]
1715 fn test_reject_partition_node_sda1() {
1716 let dir = std::env::temp_dir();
1717 let img = dir.join("fk_reject_img.bin");
1718 std::fs::write(&img, vec![0u8; 1024]).unwrap();
1719
1720 let result = reject_partition_node("/dev/sda1");
1721 assert!(result.is_err());
1722 assert!(result.unwrap_err().contains("Refusing"));
1723
1724 let _ = std::fs::remove_file(img);
1725 }
1726
1727 #[test]
1728 #[cfg(target_os = "linux")]
1729 fn test_reject_partition_node_nvme() {
1730 let result = reject_partition_node("/dev/nvme0n1p1");
1731 assert!(result.is_err());
1732 assert!(result.unwrap_err().contains("Refusing"));
1733 }
1734
1735 #[test]
1736 #[cfg(target_os = "linux")]
1737 fn test_reject_partition_node_accepts_whole_disk() {
1738 let result = reject_partition_node("/dev/sdb");
1741 assert!(result.is_ok(), "whole-disk node should not be rejected");
1742 }
1743
1744 #[test]
1747 fn test_find_mounted_partitions_parses_proc_mounts_format() {
1748 let result = find_mounted_partitions("sda", "/dev/sda");
1751 let _ = result; }
1753
1754 #[test]
1757 fn test_sha256_full_file() {
1758 use sha2::{Digest, Sha256};
1759
1760 let dir = std::env::temp_dir();
1761 let path = dir.join("fk_sha256_full.bin");
1762 let data: Vec<u8> = (0u8..=255u8).cycle().take(4096).collect();
1763 std::fs::write(&path, &data).unwrap();
1764
1765 let result = sha256_first_n_bytes(path.to_str().unwrap(), data.len() as u64).unwrap();
1766 let expected = format!("{:x}", Sha256::digest(&data));
1767 assert_eq!(result, expected);
1768
1769 let _ = std::fs::remove_file(path);
1770 }
1771
1772 #[test]
1773 fn test_sha256_partial() {
1774 use sha2::{Digest, Sha256};
1775
1776 let dir = std::env::temp_dir();
1777 let path = dir.join("fk_sha256_partial.bin");
1778 let data: Vec<u8> = (0u8..=255u8).cycle().take(8192).collect();
1779 std::fs::write(&path, &data).unwrap();
1780
1781 let n = 4096u64;
1782 let result = sha256_first_n_bytes(path.to_str().unwrap(), n).unwrap();
1783 let expected = format!("{:x}", Sha256::digest(&data[..n as usize]));
1784 assert_eq!(result, expected);
1785
1786 let _ = std::fs::remove_file(path);
1787 }
1788
1789 #[test]
1790 fn test_sha256_nonexistent_returns_error() {
1791 let result = sha256_first_n_bytes("/nonexistent/path.bin", 1024);
1792 assert!(result.is_err());
1793 assert!(result.unwrap_err().contains("Cannot open"));
1794 }
1795
1796 #[test]
1797 fn test_sha256_empty_read_is_hash_of_empty() {
1798 use sha2::{Digest, Sha256};
1799
1800 let dir = std::env::temp_dir();
1801 let path = dir.join("fk_sha256_empty.bin");
1802 std::fs::write(&path, b"hello world extended data").unwrap();
1803
1804 let result = sha256_first_n_bytes(path.to_str().unwrap(), 0).unwrap();
1806 let expected = format!("{:x}", Sha256::digest(b""));
1807 assert_eq!(result, expected);
1808
1809 let _ = std::fs::remove_file(path);
1810 }
1811
1812 #[test]
1815 fn test_write_image_to_temp_file() {
1816 let dir = std::env::temp_dir();
1817 let img_path = dir.join("fk_write_img.bin");
1818 let dev_path = dir.join("fk_write_dev.bin");
1819
1820 let image_size: u64 = 2 * 1024 * 1024; {
1822 let mut f = std::fs::File::create(&img_path).unwrap();
1823 let block: Vec<u8> = (0u8..=255u8).cycle().take(BLOCK_SIZE).collect();
1824 let mut rem = image_size;
1825 while rem > 0 {
1826 let n = rem.min(BLOCK_SIZE as u64) as usize;
1827 f.write_all(&block[..n]).unwrap();
1828 rem -= n as u64;
1829 }
1830 }
1831 std::fs::File::create(&dev_path).unwrap();
1832
1833 let (tx, rx) = make_channel();
1834 let cancel = Arc::new(AtomicBool::new(false));
1835
1836 let result = write_image(
1837 img_path.to_str().unwrap(),
1838 dev_path.to_str().unwrap(),
1839 image_size,
1840 &tx,
1841 &cancel,
1842 );
1843
1844 assert!(result.is_ok(), "write_image failed: {result:?}");
1845
1846 let written = std::fs::read(&dev_path).unwrap();
1847 let original = std::fs::read(&img_path).unwrap();
1848 assert_eq!(written, original, "written data must match image exactly");
1849
1850 let events = drain(&rx);
1851 let has_progress = events
1852 .iter()
1853 .any(|e| matches!(e, FlashEvent::Progress { .. }));
1854 assert!(has_progress, "must emit at least one Progress event");
1855
1856 let _ = std::fs::remove_file(img_path);
1857 let _ = std::fs::remove_file(dev_path);
1858 }
1859
1860 #[test]
1861 fn test_write_image_cancelled_mid_write() {
1862 let dir = std::env::temp_dir();
1863 let img_path = dir.join("fk_cancel_img.bin");
1864 let dev_path = dir.join("fk_cancel_dev.bin");
1865
1866 let image_size: u64 = 8 * 1024 * 1024; {
1869 let mut f = std::fs::File::create(&img_path).unwrap();
1870 let block = vec![0xAAu8; BLOCK_SIZE];
1871 let mut rem = image_size;
1872 while rem > 0 {
1873 let n = rem.min(BLOCK_SIZE as u64) as usize;
1874 f.write_all(&block[..n]).unwrap();
1875 rem -= n as u64;
1876 }
1877 }
1878 std::fs::File::create(&dev_path).unwrap();
1879
1880 let (tx, _rx) = make_channel();
1881 let cancel = Arc::new(AtomicBool::new(true)); let result = write_image(
1884 img_path.to_str().unwrap(),
1885 dev_path.to_str().unwrap(),
1886 image_size,
1887 &tx,
1888 &cancel,
1889 );
1890
1891 assert!(result.is_err());
1892 assert!(
1893 result.unwrap_err().contains("cancelled"),
1894 "error should mention cancellation"
1895 );
1896
1897 let _ = std::fs::remove_file(img_path);
1898 let _ = std::fs::remove_file(dev_path);
1899 }
1900
1901 #[test]
1902 fn test_write_image_missing_image_returns_error() {
1903 let dir = std::env::temp_dir();
1904 let dev_path = dir.join("fk_noimg_dev.bin");
1905 std::fs::File::create(&dev_path).unwrap();
1906
1907 let (tx, _rx) = make_channel();
1908 let cancel = Arc::new(AtomicBool::new(false));
1909
1910 let result = write_image(
1911 "/nonexistent/image.img",
1912 dev_path.to_str().unwrap(),
1913 1024,
1914 &tx,
1915 &cancel,
1916 );
1917
1918 assert!(result.is_err());
1919 assert!(result.unwrap_err().contains("Cannot open image"));
1920
1921 let _ = std::fs::remove_file(dev_path);
1922 }
1923
1924 #[test]
1927 fn test_verify_matching_files() {
1928 let dir = std::env::temp_dir();
1929 let img = dir.join("fk_verify_img.bin");
1930 let dev = dir.join("fk_verify_dev.bin");
1931 let data = vec![0xBBu8; 64 * 1024];
1932 std::fs::write(&img, &data).unwrap();
1933 std::fs::write(&dev, &data).unwrap();
1934
1935 let (tx, _rx) = make_channel();
1936 let result = verify(
1937 img.to_str().unwrap(),
1938 dev.to_str().unwrap(),
1939 data.len() as u64,
1940 &tx,
1941 );
1942 assert!(result.is_ok());
1943
1944 let _ = std::fs::remove_file(img);
1945 let _ = std::fs::remove_file(dev);
1946 }
1947
1948 #[test]
1949 fn test_verify_mismatch_returns_error() {
1950 let dir = std::env::temp_dir();
1951 let img = dir.join("fk_mismatch_img.bin");
1952 let dev = dir.join("fk_mismatch_dev.bin");
1953 std::fs::write(&img, vec![0x00u8; 64 * 1024]).unwrap();
1954 std::fs::write(&dev, vec![0xFFu8; 64 * 1024]).unwrap();
1955
1956 let (tx, _rx) = make_channel();
1957 let result = verify(img.to_str().unwrap(), dev.to_str().unwrap(), 64 * 1024, &tx);
1958 assert!(result.is_err());
1959 assert!(result.unwrap_err().contains("Verification failed"));
1960
1961 let _ = std::fs::remove_file(img);
1962 let _ = std::fs::remove_file(dev);
1963 }
1964
1965 #[test]
1966 fn test_verify_only_checks_image_size_bytes() {
1967 let dir = std::env::temp_dir();
1968 let img = dir.join("fk_trunc_img.bin");
1969 let dev = dir.join("fk_trunc_dev.bin");
1970 let image_data = vec![0xCCu8; 32 * 1024];
1971 let mut device_data = image_data.clone();
1972 device_data.extend_from_slice(&[0xDDu8; 32 * 1024]);
1973 std::fs::write(&img, &image_data).unwrap();
1974 std::fs::write(&dev, &device_data).unwrap();
1975
1976 let (tx, _rx) = make_channel();
1977 let result = verify(
1978 img.to_str().unwrap(),
1979 dev.to_str().unwrap(),
1980 image_data.len() as u64,
1981 &tx,
1982 );
1983 assert!(
1984 result.is_ok(),
1985 "should pass when first N bytes match: {result:?}"
1986 );
1987
1988 let _ = std::fs::remove_file(img);
1989 let _ = std::fs::remove_file(dev);
1990 }
1991
1992 #[test]
1995 fn test_pipeline_rejects_missing_image() {
1996 let (tx, rx) = make_channel();
1997 let cancel = Arc::new(AtomicBool::new(false));
1998 run_pipeline("/nonexistent/image.iso", "/dev/null", tx, cancel);
1999 let events = drain(&rx);
2000 let err = find_error(&events);
2001 assert!(err.is_some(), "must emit an Error event");
2002 assert!(err.unwrap().contains("Image file not found"), "err={err:?}");
2003 }
2004
2005 #[test]
2006 fn test_pipeline_rejects_empty_image() {
2007 let dir = std::env::temp_dir();
2008 let empty = dir.join("fk_empty.img");
2009 std::fs::write(&empty, b"").unwrap();
2010
2011 let (tx, rx) = make_channel();
2012 let cancel = Arc::new(AtomicBool::new(false));
2013 run_pipeline(empty.to_str().unwrap(), "/dev/null", tx, cancel);
2014
2015 let events = drain(&rx);
2016 let err = find_error(&events);
2017 assert!(err.is_some());
2018 assert!(err.unwrap().contains("empty"), "err={err:?}");
2019
2020 let _ = std::fs::remove_file(empty);
2021 }
2022
2023 #[test]
2024 fn test_pipeline_rejects_missing_device() {
2025 let dir = std::env::temp_dir();
2026 let img = dir.join("fk_nodev_img.bin");
2027 std::fs::write(&img, vec![0u8; 1024]).unwrap();
2028
2029 let (tx, rx) = make_channel();
2030 let cancel = Arc::new(AtomicBool::new(false));
2031 run_pipeline(img.to_str().unwrap(), "/nonexistent/device", tx, cancel);
2032
2033 let events = drain(&rx);
2034 let err = find_error(&events);
2035 assert!(err.is_some());
2036 assert!(
2037 err.unwrap().contains("Target device not found"),
2038 "err={err:?}"
2039 );
2040
2041 let _ = std::fs::remove_file(img);
2042 }
2043
2044 #[test]
2046 fn test_pipeline_end_to_end_temp_files() {
2047 let dir = std::env::temp_dir();
2048 let img = dir.join("fk_e2e_img.bin");
2049 let dev = dir.join("fk_e2e_dev.bin");
2050
2051 let image_data: Vec<u8> = (0u8..=255u8).cycle().take(1024 * 1024).collect();
2052 std::fs::write(&img, &image_data).unwrap();
2053 std::fs::File::create(&dev).unwrap();
2054
2055 let (tx, rx) = make_channel();
2056 let cancel = Arc::new(AtomicBool::new(false));
2057 run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2058
2059 let events = drain(&rx);
2060
2061 let has_progress = events
2063 .iter()
2064 .any(|e| matches!(e, FlashEvent::Progress { .. }));
2065 assert!(has_progress, "must emit Progress events");
2066
2067 assert!(
2069 has_stage(&events, &FlashStage::Unmounting),
2070 "must emit Unmounting stage"
2071 );
2072 assert!(
2073 has_stage(&events, &FlashStage::Writing),
2074 "must emit Writing stage"
2075 );
2076 assert!(
2077 has_stage(&events, &FlashStage::Syncing),
2078 "must emit Syncing stage"
2079 );
2080
2081 let has_done = events.iter().any(|e| matches!(e, FlashEvent::Done));
2084 let has_error = events.iter().any(|e| matches!(e, FlashEvent::Error(_)));
2085 assert!(
2086 has_done || has_error,
2087 "pipeline must end with Done or Error"
2088 );
2089
2090 if has_done {
2091 let written = std::fs::read(&dev).unwrap();
2092 assert_eq!(written, image_data, "written data must match image");
2093 } else if let Some(err_msg) = find_error(&events) {
2094 assert!(
2096 !err_msg.contains("Cannot open")
2097 && !err_msg.contains("Verification failed")
2098 && !err_msg.contains("Write error"),
2099 "unexpected error: {err_msg}"
2100 );
2101 }
2102
2103 let _ = std::fs::remove_file(img);
2104 let _ = std::fs::remove_file(dev);
2105 }
2106
2107 #[test]
2110 fn test_flash_stage_display() {
2111 assert!(FlashStage::Writing.to_string().contains("Writing"));
2112 assert!(FlashStage::Syncing.to_string().contains("Flushing"));
2113 assert!(FlashStage::Done.to_string().contains("complete"));
2114 assert!(FlashStage::Failed("oops".into())
2115 .to_string()
2116 .contains("oops"));
2117 }
2118
2119 #[test]
2122 fn test_flash_stage_eq() {
2123 assert_eq!(FlashStage::Writing, FlashStage::Writing);
2124 assert_ne!(FlashStage::Writing, FlashStage::Syncing);
2125 assert_eq!(
2126 FlashStage::Failed("x".into()),
2127 FlashStage::Failed("x".into())
2128 );
2129 assert_ne!(
2130 FlashStage::Failed("x".into()),
2131 FlashStage::Failed("y".into())
2132 );
2133 }
2134
2135 #[test]
2138 fn test_flash_event_clone() {
2139 let events = vec![
2140 FlashEvent::Stage(FlashStage::Writing),
2141 FlashEvent::Progress {
2142 bytes_written: 1024,
2143 total_bytes: 4096,
2144 speed_mb_s: 12.5,
2145 },
2146 FlashEvent::Log("hello".into()),
2147 FlashEvent::Done,
2148 FlashEvent::Error("boom".into()),
2149 ];
2150 for e in &events {
2151 let _ = e.clone(); }
2153 }
2154
2155 #[test]
2160 fn test_find_mounted_partitions_nonexistent_device_returns_empty() {
2161 #[cfg(target_os = "windows")]
2163 let result = find_mounted_partitions("PhysicalDrive999", r"\\.\PhysicalDrive999");
2164 #[cfg(not(target_os = "windows"))]
2165 let result = find_mounted_partitions("sdzzz", "/dev/sdzzz");
2166
2167 let _ = result;
2169 }
2170
2171 #[test]
2174 fn test_find_mounted_partitions_empty_name_no_panic() {
2175 let result = find_mounted_partitions("", "");
2176 let _ = result;
2177 }
2178
2179 #[test]
2184 fn test_is_partition_of_windows_style_paths() {
2185 assert!(!is_partition_of(r"\\.\PhysicalDrive0", "PhysicalDrive0"));
2187 assert!(!is_partition_of(r"\\.\PhysicalDrive1", "PhysicalDrive0"));
2188 }
2189
2190 #[test]
2195 fn test_pipeline_emits_syncing_stage() {
2196 let dir = std::env::temp_dir();
2197 let img = dir.join("fk_sync_stage_img.bin");
2198 let dev = dir.join("fk_sync_stage_dev.bin");
2199
2200 let data: Vec<u8> = (0u8..=255).cycle().take(512 * 1024).collect();
2201 std::fs::write(&img, &data).unwrap();
2202 std::fs::File::create(&dev).unwrap();
2203
2204 let (tx, rx) = make_channel();
2205 let cancel = Arc::new(AtomicBool::new(false));
2206 run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2207
2208 let events = drain(&rx);
2209 assert!(
2210 has_stage(&events, &FlashStage::Syncing),
2211 "Syncing stage must be emitted on every platform"
2212 );
2213
2214 let _ = std::fs::remove_file(&img);
2215 let _ = std::fs::remove_file(&dev);
2216 }
2217
2218 #[test]
2220 fn test_pipeline_emits_rereading_stage() {
2221 let dir = std::env::temp_dir();
2222 let img = dir.join("fk_reread_stage_img.bin");
2223 let dev = dir.join("fk_reread_stage_dev.bin");
2224
2225 let data: Vec<u8> = vec![0xABu8; 256 * 1024];
2226 std::fs::write(&img, &data).unwrap();
2227 std::fs::File::create(&dev).unwrap();
2228
2229 let (tx, rx) = make_channel();
2230 let cancel = Arc::new(AtomicBool::new(false));
2231 run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2232
2233 let events = drain(&rx);
2234 assert!(
2235 has_stage(&events, &FlashStage::Rereading),
2236 "Rereading stage must be emitted on every platform"
2237 );
2238
2239 let _ = std::fs::remove_file(&img);
2240 let _ = std::fs::remove_file(&dev);
2241 }
2242
2243 #[test]
2245 fn test_pipeline_emits_verifying_stage() {
2246 let dir = std::env::temp_dir();
2247 let img = dir.join("fk_verify_stage_img.bin");
2248 let dev = dir.join("fk_verify_stage_dev.bin");
2249
2250 let data: Vec<u8> = vec![0xCDu8; 256 * 1024];
2251 std::fs::write(&img, &data).unwrap();
2252 std::fs::File::create(&dev).unwrap();
2253
2254 let (tx, rx) = make_channel();
2255 let cancel = Arc::new(AtomicBool::new(false));
2256 run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2257
2258 let events = drain(&rx);
2259 assert!(
2260 has_stage(&events, &FlashStage::Verifying),
2261 "Verifying stage must be emitted on every platform"
2262 );
2263
2264 let _ = std::fs::remove_file(&img);
2265 let _ = std::fs::remove_file(&dev);
2266 }
2267
2268 #[test]
2273 fn test_open_device_for_writing_nonexistent_mentions_path() {
2274 let bad = if cfg!(target_os = "windows") {
2275 r"\\.\PhysicalDrive999".to_string()
2276 } else {
2277 "/nonexistent/fk_bad_device".to_string()
2278 };
2279
2280 let dir = std::env::temp_dir();
2282 let img = dir.join("fk_open_err_img.bin");
2283 std::fs::write(&img, vec![1u8; 512]).unwrap();
2284
2285 let (tx, _rx) = make_channel();
2286 let cancel = Arc::new(AtomicBool::new(false));
2287 let result = write_image(img.to_str().unwrap(), &bad, 512, &tx, &cancel);
2288
2289 assert!(result.is_err(), "must fail for nonexistent device");
2290 assert!(
2292 result.as_ref().unwrap_err().contains("PhysicalDrive999")
2293 || result.as_ref().unwrap_err().contains("fk_bad_device")
2294 || result.as_ref().unwrap_err().contains("Cannot open"),
2295 "error should reference the bad path: {:?}",
2296 result
2297 );
2298
2299 let _ = std::fs::remove_file(&img);
2300 }
2301
2302 #[test]
2307 fn test_sync_device_emits_log() {
2308 let dir = std::env::temp_dir();
2309 let dev = dir.join("fk_sync_log_dev.bin");
2310 std::fs::File::create(&dev).unwrap();
2311
2312 let (tx, rx) = make_channel();
2313 sync_device(dev.to_str().unwrap(), &tx);
2314
2315 let events = drain(&rx);
2316 let has_flush_log = events.iter().any(|e| {
2317 if let FlashEvent::Log(msg) = e {
2318 let lower = msg.to_lowercase();
2319 lower.contains("flush") || lower.contains("cache")
2320 } else {
2321 false
2322 }
2323 });
2324 assert!(
2325 has_flush_log,
2326 "sync_device must emit a flush/cache log event"
2327 );
2328
2329 let _ = std::fs::remove_file(&dev);
2330 }
2331
2332 #[test]
2337 fn test_reread_partition_table_emits_log() {
2338 let dir = std::env::temp_dir();
2339 let dev = dir.join("fk_reread_log_dev.bin");
2340 std::fs::File::create(&dev).unwrap();
2341
2342 let (tx, rx) = make_channel();
2343 reread_partition_table(dev.to_str().unwrap(), &tx);
2344
2345 let events = drain(&rx);
2346 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2347 assert!(
2348 has_log,
2349 "reread_partition_table must emit at least one Log event"
2350 );
2351
2352 let _ = std::fs::remove_file(&dev);
2353 }
2354
2355 #[test]
2360 fn test_unmount_device_no_partitions_emits_log() {
2361 let dir = std::env::temp_dir();
2362 let dev = dir.join("fk_unmount_log_dev.bin");
2363 std::fs::File::create(&dev).unwrap();
2364
2365 let path_str = dev.to_str().unwrap();
2366 let (tx, rx) = make_channel();
2367 unmount_device(path_str, &tx);
2368
2369 let events = drain(&rx);
2370 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2372 assert!(has_log, "unmount_device must emit at least one Log event");
2373
2374 let _ = std::fs::remove_file(&dev);
2375 }
2376
2377 #[test]
2382 fn test_pipeline_stage_ordering() {
2383 let dir = std::env::temp_dir();
2384 let img = dir.join("fk_order_img.bin");
2385 let dev = dir.join("fk_order_dev.bin");
2386
2387 let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
2388 std::fs::write(&img, &data).unwrap();
2389 std::fs::File::create(&dev).unwrap();
2390
2391 let (tx, rx) = make_channel();
2392 let cancel = Arc::new(AtomicBool::new(false));
2393 run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
2394
2395 let events = drain(&rx);
2396
2397 let stages: Vec<&FlashStage> = events
2399 .iter()
2400 .filter_map(|e| {
2401 if let FlashEvent::Stage(s) = e {
2402 Some(s)
2403 } else {
2404 None
2405 }
2406 })
2407 .collect();
2408
2409 let pos = |target: &FlashStage| {
2411 stages
2412 .iter()
2413 .position(|s| *s == target)
2414 .unwrap_or(usize::MAX)
2415 };
2416
2417 let unmounting = pos(&FlashStage::Unmounting);
2418 let writing = pos(&FlashStage::Writing);
2419 let syncing = pos(&FlashStage::Syncing);
2420 let rereading = pos(&FlashStage::Rereading);
2421 let verifying = pos(&FlashStage::Verifying);
2422
2423 assert!(unmounting < writing, "Unmounting must precede Writing");
2424 assert!(writing < syncing, "Writing must precede Syncing");
2425 assert!(syncing < rereading, "Syncing must precede Rereading");
2426 assert!(rereading < verifying, "Rereading must precede Verifying");
2427
2428 let _ = std::fs::remove_file(&img);
2429 let _ = std::fs::remove_file(&dev);
2430 }
2431
2432 #[test]
2437 #[cfg(target_os = "linux")]
2438 fn test_find_mounted_partitions_linux_no_panic() {
2439 let result = find_mounted_partitions("sda", "/dev/sda");
2441 let _ = result;
2442 }
2443
2444 #[test]
2448 #[cfg(target_os = "linux")]
2449 fn test_find_mounted_partitions_linux_reads_proc_mounts() {
2450 let content = std::fs::read_to_string("/proc/mounts").unwrap_or_default();
2453 if !content.is_empty() {
2455 if let Some(line) = content.lines().find(|l| l.starts_with("/dev/")) {
2458 if let Some(dev) = line.split_whitespace().next() {
2459 let name = std::path::Path::new(dev)
2460 .file_name()
2461 .map(|n| n.to_string_lossy().to_string())
2462 .unwrap_or_default();
2463 let _ = find_mounted_partitions(&name, dev);
2464 }
2465 }
2466 }
2467 }
2468
2469 #[test]
2474 #[cfg(target_os = "linux")]
2475 fn test_do_unmount_not_mounted_does_not_panic() {
2476 let (tx, rx) = make_channel();
2477 do_unmount("/dev/fk_nonexistent_part", &tx);
2478 let events = drain(&rx);
2479 let has_warning = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2482 assert!(
2483 !has_warning,
2484 "do_unmount must not emit a warning for EINVAL/ENOENT: {events:?}"
2485 );
2486 }
2487
2488 #[test]
2493 #[cfg(target_os = "macos")]
2494 fn test_do_unmount_macos_bad_path_emits_warning() {
2495 let (tx, rx) = make_channel();
2496 do_unmount("/dev/fk_nonexistent_part", &tx);
2497 let events = drain(&rx);
2498 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2499 assert!(has_log, "do_unmount must emit a Log event on failure");
2500 }
2501
2502 #[test]
2505 #[cfg(target_os = "macos")]
2506 fn test_find_mounted_partitions_macos_no_panic() {
2507 let result = find_mounted_partitions("disk2", "/dev/disk2");
2508 let _ = result;
2509 }
2510
2511 #[test]
2514 #[cfg(target_os = "macos")]
2515 fn test_reread_partition_table_macos_emits_log() {
2516 let dir = std::env::temp_dir();
2517 let dev = dir.join("fk_macos_reread_dev.bin");
2518 std::fs::File::create(&dev).unwrap();
2519
2520 let (tx, rx) = make_channel();
2521 reread_partition_table(dev.to_str().unwrap(), &tx);
2522
2523 let events = drain(&rx);
2524 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2525 assert!(has_log, "reread_partition_table must emit a log on macOS");
2526
2527 let _ = std::fs::remove_file(&dev);
2528 }
2529
2530 #[test]
2536 #[cfg(target_os = "windows")]
2537 fn test_find_mounted_partitions_windows_nonexistent() {
2538 let result = find_mounted_partitions("PhysicalDrive999", r"\\.\PhysicalDrive999");
2539 assert!(
2540 result.is_empty(),
2541 "nonexistent physical drive should have no volumes"
2542 );
2543 }
2544
2545 #[test]
2548 #[cfg(target_os = "windows")]
2549 fn test_do_unmount_windows_bad_volume_emits_log() {
2550 let (tx, rx) = make_channel();
2551 do_unmount(r"\\.\Z99:", &tx);
2552 let events = drain(&rx);
2553 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2554 assert!(has_log, "do_unmount on bad volume must emit a Log event");
2555 }
2556
2557 #[test]
2560 #[cfg(target_os = "windows")]
2561 fn test_sync_device_windows_bad_path_no_panic() {
2562 let (tx, rx) = make_channel();
2563 sync_device(r"\\.\PhysicalDrive999", &tx);
2564 let events = drain(&rx);
2565 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2568 assert!(has_log, "sync_device must emit a Log event on Windows");
2569 }
2570
2571 #[test]
2574 #[cfg(target_os = "windows")]
2575 fn test_reread_partition_table_windows_bad_path_no_panic() {
2576 let (tx, rx) = make_channel();
2577 reread_partition_table(r"\\.\PhysicalDrive999", &tx);
2578 let events = drain(&rx);
2579 let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
2580 assert!(
2581 has_log,
2582 "reread_partition_table must emit a Log event on Windows"
2583 );
2584 }
2585
2586 #[test]
2589 #[cfg(target_os = "windows")]
2590 fn test_open_device_for_writing_windows_access_denied_message() {
2591 let dir = std::env::temp_dir();
2592 let img = dir.join("fk_win_open_img.bin");
2593 std::fs::write(&img, vec![1u8; 512]).unwrap();
2594
2595 let (tx, _rx) = make_channel();
2596 let cancel = Arc::new(AtomicBool::new(false));
2597 let result = write_image(
2598 img.to_str().unwrap(),
2599 r"\\.\PhysicalDrive999",
2600 512,
2601 &tx,
2602 &cancel,
2603 );
2604
2605 assert!(result.is_err());
2606 let msg = result.unwrap_err();
2607 assert!(
2609 msg.contains("PhysicalDrive999")
2610 || msg.contains("Access denied")
2611 || msg.contains("Cannot open"),
2612 "error must be descriptive: {msg}"
2613 );
2614
2615 let _ = std::fs::remove_file(&img);
2616 }
2617 #[test]
2620 fn flash_stage_progress_floor_syncing() {
2621 assert!((FlashStage::Syncing.progress_floor() - 0.80).abs() < f32::EPSILON);
2622 }
2623
2624 #[test]
2625 fn flash_stage_progress_floor_rereading() {
2626 assert!((FlashStage::Rereading.progress_floor() - 0.88).abs() < f32::EPSILON);
2627 }
2628
2629 #[test]
2630 fn flash_stage_progress_floor_verifying() {
2631 assert!((FlashStage::Verifying.progress_floor() - 0.92).abs() < f32::EPSILON);
2632 }
2633
2634 #[test]
2635 fn flash_stage_progress_floor_other_stages_are_zero() {
2636 for stage in [
2637 FlashStage::Starting,
2638 FlashStage::Unmounting,
2639 FlashStage::Writing,
2640 FlashStage::Done,
2641 ] {
2642 assert_eq!(
2643 stage.progress_floor(),
2644 0.0,
2645 "{stage:?} should have floor 0.0"
2646 );
2647 }
2648 }
2649
2650 #[test]
2653 fn verify_overall_image_phase_start() {
2654 assert_eq!(verify_overall_progress("image", 0.0), 0.0);
2655 }
2656
2657 #[test]
2658 fn verify_overall_image_phase_end() {
2659 assert!((verify_overall_progress("image", 1.0) - 0.5).abs() < f32::EPSILON);
2660 }
2661
2662 #[test]
2663 fn verify_overall_image_phase_midpoint() {
2664 assert!((verify_overall_progress("image", 0.5) - 0.25).abs() < f32::EPSILON);
2665 }
2666
2667 #[test]
2668 fn verify_overall_device_phase_start() {
2669 assert!((verify_overall_progress("device", 0.0) - 0.5).abs() < f32::EPSILON);
2670 }
2671
2672 #[test]
2673 fn verify_overall_device_phase_end() {
2674 assert!((verify_overall_progress("device", 1.0) - 1.0).abs() < f32::EPSILON);
2675 }
2676
2677 #[test]
2678 fn verify_overall_device_phase_midpoint() {
2679 assert!((verify_overall_progress("device", 0.5) - 0.75).abs() < f32::EPSILON);
2680 }
2681
2682 #[test]
2683 fn verify_overall_unknown_phase_treated_as_device() {
2684 assert!((verify_overall_progress("other", 0.0) - 0.5).abs() < f32::EPSILON);
2686 }
2687}