ax-driver 0.5.15

ArceOS device drivers
Documentation
use core::marker::PhantomData;
#[cfg(virtio_dev)]
use core::ptr::NonNull;

#[cfg(virtio_dev)]
use ax_alloc::{UsageKind, global_allocator};
use ax_driver_base::{BaseDriverOps, DevResult, DeviceType};
#[cfg(virtio_dev)]
use ax_driver_virtio::{BufferDirection, PhysAddr, VirtIoHal};
#[cfg(virtio_dev)]
use ax_hal::mem::{phys_to_virt, virt_to_phys};
use cfg_if::cfg_if;

use crate::{AxDeviceEnum, drivers::DriverProbe};

cfg_if! {
    if #[cfg(bus = "pci")] {
        #[cfg(feature = "bus-pci")]
        use ax_driver_virtio::pci::{ConfigurationAccess, DeviceFunction, DeviceFunctionInfo, PciRoot};
        type VirtIoTransport = ax_driver_virtio::PciTransport;
    } else if #[cfg(bus =  "mmio")] {
        type VirtIoTransport = ax_driver_virtio::MmioTransport;
    } else {
        type VirtIoTransport = ax_driver_virtio::MmioTransport;
    }
}

/// A trait for VirtIO device meta information.
pub trait VirtIoDevMeta {
    const DEVICE_TYPE: DeviceType;

    type Device: BaseDriverOps;
    type Driver = VirtIoDriver<Self>;

    fn try_new(transport: VirtIoTransport, irq: Option<usize>) -> DevResult<AxDeviceEnum>;
}

cfg_if! {
    if #[cfg(net_dev = "virtio-net")] {
        #[allow(dead_code)]
        pub struct VirtIoNet;

        impl VirtIoDevMeta for VirtIoNet {
            const DEVICE_TYPE: DeviceType = DeviceType::Net;
            type Device = ax_driver_virtio::VirtIoNetDev<VirtIoHalImpl, VirtIoTransport, 64>;

            fn try_new(transport: VirtIoTransport, irq: Option<usize>) -> DevResult<AxDeviceEnum> {
                Ok(AxDeviceEnum::from_net(Self::Device::try_new(transport, irq)?))
            }
        }
    }
}

cfg_if! {
    if #[cfg(block_dev = "virtio-blk")] {
        #[allow(dead_code)]
        pub struct VirtIoBlk;

        impl VirtIoDevMeta for VirtIoBlk {
            const DEVICE_TYPE: DeviceType = DeviceType::Block;
            type Device = ax_driver_virtio::VirtIoBlkDev<VirtIoHalImpl, VirtIoTransport>;

            fn try_new(transport: VirtIoTransport, _irq: Option<usize>) -> DevResult<AxDeviceEnum> {
                Ok(AxDeviceEnum::from_block(Self::Device::try_new(transport)?))
            }
        }
    }
}

cfg_if! {
    if #[cfg(display_dev = "virtio-gpu")] {
        pub struct VirtIoGpu;

        impl VirtIoDevMeta for VirtIoGpu {
            const DEVICE_TYPE: DeviceType = DeviceType::Display;
            type Device = ax_driver_virtio::VirtIoGpuDev<VirtIoHalImpl, VirtIoTransport>;

            fn try_new(transport: VirtIoTransport, _irq: Option<usize>) -> DevResult<AxDeviceEnum> {
                Ok(AxDeviceEnum::from_display(Self::Device::try_new(transport)?))
            }
        }
    }
}

cfg_if! {
    if #[cfg(input_dev = "virtio-input")] {
        pub struct VirtIoInput;

        impl VirtIoDevMeta for VirtIoInput {
            const DEVICE_TYPE: DeviceType = DeviceType::Input;
            type Device = ax_driver_virtio::VirtIoInputDev<VirtIoHalImpl, VirtIoTransport>;

            fn try_new(transport: VirtIoTransport, irq: Option<usize>) -> DevResult<AxDeviceEnum> {
                Ok(AxDeviceEnum::from_input(Self::Device::try_new(transport, irq)?))
            }
        }
    }
}

cfg_if! {
    if #[cfg(vsock_dev = "virtio-socket")] {
        pub struct VirtIoSocket;

        impl VirtIoDevMeta for VirtIoSocket {
            const DEVICE_TYPE: DeviceType = DeviceType::Vsock;
            type Device = ax_driver_virtio::VirtIoSocketDev<VirtIoHalImpl, VirtIoTransport>;

            fn try_new(transport: VirtIoTransport, _irq:  Option<usize>) -> DevResult<AxDeviceEnum> {
                Ok(AxDeviceEnum::from_vsock(Self::Device::try_new(transport)?))
            }
        }
    }
}

/// A common driver for all VirtIO devices that implements [`DriverProbe`].
pub struct VirtIoDriver<D: VirtIoDevMeta + ?Sized>(PhantomData<D>);

impl<D: VirtIoDevMeta> DriverProbe for VirtIoDriver<D> {
    #[cfg(bus = "mmio")]
    fn probe_mmio(mmio_base: usize, mmio_size: usize) -> Option<AxDeviceEnum> {
        let base_vaddr = phys_to_virt(mmio_base.into());
        if let Some((ty, transport)) =
            ax_driver_virtio::probe_mmio_device(base_vaddr.as_mut_ptr(), mmio_size)
            && ty == D::DEVICE_TYPE
        {
            match D::try_new(transport, None) {
                Ok(dev) => return Some(dev),
                Err(e) => {
                    warn!(
                        "failed to initialize MMIO device at [PA:{:#x}, PA:{:#x}): {:?}",
                        mmio_base,
                        mmio_base + mmio_size,
                        e
                    );
                    return None;
                }
            }
        }
        None
    }

