1use bitflags::bitflags;
8use hidpp_transport::HidapiChannel;
9use tracing::{debug, trace};
10
11use crate::error::{HidppErrorCode, ProtocolError, Result};
12use crate::protocol::{build_long_request, get_error_code, is_error_response};
13
14pub struct ReprogControlsFeature {
16 device_index: u8,
17 feature_index: u8,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct ControlInfo {
23 pub cid: u16,
25 pub task_id: u16,
27 pub flags: ControlFlags,
29 pub position: u8,
31 pub group: u8,
33 pub group_mask: u8,
35 pub additional_flags: u8,
37}
38
39bitflags! {
40 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
42 pub struct ControlFlags: u16 {
43 const MOUSE_BUTTON = 0x0001;
45 const FN_KEY = 0x0002;
47 const HOTKEY = 0x0004;
49 const FN_TOGGLE = 0x0008;
51 const REPROGRAMMABLE = 0x0010;
53 const DIVERTABLE = 0x0020;
55 const PERSIST = 0x0040;
57 const VIRTUAL = 0x0080;
59 const RAW_XY = 0x0100;
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub struct CidReporting {
67 pub cid: u16,
69 pub divert: bool,
71 pub persist: bool,
73 pub raw_xy: bool,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum ButtonEvent {
80 Press(u16),
82 Release(u16),
84}
85
86impl ReprogControlsFeature {
87 #[must_use]
93 pub fn new(device_index: u8, feature_index: u8) -> Self {
94 Self {
95 device_index,
96 feature_index,
97 }
98 }
99
100 #[must_use]
102 pub fn feature_index(&self) -> u8 {
103 self.feature_index
104 }
105
106 pub async fn get_count(&self, channel: &HidapiChannel) -> Result<u8> {
111 let request = build_long_request(self.device_index, self.feature_index, 0x00, &[]);
113
114 trace!("getting control count");
115 let response = channel.request(&request, 5).await?;
116
117 if is_error_response(&response) {
118 let code = get_error_code(&response).unwrap_or(0);
119 return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
120 }
121
122 if response.len() < 5 {
123 return Err(ProtocolError::InvalidResponse(
124 "control count response too short".to_string(),
125 ));
126 }
127
128 let count = response[4];
129 debug!(count, "got control count");
130 Ok(count)
131 }
132
133 pub async fn get_control_info(
142 &self,
143 channel: &HidapiChannel,
144 index: u8,
145 ) -> Result<ControlInfo> {
146 let request = build_long_request(self.device_index, self.feature_index, 0x01, &[index]);
148
149 trace!(index, "getting control info");
150 let response = channel.request(&request, 5).await?;
151
152 if is_error_response(&response) {
153 let code = get_error_code(&response).unwrap_or(0);
154 return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
155 }
156
157 if response.len() < 13 {
158 return Err(ProtocolError::InvalidResponse(
159 "control info response too short".to_string(),
160 ));
161 }
162
163 let cid = u16::from_be_bytes([response[4], response[5]]);
172 let task_id = u16::from_be_bytes([response[6], response[7]]);
173 let flags_raw = u16::from_be_bytes([response[8], response[9]]);
174 let flags = ControlFlags::from_bits_truncate(flags_raw);
175 let position = response[10];
176 let group = response[11];
177 let group_mask = response[12];
178 let additional_flags = response.get(13).copied().unwrap_or(0);
179
180 let info = ControlInfo {
181 cid,
182 task_id,
183 flags,
184 position,
185 group,
186 group_mask,
187 additional_flags,
188 };
189
190 debug!(
191 index,
192 cid = format!("0x{:04X}", info.cid),
193 task_id = format!("0x{:04X}", info.task_id),
194 ?flags,
195 "got control info"
196 );
197
198 Ok(info)
199 }
200
201 pub async fn get_all_controls(&self, channel: &HidapiChannel) -> Result<Vec<ControlInfo>> {
206 let count = self.get_count(channel).await?;
207 let mut controls = Vec::with_capacity(count as usize);
208
209 for i in 0..count {
210 let info = self.get_control_info(channel, i).await?;
211 controls.push(info);
212 }
213
214 debug!(count = controls.len(), "got all controls");
215 Ok(controls)
216 }
217
218 pub async fn get_cid_reporting(
227 &self,
228 channel: &HidapiChannel,
229 cid: u16,
230 ) -> Result<CidReporting> {
231 let cid_bytes = cid.to_be_bytes();
233 let request = build_long_request(
234 self.device_index,
235 self.feature_index,
236 0x02,
237 &[cid_bytes[0], cid_bytes[1]],
238 );
239
240 trace!(cid = format!("0x{cid:04X}"), "getting CID reporting");
241 let response = channel.request(&request, 5).await?;
242
243 if is_error_response(&response) {
244 let code = get_error_code(&response).unwrap_or(0);
245 return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
246 }
247
248 if response.len() < 8 {
249 return Err(ProtocolError::InvalidResponse(
250 "CID reporting response too short".to_string(),
251 ));
252 }
253
254 let response_cid = u16::from_be_bytes([response[4], response[5]]);
261 let flags = response[6];
262
263 let reporting = CidReporting {
264 cid: response_cid,
265 divert: flags & 0x01 != 0,
266 persist: flags & 0x02 != 0,
267 raw_xy: flags & 0x10 != 0,
268 };
269
270 debug!(
271 cid = format!("0x{:04X}", reporting.cid),
272 divert = reporting.divert,
273 persist = reporting.persist,
274 raw_xy = reporting.raw_xy,
275 "got CID reporting"
276 );
277
278 Ok(reporting)
279 }
280
281 pub async fn set_cid_reporting(
290 &self,
291 channel: &HidapiChannel,
292 reporting: &CidReporting,
293 ) -> Result<()> {
294 let cid_bytes = reporting.cid.to_be_bytes();
296
297 let mut flags: u8 = 0;
299 if reporting.divert {
300 flags |= 0x01;
301 }
302 if reporting.persist {
303 flags |= 0x02;
304 }
305 if reporting.raw_xy {
306 flags |= 0x10;
307 }
308
309 let request = build_long_request(
310 self.device_index,
311 self.feature_index,
312 0x03,
313 &[cid_bytes[0], cid_bytes[1], flags],
314 );
315
316 trace!(
317 cid = format!("0x{:04X}", reporting.cid),
318 divert = reporting.divert,
319 persist = reporting.persist,
320 "setting CID reporting"
321 );
322
323 let response = channel.request(&request, 5).await?;
324
325 if is_error_response(&response) {
326 let code = get_error_code(&response).unwrap_or(0);
327 return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
328 }
329
330 debug!(
331 cid = format!("0x{:04X}", reporting.cid),
332 divert = reporting.divert,
333 "set CID reporting"
334 );
335
336 Ok(())
337 }
338
339 #[must_use]
351 pub fn parse_button_event(data: &[u8], feature_index: u8) -> Option<Vec<ButtonEvent>> {
352 if data.len() < 7 {
354 return None;
355 }
356
357 if data[2] != feature_index {
360 return None;
361 }
362
363 let function_id = data[3] >> 4;
365
366 if function_id != 0 {
368 return None;
369 }
370
371 let mut events = Vec::new();
375 let mut i = 4;
376
377 while i + 1 < data.len() {
378 let cid = u16::from_be_bytes([data[i], data[i + 1]]);
379 if cid == 0 {
380 break;
381 }
382
383 events.push(ButtonEvent::Press(cid));
387
388 i += 2;
389 }
390
391 if events.is_empty() {
392 return None;
396 }
397
398 Some(events)
399 }
400
401 #[must_use]
414 pub fn parse_button_event_with_state(
415 data: &[u8],
416 feature_index: u8,
417 previous_pressed: &std::collections::HashSet<u16>,
418 ) -> Option<(Vec<ButtonEvent>, std::collections::HashSet<u16>)> {
419 if data.len() < 7 {
421 return None;
422 }
423
424 if data[2] != feature_index {
426 return None;
427 }
428
429 let function_id = data[3] >> 4;
431
432 if function_id != 0 {
434 return None;
435 }
436
437 let mut current_pressed = std::collections::HashSet::new();
439 let mut i = 4;
440
441 while i + 1 < data.len() {
442 let cid = u16::from_be_bytes([data[i], data[i + 1]]);
443 if cid == 0 {
444 break;
445 }
446 current_pressed.insert(cid);
447 i += 2;
448 }
449
450 let mut events = Vec::new();
452
453 for &cid in ¤t_pressed {
455 if !previous_pressed.contains(&cid) {
456 events.push(ButtonEvent::Press(cid));
457 }
458 }
459
460 for &cid in previous_pressed {
462 if !current_pressed.contains(&cid) {
463 events.push(ButtonEvent::Release(cid));
464 }
465 }
466
467 Some((events, current_pressed))
468 }
469}
470
471pub mod cid {
473 pub const LEFT_CLICK: u16 = 0x0050;
475 pub const RIGHT_CLICK: u16 = 0x0051;
477 pub const MIDDLE_CLICK: u16 = 0x0052;
479 pub const BACK: u16 = 0x0053;
481 pub const FORWARD: u16 = 0x0056;
483 pub const THUMB_BUTTON: u16 = 0x00C3;
485 pub const TOP_BUTTON: u16 = 0x00C4;
487 pub const SCROLL_LEFT: u16 = 0x00D7;
489 pub const SCROLL_RIGHT: u16 = 0x00D0;
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_control_flags() {
499 let flags = ControlFlags::DIVERTABLE | ControlFlags::REPROGRAMMABLE;
500 assert!(flags.contains(ControlFlags::DIVERTABLE));
501 assert!(flags.contains(ControlFlags::REPROGRAMMABLE));
502 assert!(!flags.contains(ControlFlags::MOUSE_BUTTON));
503 }
504
505 #[test]
506 fn test_cid_reporting() {
507 let reporting = CidReporting {
508 cid: 0x00C3,
509 divert: true,
510 persist: false,
511 raw_xy: false,
512 };
513 assert!(reporting.divert);
514 assert!(!reporting.persist);
515 }
516
517 #[test]
518 fn test_button_event() {
519 let press = ButtonEvent::Press(0x00C3);
520 let release = ButtonEvent::Release(0x00C3);
521
522 match press {
523 ButtonEvent::Press(cid) => assert_eq!(cid, 0x00C3),
524 ButtonEvent::Release(_) => panic!("expected press"),
525 }
526
527 match release {
528 ButtonEvent::Release(cid) => assert_eq!(cid, 0x00C3),
529 ButtonEvent::Press(_) => panic!("expected release"),
530 }
531 }
532
533 #[test]
534 fn test_parse_button_event_with_state() {
535 use std::collections::HashSet;
536
537 let feature_index = 0x05;
540 let notification = [
541 0x11, 0xFF, feature_index,
544 0x00, 0x00,
546 0xC3, 0x00,
548 0x00,
549 ];
550
551 let previous: HashSet<u16> = HashSet::new();
552 let result = ReprogControlsFeature::parse_button_event_with_state(
553 ¬ification,
554 feature_index,
555 &previous,
556 );
557
558 assert!(result.is_some());
559 let (events, new_pressed) = result.unwrap();
560 assert_eq!(events.len(), 1);
561 assert!(matches!(events[0], ButtonEvent::Press(0x00C3)));
562 assert!(new_pressed.contains(&0x00C3));
563 }
564}