ax-driver 0.6.1

ArceOS rdrive driver registration and rdif binding collection
Documentation
extern crate alloc;

use alloc::{borrow::ToOwned, format, string::String};

use rdif_input::{AbsInfo, EventType, InputDeviceId, InputError, InputEvent};
use rdrive::{DriverGeneric, PlatformDevice, probe::OnProbeError};
#[cfg(any(plat_static, plat_dyn))]
use virtio_drivers::transport::DeviceType;
use virtio_drivers::{
    Error as VirtIoError,
    device::input::{InputConfigSelect, VirtIOInput},
    transport::Transport,
};

use crate::{input::PlatformDeviceInput, virtio::VirtIoHalImpl};

#[cfg(any(plat_static, plat_dyn))]
crate::model_register!(
    name: "VirtIO Input",
    level: ProbeLevel::PostKernel,
    priority: ProbePriority::DEFAULT,
    probe_kinds: &[ProbeKind::Pci {
        on_probe: probe_pci,
    }],
);

#[cfg(any(plat_static, plat_dyn))]
fn probe_pci(
    endpoint: &mut rdrive::probe::pci::EndpointRc,
    plat_dev: PlatformDevice,
) -> Result<(), OnProbeError> {
    let transport = crate::pci::take_virtio_transport(endpoint, DeviceType::Input)?;
    register_transport(plat_dev, transport)
}

pub fn register_transport<T: Transport + 'static>(
    plat_dev: PlatformDevice,
    transport: T,
) -> Result<(), OnProbeError> {
    let dev = VirtIoInputDevice::new(transport).map_err(|err| {
        OnProbeError::other(format!("failed to initialize virtio-input: {err:?}"))
    })?;
    plat_dev.register_input(dev);
    log::info!("registered virtio input device");
    Ok(())
}

struct VirtIoInputDevice<T: Transport + 'static> {
    raw: VirtIOInput<VirtIoHalImpl, T>,
    device_id: InputDeviceId,
    name: String,
    physical_location: String,
    unique_id: String,
}

unsafe impl<T: Transport + 'static> Send for VirtIoInputDevice<T> {}

impl<T: Transport + 'static> VirtIoInputDevice<T> {
    fn new(transport: T) -> Result<Self, VirtIoError> {
        let mut raw = VirtIOInput::new(transport)?;
        let name = raw.name().unwrap_or_else(|_| "<unknown>".to_owned());
        let id = raw.ids()?;
        let device_id = InputDeviceId {
            bus_type: id.bustype,
            vendor: id.vendor,
            product: id.product,
            version: id.version,
        };
        let physical_location = format!(
            "virtio-input/{:04x}:{:04x}:{:04x}:{:04x}",
            device_id.bus_type, device_id.vendor, device_id.product, device_id.version
        );
        let unique_id = format!(
            "virtio-{:04x}-{:04x}-{:04x}-{:04x}-{}",
            device_id.bus_type, device_id.vendor, device_id.product, device_id.version, name
        );
        Ok(Self {
            raw,
            device_id,
            name,
            physical_location,
            unique_id,
        })
    }
}

impl<T: Transport + 'static> DriverGeneric for VirtIoInputDevice<T> {
    fn name(&self) -> &str {
        &self.name
    }
}

impl<T: Transport + 'static> rdif_input::Interface for VirtIoInputDevice<T> {
    fn device_id(&self) -> InputDeviceId {
        self.device_id
    }

    fn physical_location(&self) -> &str {
        &self.physical_location
    }

    fn unique_id(&self) -> &str {
        &self.unique_id
    }

    fn get_event_bits(&mut self, ty: EventType, out: &mut [u8]) -> Result<bool, InputError> {
        self.raw
            .query_config_select(InputConfigSelect::EvBits, ty as u8, out)
            .map(|read| read != 0)
            .map_err(map_input_err)
    }

    fn read_event(&mut self) -> Result<InputEvent, InputError> {
        self.raw.ack_interrupt();
        let event = self.raw.pop_pending_event().ok_or(InputError::Again)?;
        Ok(InputEvent {
            event_type: event.event_type,
            code: event.code,
            value: event.value as i32,
        })
    }

    fn get_prop_bits(&mut self, out: &mut [u8]) -> Result<usize, InputError> {
        let bits = self.raw.prop_bits().map_err(map_input_err)?;
        let len = bits.len().min(out.len());
        out[..len].copy_from_slice(&bits[..len]);
        Ok(len)
    }

    fn get_abs_info(&mut self, axis: u8) -> Result<AbsInfo, InputError> {
        let info = self.raw.abs_info(axis).map_err(map_input_err)?;
        Ok(AbsInfo {
            min: info.min as i32,
            max: info.max as i32,
            fuzz: info.fuzz as i32,
            flat: info.flat as i32,
            res: info.res as i32,
        })
    }
}

fn map_input_err(err: VirtIoError) -> InputError {
    match err {
        VirtIoError::Unsupported => InputError::NotSupported,
        VirtIoError::NotReady => InputError::Again,
        _ => InputError::Other(alloc::boxed::Box::new(err)),
    }
}