use super::{ScsiRequest, ScsiResponse};
use crate::StatusExt;
use crate::mem::{AlignedBuffer, PoolAllocation};
use crate::proto::device_path::PoolDevicePathNode;
use crate::proto::unsafe_protocol;
use core::alloc::LayoutError;
use core::cell::UnsafeCell;
use core::ptr::{self, NonNull};
use uefi_raw::Status;
use uefi_raw::protocol::device_path::DevicePathProtocol;
use uefi_raw::protocol::scsi::{
ExtScsiPassThruMode, ExtScsiPassThruProtocol, SCSI_TARGET_MAX_BYTES,
};
pub type ScsiTarget = [u8; SCSI_TARGET_MAX_BYTES];
#[derive(Clone, Debug)]
pub struct ScsiTargetLun(ScsiTarget, u64);
impl Default for ScsiTargetLun {
fn default() -> Self {
Self([0xFF; SCSI_TARGET_MAX_BYTES], 0)
}
}
#[derive(Debug)]
#[repr(transparent)]
#[unsafe_protocol(ExtScsiPassThruProtocol::GUID)]
pub struct ExtScsiPassThru(UnsafeCell<ExtScsiPassThruProtocol>);
impl ExtScsiPassThru {
#[must_use]
pub fn mode(&self) -> ExtScsiPassThruMode {
let mut mode = unsafe { (*(*self.0.get()).passthru_mode).clone() };
mode.io_align = mode.io_align.max(1); mode
}
#[must_use]
pub fn io_align(&self) -> u32 {
self.mode().io_align
}
pub fn alloc_io_buffer(&self, len: usize) -> Result<AlignedBuffer, LayoutError> {
AlignedBuffer::from_size_align(len, self.io_align() as usize)
}
#[must_use]
pub fn iter_devices(&self) -> ScsiTargetLunIterator<'_> {
ScsiTargetLunIterator {
proto: &self.0,
prev: ScsiTargetLun::default(),
}
}
pub fn reset_channel(&mut self) -> crate::Result<()> {
unsafe { ((*self.0.get()).reset_channel)(self.0.get()).to_result() }
}
}
#[derive(Clone, Debug)]
pub struct ScsiDevice<'a> {
proto: &'a UnsafeCell<ExtScsiPassThruProtocol>,
target_lun: ScsiTargetLun,
}
impl ScsiDevice<'_> {
#[must_use]
pub const fn target(&self) -> &ScsiTarget {
&self.target_lun.0
}
#[must_use]
pub const fn lun(&self) -> u64 {
self.target_lun.1
}
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
unsafe {
let mut path_ptr: *const DevicePathProtocol = ptr::null();
((*self.proto.get()).build_device_path)(
self.proto.get(),
self.target().as_ptr(),
self.lun(),
&mut path_ptr,
)
.to_result()?;
NonNull::new(path_ptr.cast_mut())
.map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast())))
.ok_or_else(|| Status::OUT_OF_RESOURCES.into())
}
}
pub fn reset(&mut self) -> crate::Result<()> {
unsafe {
((*self.proto.get()).reset_target_lun)(
self.proto.get(),
self.target_lun.0.as_ptr(),
self.lun(),
)
.to_result()
}
}
pub fn execute_command<'req>(
&mut self,
mut scsi_req: ScsiRequest<'req>,
) -> crate::Result<ScsiResponse<'req>> {
unsafe {
((*self.proto.get()).pass_thru)(
self.proto.get(),
self.target_lun.0.as_ptr(),
self.target_lun.1,
&mut scsi_req.packet,
ptr::null_mut(),
)
.to_result_with_val(|| ScsiResponse(scsi_req))
}
}
}
#[derive(Debug)]
pub struct ScsiTargetLunIterator<'a> {
proto: &'a UnsafeCell<ExtScsiPassThruProtocol>,
prev: ScsiTargetLun,
}
impl<'a> Iterator for ScsiTargetLunIterator<'a> {
type Item = ScsiDevice<'a>;
fn next(&mut self) -> Option<Self::Item> {
let mut target: *mut u8 = self.prev.0.as_mut_ptr();
let result = unsafe {
((*self.proto.get()).get_next_target_lun)(
self.proto.get(),
&mut target,
&mut self.prev.1,
)
};
if target != self.prev.0.as_mut_ptr() {
unsafe {
target.copy_to(self.prev.0.as_mut_ptr(), SCSI_TARGET_MAX_BYTES);
}
}
let scsi_device = ScsiDevice {
proto: self.proto,
target_lun: self.prev.clone(),
};
match result {
Status::SUCCESS => Some(scsi_device),
Status::NOT_FOUND => None,
_ => panic!("Must not happen according to spec!"),
}
}
}