joycon-driver 0.0.0

Driver for Nintendo Switchâ„¢
Documentation
use std::{
    fmt,
    sync::atomic::{AtomicU8, Ordering},
};

use hidapi::{DeviceInfo, HidApi, HidDevice};

use crate::{
    error::{JoyConDriverError, JoyConDriverResult},
    imu::{IMUCalibration, IMUOffset},
    report::{InputReport, InputReportMode},
    stick::StickCalibration,
};

pub const VENDOR_ID: u16 = 0x057E;
pub const PRODUCT_ID_JOYCON_L: u16 = 0x2006;
pub const PRODUCT_ID_JOYCON_R: u16 = 0x2007;
pub const PRODUCT_ID_PROCON: u16 = 0x2009;

pub const SUB_COMMAND_IN_HEADER_BYTES: usize = 11;
pub const SUB_COMMAND_OUT_HEADER_BYTES: usize = 15;

pub const SUB_COMMAND_READ_ARGS_BYTES: usize = 5;
pub const SUB_COMMAND_READ_HEADER_BYTES: usize =
    SUB_COMMAND_OUT_HEADER_BYTES + SUB_COMMAND_READ_ARGS_BYTES;

pub const SUB_COMMAND_SET_REPORT_MODE_ARGS_BYTES: usize = 1;

pub const USER_CALIBRATION_DATA_MAGIC: [u8; 2] = [0xB2, 0xA1];

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceType {
    JoyConL,
    JoyConR,
    ProCon,
}

impl TryFrom<&DeviceInfo> for DeviceType {
    type Error = JoyConDriverError;

