#[cfg(unix)]
use nix::libc;
use std::io::{self, Read, Write};
use std::path::Path;
use std::sync::{
atomic::{AtomicBool, Ordering},
mpsc, Arc, OnceLock,
};
use std::time::{Duration, Instant};
const BLOCK_SIZE: usize = 4 * 1024 * 1024;
const PROGRESS_INTERVAL: Duration = Duration::from_millis(400);
static REAL_UID: OnceLock<u32> = OnceLock::new();
pub fn set_real_uid(uid: u32) {
let _ = REAL_UID.set(uid);
}
pub fn is_privileged() -> bool {
#[cfg(unix)]
{
nix::unistd::geteuid().is_root()
}
#[cfg(not(unix))]
{
false
}
}
#[cfg(unix)]
pub fn reexec_as_root() {
if is_running_under_test_harness() {
return;
}
#[cfg(test)]
return;
#[cfg(not(test))]
reexec_as_root_inner();
}
#[cfg(unix)]
fn is_running_under_test_harness() -> bool {
if std::env::var("FLASHKRAFT_NO_REEXEC").is_ok() {
return true;
}
if std::env::var("NEXTEST_TEST_FILTER").is_ok() {
return true;
}
if let Ok(exe) = std::env::current_exe() {
let path_str = exe.to_string_lossy();
if path_str.contains("/deps/") {
return true;
}
if path_str.contains("\\deps\\") {
return true;
}
}
false
}
#[cfg(all(unix, not(test)))]
fn reexec_as_root_inner() {
use std::ffi::CString;
if std::env::var("FLASHKRAFT_ESCALATED").as_deref() == Ok("1") {
return;
}
let self_exe = match std::fs::read_link("/proc/self/exe").or_else(|_| std::env::current_exe()) {
Ok(p) => p,
Err(_) => return,
};
let self_exe_str = match self_exe.to_str() {
Some(s) => s.to_owned(),
None => return,
};
let extra_args: Vec<String> = std::env::args().skip(1).collect();
std::env::set_var("FLASHKRAFT_ESCALATED", "1");
if unix_which_exists("pkexec") {
let mut argv: Vec<CString> = Vec::new();
argv.push(unix_c_str("pkexec"));
argv.push(unix_c_str(&self_exe_str));
for a in &extra_args {
argv.push(unix_c_str(a));
}
let _ = nix::unistd::execvp(&unix_c_str("pkexec"), &argv);
}
if unix_which_exists("sudo") {
let mut argv: Vec<CString> = Vec::new();
argv.push(unix_c_str("sudo"));
argv.push(unix_c_str("-E")); argv.push(unix_c_str(&self_exe_str));
for a in &extra_args {
argv.push(unix_c_str(a));
}
let _ = nix::unistd::execvp(&unix_c_str("sudo"), &argv);
}
std::env::remove_var("FLASHKRAFT_ESCALATED");
}
#[cfg(not(unix))]
pub fn reexec_as_root() {}
#[cfg(all(unix, not(test)))]
fn unix_which_exists(name: &str) -> bool {
use std::os::unix::fs::PermissionsExt;
if let Ok(path_var) = std::env::var("PATH") {
for dir in path_var.split(':') {
let candidate = std::path::Path::new(dir).join(name);
if let Ok(meta) = std::fs::metadata(&candidate) {
if meta.is_file() && meta.permissions().mode() & 0o111 != 0 {
return true;
}
}
}
}
false
}
#[cfg(all(unix, not(test)))]
fn unix_c_str(s: &str) -> std::ffi::CString {
let sanitised: Vec<u8> = s.bytes().map(|b| if b == 0 { b'?' } else { b }).collect();
std::ffi::CString::new(sanitised).unwrap_or_else(|_| std::ffi::CString::new("?").unwrap())
}
#[cfg(unix)]
fn real_uid() -> nix::unistd::Uid {
let raw = REAL_UID
.get()
.copied()
.unwrap_or_else(|| nix::unistd::getuid().as_raw());
nix::unistd::Uid::from_raw(raw)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FlashStage {
Starting,
Unmounting,
Writing,
Syncing,
Rereading,
Verifying,
Done,
Failed(String),
}
impl std::fmt::Display for FlashStage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FlashStage::Starting => write!(f, "Starting…"),
FlashStage::Unmounting => write!(f, "Unmounting partitions…"),
FlashStage::Writing => write!(f, "Writing image to device…"),
FlashStage::Syncing => write!(f, "Flushing write buffers…"),
FlashStage::Rereading => write!(f, "Refreshing partition table…"),
FlashStage::Verifying => write!(f, "Verifying written data…"),
FlashStage::Done => write!(f, "Flash complete!"),
FlashStage::Failed(m) => write!(f, "Failed: {m}"),
}
}
}
impl FlashStage {
pub fn progress_floor(&self) -> f32 {
match self {
FlashStage::Syncing => 0.80,
FlashStage::Rereading => 0.88,
FlashStage::Verifying => 0.92,
_ => 0.0,
}
}
}
pub fn verify_overall_progress(phase: &str, pass_fraction: f32) -> f32 {
if phase == "image" {
pass_fraction * 0.5
} else {
0.5 + pass_fraction * 0.5
}
}
#[derive(Debug, Clone)]
pub enum FlashEvent {
Stage(FlashStage),
Progress {
bytes_written: u64,
total_bytes: u64,
speed_mb_s: f32,
},
VerifyProgress {
phase: &'static str,
bytes_read: u64,
total_bytes: u64,
speed_mb_s: f32,
},
Log(String),
Done,
Error(String),
}
#[derive(Debug, Clone)]
pub enum FlashUpdate {
Progress {
progress: f32,
bytes_written: u64,
speed_mb_s: f32,
},
VerifyProgress {
phase: &'static str,
overall: f32,
bytes_read: u64,
total_bytes: u64,
speed_mb_s: f32,
},
Message(String),
Completed,
Failed(String),
}
impl From<FlashEvent> for FlashUpdate {
fn from(event: FlashEvent) -> Self {
match event {
FlashEvent::Progress {
bytes_written,
total_bytes,
speed_mb_s,
} => {
let progress = if total_bytes > 0 {
(bytes_written as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
} else {
0.0
};
FlashUpdate::Progress {
progress,
bytes_written,
speed_mb_s,
}
}
FlashEvent::VerifyProgress {
phase,
bytes_read,
total_bytes,
speed_mb_s,
} => {
let pass_fraction = if total_bytes > 0 {
(bytes_read as f64 / total_bytes as f64).clamp(0.0, 1.0) as f32
} else {
0.0
};
let overall = verify_overall_progress(phase, pass_fraction);
FlashUpdate::VerifyProgress {
phase,
overall,
bytes_read,
total_bytes,
speed_mb_s,
}
}
FlashEvent::Stage(stage) => FlashUpdate::Message(stage.to_string()),
FlashEvent::Log(msg) => FlashUpdate::Message(msg),
FlashEvent::Done => FlashUpdate::Completed,
FlashEvent::Error(e) => FlashUpdate::Failed(e),
}
}
}
pub fn run_pipeline(
image_path: &str,
device_path: &str,
tx: mpsc::Sender<FlashEvent>,
cancel: Arc<AtomicBool>,
) {
if let Err(e) = flash_pipeline(image_path, device_path, &tx, cancel) {
let _ = tx.send(FlashEvent::Error(e));
}
}
fn send(tx: &mpsc::Sender<FlashEvent>, event: FlashEvent) {
let _ = tx.send(event);
}
fn flash_pipeline(
image_path: &str,
device_path: &str,
tx: &mpsc::Sender<FlashEvent>,
cancel: Arc<AtomicBool>,
) -> Result<(), String> {
if !Path::new(image_path).is_file() {
return Err(format!("Image file not found: {image_path}"));
}
if !Path::new(device_path).exists() {
return Err(format!("Target device not found: {device_path}"));
}
#[cfg(target_os = "linux")]
reject_partition_node(device_path)?;
let image_size = std::fs::metadata(image_path)
.map_err(|e| format!("Cannot stat image: {e}"))?
.len();
if image_size == 0 {
return Err("Image file is empty".to_string());
}
send(tx, FlashEvent::Stage(FlashStage::Unmounting));
unmount_device(device_path, tx);
#[cfg(target_os = "linux")]
check_device_not_busy(device_path)?;
send(tx, FlashEvent::Stage(FlashStage::Writing));
send(
tx,
FlashEvent::Log(format!(
"Writing {image_size} bytes from {image_path} → {device_path}"
)),
);
write_image(image_path, device_path, image_size, tx, &cancel)?;
send(tx, FlashEvent::Stage(FlashStage::Syncing));
sync_device(device_path, tx);
send(tx, FlashEvent::Stage(FlashStage::Rereading));
reread_partition_table(device_path, tx);
send(tx, FlashEvent::Stage(FlashStage::Verifying));
verify(image_path, device_path, image_size, tx)?;
send(tx, FlashEvent::Done);
Ok(())
}
#[cfg(target_os = "linux")]
fn check_device_not_busy(device_path: &str) -> Result<(), String> {
check_device_not_busy_with(device_path, |path| {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_EXCL)
.open(path)
.map(|_| ())
})
}
#[cfg(target_os = "linux")]
fn check_device_not_busy_with<F>(device_path: &str, open_fn: F) -> Result<(), String>
where
F: FnOnce(&str) -> std::io::Result<()>,
{
if let Err(e) = open_fn(device_path) {
if e.raw_os_error() == Some(libc::EBUSY) {
return Err(format!(
"Device '{device_path}' is already in use by another process.\n\
Is another flash operation already running?"
));
}
}
Ok(())
}
#[cfg(target_os = "linux")]
fn reject_partition_node(device_path: &str) -> Result<(), String> {
let dev_name = Path::new(device_path)
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
let is_partition = {
let bytes = dev_name.as_bytes();
!bytes.is_empty() && bytes[bytes.len() - 1].is_ascii_digit() && {
let stem = dev_name.trim_end_matches(|c: char| c.is_ascii_digit());
stem.ends_with('p')
|| (!stem.is_empty()
&& !stem.ends_with(|c: char| c.is_ascii_digit())
&& stem.chars().any(|c| c.is_ascii_alphabetic()))
}
};
if is_partition {
let whole = dev_name.trim_end_matches(|c: char| c.is_ascii_digit() || c == 'p');
return Err(format!(
"Refusing to write to partition node '{device_path}'. \
Select the whole-disk device (e.g. /dev/{whole}) instead."
));
}
Ok(())
}
fn open_device_for_writing(device_path: &str) -> Result<std::fs::File, String> {
#[cfg(unix)]
{
use nix::unistd::seteuid;
let escalated = seteuid(nix::unistd::Uid::from_raw(0)).is_ok();
let result = std::fs::OpenOptions::new()
.write(true)
.open(device_path)
.map_err(|e| {
let raw = e.raw_os_error().unwrap_or(0);
if raw == libc::EACCES || raw == libc::EPERM {
if escalated {
format!(
"Permission denied opening '{device_path}'.\n\
Even with setuid-root the device refused access — \
check that the device exists and is not in use."
)
} else {
format!(
"Permission denied opening '{device_path}'.\n\
FlashKraft needs root access to write to block devices.\n\
Install setuid-root so it can escalate automatically:\n\
sudo chown root:root /usr/bin/flashkraft\n\
sudo chmod u+s /usr/bin/flashkraft"
)
}
} else if raw == libc::EBUSY {
format!(
"Device '{device_path}' is busy. \
Ensure all partitions are unmounted before flashing."
)
} else {
format!("Cannot open device '{device_path}' for writing: {e}")
}
});
if escalated {
let _ = seteuid(real_uid());
}
result
}
#[cfg(not(unix))]
{
std::fs::OpenOptions::new()
.write(true)
.open(device_path)
.map_err(|e| {
let raw = e.raw_os_error().unwrap_or(0);
if raw == 5 || raw == 1314 {
format!(
"Access denied opening '{device_path}'.\n\
FlashKraft must be run as Administrator on Windows.\n\
Right-click the application and choose \
'Run as administrator'."
)
} else if raw == 32 {
format!(
"Device '{device_path}' is in use by another process.\n\
Close any applications using the drive and try again."
)
} else {
format!("Cannot open device '{device_path}' for writing: {e}")
}
})
}
}
fn unmount_device(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
let device_name = Path::new(device_path)
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
let partitions = find_mounted_partitions(&device_name, device_path);
if partitions.is_empty() {
send(tx, FlashEvent::Log("No mounted partitions found".into()));
} else {
for partition in &partitions {
send(tx, FlashEvent::Log(format!("Unmounting {partition}")));
do_unmount(partition, tx);
}
}
}
fn find_mounted_partitions(
#[cfg_attr(target_os = "windows", allow(unused_variables))] device_name: &str,
device_path: &str,
) -> Vec<String> {
#[cfg(not(target_os = "windows"))]
{
let mounts = std::fs::read_to_string("/proc/mounts")
.or_else(|_| std::fs::read_to_string("/proc/self/mounts"))
.unwrap_or_default();
let mut mount_points = Vec::new();
for line in mounts.lines() {
let mut fields = line.split_whitespace();
let dev = match fields.next() {
Some(d) => d,
None => continue,
};
let mount_point = match fields.next() {
Some(m) => m,
None => continue,
};
if dev == device_path || is_partition_of(dev, device_name) {
mount_points.push(mount_point.to_string());
}
}
mount_points
}
#[cfg(target_os = "windows")]
{
windows::find_volumes_on_physical_drive(device_path)
}
}
#[cfg(not(target_os = "windows"))]
fn is_partition_of(dev: &str, device_name: &str) -> bool {
let dev_base = Path::new(dev)
.file_name()
.map(|n| n.to_string_lossy())
.unwrap_or_default();
if !dev_base.starts_with(device_name) {
return false;
}
let suffix = &dev_base[device_name.len()..];
if suffix.is_empty() {
return false;
}
let first = suffix.chars().next().unwrap();
first.is_ascii_digit() || (first == 'p' && suffix.len() > 1)
}
#[cfg(target_os = "linux")]
fn which_exists(name: &str) -> bool {
use std::os::unix::fs::PermissionsExt;
std::env::var("PATH")
.unwrap_or_default()
.split(':')
.any(|dir| {
let p = std::path::Path::new(dir).join(name);
std::fs::metadata(&p)
.map(|m| m.is_file() && m.permissions().mode() & 0o111 != 0)
.unwrap_or(false)
})
}
fn do_unmount(partition: &str, tx: &mpsc::Sender<FlashEvent>) {
#[cfg(target_os = "linux")]
{
use nix::unistd::seteuid;
use std::ffi::CString;
if which_exists("udisksctl") {
let result = std::process::Command::new("udisksctl")
.args(["unmount", "--no-user-interaction", "-b", partition])
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.spawn();
let udisks_ok = match result {
Ok(mut child) => {
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
loop {
match child.try_wait() {
Ok(Some(status)) => break status.success(),
Ok(None) if std::time::Instant::now() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(100));
}
_ => {
let _ = child.kill();
send(
tx,
FlashEvent::Log(
"udisksctl timed out — falling back to umount2".into(),
),
);
break false;
}
}
}
}
Err(_) => false,
};
if udisks_ok {
send(
tx,
FlashEvent::Log(format!("Unmounted {partition} via udisksctl")),
);
return;
}
}
let _ = seteuid(nix::unistd::Uid::from_raw(0));
if let Ok(c_path) = CString::new(partition) {
let ret = unsafe { libc::umount2(c_path.as_ptr(), libc::MNT_DETACH) };
if ret != 0 {
let raw = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
match raw {
libc::EINVAL => {}
libc::ENOENT => {}
libc::EPERM => {}
_ => {
let err = std::io::Error::from_raw_os_error(raw);
send(
tx,
FlashEvent::Log(format!(
"Warning — could not unmount {partition}: {err}"
)),
);
}
}
}
}
let _ = seteuid(real_uid());
}
#[cfg(target_os = "macos")]
{
let out = std::process::Command::new("diskutil")
.args(["unmount", partition])
.output();
if let Ok(o) = out {
if !o.status.success() {
send(
tx,
FlashEvent::Log(format!("Warning — diskutil unmount {partition} failed")),
);
}
}
}
#[cfg(target_os = "windows")]
{
match windows::lock_and_dismount_volume(partition) {
Ok(()) => send(
tx,
FlashEvent::Log(format!("Dismounted volume {partition}")),
),
Err(e) => send(
tx,
FlashEvent::Log(format!("Warning — could not dismount {partition}: {e}")),
),
}
}
}
fn write_image(
image_path: &str,
device_path: &str,
image_size: u64,
tx: &mpsc::Sender<FlashEvent>,
cancel: &Arc<AtomicBool>,
) -> Result<(), String> {
let image_file =
std::fs::File::open(image_path).map_err(|e| format!("Cannot open image: {e}"))?;
let device_file = open_device_for_writing(device_path)?;
let mut reader = io::BufReader::with_capacity(BLOCK_SIZE, image_file);
let mut writer = io::BufWriter::with_capacity(BLOCK_SIZE, device_file);
let mut buf = vec![0u8; BLOCK_SIZE];
let mut bytes_written: u64 = 0;
let start = Instant::now();
let mut last_report = Instant::now();
loop {
if cancel.load(Ordering::SeqCst) {
return Err("Flash operation cancelled by user".to_string());
}
let n = reader
.read(&mut buf)
.map_err(|e| format!("Read error on image: {e}"))?;
if n == 0 {
break; }
writer
.write_all(&buf[..n])
.map_err(|e| format!("Write error on device: {e}"))?;
bytes_written += n as u64;
let now = Instant::now();
if now.duration_since(last_report) >= PROGRESS_INTERVAL || bytes_written >= image_size {
let elapsed_s = now.duration_since(start).as_secs_f32();
let speed_mb_s = if elapsed_s > 0.001 {
(bytes_written as f32 / (1024.0 * 1024.0)) / elapsed_s
} else {
0.0
};
send(
tx,
FlashEvent::Progress {
bytes_written,
total_bytes: image_size,
speed_mb_s,
},
);
last_report = now;
}
}
writer
.flush()
.map_err(|e| format!("Buffer flush error: {e}"))?;
#[cfg_attr(not(unix), allow(unused_variables))]
let device_file = writer
.into_inner()
.map_err(|e| format!("BufWriter error: {e}"))?;
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
let fd = device_file.as_raw_fd();
let ret = unsafe { libc::fsync(fd) };
if ret != 0 {
let err = std::io::Error::last_os_error();
return Err(format!(
"fsync failed on '{device_path}': {err} — \
data may not have been fully written to the device"
));
}
}
let elapsed_s = start.elapsed().as_secs_f32();
let speed_mb_s = if elapsed_s > 0.001 {
(bytes_written as f32 / (1024.0 * 1024.0)) / elapsed_s
} else {
0.0
};
send(
tx,
FlashEvent::Progress {
bytes_written,
total_bytes: image_size,
speed_mb_s,
},
);
send(tx, FlashEvent::Log("Image write complete".into()));
Ok(())
}
fn sync_device(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
#[cfg(unix)]
if let Ok(f) = std::fs::OpenOptions::new().write(true).open(device_path) {
use std::os::unix::io::AsRawFd;
let fd = f.as_raw_fd();
#[cfg(target_os = "linux")]
unsafe {
libc::fdatasync(fd);
}
#[cfg(not(target_os = "linux"))]
unsafe {
libc::fsync(fd);
}
drop(f);
}
#[cfg(target_os = "linux")]
unsafe {
libc::sync();
}
#[cfg(target_os = "windows")]
{
match windows::flush_device_buffers(device_path) {
Ok(()) => {}
Err(e) => send(
tx,
FlashEvent::Log(format!(
"Warning — FlushFileBuffers on '{device_path}' failed: {e}"
)),
),
}
}
send(tx, FlashEvent::Log("Write-back caches flushed".into()));
}
#[cfg(target_os = "linux")]
fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
use nix::ioctl_none;
use std::os::unix::io::AsRawFd;
ioctl_none!(blkrrpart, 0x12, 95);
std::thread::sleep(Duration::from_millis(500));
match std::fs::OpenOptions::new().write(true).open(device_path) {
Ok(f) => {
let result = unsafe { blkrrpart(f.as_raw_fd()) };
match result {
Ok(_) => send(
tx,
FlashEvent::Log("Kernel partition table refreshed".into()),
),
Err(e) => send(
tx,
FlashEvent::Log(format!(
"Warning — BLKRRPART ioctl failed \
(device may not be partitioned): {e}"
)),
),
}
}
Err(e) => send(
tx,
FlashEvent::Log(format!(
"Warning — could not open device for BLKRRPART: {e}"
)),
),
}
}
#[cfg(target_os = "macos")]
fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
let _ = std::process::Command::new("diskutil")
.args(["rereadPartitionTable", device_path])
.output();
send(
tx,
FlashEvent::Log("Partition table refresh requested (macOS)".into()),
);
}
#[cfg(target_os = "windows")]
fn reread_partition_table(device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
std::thread::sleep(Duration::from_millis(500));
match windows::update_disk_properties(device_path) {
Ok(()) => send(
tx,
FlashEvent::Log("Partition table refreshed (IOCTL_DISK_UPDATE_PROPERTIES)".into()),
),
Err(e) => send(
tx,
FlashEvent::Log(format!(
"Warning — IOCTL_DISK_UPDATE_PROPERTIES failed: {e}"
)),
),
}
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
fn reread_partition_table(_device_path: &str, tx: &mpsc::Sender<FlashEvent>) {
send(
tx,
FlashEvent::Log("Partition table refresh not supported on this platform".into()),
);
}
fn verify(
image_path: &str,
device_path: &str,
image_size: u64,
tx: &mpsc::Sender<FlashEvent>,
) -> Result<(), String> {
send(
tx,
FlashEvent::Log("Computing SHA-256 of source image".into()),
);
let image_hash = sha256_with_progress(image_path, image_size, "image", tx)?;
send(
tx,
FlashEvent::Log(format!(
"Reading back {image_size} bytes from device for verification"
)),
);
let device_hash = sha256_with_progress(device_path, image_size, "device", tx)?;
if image_hash != device_hash {
return Err(format!(
"Verification failed — data mismatch \
(image={image_hash} device={device_hash})"
));
}
send(
tx,
FlashEvent::Log(format!("Verification passed ({image_hash})")),
);
Ok(())
}
fn sha256_with_progress(
path: &str,
max_bytes: u64,
phase: &'static str,
tx: &mpsc::Sender<FlashEvent>,
) -> Result<String, String> {
use sha2::{Digest, Sha256};
let file =
std::fs::File::open(path).map_err(|e| format!("Cannot open {path} for hashing: {e}"))?;
let mut hasher = Sha256::new();
let mut reader = io::BufReader::with_capacity(BLOCK_SIZE, file);
let mut buf = vec![0u8; BLOCK_SIZE];
let mut remaining = max_bytes;
let mut bytes_read: u64 = 0;
let start = Instant::now();
let mut last_report = Instant::now();
while remaining > 0 {
let to_read = (remaining as usize).min(buf.len());
let n = reader
.read(&mut buf[..to_read])
.map_err(|e| format!("Read error while hashing {path}: {e}"))?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
bytes_read += n as u64;
remaining -= n as u64;
let now = Instant::now();
if now.duration_since(last_report) >= PROGRESS_INTERVAL || remaining == 0 {
let elapsed_s = now.duration_since(start).as_secs_f32();
let speed_mb_s = if elapsed_s > 0.001 {
(bytes_read as f32 / (1024.0 * 1024.0)) / elapsed_s
} else {
0.0
};
send(
tx,
FlashEvent::VerifyProgress {
phase,
bytes_read,
total_bytes: max_bytes,
speed_mb_s,
},
);
last_report = now;
}
}
Ok(hasher
.finalize()
.iter()
.map(|b| format!("{:02x}", b))
.collect())
}
#[cfg(test)]
fn sha256_first_n_bytes(path: &str, max_bytes: u64) -> Result<String, String> {
let (tx, _rx) = mpsc::channel();
sha256_with_progress(path, max_bytes, "image", &tx)
}
#[cfg(target_os = "windows")]
mod windows {
use windows_sys::Win32::{
Foundation::{
CloseHandle, FALSE, GENERIC_READ, GENERIC_WRITE, HANDLE, INVALID_HANDLE_VALUE,
},
Storage::FileSystem::{
CreateFileW, FlushFileBuffers, FILE_FLAG_WRITE_THROUGH, FILE_SHARE_READ,
FILE_SHARE_WRITE, OPEN_EXISTING,
},
System::{
Ioctl::{FSCTL_DISMOUNT_VOLUME, FSCTL_LOCK_VOLUME, IOCTL_DISK_UPDATE_PROPERTIES},
IO::DeviceIoControl,
},
};
fn to_wide(s: &str) -> Vec<u16> {
use std::os::windows::ffi::OsStrExt;
std::ffi::OsStr::new(s)
.encode_wide()
.chain(std::iter::once(0))
.collect()
}
fn open_device_handle(path: &str, access: u32) -> Result<HANDLE, String> {
let wide = to_wide(path);
let handle = unsafe {
CreateFileW(
wide.as_ptr(),
access,
FILE_SHARE_READ | FILE_SHARE_WRITE,
std::ptr::null(),
OPEN_EXISTING,
FILE_FLAG_WRITE_THROUGH,
std::ptr::null_mut(),
)
};
if handle == INVALID_HANDLE_VALUE {
Err(format!(
"Cannot open device '{}': {}",
path,
std::io::Error::last_os_error()
))
} else {
Ok(handle)
}
}
fn device_ioctl(handle: HANDLE, code: u32) -> Result<(), String> {
let mut bytes_returned: u32 = 0;
let ok = unsafe {
DeviceIoControl(
handle,
code,
std::ptr::null(), 0,
std::ptr::null_mut(), 0,
&mut bytes_returned,
std::ptr::null_mut(), )
};
if ok == FALSE {
Err(format!("{}", std::io::Error::last_os_error()))
} else {
Ok(())
}
}
pub fn find_volumes_on_physical_drive(physical_drive: &str) -> Vec<String> {
use windows_sys::Win32::{
Storage::FileSystem::GetLogicalDriveStringsW,
System::Ioctl::{IOCTL_STORAGE_GET_DEVICE_NUMBER, STORAGE_DEVICE_NUMBER},
};
let target_index: u32 = physical_drive
.to_ascii_lowercase()
.trim_start_matches(r"\\.\physicaldrive")
.parse()
.unwrap_or(u32::MAX);
let mut buf = vec![0u16; 512];
let len = unsafe { GetLogicalDriveStringsW(buf.len() as u32, buf.as_mut_ptr()) };
if len == 0 || len > buf.len() as u32 {
return Vec::new();
}
let drive_letters: Vec<String> = buf[..len as usize]
.split(|&c| c == 0)
.filter(|s| !s.is_empty())
.map(|s| {
let letter: String = std::char::from_u32(s[0] as u32)
.map(|c| c.to_string())
.unwrap_or_default();
format!(r"\\.\{}:", letter)
})
.collect();
let mut matching = Vec::new();
for vol_path in &drive_letters {
let wide = to_wide(vol_path);
let handle = unsafe {
CreateFileW(
wide.as_ptr(),
GENERIC_READ,
FILE_SHARE_READ | FILE_SHARE_WRITE,
std::ptr::null(),
OPEN_EXISTING,
0,
std::ptr::null_mut(),
)
};
if handle == INVALID_HANDLE_VALUE {
continue;
}
let mut dev_num = STORAGE_DEVICE_NUMBER {
DeviceType: 0,
DeviceNumber: u32::MAX,
PartitionNumber: 0,
};
let mut bytes_returned: u32 = 0;
let ok = unsafe {
DeviceIoControl(
handle,
IOCTL_STORAGE_GET_DEVICE_NUMBER,
std::ptr::null(),
0,
&mut dev_num as *mut _ as *mut _,
std::mem::size_of::<STORAGE_DEVICE_NUMBER>() as u32,
&mut bytes_returned,
std::ptr::null_mut(),
)
};
unsafe { CloseHandle(handle) };
if ok != FALSE && dev_num.DeviceNumber == target_index {
matching.push(vol_path.clone());
}
}
matching
}
pub fn lock_and_dismount_volume(volume_path: &str) -> Result<(), String> {
let handle = open_device_handle(volume_path, GENERIC_READ | GENERIC_WRITE)?;
let lock_result = device_ioctl(handle, FSCTL_LOCK_VOLUME);
if let Err(ref e) = lock_result {
eprintln!(
"[flash] FSCTL_LOCK_VOLUME on '{volume_path}' failed ({e}); \
attempting dismount anyway"
);
}
let dismount_result = device_ioctl(handle, FSCTL_DISMOUNT_VOLUME);
unsafe { CloseHandle(handle) };
lock_result.and(dismount_result)
}
pub fn flush_device_buffers(device_path: &str) -> Result<(), String> {
let handle = open_device_handle(device_path, GENERIC_WRITE)?;
let ok = unsafe { FlushFileBuffers(handle) };
unsafe { CloseHandle(handle) };
if ok == FALSE {
Err(format!("{}", std::io::Error::last_os_error()))
} else {
Ok(())
}
}
pub fn update_disk_properties(device_path: &str) -> Result<(), String> {
let handle = open_device_handle(device_path, GENERIC_READ | GENERIC_WRITE)?;
let result = device_ioctl(handle, IOCTL_DISK_UPDATE_PROPERTIES);
unsafe { CloseHandle(handle) };
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_wide_null_terminated() {
let wide = to_wide("ABC");
assert_eq!(wide.last(), Some(&0u16), "must be null-terminated");
assert_eq!(&wide[..3], &[b'A' as u16, b'B' as u16, b'C' as u16]);
}
#[test]
fn test_to_wide_empty() {
let wide = to_wide("");
assert_eq!(wide, vec![0u16]);
}
#[test]
fn test_open_device_handle_bad_path_returns_error() {
let result = open_device_handle(r"\\.\NonExistentDevice999", GENERIC_READ);
assert!(result.is_err(), "expected error for nonexistent device");
}
#[test]
fn test_flush_device_buffers_bad_path() {
let result = flush_device_buffers(r"\\.\PhysicalDrive999");
assert!(result.is_err());
}
#[test]
fn test_update_disk_properties_bad_path() {
let result = update_disk_properties(r"\\.\PhysicalDrive999");
assert!(result.is_err());
}
#[test]
fn test_lock_and_dismount_bad_path() {
let result = lock_and_dismount_volume(r"\\.\Z99:");
assert!(result.is_err());
}
#[test]
fn test_find_volumes_bad_path_no_panic() {
let result = find_volumes_on_physical_drive("not-a-valid-path");
let _ = result;
}
#[test]
fn test_find_volumes_nonexistent_drive_returns_empty() {
let result = find_volumes_on_physical_drive(r"\\.\PhysicalDrive999");
assert!(
result.is_empty(),
"expected no volumes for PhysicalDrive999"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::sync::mpsc;
fn make_channel() -> (mpsc::Sender<FlashEvent>, mpsc::Receiver<FlashEvent>) {
mpsc::channel()
}
fn drain(rx: &mpsc::Receiver<FlashEvent>) -> Vec<FlashEvent> {
let mut events = Vec::new();
while let Ok(e) = rx.try_recv() {
events.push(e);
}
events
}
fn has_stage(events: &[FlashEvent], stage: &FlashStage) -> bool {
events
.iter()
.any(|e| matches!(e, FlashEvent::Stage(s) if s == stage))
}
fn find_error(events: &[FlashEvent]) -> Option<&str> {
events.iter().find_map(|e| {
if let FlashEvent::Error(msg) = e {
Some(msg.as_str())
} else {
None
}
})
}
#[test]
fn test_is_privileged_returns_bool() {
let first = is_privileged();
let second = is_privileged();
assert_eq!(first, second, "is_privileged must be deterministic");
}
#[test]
fn test_reexec_as_root_does_not_panic_when_already_escalated() {
std::env::set_var("FLASHKRAFT_ESCALATED", "1");
reexec_as_root(); std::env::remove_var("FLASHKRAFT_ESCALATED");
}
#[test]
fn test_set_real_uid_stores_value() {
set_real_uid(1000);
}
#[test]
#[cfg(not(target_os = "windows"))]
fn test_is_partition_of_sda() {
assert!(is_partition_of("/dev/sda1", "sda"));
assert!(is_partition_of("/dev/sda2", "sda"));
assert!(!is_partition_of("/dev/sdb1", "sda"));
assert!(!is_partition_of("/dev/sda", "sda"));
}
#[test]
#[cfg(not(target_os = "windows"))]
fn test_is_partition_of_nvme() {
assert!(is_partition_of("/dev/nvme0n1p1", "nvme0n1"));
assert!(is_partition_of("/dev/nvme0n1p2", "nvme0n1"));
assert!(!is_partition_of("/dev/nvme0n1", "nvme0n1"));
}
#[test]
#[cfg(not(target_os = "windows"))]
fn test_is_partition_of_mmcblk() {
assert!(is_partition_of("/dev/mmcblk0p1", "mmcblk0"));
assert!(!is_partition_of("/dev/mmcblk0", "mmcblk0"));
}
#[test]
#[cfg(not(target_os = "windows"))]
fn test_is_partition_of_no_false_prefix_match() {
assert!(!is_partition_of("/dev/sda1", "sd"));
}
#[test]
#[cfg(target_os = "linux")]
fn test_reject_partition_node_sda1() {
let dir = std::env::temp_dir();
let img = dir.join("fk_reject_img.bin");
std::fs::write(&img, vec![0u8; 1024]).unwrap();
let result = reject_partition_node("/dev/sda1");
assert!(result.is_err());
assert!(result.unwrap_err().contains("Refusing"));
let _ = std::fs::remove_file(img);
}
#[test]
#[cfg(target_os = "linux")]
fn test_reject_partition_node_nvme() {
let result = reject_partition_node("/dev/nvme0n1p1");
assert!(result.is_err());
assert!(result.unwrap_err().contains("Refusing"));
}
#[test]
#[cfg(target_os = "linux")]
fn test_reject_partition_node_accepts_whole_disk() {
let result = reject_partition_node("/dev/sdb");
assert!(result.is_ok(), "whole-disk node should not be rejected");
}
#[test]
fn test_find_mounted_partitions_parses_proc_mounts_format() {
let result = find_mounted_partitions("sda", "/dev/sda");
let _ = result; }
#[test]
fn test_sha256_full_file() {
use sha2::{Digest, Sha256};
let dir = std::env::temp_dir();
let path = dir.join("fk_sha256_full.bin");
let data: Vec<u8> = (0u8..=255u8).cycle().take(4096).collect();
std::fs::write(&path, &data).unwrap();
let result = sha256_first_n_bytes(path.to_str().unwrap(), data.len() as u64).unwrap();
let expected: String = Sha256::digest(&data)
.iter()
.map(|b| format!("{:02x}", b))
.collect();
assert_eq!(result, expected);
let _ = std::fs::remove_file(path);
}
#[test]
fn test_sha256_partial() {
use sha2::{Digest, Sha256};
let dir = std::env::temp_dir();
let path = dir.join("fk_sha256_partial.bin");
let data: Vec<u8> = (0u8..=255u8).cycle().take(8192).collect();
std::fs::write(&path, &data).unwrap();
let n = 4096u64;
let result = sha256_first_n_bytes(path.to_str().unwrap(), n).unwrap();
let expected: String = Sha256::digest(&data[..n as usize])
.iter()
.map(|b| format!("{:02x}", b))
.collect();
assert_eq!(result, expected);
let _ = std::fs::remove_file(path);
}
#[test]
fn test_sha256_nonexistent_returns_error() {
let result = sha256_first_n_bytes("/nonexistent/path.bin", 1024);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Cannot open"));
}
#[test]
fn test_sha256_empty_read_is_hash_of_empty() {
use sha2::{Digest, Sha256};
let dir = std::env::temp_dir();
let path = dir.join("fk_sha256_empty.bin");
std::fs::write(&path, b"hello world extended data").unwrap();
let result = sha256_first_n_bytes(path.to_str().unwrap(), 0).unwrap();
let expected: String = Sha256::digest(b"")
.iter()
.map(|b| format!("{:02x}", b))
.collect();
assert_eq!(result, expected);
let _ = std::fs::remove_file(path);
}
#[test]
fn test_write_image_to_temp_file() {
let dir = std::env::temp_dir();
let img_path = dir.join("fk_write_img.bin");
let dev_path = dir.join("fk_write_dev.bin");
let image_size: u64 = 2 * 1024 * 1024; {
let mut f = std::fs::File::create(&img_path).unwrap();
let block: Vec<u8> = (0u8..=255u8).cycle().take(BLOCK_SIZE).collect();
let mut rem = image_size;
while rem > 0 {
let n = rem.min(BLOCK_SIZE as u64) as usize;
f.write_all(&block[..n]).unwrap();
rem -= n as u64;
}
}
std::fs::File::create(&dev_path).unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
let result = write_image(
img_path.to_str().unwrap(),
dev_path.to_str().unwrap(),
image_size,
&tx,
&cancel,
);
assert!(result.is_ok(), "write_image failed: {result:?}");
let written = std::fs::read(&dev_path).unwrap();
let original = std::fs::read(&img_path).unwrap();
assert_eq!(written, original, "written data must match image exactly");
let events = drain(&rx);
let has_progress = events
.iter()
.any(|e| matches!(e, FlashEvent::Progress { .. }));
assert!(has_progress, "must emit at least one Progress event");
let _ = std::fs::remove_file(img_path);
let _ = std::fs::remove_file(dev_path);
}
#[test]
fn test_write_image_cancelled_mid_write() {
let dir = std::env::temp_dir();
let img_path = dir.join("fk_cancel_img.bin");
let dev_path = dir.join("fk_cancel_dev.bin");
let image_size: u64 = 8 * 1024 * 1024; {
let mut f = std::fs::File::create(&img_path).unwrap();
let block = vec![0xAAu8; BLOCK_SIZE];
let mut rem = image_size;
while rem > 0 {
let n = rem.min(BLOCK_SIZE as u64) as usize;
f.write_all(&block[..n]).unwrap();
rem -= n as u64;
}
}
std::fs::File::create(&dev_path).unwrap();
let (tx, _rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(true));
let result = write_image(
img_path.to_str().unwrap(),
dev_path.to_str().unwrap(),
image_size,
&tx,
&cancel,
);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("cancelled"),
"error should mention cancellation"
);
let _ = std::fs::remove_file(img_path);
let _ = std::fs::remove_file(dev_path);
}
#[test]
fn test_write_image_missing_image_returns_error() {
let dir = std::env::temp_dir();
let dev_path = dir.join("fk_noimg_dev.bin");
std::fs::File::create(&dev_path).unwrap();
let (tx, _rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
let result = write_image(
"/nonexistent/image.img",
dev_path.to_str().unwrap(),
1024,
&tx,
&cancel,
);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Cannot open image"));
let _ = std::fs::remove_file(dev_path);
}
#[test]
fn test_verify_matching_files() {
let dir = std::env::temp_dir();
let img = dir.join("fk_verify_img.bin");
let dev = dir.join("fk_verify_dev.bin");
let data = vec![0xBBu8; 64 * 1024];
std::fs::write(&img, &data).unwrap();
std::fs::write(&dev, &data).unwrap();
let (tx, _rx) = make_channel();
let result = verify(
img.to_str().unwrap(),
dev.to_str().unwrap(),
data.len() as u64,
&tx,
);
assert!(result.is_ok());
let _ = std::fs::remove_file(img);
let _ = std::fs::remove_file(dev);
}
#[test]
fn test_verify_mismatch_returns_error() {
let dir = std::env::temp_dir();
let img = dir.join("fk_mismatch_img.bin");
let dev = dir.join("fk_mismatch_dev.bin");
std::fs::write(&img, vec![0x00u8; 64 * 1024]).unwrap();
std::fs::write(&dev, vec![0xFFu8; 64 * 1024]).unwrap();
let (tx, _rx) = make_channel();
let result = verify(img.to_str().unwrap(), dev.to_str().unwrap(), 64 * 1024, &tx);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Verification failed"));
let _ = std::fs::remove_file(img);
let _ = std::fs::remove_file(dev);
}
#[test]
fn test_verify_only_checks_image_size_bytes() {
let dir = std::env::temp_dir();
let img = dir.join("fk_trunc_img.bin");
let dev = dir.join("fk_trunc_dev.bin");
let image_data = vec![0xCCu8; 32 * 1024];
let mut device_data = image_data.clone();
device_data.extend_from_slice(&[0xDDu8; 32 * 1024]);
std::fs::write(&img, &image_data).unwrap();
std::fs::write(&dev, &device_data).unwrap();
let (tx, _rx) = make_channel();
let result = verify(
img.to_str().unwrap(),
dev.to_str().unwrap(),
image_data.len() as u64,
&tx,
);
assert!(
result.is_ok(),
"should pass when first N bytes match: {result:?}"
);
let _ = std::fs::remove_file(img);
let _ = std::fs::remove_file(dev);
}
#[test]
fn test_pipeline_rejects_missing_image() {
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline("/nonexistent/image.iso", "/dev/null", tx, cancel);
let events = drain(&rx);
let err = find_error(&events);
assert!(err.is_some(), "must emit an Error event");
assert!(err.unwrap().contains("Image file not found"), "err={err:?}");
}
#[test]
fn test_pipeline_rejects_empty_image() {
let dir = std::env::temp_dir();
let empty = dir.join("fk_empty.img");
std::fs::write(&empty, b"").unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(empty.to_str().unwrap(), "/dev/null", tx, cancel);
let events = drain(&rx);
let err = find_error(&events);
assert!(err.is_some());
assert!(err.unwrap().contains("empty"), "err={err:?}");
let _ = std::fs::remove_file(empty);
}
#[test]
fn test_pipeline_rejects_missing_device() {
let dir = std::env::temp_dir();
let img = dir.join("fk_nodev_img.bin");
std::fs::write(&img, vec![0u8; 1024]).unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(img.to_str().unwrap(), "/nonexistent/device", tx, cancel);
let events = drain(&rx);
let err = find_error(&events);
assert!(err.is_some());
assert!(
err.unwrap().contains("Target device not found"),
"err={err:?}"
);
let _ = std::fs::remove_file(img);
}
#[test]
fn test_pipeline_end_to_end_temp_files() {
let dir = std::env::temp_dir();
let img = dir.join("fk_e2e_img.bin");
let dev = dir.join("fk_e2e_dev.bin");
let image_data: Vec<u8> = (0u8..=255u8).cycle().take(1024 * 1024).collect();
std::fs::write(&img, &image_data).unwrap();
std::fs::File::create(&dev).unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
let events = drain(&rx);
let has_progress = events
.iter()
.any(|e| matches!(e, FlashEvent::Progress { .. }));
assert!(has_progress, "must emit Progress events");
assert!(
has_stage(&events, &FlashStage::Unmounting),
"must emit Unmounting stage"
);
assert!(
has_stage(&events, &FlashStage::Writing),
"must emit Writing stage"
);
assert!(
has_stage(&events, &FlashStage::Syncing),
"must emit Syncing stage"
);
let has_done = events.iter().any(|e| matches!(e, FlashEvent::Done));
let has_error = events.iter().any(|e| matches!(e, FlashEvent::Error(_)));
assert!(
has_done || has_error,
"pipeline must end with Done or Error"
);
if has_done {
let written = std::fs::read(&dev).unwrap();
assert_eq!(written, image_data, "written data must match image");
} else if let Some(err_msg) = find_error(&events) {
assert!(
!err_msg.contains("Cannot open")
&& !err_msg.contains("Verification failed")
&& !err_msg.contains("Write error"),
"unexpected error: {err_msg}"
);
}
let _ = std::fs::remove_file(img);
let _ = std::fs::remove_file(dev);
}
#[test]
fn test_flash_stage_display() {
assert!(FlashStage::Writing.to_string().contains("Writing"));
assert!(FlashStage::Syncing.to_string().contains("Flushing"));
assert!(FlashStage::Done.to_string().contains("complete"));
assert!(FlashStage::Failed("oops".into())
.to_string()
.contains("oops"));
}
#[test]
fn test_flash_stage_eq() {
assert_eq!(FlashStage::Writing, FlashStage::Writing);
assert_ne!(FlashStage::Writing, FlashStage::Syncing);
assert_eq!(
FlashStage::Failed("x".into()),
FlashStage::Failed("x".into())
);
assert_ne!(
FlashStage::Failed("x".into()),
FlashStage::Failed("y".into())
);
}
#[test]
fn test_flash_event_clone() {
let events = vec![
FlashEvent::Stage(FlashStage::Writing),
FlashEvent::Progress {
bytes_written: 1024,
total_bytes: 4096,
speed_mb_s: 12.5,
},
FlashEvent::Log("hello".into()),
FlashEvent::Done,
FlashEvent::Error("boom".into()),
];
for e in &events {
let _ = e.clone(); }
}
#[test]
fn test_find_mounted_partitions_nonexistent_device_returns_empty() {
#[cfg(target_os = "windows")]
let result = find_mounted_partitions("PhysicalDrive999", r"\\.\PhysicalDrive999");
#[cfg(not(target_os = "windows"))]
let result = find_mounted_partitions("sdzzz", "/dev/sdzzz");
let _ = result;
}
#[test]
fn test_find_mounted_partitions_empty_name_no_panic() {
let result = find_mounted_partitions("", "");
let _ = result;
}
#[test]
fn test_is_partition_of_windows_style_paths() {
assert!(!is_partition_of(r"\\.\PhysicalDrive0", "PhysicalDrive0"));
assert!(!is_partition_of(r"\\.\PhysicalDrive1", "PhysicalDrive0"));
}
macro_rules! pipeline_test_events {
($img_name:literal, $dev_name:literal, $data:expr) => {{
let dir = std::env::temp_dir();
let img = dir.join($img_name);
let dev = dir.join($dev_name);
std::fs::write(&img, $data).unwrap();
std::fs::File::create(&dev).unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
let events = drain(&rx);
let _ = std::fs::remove_file(&img);
let _ = std::fs::remove_file(&dev);
events
}};
}
macro_rules! assert_pipeline_emits_stage {
($name:ident, $img:literal, $dev:literal, $stage:expr) => {
#[test]
fn $name() {
let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
let events = pipeline_test_events!($img, $dev, &data);
assert!(
has_stage(&events, &$stage),
"{} stage must be emitted on every platform",
stringify!($stage)
);
}
};
}
assert_pipeline_emits_stage!(
test_pipeline_emits_syncing_stage,
"fk_sync_stage_img.bin",
"fk_sync_stage_dev.bin",
FlashStage::Syncing
);
assert_pipeline_emits_stage!(
test_pipeline_emits_rereading_stage,
"fk_reread_stage_img.bin",
"fk_reread_stage_dev.bin",
FlashStage::Rereading
);
assert_pipeline_emits_stage!(
test_pipeline_emits_verifying_stage,
"fk_verify_stage_img.bin",
"fk_verify_stage_dev.bin",
FlashStage::Verifying
);
#[test]
fn test_open_device_for_writing_nonexistent_mentions_path() {
let bad = if cfg!(target_os = "windows") {
r"\\.\PhysicalDrive999".to_string()
} else {
"/nonexistent/fk_bad_device".to_string()
};
let dir = std::env::temp_dir();
let img = dir.join("fk_open_err_img.bin");
std::fs::write(&img, vec![1u8; 512]).unwrap();
let (tx, _rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
let result = write_image(img.to_str().unwrap(), &bad, 512, &tx, &cancel);
assert!(result.is_err(), "must fail for nonexistent device");
assert!(
result.as_ref().unwrap_err().contains("PhysicalDrive999")
|| result.as_ref().unwrap_err().contains("fk_bad_device")
|| result.as_ref().unwrap_err().contains("Cannot open"),
"error should reference the bad path: {:?}",
result
);
let _ = std::fs::remove_file(&img);
}
#[test]
fn test_sync_device_emits_log() {
let dir = std::env::temp_dir();
let dev = dir.join("fk_sync_log_dev.bin");
std::fs::File::create(&dev).unwrap();
let (tx, rx) = make_channel();
sync_device(dev.to_str().unwrap(), &tx);
let events = drain(&rx);
let has_flush_log = events.iter().any(|e| {
if let FlashEvent::Log(msg) = e {
let lower = msg.to_lowercase();
lower.contains("flush") || lower.contains("cache")
} else {
false
}
});
assert!(
has_flush_log,
"sync_device must emit a flush/cache log event"
);
let _ = std::fs::remove_file(&dev);
}
#[test]
fn test_reread_partition_table_emits_log() {
let dir = std::env::temp_dir();
let dev = dir.join("fk_reread_log_dev.bin");
std::fs::File::create(&dev).unwrap();
let (tx, rx) = make_channel();
reread_partition_table(dev.to_str().unwrap(), &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(
has_log,
"reread_partition_table must emit at least one Log event"
);
let _ = std::fs::remove_file(&dev);
}
#[test]
fn test_unmount_device_no_partitions_emits_log() {
let dir = std::env::temp_dir();
let dev = dir.join("fk_unmount_log_dev.bin");
std::fs::File::create(&dev).unwrap();
let path_str = dev.to_str().unwrap();
let (tx, rx) = make_channel();
unmount_device(path_str, &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(has_log, "unmount_device must emit at least one Log event");
let _ = std::fs::remove_file(&dev);
}
#[test]
fn test_pipeline_stage_ordering() {
let dir = std::env::temp_dir();
let img = dir.join("fk_order_img.bin");
let dev = dir.join("fk_order_dev.bin");
let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
std::fs::write(&img, &data).unwrap();
std::fs::File::create(&dev).unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
let events = drain(&rx);
let stages: Vec<&FlashStage> = events
.iter()
.filter_map(|e| {
if let FlashEvent::Stage(s) = e {
Some(s)
} else {
None
}
})
.collect();
let pos = |target: &FlashStage| {
stages
.iter()
.position(|s| *s == target)
.unwrap_or(usize::MAX)
};
let unmounting = pos(&FlashStage::Unmounting);
let writing = pos(&FlashStage::Writing);
let syncing = pos(&FlashStage::Syncing);
let rereading = pos(&FlashStage::Rereading);
let verifying = pos(&FlashStage::Verifying);
assert!(unmounting < writing, "Unmounting must precede Writing");
assert!(writing < syncing, "Writing must precede Syncing");
assert!(syncing < rereading, "Syncing must precede Rereading");
assert!(rereading < verifying, "Rereading must precede Verifying");
let _ = std::fs::remove_file(&img);
let _ = std::fs::remove_file(&dev);
}
#[test]
#[cfg(target_os = "linux")]
fn test_find_mounted_partitions_linux_no_panic() {
let result = find_mounted_partitions("sda", "/dev/sda");
let _ = result;
}
#[test]
#[cfg(target_os = "linux")]
fn test_find_mounted_partitions_linux_reads_proc_mounts() {
let content = std::fs::read_to_string("/proc/mounts").unwrap_or_default();
if !content.is_empty() {
if let Some(line) = content.lines().find(|l| l.starts_with("/dev/")) {
if let Some(dev) = line.split_whitespace().next() {
let name = std::path::Path::new(dev)
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
let _ = find_mounted_partitions(&name, dev);
}
}
}
}
#[test]
#[cfg(target_os = "linux")]
fn test_do_unmount_not_mounted_does_not_panic() {
let (tx, rx) = make_channel();
do_unmount("/dev/fk_nonexistent_part", &tx);
let events = drain(&rx);
let has_warning = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(
!has_warning,
"do_unmount must not emit a warning for EINVAL/ENOENT: {events:?}"
);
}
#[test]
#[cfg(target_os = "macos")]
fn test_do_unmount_macos_bad_path_emits_warning() {
let (tx, rx) = make_channel();
do_unmount("/dev/fk_nonexistent_part", &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(has_log, "do_unmount must emit a Log event on failure");
}
#[test]
#[cfg(target_os = "macos")]
fn test_find_mounted_partitions_macos_no_panic() {
let result = find_mounted_partitions("disk2", "/dev/disk2");
let _ = result;
}
#[test]
#[cfg(target_os = "macos")]
fn test_reread_partition_table_macos_emits_log() {
let dir = std::env::temp_dir();
let dev = dir.join("fk_macos_reread_dev.bin");
std::fs::File::create(&dev).unwrap();
let (tx, rx) = make_channel();
reread_partition_table(dev.to_str().unwrap(), &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(has_log, "reread_partition_table must emit a log on macOS");
let _ = std::fs::remove_file(&dev);
}
#[test]
#[cfg(target_os = "windows")]
fn test_find_mounted_partitions_windows_nonexistent() {
let result = find_mounted_partitions("PhysicalDrive999", r"\\.\PhysicalDrive999");
assert!(
result.is_empty(),
"nonexistent physical drive should have no volumes"
);
}
#[test]
#[cfg(target_os = "windows")]
fn test_do_unmount_windows_bad_volume_emits_log() {
let (tx, rx) = make_channel();
do_unmount(r"\\.\Z99:", &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(has_log, "do_unmount on bad volume must emit a Log event");
}
#[test]
#[cfg(target_os = "windows")]
fn test_sync_device_windows_bad_path_no_panic() {
let (tx, rx) = make_channel();
sync_device(r"\\.\PhysicalDrive999", &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(has_log, "sync_device must emit a Log event on Windows");
}
#[test]
#[cfg(target_os = "windows")]
fn test_reread_partition_table_windows_bad_path_no_panic() {
let (tx, rx) = make_channel();
reread_partition_table(r"\\.\PhysicalDrive999", &tx);
let events = drain(&rx);
let has_log = events.iter().any(|e| matches!(e, FlashEvent::Log(_)));
assert!(
has_log,
"reread_partition_table must emit a Log event on Windows"
);
}
#[test]
#[cfg(target_os = "windows")]
fn test_open_device_for_writing_windows_access_denied_message() {
let dir = std::env::temp_dir();
let img = dir.join("fk_win_open_img.bin");
std::fs::write(&img, vec![1u8; 512]).unwrap();
let (tx, _rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
let result = write_image(
img.to_str().unwrap(),
r"\\.\PhysicalDrive999",
512,
&tx,
&cancel,
);
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(
msg.contains("PhysicalDrive999")
|| msg.contains("Access denied")
|| msg.contains("Cannot open"),
"error must be descriptive: {msg}"
);
let _ = std::fs::remove_file(&img);
}
#[test]
fn flash_stage_progress_floor_syncing() {
assert!((FlashStage::Syncing.progress_floor() - 0.80).abs() < f32::EPSILON);
}
#[test]
fn flash_stage_progress_floor_rereading() {
assert!((FlashStage::Rereading.progress_floor() - 0.88).abs() < f32::EPSILON);
}
#[test]
fn flash_stage_progress_floor_verifying() {
assert!((FlashStage::Verifying.progress_floor() - 0.92).abs() < f32::EPSILON);
}
#[test]
fn flash_stage_progress_floor_other_stages_are_zero() {
for stage in [
FlashStage::Starting,
FlashStage::Unmounting,
FlashStage::Writing,
FlashStage::Done,
] {
assert_eq!(
stage.progress_floor(),
0.0,
"{stage:?} should have floor 0.0"
);
}
}
#[test]
fn verify_overall_image_phase_start() {
assert_eq!(verify_overall_progress("image", 0.0), 0.0);
}
#[test]
fn verify_overall_image_phase_end() {
assert!((verify_overall_progress("image", 1.0) - 0.5).abs() < f32::EPSILON);
}
#[test]
fn verify_overall_image_phase_midpoint() {
assert!((verify_overall_progress("image", 0.5) - 0.25).abs() < f32::EPSILON);
}
#[test]
fn verify_overall_device_phase_start() {
assert!((verify_overall_progress("device", 0.0) - 0.5).abs() < f32::EPSILON);
}
#[test]
fn verify_overall_device_phase_end() {
assert!((verify_overall_progress("device", 1.0) - 1.0).abs() < f32::EPSILON);
}
#[test]
fn verify_overall_device_phase_midpoint() {
assert!((verify_overall_progress("device", 0.5) - 0.75).abs() < f32::EPSILON);
}
#[test]
fn verify_overall_unknown_phase_treated_as_device() {
assert!((verify_overall_progress("other", 0.0) - 0.5).abs() < f32::EPSILON);
}
#[test]
#[cfg(target_os = "linux")]
fn check_device_not_busy_ebusy_returns_error() {
let err = check_device_not_busy_with("/dev/sdz", |_| {
Err(std::io::Error::from_raw_os_error(libc::EBUSY))
});
assert!(err.is_err(), "EBUSY must be reported as an error");
let msg = err.unwrap_err();
assert!(
msg.contains("already in use"),
"error must mention 'already in use': {msg}"
);
assert!(
msg.contains("/dev/sdz"),
"error must include the device path: {msg}"
);
assert!(
msg.contains("another flash operation"),
"error must hint at another flash operation: {msg}"
);
}
macro_rules! busy_check_test {
($name:ident, errno: $errno:expr, expect_ok: $ok:expr) => {
#[test]
#[cfg(target_os = "linux")]
fn $name() {
let result = check_device_not_busy_with("/dev/sdz", |_| {
Err(std::io::Error::from_raw_os_error($errno))
});
assert_eq!(
result.is_ok(),
$ok,
"errno {} — expected is_ok()={}, got {:?}",
$errno,
$ok,
result
);
}
};
}
busy_check_test!(check_device_not_busy_eperm_is_ignored, errno: libc::EPERM, expect_ok: true);
busy_check_test!(check_device_not_busy_eacces_is_ignored, errno: libc::EACCES, expect_ok: true);
#[test]
#[cfg(target_os = "linux")]
fn check_device_not_busy_success_returns_ok() {
let result = check_device_not_busy_with("/dev/sdz", |_| Ok(()));
assert!(result.is_ok(), "successful open must return Ok");
}
#[test]
#[cfg(target_os = "linux")]
fn check_device_not_busy_regular_file_never_ebusy() {
let f = tempfile::NamedTempFile::new().expect("tempfile");
let result = check_device_not_busy(f.path().to_str().unwrap());
assert!(
result.is_ok(),
"regular file must never trigger the EBUSY guard: {result:?}"
);
}
macro_rules! pipeline_stage_order_test {
($name:ident, $os:meta) => {
#[$os]
#[test]
fn $name() {
let dir = tempfile::tempdir().expect("tempdir");
let img = dir.path().join("img.bin");
let dev = dir.path().join("dev.bin");
let data: Vec<u8> = (0u8..=255).cycle().take(256 * 1024).collect();
if let Err(e) = std::fs::write(&img, &data) {
if e.kind() == std::io::ErrorKind::PermissionDenied {
eprintln!("Skipping: cannot write temp files (PermissionDenied): {e}");
return;
}
panic!("unexpected error writing temp image: {e}");
}
if let Err(e) = std::fs::File::create(&dev) {
if e.kind() == std::io::ErrorKind::PermissionDenied {
eprintln!("Skipping: cannot create temp files (PermissionDenied): {e}");
return;
}
panic!("unexpected error creating temp device: {e}");
}
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(img.to_str().unwrap(), dev.to_str().unwrap(), tx, cancel);
let events = drain(&rx);
if let Some(msg) = find_error(&events) {
if msg.contains("Permission denied") || msg.contains("permission denied") {
eprintln!("Skipping: pipeline returned permission error (non-root): {msg}");
return;
}
assert!(
!msg.contains("already in use"),
"must not emit a false-positive busy error: {msg}"
);
}
let stages: Vec<&FlashStage> = events
.iter()
.filter_map(|e| {
if let FlashEvent::Stage(s) = e {
Some(s)
} else {
None
}
})
.collect();
let pos_unmounting = stages.iter().position(|s| **s == FlashStage::Unmounting);
let pos_writing = stages.iter().position(|s| **s == FlashStage::Writing);
assert!(
pos_unmounting.is_some(),
"pipeline must emit Unmounting stage"
);
assert!(pos_writing.is_some(), "pipeline must emit Writing stage");
assert!(
pos_unmounting.unwrap() < pos_writing.unwrap(),
"Unmounting must precede Writing"
);
}
};
}
pipeline_stage_order_test!(
pipeline_unmounting_precedes_busy_check_in_stage_stream,
cfg(target_os = "linux")
);
pipeline_stage_order_test!(
pipeline_unmounting_precedes_writing_macos,
cfg(target_os = "macos")
);
pipeline_stage_order_test!(
pipeline_unmounting_precedes_writing_windows,
cfg(target_os = "windows")
);
#[test]
#[cfg(target_os = "windows")]
fn open_device_for_writing_sharing_violation_message() {
let dir = tempfile::tempdir().expect("tempdir");
let img = dir.path().join("img.bin");
let nonexistent_dev = dir.path().join("no_such_device");
let data: Vec<u8> = vec![0u8; 512];
std::fs::write(&img, &data).unwrap();
let (tx, rx) = make_channel();
let cancel = Arc::new(AtomicBool::new(false));
run_pipeline(
img.to_str().unwrap(),
nonexistent_dev.to_str().unwrap(),
tx,
cancel,
);
let events = drain(&rx);
let has_error = events.iter().any(|e| matches!(e, FlashEvent::Error(_)));
assert!(has_error, "pipeline must fail for a non-existent device");
if let Some(msg) = find_error(&events) {
assert!(
!msg.contains("already in use"),
"non-existent device must not emit a spurious 'already in use' message: {msg}"
);
}
}
}