use super::{AtaRequest, AtaResponse};
use crate::StatusExt;
use crate::mem::{AlignedBuffer, PoolAllocation};
use crate::proto::device_path::PoolDevicePathNode;
use core::alloc::LayoutError;
use core::cell::UnsafeCell;
use core::ptr::{self, NonNull};
use uefi_macros::unsafe_protocol;
use uefi_raw::Status;
use uefi_raw::protocol::ata::AtaPassThruProtocol;
use uefi_raw::protocol::device_path::DevicePathProtocol;
pub type AtaPassThruMode = uefi_raw::protocol::ata::AtaPassThruMode;
#[derive(Debug)]
#[repr(transparent)]
#[unsafe_protocol(AtaPassThruProtocol::GUID)]
pub struct AtaPassThru(UnsafeCell<AtaPassThruProtocol>);
impl AtaPassThru {
#[must_use]
pub fn mode(&self) -> AtaPassThruMode {
let mut mode = unsafe { (*(*self.0.get()).mode).clone() };
mode.io_align = mode.io_align.max(1); mode
}
#[must_use]
pub fn io_align(&self) -> u32 {
self.mode().io_align
}
pub fn alloc_io_buffer(&self, len: usize) -> Result<AlignedBuffer, LayoutError> {
AlignedBuffer::from_size_align(len, self.io_align() as usize)
}
#[must_use]
pub const fn iter_devices(&self) -> AtaDeviceIterator<'_> {
AtaDeviceIterator {
proto: &self.0,
end_of_port: true,
prev_port: 0xFFFF,
prev_pmp: 0xFFFF,
}
}
}
#[derive(Debug)]
pub struct AtaDevice<'a> {
proto: &'a UnsafeCell<AtaPassThruProtocol>,
port: u16,
pmp: u16,
}
impl AtaDevice<'_> {
#[must_use]
pub const fn port(&self) -> u16 {
self.port
}
#[must_use]
pub const fn port_multiplier_port(&self) -> u16 {
self.pmp
}
pub fn reset(&mut self) -> crate::Result<()> {
unsafe {
((*self.proto.get()).reset_device)(self.proto.get(), self.port, self.pmp).to_result()
}
}
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
unsafe {
let mut path_ptr: *const DevicePathProtocol = ptr::null();
((*self.proto.get()).build_device_path)(
self.proto.get(),
self.port,
self.pmp,
&mut path_ptr,
)
.to_result()?;
NonNull::new(path_ptr.cast_mut())
.map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast())))
.ok_or_else(|| Status::OUT_OF_RESOURCES.into())
}
}
#[allow(clippy::result_large_err)]
pub fn execute_command<'req>(
&mut self,
mut req: AtaRequest<'req>,
) -> crate::Result<AtaResponse<'req>, AtaResponse<'req>> {
req.packet.acb = &req.acb;
let result = unsafe {
((*self.proto.get()).pass_thru)(
self.proto.get(),
self.port,
self.pmp,
&mut req.packet,
ptr::null_mut(),
)
.to_result()
};
match result {
Ok(_) => Ok(AtaResponse { req }),
Err(s) => Err(crate::Error::new(s.status(), AtaResponse { req })),
}
}
}
#[derive(Debug)]
pub struct AtaDeviceIterator<'a> {
proto: &'a UnsafeCell<AtaPassThruProtocol>,
end_of_port: bool,
prev_port: u16,
prev_pmp: u16,
}
impl<'a> Iterator for AtaDeviceIterator<'a> {
type Item = AtaDevice<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.end_of_port {
let result = unsafe {
((*self.proto.get()).get_next_port)(self.proto.get(), &mut self.prev_port)
};
match result {
Status::SUCCESS => self.end_of_port = false,
Status::NOT_FOUND => return None, _ => panic!("Must not happen according to spec!"),
}
}
let was_first = self.prev_pmp == 0xFFFF;
let result = unsafe {
((*self.proto.get()).get_next_device)(
self.proto.get(),
self.prev_port,
&mut self.prev_pmp,
)
};
match result {
Status::SUCCESS => {
if self.prev_pmp == 0xFFFF {
self.end_of_port = true;
}
return Some(AtaDevice {
proto: self.proto,
port: self.prev_port,
pmp: self.prev_pmp,
});
}
Status::NOT_FOUND => {
self.end_of_port = true;
self.prev_pmp = 0xFFFF;
if was_first {
return Some(AtaDevice {
proto: self.proto,
port: self.prev_port,
pmp: 0xFFFF,
});
}
}
_ => panic!("Must not happen according to spec!"),
}
}
}
}