    #[cfg(all(bus = "pci", feature = "bus-pci"))]
    fn probe_pci<C: ConfigurationAccess>(
        root: &mut PciRoot<C>,
        bdf: DeviceFunction,
        dev_info: &DeviceFunctionInfo,
    ) -> Option<AxDeviceEnum> {
        if dev_info.vendor_id != 0x1af4 {
            return None;
        }
        match (D::DEVICE_TYPE, dev_info.device_id) {
            (DeviceType::Net, 0x1000) | (DeviceType::Net, 0x1041) => {}
            (DeviceType::Block, 0x1001) | (DeviceType::Block, 0x1042) => {}
            (DeviceType::Input, 0x1052) => {}
            (DeviceType::Display, 0x1050) => {}
            (DeviceType::Vsock, 0x1053) => {}
            _ => return None,
        }

        if let Some((ty, transport)) =
            ax_driver_virtio::probe_pci_device::<VirtIoHalImpl, C>(root, bdf, dev_info)
            && ty == D::DEVICE_TYPE
        {
            let irq = pci_irq_vector(bdf);
            match D::try_new(transport, Some(irq)) {
                Ok(dev) => return Some(dev),
                Err(e) => {
                    warn!("failed to initialize PCI device at {bdf}({dev_info}): {e:?}");
                    return None;
                }
            }
        }
        None
    }
}

#[cfg(all(bus = "pci", feature = "bus-pci", target_arch = "x86_64"))]
fn pci_irq_vector(bdf: DeviceFunction) -> usize {
    const PCI_INTERRUPT_REG: usize = 0x3c;
    const IO_APIC_VECTOR_OFFSET: usize = 0x20;

    let fallback = pci_legacy_irq_fallback(bdf);
    let config_offset = ((bdf.bus as usize) << 20)
        | ((bdf.device as usize) << 15)
        | ((bdf.function as usize) << 12)
        | PCI_INTERRUPT_REG;
    let config_addr = ax_config::devices::PCI_ECAM_BASE + config_offset;
    let int_reg =
        unsafe { (phys_to_virt(config_addr.into()).as_usize() as *const u32).read_volatile() };
    let line = (int_reg & 0xff) as usize;
    let pin = ((int_reg >> 8) & 0xff) as usize;

    if (1..0x20).contains(&line) && pin != 0 {
        let vector = IO_APIC_VECTOR_OFFSET + line;
        debug!("PCI {bdf} INTx line {line}, pin {pin} -> vector {vector:#x}");
        vector
    } else {
        debug!("PCI {bdf} has INTx line {line}, pin {pin}; using fallback vector {fallback:#x}");
        fallback
    }
}

#[cfg(all(bus = "pci", feature = "bus-pci", not(target_arch = "x86_64")))]
fn pci_irq_vector(bdf: DeviceFunction) -> usize {
    pci_legacy_irq_fallback(bdf)
}

#[cfg(all(bus = "pci", feature = "bus-pci"))]
const fn pci_irq_base() -> usize {
    #[cfg(target_arch = "x86_64")]
    {
        0x20
    }
    #[cfg(target_arch = "riscv64")]
    {
        0x20
    }
    #[cfg(target_arch = "loongarch64")]
    {
        0x10
    }
    #[cfg(target_arch = "aarch64")]
    {
        0x23
    }
}

#[cfg(all(bus = "pci", feature = "bus-pci"))]
const fn pci_legacy_irq_fallback(bdf: DeviceFunction) -> usize {
    pci_irq_base() + (bdf.device & 3) as usize
}

#[cfg(virtio_dev)]
pub struct VirtIoHalImpl;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct HalAddress {
    addr: usize,
}

impl HalAddress {
    const fn new(addr: usize) -> Self {
        Self { addr }
    }

    const fn addr(self) -> usize {
        self.addr
    }
}

#[inline]
fn nonnull_from_addr(addr: usize, context: &str) -> NonNull<u8> {
    assert_ne!(addr, 0, "{context} returned a null address");
    // SAFETY: The assertion above guarantees the pointer is non-null.
    unsafe { NonNull::new_unchecked(addr as *mut u8) }
}

#[inline]
fn nonnull_from_hal_address(addr: HalAddress, context: &str) -> NonNull<u8> {
    nonnull_from_addr(addr.addr(), context)
}

#[inline]
fn hal_phys_to_virt_addr(paddr: PhysAddr) -> HalAddress {
    HalAddress::new(phys_to_virt((paddr as usize).into()).as_mut_ptr() as usize)
}

#[inline]
fn hal_virt_to_phys_addr(vaddr: usize) -> PhysAddr {
    virt_to_phys(vaddr.into()).as_usize() as PhysAddr
}

unsafe impl VirtIoHal for VirtIoHalImpl {
    fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
        let vaddr = if let Ok(vaddr) = global_allocator().alloc_pages(pages, 0x1000, UsageKind::Dma)
        {
            vaddr
        } else {
            return (0, NonNull::dangling());
        };
        let paddr = hal_virt_to_phys_addr(vaddr);
        let ptr = nonnull_from_hal_address(HalAddress::new(vaddr), "dma_alloc");
        (paddr, ptr)
    }

    unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
        global_allocator().dealloc_pages(vaddr.as_ptr() as usize, pages, UsageKind::Dma);
        0
    }

    #[inline]
    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
        let vaddr = hal_phys_to_virt_addr(paddr);
        nonnull_from_hal_address(vaddr, "mmio_phys_to_virt")
    }

    #[inline]
    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
        let vaddr = buffer.as_ptr() as *mut u8 as usize;
        hal_virt_to_phys_addr(vaddr)
    }

    #[inline]
    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {}
}