use std::num::NonZeroU8;
use std::sync::Arc;
use async_hid::AsyncHidWrite;
use hidpp::{
channel::HidppChannel,
device::Device,
feature::CreatableFeature,
feature::adjustable_dpi::AdjustableDpiFeature,
feature::smartshift::{SmartShiftFeature, WheelMode},
protocol::v20::{ErrorType, Hidpp20Error},
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::debug;
use crate::route::{DeviceRoute, open_route_channel};
use crate::smartshift::{SmartShiftFeatureV0, SmartShiftMode, SmartShiftStatus};
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
pub enum WriteError {
#[error("HID transport error: {0}")]
Hid(String),
#[error("no connected device matched the route")]
DeviceNotFound,
#[error("device at index {index:#04x} did not respond to HID++")]
DeviceUnreachable { index: u8 },
#[error("device does not expose HID++ feature {feature_hex:#06x}")]
FeatureUnsupported { feature_hex: u16 },
#[error("device returned no supported DPI values")]
EmptyDpiList,
#[error("HID++ protocol error: {0}")]
Hidpp(String),
}
impl From<async_hid::HidError> for WriteError {
fn from(e: async_hid::HidError) -> Self {
Self::Hid(e.to_string())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DpiCapabilities {
values: Vec<u16>,
}
impl DpiCapabilities {
pub fn new(mut values: Vec<u16>) -> Result<Self, WriteError> {
values.sort_unstable();
values.dedup();
if values.is_empty() {
return Err(WriteError::EmptyDpiList);
}
Ok(Self { values })
}
#[must_use]
pub fn values(&self) -> &[u16] {
&self.values
}
#[must_use]
pub fn min(&self) -> u16 {
self.values[0]
}
#[must_use]
pub fn max(&self) -> u16 {
self.values[self.values.len() - 1]
}
#[must_use]
pub fn contains(&self, dpi: u16) -> bool {
self.values.binary_search(&dpi).is_ok()
}
#[must_use]
pub fn nearest(&self, dpi: u32) -> u16 {
let mut nearest = self.values[0];
let mut best_delta = u32::from(nearest).abs_diff(dpi);
for &candidate in &self.values[1..] {
let delta = u32::from(candidate).abs_diff(dpi);
if delta < best_delta {
nearest = candidate;
best_delta = delta;
}
}
nearest
}
#[must_use]
pub fn snap(&self, dpi: u32) -> u32 {
u32::from(self.nearest(dpi))
}
#[must_use]
pub fn step_hint(&self) -> u16 {
self.values
.windows(2)
.filter_map(|pair| pair[1].checked_sub(pair[0]))
.filter(|step| *step > 0)
.min()
.unwrap_or(1)
}
#[must_use]
pub fn adjacent_test_target(&self, current: u16) -> Option<u16> {
if self.values.len() < 2 {
return None;
}
match self.values.binary_search(¤t) {
Ok(index) if index + 1 < self.values.len() => Some(self.values[index + 1]),
Ok(index) if index > 0 => Some(self.values[index - 1]),
Ok(_) => None,
Err(index) if index < self.values.len() => Some(self.values[index]),
Err(_) => self.values.last().copied(),
}
.filter(|target| *target != current)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DpiInfo {
pub current: u16,
pub capabilities: DpiCapabilities,
}
#[derive(Debug, Clone, Copy)]
pub struct FeatureEntry {
pub id: u16,
pub version: u8,
}
pub async fn dump_features(route: &DeviceRoute) -> Result<Vec<FeatureEntry>, WriteError> {
use hidpp::feature::feature_set::FeatureSetFeature;
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature_set_info = device
.root()
.get_feature(FeatureSetFeature::ID)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?
.ok_or(WriteError::FeatureUnsupported {
feature_hex: FeatureSetFeature::ID,
})?;
let feature_set = device.add_feature::<FeatureSetFeature>(feature_set_info.index);
let count = feature_set
.count()
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
let mut entries = Vec::with_capacity(usize::from(count));
for i in 0..=count {
let info = feature_set
.get_feature(i)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
entries.push(FeatureEntry {
id: info.id,
version: info.version,
});
}
Ok(entries)
})
.await
}
async fn open_feature<F: CreatableFeature + 'static>(
device: &mut Device,
) -> Result<Arc<F>, WriteError> {
let info = device
.root()
.get_feature(F::ID)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?
.ok_or(WriteError::FeatureUnsupported { feature_hex: F::ID })?;
Ok(device.add_feature::<F>(info.index))
}
fn is_missing_enhanced(err: &WriteError) -> bool {
matches!(
err,
WriteError::FeatureUnsupported { feature_hex } if *feature_hex == 0x2111
)
}
fn wheel_mode_to_smartshift(wheel: WheelMode) -> SmartShiftMode {
if matches!(wheel, WheelMode::Freespin) {
SmartShiftMode::Free
} else {
SmartShiftMode::Ratchet
}
}
fn smartshift_to_wheel(mode: SmartShiftMode) -> WheelMode {
match mode {
SmartShiftMode::Free => WheelMode::Freespin,
SmartShiftMode::Ratchet => WheelMode::Ratchet,
}
}
enum SmartShift {
Enhanced(Arc<SmartShiftFeatureV0>),
Legacy(Arc<SmartShiftFeature>),
}
impl SmartShift {
async fn open(device: &mut Device) -> Result<Self, WriteError> {
match open_feature::<SmartShiftFeatureV0>(device).await {
Ok(feature) => Ok(Self::Enhanced(feature)),
Err(err) if is_missing_enhanced(&err) => {
let feature = open_feature::<SmartShiftFeature>(device).await?;
Ok(Self::Legacy(feature))
}
Err(err) => Err(err),
}
}
async fn status(&self) -> Result<SmartShiftStatus, WriteError> {
match self {
Self::Enhanced(feature) => feature
.get_status()
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}"))),
Self::Legacy(feature) => {
let rcm = feature
.get_ratchet_control_mode()
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
Ok(SmartShiftStatus {
mode: wheel_mode_to_smartshift(rcm.wheel_mode),
auto_disengage: rcm.auto_disengage,
tunable_torque: 0,
})
}
}
}
async fn set_status(&self, status: SmartShiftStatus) -> Result<(), WriteError> {
let SmartShiftStatus {
mode,
auto_disengage,
tunable_torque,
} = status;
match self {
Self::Enhanced(feature) => feature
.set_status(mode, auto_disengage, tunable_torque)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}"))),
Self::Legacy(feature) => feature
.set_ratchet_control_mode(
Some(smartshift_to_wheel(mode)),
Some(auto_disengage),
None,
)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}"))),
}
}
async fn set_sensitivity(&self, value: NonZeroU8) -> Result<(), WriteError> {
let current = self.status().await?;
self.set_status(SmartShiftStatus {
auto_disengage: value.get(),
..current
})
.await
}
}
pub async fn get_dpi(route: &DeviceRoute) -> Result<u16, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<AdjustableDpiFeature>(&mut device).await?;
feature
.get_sensor_dpi(0)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))
})
.await
}
fn classify_dpi_error(error: Hidpp20Error) -> WriteError {
match error {
Hidpp20Error::Feature(ErrorType::Unsupported | ErrorType::InvalidFunctionId)
| Hidpp20Error::UnsupportedResponse => WriteError::FeatureUnsupported {
feature_hex: AdjustableDpiFeature::ID,
},
other => WriteError::Hidpp(format!("{other:?}")),
}
}
pub async fn get_dpi_info(route: &DeviceRoute) -> Result<DpiInfo, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<AdjustableDpiFeature>(&mut device).await?;
let sensor_count = feature
.get_sensor_count()
.await
.map_err(classify_dpi_error)?;
if sensor_count == 0 {
return Err(WriteError::FeatureUnsupported {
feature_hex: AdjustableDpiFeature::ID,
});
}
let current = feature
.get_sensor_dpi(0)
.await
.map_err(classify_dpi_error)?;
let values = feature
.get_sensor_dpi_list(0)
.await
.map_err(classify_dpi_error)?;
Ok(DpiInfo {
current,
capabilities: DpiCapabilities::new(values)?,
})
})
.await
}
pub async fn get_smartshift_status(route: &DeviceRoute) -> Result<SmartShiftStatus, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let smartshift = SmartShift::open(&mut device).await?;
smartshift.status().await
})
.await
}
pub async fn set_smartshift_sensitivity(
route: &DeviceRoute,
value: NonZeroU8,
) -> Result<SmartShiftStatus, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let smartshift = SmartShift::open(&mut device).await?;
smartshift.set_sensitivity(value).await?;
smartshift.status().await
})
.await
}
pub async fn set_dpi(route: &DeviceRoute, dpi: u16) -> Result<(), WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
set_dpi_on_channel(&channel, index, dpi).await
})
.await
}
const PER_KEY_LIGHTING_FEATURE: u16 = 0x8080;
pub async fn set_keyboard_color(
route: &DeviceRoute,
r: u8,
g: u8,
b: u8,
) -> Result<(), WriteError> {
let device_index = route.device_index();
let feature_index = with_route(route, move |channel| async move {
let device = Device::new(Arc::clone(&channel), device_index)
.await
.map_err(|_| WriteError::DeviceUnreachable {
index: device_index,
})?;
let info = device
.root()
.get_feature(PER_KEY_LIGHTING_FEATURE)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?
.ok_or(WriteError::FeatureUnsupported {
feature_hex: PER_KEY_LIGHTING_FEATURE,
})?;
Ok(info.index)
})
.await?;
let Some(mut writer) = crate::transport::open_route_writer(route).await? else {
return Err(WriteError::DeviceNotFound);
};
let key_ids: Vec<u8> = (0x00u8..=0xe8).collect();
for chunk in key_ids.chunks(14) {
let mut rep = vec![0u8; 64];
rep[0] = 0x12;
rep[1] = device_index;
rep[2] = feature_index;
rep[3] = 0x3a;
rep[5] = 0x01;
rep[7] = 0x0e;
for (i, &key) in chunk.iter().enumerate() {
let off = 8 + i * 4;
rep[off] = key;
rep[off + 1] = r;
rep[off + 2] = g;
rep[off + 3] = b;
}
writer
.write_output_report(&rep)
.await
.map_err(WriteError::from)?;
}
let mut commit = vec![0u8; 20];
commit[0] = 0x11;
commit[1] = device_index;
commit[2] = feature_index;
commit[3] = 0x5a;
writer
.write_output_report(&commit)
.await
.map_err(WriteError::from)?;
debug!(
device_index,
feature_index, r, g, b, "wrote keyboard colour"
);
Ok(())
}
async fn set_dpi_on_channel(
channel: &Arc<HidppChannel>,
index: u8,
dpi: u16,
) -> Result<(), WriteError> {
let mut device = Device::new(Arc::clone(channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<AdjustableDpiFeature>(&mut device).await?;
feature
.set_sensor_dpi(0, dpi)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
if let Ok(actual) = feature.get_sensor_dpi(0).await {
if actual == dpi {
debug!(index, dpi, "wrote DPI (verified)");
} else {
tracing::warn!(
index,
requested = dpi,
actual,
"DPI write accepted but device reports a different value — \
likely out of the device's supported range"
);
}
} else {
debug!(index, dpi, "wrote DPI (read-back skipped)");
}
Ok(())
}
pub async fn toggle_smartshift(route: &DeviceRoute) -> Result<SmartShiftMode, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
toggle_smartshift_on_channel(&channel, index).await
})
.await
}
async fn toggle_smartshift_on_channel(
channel: &Arc<HidppChannel>,
index: u8,
) -> Result<SmartShiftMode, WriteError> {
let mut device = Device::new(Arc::clone(channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let smartshift = SmartShift::open(&mut device).await?;
let status = smartshift.status().await?;
let next = status.mode.flipped();
smartshift
.set_status(SmartShiftStatus {
mode: next,
..status
})
.await?;
debug!(index, ?next, "wrote SmartShift mode");
Ok(next)
}
pub async fn set_smartshift(
route: &DeviceRoute,
mode: SmartShiftMode,
auto_disengage: u8,
tunable_torque: u8,
) -> Result<(), WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
set_smartshift_on_channel(&channel, index, mode, auto_disengage, tunable_torque).await
})
.await
}
async fn set_smartshift_on_channel(
channel: &Arc<HidppChannel>,
index: u8,
mode: SmartShiftMode,
auto_disengage: u8,
tunable_torque: u8,
) -> Result<(), WriteError> {
let mut device = Device::new(Arc::clone(channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let smartshift = SmartShift::open(&mut device).await?;
smartshift
.set_status(SmartShiftStatus {
mode,
auto_disengage,
tunable_torque,
})
.await?;
debug!(
index,
?mode,
auto_disengage,
tunable_torque,
"wrote SmartShift config"
);
Ok(())
}
#[derive(Clone)]
pub struct SharedChannel {
channel: Arc<HidppChannel>,
route: DeviceRoute,
}
impl SharedChannel {
#[must_use]
pub(crate) fn new(channel: Arc<HidppChannel>, route: DeviceRoute) -> Self {
Self { channel, route }
}
#[must_use]
pub fn matches(&self, route: &DeviceRoute) -> bool {
self.route == *route
}
}
pub async fn set_dpi_on(shared: &SharedChannel, dpi: u16) -> Result<(), WriteError> {
set_dpi_on_channel(&shared.channel, shared.route.device_index(), dpi).await
}
pub async fn toggle_smartshift_on(shared: &SharedChannel) -> Result<SmartShiftMode, WriteError> {
toggle_smartshift_on_channel(&shared.channel, shared.route.device_index()).await
}
pub async fn set_smartshift_on(
shared: &SharedChannel,
mode: SmartShiftMode,
auto_disengage: u8,
tunable_torque: u8,
) -> Result<(), WriteError> {
set_smartshift_on_channel(
&shared.channel,
shared.route.device_index(),
mode,
auto_disengage,
tunable_torque,
)
.await
}
async fn with_route<F, Fut, T>(route: &DeviceRoute, f: F) -> Result<T, WriteError>
where
F: FnOnce(Arc<HidppChannel>) -> Fut,
Fut: std::future::Future<Output = Result<T, WriteError>>,
{
match open_route_channel(route).await? {
Some(channel) => f(channel).await,
None => Err(WriteError::DeviceNotFound),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn capabilities_sort_and_deduplicate_values() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![1600, 400, 800, 800])?;
assert_eq!(caps.values(), [400, 800, 1600]);
assert_eq!(caps.min(), 400);
assert_eq!(caps.max(), 1600);
Ok(())
}
#[test]
fn capabilities_reject_empty_list() {
assert!(matches!(
DpiCapabilities::new(Vec::new()),
Err(WriteError::EmptyDpiList)
));
}
#[test]
fn nearest_returns_closest_supported_value() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1600])?;
assert_eq!(caps.nearest(390), 400);
assert_eq!(caps.nearest(1000), 800);
assert_eq!(caps.nearest(2000), 1600);
Ok(())
}
#[test]
fn step_hint_returns_smallest_positive_gap() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1200, 2000])?;
assert_eq!(caps.step_hint(), 400);
Ok(())
}
#[test]
fn adjacent_test_target_prefers_next_then_previous_value() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1600])?;
assert_eq!(caps.adjacent_test_target(400), Some(800));
assert_eq!(caps.adjacent_test_target(800), Some(1600));
assert_eq!(caps.adjacent_test_target(1600), Some(800));
Ok(())
}
#[test]
fn adjacent_test_target_handles_current_outside_list() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1600])?;
assert_eq!(caps.adjacent_test_target(1000), Some(1600));
assert_eq!(caps.adjacent_test_target(2000), Some(1600));
Ok(())
}
#[test]
fn smartshift_and_wheel_mode_byte_encodings_match() {
assert_eq!(
u8::from(SmartShiftMode::Free),
u8::from(WheelMode::Freespin)
);
assert_eq!(
u8::from(SmartShiftMode::Ratchet),
u8::from(WheelMode::Ratchet)
);
}
#[test]
fn wheel_mode_maps_to_smartshift_mode() {
assert_eq!(
wheel_mode_to_smartshift(WheelMode::Freespin),
SmartShiftMode::Free
);
assert_eq!(
wheel_mode_to_smartshift(WheelMode::Ratchet),
SmartShiftMode::Ratchet
);
}
#[test]
fn smartshift_to_wheel_round_trips() {
for mode in [SmartShiftMode::Free, SmartShiftMode::Ratchet] {
assert_eq!(wheel_mode_to_smartshift(smartshift_to_wheel(mode)), mode);
}
}
#[test]
fn missing_enhanced_triggers_fallback() {
assert!(is_missing_enhanced(&WriteError::FeatureUnsupported {
feature_hex: 0x2111,
}));
}
#[test]
fn missing_legacy_does_not_trigger_fallback() {
assert!(!is_missing_enhanced(&WriteError::FeatureUnsupported {
feature_hex: 0x2110,
}));
}
#[test]
fn transport_errors_do_not_trigger_fallback() {
assert!(!is_missing_enhanced(&WriteError::DeviceUnreachable {
index: 0xff,
}));
assert!(!is_missing_enhanced(&WriteError::Hidpp("boom".into())));
}
}