use std::sync::Arc;
use async_hid::AsyncHidWrite;
use hidpp::{
channel::HidppChannel,
device::Device,
feature::CreatableFeature,
feature::adjustable_dpi::AdjustableDpiFeature,
protocol::v20::{ErrorType, Hidpp20Error},
};
use thiserror::Error;
use tracing::debug;
use crate::route::{DeviceRoute, open_route_channel};
use crate::smartshift::{SmartShiftFeatureV0, SmartShiftMode, SmartShiftStatus};
#[derive(Debug, Error)]
pub enum WriteError {
#[error("HID transport error")]
Hid(#[from] async_hid::HidError),
#[error("no connected device matched the route")]
DeviceNotFound,
#[error("device at index {index:#04x} did not respond to HID++")]
DeviceUnreachable { index: u8 },
#[error("device does not expose HID++ feature {feature_hex:#06x}")]
FeatureUnsupported { feature_hex: u16 },
#[error("device returned no supported DPI values")]
EmptyDpiList,
#[error("HID++ protocol error: {0}")]
Hidpp(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DpiCapabilities {
values: Vec<u16>,
}
impl DpiCapabilities {
pub fn new(mut values: Vec<u16>) -> Result<Self, WriteError> {
values.sort_unstable();
values.dedup();
if values.is_empty() {
return Err(WriteError::EmptyDpiList);
}
Ok(Self { values })
}
#[must_use]
pub fn values(&self) -> &[u16] {
&self.values
}
#[must_use]
pub fn min(&self) -> u16 {
self.values[0]
}
#[must_use]
pub fn max(&self) -> u16 {
self.values[self.values.len() - 1]
}
#[must_use]
pub fn contains(&self, dpi: u16) -> bool {
self.values.binary_search(&dpi).is_ok()
}
#[must_use]
pub fn nearest(&self, dpi: u32) -> u16 {
let mut nearest = self.values[0];
let mut best_delta = u32::from(nearest).abs_diff(dpi);
for &candidate in &self.values[1..] {
let delta = u32::from(candidate).abs_diff(dpi);
if delta < best_delta {
nearest = candidate;
best_delta = delta;
}
}
nearest
}
#[must_use]
pub fn snap(&self, dpi: u32) -> u32 {
u32::from(self.nearest(dpi))
}
#[must_use]
pub fn step_hint(&self) -> u16 {
self.values
.windows(2)
.filter_map(|pair| pair[1].checked_sub(pair[0]))
.filter(|step| *step > 0)
.min()
.unwrap_or(1)
}
#[must_use]
pub fn adjacent_test_target(&self, current: u16) -> Option<u16> {
if self.values.len() < 2 {
return None;
}
match self.values.binary_search(¤t) {
Ok(index) if index + 1 < self.values.len() => Some(self.values[index + 1]),
Ok(index) if index > 0 => Some(self.values[index - 1]),
Ok(_) => None,
Err(index) if index < self.values.len() => Some(self.values[index]),
Err(_) => self.values.last().copied(),
}
.filter(|target| *target != current)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DpiInfo {
pub current: u16,
pub capabilities: DpiCapabilities,
}
#[derive(Debug, Clone, Copy)]
pub struct FeatureEntry {
pub id: u16,
pub version: u8,
}
pub async fn dump_features(route: &DeviceRoute) -> Result<Vec<FeatureEntry>, WriteError> {
use hidpp::feature::feature_set::FeatureSetFeature;
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature_set_info = device
.root()
.get_feature(FeatureSetFeature::ID)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?
.ok_or(WriteError::FeatureUnsupported {
feature_hex: FeatureSetFeature::ID,
})?;
let feature_set = device.add_feature::<FeatureSetFeature>(feature_set_info.index);
let count = feature_set
.count()
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
let mut entries = Vec::with_capacity(usize::from(count));
for i in 0..=count {
let info = feature_set
.get_feature(i)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
entries.push(FeatureEntry {
id: info.id,
version: info.version,
});
}
Ok(entries)
})
.await
}
async fn open_feature<F: CreatableFeature + 'static>(
device: &mut Device,
) -> Result<Arc<F>, WriteError> {
let info = device
.root()
.get_feature(F::ID)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?
.ok_or(WriteError::FeatureUnsupported { feature_hex: F::ID })?;
Ok(device.add_feature::<F>(info.index))
}
pub async fn get_dpi(route: &DeviceRoute) -> Result<u16, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<AdjustableDpiFeature>(&mut device).await?;
feature
.get_sensor_dpi(0)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))
})
.await
}
fn classify_dpi_error(error: Hidpp20Error) -> WriteError {
match error {
Hidpp20Error::Feature(ErrorType::Unsupported | ErrorType::InvalidFunctionId)
| Hidpp20Error::UnsupportedResponse => WriteError::FeatureUnsupported {
feature_hex: AdjustableDpiFeature::ID,
},
other => WriteError::Hidpp(format!("{other:?}")),
}
}
pub async fn get_dpi_info(route: &DeviceRoute) -> Result<DpiInfo, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<AdjustableDpiFeature>(&mut device).await?;
let sensor_count = feature
.get_sensor_count()
.await
.map_err(classify_dpi_error)?;
if sensor_count == 0 {
return Err(WriteError::FeatureUnsupported {
feature_hex: AdjustableDpiFeature::ID,
});
}
let current = feature
.get_sensor_dpi(0)
.await
.map_err(classify_dpi_error)?;
let values = feature
.get_sensor_dpi_list(0)
.await
.map_err(classify_dpi_error)?;
Ok(DpiInfo {
current,
capabilities: DpiCapabilities::new(values)?,
})
})
.await
}
pub async fn get_smartshift_status(route: &DeviceRoute) -> Result<SmartShiftStatus, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
let mut device = Device::new(Arc::clone(&channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<SmartShiftFeatureV0>(&mut device).await?;
feature
.get_status()
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))
})
.await
}
pub async fn set_dpi(route: &DeviceRoute, dpi: u16) -> Result<(), WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
set_dpi_on_channel(&channel, index, dpi).await
})
.await
}
const PER_KEY_LIGHTING_FEATURE: u16 = 0x8080;
pub async fn set_keyboard_color(
route: &DeviceRoute,
r: u8,
g: u8,
b: u8,
) -> Result<(), WriteError> {
let device_index = route.device_index();
let feature_index = with_route(route, move |channel| async move {
let device = Device::new(Arc::clone(&channel), device_index)
.await
.map_err(|_| WriteError::DeviceUnreachable {
index: device_index,
})?;
let info = device
.root()
.get_feature(PER_KEY_LIGHTING_FEATURE)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?
.ok_or(WriteError::FeatureUnsupported {
feature_hex: PER_KEY_LIGHTING_FEATURE,
})?;
Ok(info.index)
})
.await?;
let Some(mut writer) = crate::transport::open_route_writer(route).await? else {
return Err(WriteError::DeviceNotFound);
};
let key_ids: Vec<u8> = (0x00u8..=0xe8).collect();
for chunk in key_ids.chunks(14) {
let mut rep = vec![0u8; 64];
rep[0] = 0x12;
rep[1] = device_index;
rep[2] = feature_index;
rep[3] = 0x3a;
rep[5] = 0x01;
rep[7] = 0x0e;
for (i, &key) in chunk.iter().enumerate() {
let off = 8 + i * 4;
rep[off] = key;
rep[off + 1] = r;
rep[off + 2] = g;
rep[off + 3] = b;
}
writer
.write_output_report(&rep)
.await
.map_err(WriteError::Hid)?;
}
let mut commit = vec![0u8; 20];
commit[0] = 0x11;
commit[1] = device_index;
commit[2] = feature_index;
commit[3] = 0x5a;
writer
.write_output_report(&commit)
.await
.map_err(WriteError::Hid)?;
debug!(
device_index,
feature_index, r, g, b, "wrote keyboard colour"
);
Ok(())
}
async fn set_dpi_on_channel(
channel: &Arc<HidppChannel>,
index: u8,
dpi: u16,
) -> Result<(), WriteError> {
let mut device = Device::new(Arc::clone(channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<AdjustableDpiFeature>(&mut device).await?;
feature
.set_sensor_dpi(0, dpi)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
if let Ok(actual) = feature.get_sensor_dpi(0).await {
if actual == dpi {
debug!(index, dpi, "wrote DPI (verified)");
} else {
tracing::warn!(
index,
requested = dpi,
actual,
"DPI write accepted but device reports a different value — \
likely out of the device's supported range"
);
}
} else {
debug!(index, dpi, "wrote DPI (read-back skipped)");
}
Ok(())
}
pub async fn toggle_smartshift(route: &DeviceRoute) -> Result<SmartShiftMode, WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
toggle_smartshift_on_channel(&channel, index).await
})
.await
}
async fn toggle_smartshift_on_channel(
channel: &Arc<HidppChannel>,
index: u8,
) -> Result<SmartShiftMode, WriteError> {
let mut device = Device::new(Arc::clone(channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<SmartShiftFeatureV0>(&mut device).await?;
let SmartShiftStatus {
mode,
auto_disengage,
tunable_torque,
} = feature
.get_status()
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
let next = mode.flipped();
feature
.set_status(next, auto_disengage, tunable_torque)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
debug!(index, ?next, "wrote SmartShift mode");
Ok(next)
}
pub async fn set_smartshift(
route: &DeviceRoute,
mode: SmartShiftMode,
auto_disengage: u8,
tunable_torque: u8,
) -> Result<(), WriteError> {
let index = route.device_index();
with_route(route, move |channel| async move {
set_smartshift_on_channel(&channel, index, mode, auto_disengage, tunable_torque).await
})
.await
}
async fn set_smartshift_on_channel(
channel: &Arc<HidppChannel>,
index: u8,
mode: SmartShiftMode,
auto_disengage: u8,
tunable_torque: u8,
) -> Result<(), WriteError> {
let mut device = Device::new(Arc::clone(channel), index)
.await
.map_err(|_| WriteError::DeviceUnreachable { index })?;
let feature = open_feature::<SmartShiftFeatureV0>(&mut device).await?;
feature
.set_status(mode, auto_disengage, tunable_torque)
.await
.map_err(|e| WriteError::Hidpp(format!("{e:?}")))?;
debug!(
index,
?mode,
auto_disengage,
tunable_torque,
"wrote SmartShift config"
);
Ok(())
}
#[derive(Clone)]
pub struct SharedChannel {
channel: Arc<HidppChannel>,
route: DeviceRoute,
}
impl SharedChannel {
#[must_use]
pub(crate) fn new(channel: Arc<HidppChannel>, route: DeviceRoute) -> Self {
Self { channel, route }
}
#[must_use]
pub fn matches(&self, route: &DeviceRoute) -> bool {
self.route == *route
}
}
pub async fn set_dpi_on(shared: &SharedChannel, dpi: u16) -> Result<(), WriteError> {
set_dpi_on_channel(&shared.channel, shared.route.device_index(), dpi).await
}
pub async fn toggle_smartshift_on(shared: &SharedChannel) -> Result<SmartShiftMode, WriteError> {
toggle_smartshift_on_channel(&shared.channel, shared.route.device_index()).await
}
pub async fn set_smartshift_on(
shared: &SharedChannel,
mode: SmartShiftMode,
auto_disengage: u8,
tunable_torque: u8,
) -> Result<(), WriteError> {
set_smartshift_on_channel(
&shared.channel,
shared.route.device_index(),
mode,
auto_disengage,
tunable_torque,
)
.await
}
async fn with_route<F, Fut, T>(route: &DeviceRoute, f: F) -> Result<T, WriteError>
where
F: FnOnce(Arc<HidppChannel>) -> Fut,
Fut: std::future::Future<Output = Result<T, WriteError>>,
{
match open_route_channel(route).await? {
Some(channel) => f(channel).await,
None => Err(WriteError::DeviceNotFound),
}
}
#[cfg(test)]
mod tests {
use super::{DpiCapabilities, WriteError};
#[test]
fn capabilities_sort_and_deduplicate_values() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![1600, 400, 800, 800])?;
assert_eq!(caps.values(), [400, 800, 1600]);
assert_eq!(caps.min(), 400);
assert_eq!(caps.max(), 1600);
Ok(())
}
#[test]
fn capabilities_reject_empty_list() {
assert!(matches!(
DpiCapabilities::new(Vec::new()),
Err(WriteError::EmptyDpiList)
));
}
#[test]
fn nearest_returns_closest_supported_value() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1600])?;
assert_eq!(caps.nearest(390), 400);
assert_eq!(caps.nearest(1000), 800);
assert_eq!(caps.nearest(2000), 1600);
Ok(())
}
#[test]
fn step_hint_returns_smallest_positive_gap() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1200, 2000])?;
assert_eq!(caps.step_hint(), 400);
Ok(())
}
#[test]
fn adjacent_test_target_prefers_next_then_previous_value() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1600])?;
assert_eq!(caps.adjacent_test_target(400), Some(800));
assert_eq!(caps.adjacent_test_target(800), Some(1600));
assert_eq!(caps.adjacent_test_target(1600), Some(800));
Ok(())
}
#[test]
fn adjacent_test_target_handles_current_outside_list() -> Result<(), WriteError> {
let caps = DpiCapabilities::new(vec![400, 800, 1600])?;
assert_eq!(caps.adjacent_test_target(1000), Some(1600));
assert_eq!(caps.adjacent_test_target(2000), Some(1600));
Ok(())
}
}