use super::{NvmeRequest, NvmeResponse};
use crate::StatusExt;
use crate::mem::{AlignedBuffer, PoolAllocation};
use crate::proto::device_path::PoolDevicePathNode;
use core::alloc::LayoutError;
use core::cell::UnsafeCell;
use core::ptr::{self, NonNull};
use uefi_macros::unsafe_protocol;
use uefi_raw::Status;
use uefi_raw::protocol::device_path::DevicePathProtocol;
use uefi_raw::protocol::nvme::{NvmExpressCompletion, NvmExpressPassThruProtocol};
pub type NvmePassThruMode = uefi_raw::protocol::nvme::NvmExpressPassThruMode;
pub type NvmeNamespaceId = u32;
#[derive(Debug)]
#[repr(transparent)]
#[unsafe_protocol(NvmExpressPassThruProtocol::GUID)]
pub struct NvmePassThru(UnsafeCell<NvmExpressPassThruProtocol>);
impl NvmePassThru {
#[must_use]
pub fn mode(&self) -> NvmePassThruMode {
let mut mode = unsafe { (*(*self.0.get()).mode).clone() };
mode.io_align = mode.io_align.max(1); mode
}
#[must_use]
pub fn io_align(&self) -> u32 {
self.mode().io_align
}
pub fn alloc_io_buffer(&self, len: usize) -> Result<AlignedBuffer, LayoutError> {
AlignedBuffer::from_size_align(len, self.io_align() as usize)
}
#[must_use]
pub const fn iter_namespaces(&self) -> NvmeNamespaceIterator<'_> {
NvmeNamespaceIterator {
proto: &self.0,
prev: 0xFFFFFFFF,
}
}
#[must_use]
pub const fn controller(&self) -> NvmeNamespace<'_> {
NvmeNamespace {
proto: &self.0,
namespace_id: 0,
}
}
#[must_use]
pub const fn broadcast(&self) -> NvmeNamespace<'_> {
NvmeNamespace {
proto: &self.0,
namespace_id: 0xffffffff,
}
}
}
#[derive(Debug)]
pub struct NvmeNamespace<'a> {
proto: &'a UnsafeCell<NvmExpressPassThruProtocol>,
namespace_id: NvmeNamespaceId,
}
impl NvmeNamespace<'_> {
#[must_use]
pub const fn namespace_id(&self) -> NvmeNamespaceId {
self.namespace_id
}
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
unsafe {
let mut path_ptr: *const DevicePathProtocol = ptr::null();
((*self.proto.get()).build_device_path)(
self.proto.get(),
self.namespace_id,
&mut path_ptr,
)
.to_result()?;
NonNull::new(path_ptr.cast_mut())
.map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast())))
.ok_or_else(|| Status::OUT_OF_RESOURCES.into())
}
}
pub fn execute_command<'req>(
&mut self,
mut req: NvmeRequest<'req>,
) -> crate::Result<NvmeResponse<'req>> {
let mut completion = NvmExpressCompletion::default();
req.cmd.nsid = self.namespace_id;
req.packet.nvme_cmd = &req.cmd;
req.packet.nvme_completion = &mut completion;
unsafe {
((*self.proto.get()).pass_thru)(
self.proto.get(),
self.namespace_id,
&mut req.packet,
ptr::null_mut(),
)
.to_result_with_val(|| NvmeResponse { req, completion })
}
}
}
#[derive(Debug)]
pub struct NvmeNamespaceIterator<'a> {
proto: &'a UnsafeCell<NvmExpressPassThruProtocol>,
prev: NvmeNamespaceId,
}
impl<'a> Iterator for NvmeNamespaceIterator<'a> {
type Item = NvmeNamespace<'a>;
fn next(&mut self) -> Option<Self::Item> {
let result =
unsafe { ((*self.proto.get()).get_next_namespace)(self.proto.get(), &mut self.prev) };
match result {
Status::SUCCESS => Some(NvmeNamespace {
proto: self.proto,
namespace_id: self.prev,
}),
Status::NOT_FOUND => None,
_ => panic!("Must not happen according to spec!"),
}
}
}