#![cfg(target_os = "windows")]
#![allow(dead_code)]
use crate::{Error, Result};
use std::os::windows::ffi::OsStrExt;
use std::path::{Path, PathBuf};
use windows_sys::Win32::Foundation::{
CloseHandle, GENERIC_READ, GENERIC_WRITE, INVALID_HANDLE_VALUE,
};
use windows_sys::Win32::Storage::FileSystem::{
CreateFileW, FILE_ATTRIBUTE_NORMAL, FILE_SHARE_READ, FILE_SHARE_WRITE, OPEN_EXISTING,
};
use windows_sys::Win32::System::Ioctl::{
ProtocolTypeNvme, IOCTL_STORAGE_PROTOCOL_COMMAND, STORAGE_PROTOCOL_COMMAND,
STORAGE_PROTOCOL_STRUCTURE_VERSION,
};
use windows_sys::Win32::System::IO::DeviceIoControl;
pub(crate) struct NvmeAccess {
pub(crate) volume_root: PathBuf,
}
pub(crate) fn nvme_flush_capable(path: &Path) -> Option<NvmeAccess> {
if std::env::var_os("FSYS_DISABLE_NVME_PASSTHROUGH").is_some() {
return None;
}
let volume_root = volume_root_for(path)?;
let handle = open_volume(&volume_root)?;
let probe_ok = issue_identify_controller(handle).is_ok();
let _ = unsafe { CloseHandle(handle) };
if probe_ok {
Some(NvmeAccess { volume_root })
} else {
None
}
}
pub(crate) fn nvme_flush(access: &NvmeAccess) -> Result<()> {
let handle = open_volume(&access.volume_root).ok_or_else(|| {
Error::Io(std::io::Error::other(
"failed to reopen volume for NVMe flush",
))
})?;
let result = issue_flush_command(handle);
let _ = unsafe { CloseHandle(handle) };
result
}
type WinHandle = windows_sys::Win32::Foundation::HANDLE;
fn volume_root_for(path: &Path) -> Option<PathBuf> {
let canonical = std::fs::canonicalize(path).ok()?;
let s = canonical.to_str()?;
let trimmed = s.strip_prefix(r"\\?\").unwrap_or(s);
let drive = trimmed.split('\\').next()?;
if drive.len() != 2 || !drive.ends_with(':') {
return None;
}
Some(PathBuf::from(format!(r"\\.\{drive}")))
}
fn open_volume(volume_root: &Path) -> Option<WinHandle> {
let wide: Vec<u16> = volume_root
.as_os_str()
.encode_wide()
.chain(std::iter::once(0))
.collect();
let handle = unsafe {
CreateFileW(
wide.as_ptr(),
GENERIC_READ | GENERIC_WRITE,
FILE_SHARE_READ | FILE_SHARE_WRITE,
std::ptr::null(),
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
std::ptr::null_mut(),
)
};
if handle == INVALID_HANDLE_VALUE {
None
} else {
Some(handle)
}
}
fn issue_identify_controller(handle: WinHandle) -> Result<()> {
const NVME_COMMAND_LENGTH: u32 = 64;
const IDENTIFY_DATA_LEN: u32 = 4096;
let total_len = (std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>() as u32)
+ NVME_COMMAND_LENGTH
+ IDENTIFY_DATA_LEN;
let mut buf: Vec<u8> = vec![0u8; total_len as usize];
unsafe {
let header = buf.as_mut_ptr() as *mut STORAGE_PROTOCOL_COMMAND;
(*header).Version = STORAGE_PROTOCOL_STRUCTURE_VERSION;
(*header).Length = std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>() as u32;
(*header).ProtocolType = ProtocolTypeNvme;
(*header).Flags = 0;
(*header).CommandLength = NVME_COMMAND_LENGTH;
(*header).ErrorInfoLength = 0;
(*header).DataToDeviceTransferLength = 0;
(*header).DataFromDeviceTransferLength = IDENTIFY_DATA_LEN;
(*header).TimeOutValue = 30;
(*header).ErrorInfoOffset = 0;
(*header).DataToDeviceBufferOffset = 0;
(*header).DataFromDeviceBufferOffset =
(std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>() as u32) + NVME_COMMAND_LENGTH;
let cmd_offset = std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>();
let cmd_ptr = buf.as_mut_ptr().add(cmd_offset);
*cmd_ptr = 0x06;
*(cmd_ptr.add(40) as *mut u32) = 1;
}
issue_protocol_command(handle, &mut buf)
}
fn issue_flush_command(handle: WinHandle) -> Result<()> {
const NVME_COMMAND_LENGTH: u32 = 64;
let total_len = (std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>() as u32) + NVME_COMMAND_LENGTH;
let mut buf: Vec<u8> = vec![0u8; total_len as usize];
unsafe {
let header = buf.as_mut_ptr() as *mut STORAGE_PROTOCOL_COMMAND;
(*header).Version = STORAGE_PROTOCOL_STRUCTURE_VERSION;
(*header).Length = std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>() as u32;
(*header).ProtocolType = ProtocolTypeNvme;
(*header).Flags = 0;
(*header).CommandLength = NVME_COMMAND_LENGTH;
(*header).ErrorInfoLength = 0;
(*header).DataToDeviceTransferLength = 0;
(*header).DataFromDeviceTransferLength = 0;
(*header).TimeOutValue = 30;
(*header).ErrorInfoOffset = 0;
(*header).DataToDeviceBufferOffset = 0;
(*header).DataFromDeviceBufferOffset = 0;
let cmd_offset = std::mem::size_of::<STORAGE_PROTOCOL_COMMAND>();
*buf.as_mut_ptr().add(cmd_offset) = 0x00;
}
issue_protocol_command(handle, &mut buf)
}
fn issue_protocol_command(handle: WinHandle, buf: &mut [u8]) -> Result<()> {
let mut bytes_returned: u32 = 0;
let len = buf.len() as u32;
let ok = unsafe {
DeviceIoControl(
handle,
IOCTL_STORAGE_PROTOCOL_COMMAND,
buf.as_mut_ptr().cast(),
len,
buf.as_mut_ptr().cast(),
len,
&mut bytes_returned,
std::ptr::null_mut(),
)
};
if ok != 0 {
Ok(())
} else {
Err(Error::Io(std::io::Error::last_os_error()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn volume_root_for_local_path_returns_drive_form() {
let p = std::env::temp_dir();
if let Some(root) = volume_root_for(&p) {
let s = root.to_string_lossy();
assert!(
s.starts_with(r"\\.\") && s.ends_with(':'),
"expected device-namespace volume root, got {s}"
);
}
}
#[test]
fn capability_probe_returns_some_or_none_without_panic() {
let p = std::env::temp_dir();
let _ = nvme_flush_capable(&p);
}
#[test]
fn env_override_forces_none() {
let prior = std::env::var_os("FSYS_DISABLE_NVME_PASSTHROUGH");
unsafe {
std::env::set_var("FSYS_DISABLE_NVME_PASSTHROUGH", "1");
}
let p = std::env::temp_dir();
let result = nvme_flush_capable(&p);
assert!(result.is_none(), "env override must force None");
unsafe {
match prior {
Some(v) => std::env::set_var("FSYS_DISABLE_NVME_PASSTHROUGH", v),
None => std::env::remove_var("FSYS_DISABLE_NVME_PASSTHROUGH"),
}
}
}
}