    fn try_from(info: &DeviceInfo) -> JoyConDriverResult<Self> {
        Ok(match (info.vendor_id(), info.product_id()) {
            (VENDOR_ID, PRODUCT_ID_JOYCON_L) => DeviceType::JoyConL,
            (VENDOR_ID, PRODUCT_ID_JOYCON_R) => DeviceType::JoyConR,
            (VENDOR_ID, PRODUCT_ID_PROCON) => DeviceType::ProCon,
            (VENDOR_ID, id) => return Err(JoyConDriverError::InvalidProductId(id)),
            (id, _) => return Err(JoyConDriverError::InvalidVendorId(id)),
        })
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Command {
    RumbleAndSubCommand = 0x01,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum SubCommand {
    SetInputReportMode = 0x03,
    SPIFlashRead = 0x10,
    EnableIMU = 0x40,
}

#[derive(Debug, Clone)]
pub enum SubCommandReply {
    Ack(u8),
    Nack,
}

impl From<u8> for SubCommandReply {
    fn from(v: u8) -> Self {
        if v & 0x80 > 0 {
            SubCommandReply::Ack(v & 0x7F)
        } else {
            SubCommandReply::Nack
        }
    }
}

#[derive(Debug, Clone)]
pub struct JoyConCalibrationData {
    pub left_stick_factory_calib: Option<StickCalibration>,
    pub left_stick_user_calib: Option<StickCalibration>,
    pub right_stick_factory_calib: Option<StickCalibration>,
    pub right_stick_user_calib: Option<StickCalibration>,
    pub imu_factory_calib: IMUCalibration,
    pub imu_user_calib: Option<IMUCalibration>,
    pub imu_offset: IMUOffset,
}

impl JoyConCalibrationData {
    pub fn read_data(device: &JoyConDevice) -> JoyConDriverResult<Self> {
        Ok(Self {
            left_stick_factory_calib: StickCalibration::read_left_factory_data(device)?,
            left_stick_user_calib: StickCalibration::read_left_user_data(device)?,
            right_stick_factory_calib: StickCalibration::read_right_factory_data(device)?,
            right_stick_user_calib: StickCalibration::read_right_user_data(device)?,
            imu_factory_calib: IMUCalibration::read_factory_data(device)?,
            imu_user_calib: IMUCalibration::read_user_data(device)?,
            imu_offset: IMUOffset::read_data(device)?,
        })
    }
}

pub struct JoyConDevice {
    hid_device: HidDevice,
    count: AtomicU8,
    pub device_type: DeviceType,
    pub serial_number: String,
}

impl JoyConDevice {
    pub fn new(api: &HidApi, info: &DeviceInfo) -> JoyConDriverResult<Self> {
        Ok(Self {
            hid_device: api
                .open_serial(
                    info.vendor_id(),
                    info.product_id(),
                    info.serial_number().unwrap_or_default(),
                )
                .map_err(JoyConDriverError::from)?,
            count: AtomicU8::default(),
            device_type: DeviceType::try_from(info)?,
            serial_number: info.serial_number().unwrap_or_default().to_string(),
        })
    }

    pub fn read_calibration_data(&self) -> JoyConDriverResult<JoyConCalibrationData> {
        JoyConCalibrationData::read_data(self)
    }

    pub fn input_report(&self) -> JoyConDriverResult<InputReport> {
        loop {
            match InputReport::read(self) {
                Ok(Some(v)) => return Ok(v),
                Ok(None) => continue,
                Err(e) => return Err(e),
            }
        }
    }

    pub(crate) fn send_sub_command(
        &self,
        sub_command: SubCommand,
        args: &[u8],
        buf: &mut [u8],
    ) -> JoyConDriverResult<()> {
        let mut data = vec![0u8; SUB_COMMAND_IN_HEADER_BYTES + args.len()];
        data[0] = Command::RumbleAndSubCommand as u8;
        data[1] = self.count.fetch_add(1, Ordering::Relaxed);
        data[10] = sub_command as u8;
        data[SUB_COMMAND_IN_HEADER_BYTES..].copy_from_slice(args);
        self.hid_device.write(&data[..])?;

        loop {
            self.hid_device.read_timeout(buf, 20)?;

            if buf[0] == 0x21 && buf[14] == (sub_command as u8) {
                return match SubCommandReply::from(buf[13]) {
                    SubCommandReply::Ack(_) => Ok(()),
                    SubCommandReply::Nack => Err(JoyConDriverError::SubCommandFailed),
                };
            }
        }
    }

    pub(crate) fn read_timeout(&self, buf: &mut [u8], timeout: i32) -> JoyConDriverResult<()> {
        self.hid_device.read_timeout(buf, timeout)?;
        Ok(())
    }

    pub fn set_input_report_mode(&self, report_mode: InputReportMode) -> JoyConDriverResult<()> {
        let mut buf = [0u8; SUB_COMMAND_OUT_HEADER_BYTES];
        self.send_sub_command(
            SubCommand::SetInputReportMode,
            &[report_mode as u8],
            &mut buf[..],
        )
    }

    pub fn imu_feature(&self, enable: bool) -> JoyConDriverResult<()> {
        let mut buf = [0u8; SUB_COMMAND_OUT_HEADER_BYTES];
        let arg = if enable { 0x01 } else { 0x00 };
        self.send_sub_command(SubCommand::EnableIMU, &[arg], &mut buf[..])
    }

    // https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_subcommands_notes.md#subcommand-0x10-spi-flash-read
    pub(crate) fn read(&self, buf: &mut [u8], address: u16) -> JoyConDriverResult<()> {
        if buf.len() <= SUB_COMMAND_READ_HEADER_BYTES {
            return Err(JoyConDriverError::BufferIsTooSmall(buf.len()));
        }

        let address_upper = ((address >> 8) & 0xFF) as u8;
        let address_lower = (address & 0xFF) as u8;
        let data_len = match (buf.len() - SUB_COMMAND_READ_HEADER_BYTES).try_into() {
            Ok(v) => v,
            Err(_) => {
                return Err(JoyConDriverError::BufferIsTooLarge(buf.len()));
            }
        };

        self.send_sub_command(
            SubCommand::SPIFlashRead,
            &[address_lower, address_upper, 0x00, 0x00, data_len],
            buf,
        )
    }

    fn list_devices_with_product_id(
        api: &HidApi,
        product_id: Option<u16>,
    ) -> JoyConDriverResult<Vec<Self>> {
        api.device_list()
            .filter(|info| {
                info.vendor_id() == VENDOR_ID
                    && (product_id.is_none() || Some(info.product_id()) == product_id)
            })
            .map(|info| Self::new(api, info))
            .collect::<JoyConDriverResult<Vec<_>>>()
    }

    pub fn list_devices(api: &HidApi) -> JoyConDriverResult<Vec<Self>> {
        Self::list_devices_with_product_id(api, None)
    }

    pub fn list_joycon_l_devices(api: &HidApi) -> JoyConDriverResult<Vec<Self>> {
        Self::list_devices_with_product_id(api, Some(PRODUCT_ID_JOYCON_L))
    }

    pub fn list_joycon_r_devices(api: &HidApi) -> JoyConDriverResult<Vec<Self>> {
        Self::list_devices_with_product_id(api, Some(PRODUCT_ID_JOYCON_R))
    }

    pub fn list_procon_devices(api: &HidApi) -> JoyConDriverResult<Vec<Self>> {
        Self::list_devices_with_product_id(api, Some(PRODUCT_ID_PROCON))
    }
}

impl fmt::Debug for JoyConDevice {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("JoyConDevice")
            .field("device_type", &self.device_type)
            .field("serial_number", &self.serial_number)
            .finish()
    }
}