logiops_core/features/
reprog_controls.rs

1//! `Reprogrammable Controls` feature (0x1B04) - Button remapping.
2//!
3//! This feature allows querying and diverting button events. When a button
4//! is "diverted", the device sends HID++ events instead of standard HID
5//! reports, allowing the daemon to intercept and remap button presses.
6
7use bitflags::bitflags;
8use hidpp_transport::HidapiChannel;
9use tracing::{debug, trace};
10
11use crate::error::{HidppErrorCode, ProtocolError, Result};
12use crate::protocol::{build_long_request, get_error_code, is_error_response};
13
14/// `Reprogrammable Controls` feature implementation.
15pub struct ReprogControlsFeature {
16    device_index: u8,
17    feature_index: u8,
18}
19
20/// Information about a control (button).
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct ControlInfo {
23    /// Control ID (e.g., 0x00c3 for thumb button).
24    pub cid: u16,
25    /// Default task ID.
26    pub task_id: u16,
27    /// Capability flags.
28    pub flags: ControlFlags,
29    /// Physical position on device.
30    pub position: u8,
31    /// Button group.
32    pub group: u8,
33    /// Group membership mask.
34    pub group_mask: u8,
35    /// Additional flags (for v4).
36    pub additional_flags: u8,
37}
38
39bitflags! {
40    /// Control capability flags.
41    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
42    pub struct ControlFlags: u16 {
43        /// Control is a mouse button.
44        const MOUSE_BUTTON   = 0x0001;
45        /// Control is an Fn key.
46        const FN_KEY         = 0x0002;
47        /// Control is a hotkey.
48        const HOTKEY         = 0x0004;
49        /// Control toggles Fn mode.
50        const FN_TOGGLE      = 0x0008;
51        /// Control can be reprogrammed.
52        const REPROGRAMMABLE = 0x0010;
53        /// Control can be diverted.
54        const DIVERTABLE     = 0x0020;
55        /// Divert setting persists across power cycles.
56        const PERSIST        = 0x0040;
57        /// Control is virtual (software-generated).
58        const VIRTUAL        = 0x0080;
59        /// Control supports raw XY reporting.
60        const RAW_XY         = 0x0100;
61    }
62}
63
64/// Current reporting configuration for a control.
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub struct CidReporting {
67    /// Control ID.
68    pub cid: u16,
69    /// Whether the control is diverted to HID++ events.
70    pub divert: bool,
71    /// Whether the divert setting persists.
72    pub persist: bool,
73    /// Whether raw XY is enabled.
74    pub raw_xy: bool,
75}
76
77/// A button event from a diverted control.
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum ButtonEvent {
80    /// Button was pressed.
81    Press(u16),
82    /// Button was released.
83    Release(u16),
84}
85
86impl ReprogControlsFeature {
87    /// Creates a new reprogrammable controls feature accessor.
88    ///
89    /// # Arguments
90    /// * `device_index` - Device index (0xFF for direct)
91    /// * `feature_index` - Feature index from root feature discovery
92    #[must_use]
93    pub fn new(device_index: u8, feature_index: u8) -> Self {
94        Self {
95            device_index,
96            feature_index,
97        }
98    }
99
100    /// Returns the feature index.
101    #[must_use]
102    pub fn feature_index(&self) -> u8 {
103        self.feature_index
104    }
105
106    /// Gets the number of controls on the device.
107    ///
108    /// # Errors
109    /// Returns an error if HID++ communication fails.
110    pub async fn get_count(&self, channel: &HidapiChannel) -> Result<u8> {
111        // getCount: function_id=0x00
112        let request = build_long_request(self.device_index, self.feature_index, 0x00, &[]);
113
114        trace!("getting control count");
115        let response = channel.request(&request, 5).await?;
116
117        if is_error_response(&response) {
118            let code = get_error_code(&response).unwrap_or(0);
119            return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
120        }
121
122        if response.len() < 5 {
123            return Err(ProtocolError::InvalidResponse(
124                "control count response too short".to_string(),
125            ));
126        }
127
128        let count = response[4];
129        debug!(count, "got control count");
130        Ok(count)
131    }
132
133    /// Gets information about a control by index.
134    ///
135    /// # Arguments
136    /// * `channel` - HID channel
137    /// * `index` - Control index (0 to count-1)
138    ///
139    /// # Errors
140    /// Returns an error if HID++ communication fails.
141    pub async fn get_control_info(
142        &self,
143        channel: &HidapiChannel,
144        index: u8,
145    ) -> Result<ControlInfo> {
146        // getControlInfo: function_id=0x01
147        let request = build_long_request(self.device_index, self.feature_index, 0x01, &[index]);
148
149        trace!(index, "getting control info");
150        let response = channel.request(&request, 5).await?;
151
152        if is_error_response(&response) {
153            let code = get_error_code(&response).unwrap_or(0);
154            return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
155        }
156
157        if response.len() < 13 {
158            return Err(ProtocolError::InvalidResponse(
159                "control info response too short".to_string(),
160            ));
161        }
162
163        // Response format:
164        // [4:5] = CID (big-endian)
165        // [6:7] = Task ID (big-endian)
166        // [8:9] = Flags (big-endian)
167        // [10] = Position
168        // [11] = Group
169        // [12] = Group mask
170        // [13] = Additional flags (v4+)
171        let cid = u16::from_be_bytes([response[4], response[5]]);
172        let task_id = u16::from_be_bytes([response[6], response[7]]);
173        let flags_raw = u16::from_be_bytes([response[8], response[9]]);
174        let flags = ControlFlags::from_bits_truncate(flags_raw);
175        let position = response[10];
176        let group = response[11];
177        let group_mask = response[12];
178        let additional_flags = response.get(13).copied().unwrap_or(0);
179
180        let info = ControlInfo {
181            cid,
182            task_id,
183            flags,
184            position,
185            group,
186            group_mask,
187            additional_flags,
188        };
189
190        debug!(
191            index,
192            cid = format!("0x{:04X}", info.cid),
193            task_id = format!("0x{:04X}", info.task_id),
194            ?flags,
195            "got control info"
196        );
197
198        Ok(info)
199    }
200
201    /// Gets information about all controls on the device.
202    ///
203    /// # Errors
204    /// Returns an error if HID++ communication fails.
205    pub async fn get_all_controls(&self, channel: &HidapiChannel) -> Result<Vec<ControlInfo>> {
206        let count = self.get_count(channel).await?;
207        let mut controls = Vec::with_capacity(count as usize);
208
209        for i in 0..count {
210            let info = self.get_control_info(channel, i).await?;
211            controls.push(info);
212        }
213
214        debug!(count = controls.len(), "got all controls");
215        Ok(controls)
216    }
217
218    /// Gets the current reporting configuration for a control.
219    ///
220    /// # Arguments
221    /// * `channel` - HID channel
222    /// * `cid` - Control ID
223    ///
224    /// # Errors
225    /// Returns an error if HID++ communication fails.
226    pub async fn get_cid_reporting(
227        &self,
228        channel: &HidapiChannel,
229        cid: u16,
230    ) -> Result<CidReporting> {
231        // getCidReporting: function_id=0x02
232        let cid_bytes = cid.to_be_bytes();
233        let request = build_long_request(
234            self.device_index,
235            self.feature_index,
236            0x02,
237            &[cid_bytes[0], cid_bytes[1]],
238        );
239
240        trace!(cid = format!("0x{cid:04X}"), "getting CID reporting");
241        let response = channel.request(&request, 5).await?;
242
243        if is_error_response(&response) {
244            let code = get_error_code(&response).unwrap_or(0);
245            return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
246        }
247
248        if response.len() < 8 {
249            return Err(ProtocolError::InvalidResponse(
250                "CID reporting response too short".to_string(),
251            ));
252        }
253
254        // Response format:
255        // [4:5] = CID (big-endian)
256        // [6] = Reporting flags
257        // Bit 0: Divert
258        // Bit 1: Persist
259        // Bit 4: Raw XY
260        let response_cid = u16::from_be_bytes([response[4], response[5]]);
261        let flags = response[6];
262
263        let reporting = CidReporting {
264            cid: response_cid,
265            divert: flags & 0x01 != 0,
266            persist: flags & 0x02 != 0,
267            raw_xy: flags & 0x10 != 0,
268        };
269
270        debug!(
271            cid = format!("0x{:04X}", reporting.cid),
272            divert = reporting.divert,
273            persist = reporting.persist,
274            raw_xy = reporting.raw_xy,
275            "got CID reporting"
276        );
277
278        Ok(reporting)
279    }
280
281    /// Sets the reporting configuration for a control.
282    ///
283    /// # Arguments
284    /// * `channel` - HID channel
285    /// * `reporting` - Reporting configuration
286    ///
287    /// # Errors
288    /// Returns an error if HID++ communication fails.
289    pub async fn set_cid_reporting(
290        &self,
291        channel: &HidapiChannel,
292        reporting: &CidReporting,
293    ) -> Result<()> {
294        // setCidReporting: function_id=0x03
295        let cid_bytes = reporting.cid.to_be_bytes();
296
297        // Build reporting flags
298        let mut flags: u8 = 0;
299        if reporting.divert {
300            flags |= 0x01;
301        }
302        if reporting.persist {
303            flags |= 0x02;
304        }
305        if reporting.raw_xy {
306            flags |= 0x10;
307        }
308
309        let request = build_long_request(
310            self.device_index,
311            self.feature_index,
312            0x03,
313            &[cid_bytes[0], cid_bytes[1], flags],
314        );
315
316        trace!(
317            cid = format!("0x{:04X}", reporting.cid),
318            divert = reporting.divert,
319            persist = reporting.persist,
320            "setting CID reporting"
321        );
322
323        let response = channel.request(&request, 5).await?;
324
325        if is_error_response(&response) {
326            let code = get_error_code(&response).unwrap_or(0);
327            return Err(ProtocolError::HidppError(HidppErrorCode::from_byte(code)));
328        }
329
330        debug!(
331            cid = format!("0x{:04X}", reporting.cid),
332            divert = reporting.divert,
333            "set CID reporting"
334        );
335
336        Ok(())
337    }
338
339    /// Parses button events from a HID++ notification.
340    ///
341    /// When buttons are diverted, the device sends HID++ notifications
342    /// containing button state changes. This function parses such notifications.
343    ///
344    /// # Arguments
345    /// * `data` - Raw HID++ report data
346    /// * `feature_index` - Expected feature index
347    ///
348    /// # Returns
349    /// A vector of button events, or None if this is not a button event notification.
350    #[must_use]
351    pub fn parse_button_event(data: &[u8], feature_index: u8) -> Option<Vec<ButtonEvent>> {
352        // Check minimum length and report type
353        if data.len() < 7 {
354            return None;
355        }
356
357        // Check if this is for our feature
358        // Report format: [report_id, device_idx, feature_idx, func_sw_id, ...]
359        if data[2] != feature_index {
360            return None;
361        }
362
363        // Function ID is in high nibble of byte 3
364        let function_id = data[3] >> 4;
365
366        // Function 0 is the divert event notification
367        if function_id != 0 {
368            return None;
369        }
370
371        // Parse button states from the notification
372        // Format varies but typically: [CID hi, CID lo, state, ...]
373        // Multiple buttons can be reported in one notification
374        let mut events = Vec::new();
375        let mut i = 4;
376
377        while i + 1 < data.len() {
378            let cid = u16::from_be_bytes([data[i], data[i + 1]]);
379            if cid == 0 {
380                break;
381            }
382
383            // Check if there's a state byte
384            // In most notifications, presence in the report means pressed
385            // and absence means released
386            events.push(ButtonEvent::Press(cid));
387
388            i += 2;
389        }
390
391        if events.is_empty() {
392            // If all CIDs are 0, this is a release notification
393            // We need to track previous state to determine which buttons were released
394            // For now, return empty - the caller should track state
395            return None;
396        }
397
398        Some(events)
399    }
400
401    /// Parses button events from a HID++ notification, tracking state changes.
402    ///
403    /// This version compares against the previous state to generate both
404    /// press and release events.
405    ///
406    /// # Arguments
407    /// * `data` - Raw HID++ report data
408    /// * `feature_index` - Expected feature index
409    /// * `previous_pressed` - Set of CIDs that were pressed in the previous state
410    ///
411    /// # Returns
412    /// A tuple of (events, `new_pressed_set`), or None if not a button event.
413    #[must_use]
414    pub fn parse_button_event_with_state(
415        data: &[u8],
416        feature_index: u8,
417        previous_pressed: &std::collections::HashSet<u16>,
418    ) -> Option<(Vec<ButtonEvent>, std::collections::HashSet<u16>)> {
419        // Check minimum length and report type
420        if data.len() < 7 {
421            return None;
422        }
423
424        // Check if this is for our feature
425        if data[2] != feature_index {
426            return None;
427        }
428
429        // Function ID is in high nibble of byte 3
430        let function_id = data[3] >> 4;
431
432        // Function 0 is the divert event notification
433        if function_id != 0 {
434            return None;
435        }
436
437        // Parse currently pressed buttons
438        let mut current_pressed = std::collections::HashSet::new();
439        let mut i = 4;
440
441        while i + 1 < data.len() {
442            let cid = u16::from_be_bytes([data[i], data[i + 1]]);
443            if cid == 0 {
444                break;
445            }
446            current_pressed.insert(cid);
447            i += 2;
448        }
449
450        // Generate events by comparing states
451        let mut events = Vec::new();
452
453        // New presses: in current but not in previous
454        for &cid in &current_pressed {
455            if !previous_pressed.contains(&cid) {
456                events.push(ButtonEvent::Press(cid));
457            }
458        }
459
460        // Releases: in previous but not in current
461        for &cid in previous_pressed {
462            if !current_pressed.contains(&cid) {
463                events.push(ButtonEvent::Release(cid));
464            }
465        }
466
467        Some((events, current_pressed))
468    }
469}
470
471/// Well-known Control IDs for Logitech mice.
472pub mod cid {
473    /// Left mouse button.
474    pub const LEFT_CLICK: u16 = 0x0050;
475    /// Right mouse button.
476    pub const RIGHT_CLICK: u16 = 0x0051;
477    /// Middle mouse button (wheel click).
478    pub const MIDDLE_CLICK: u16 = 0x0052;
479    /// Back button (thumb).
480    pub const BACK: u16 = 0x0053;
481    /// Forward button (thumb).
482    pub const FORWARD: u16 = 0x0056;
483    /// Thumb button (gesture button on MX mice).
484    pub const THUMB_BUTTON: u16 = 0x00C3;
485    /// Top button (near DPI button).
486    pub const TOP_BUTTON: u16 = 0x00C4;
487    /// Scroll wheel left tilt.
488    pub const SCROLL_LEFT: u16 = 0x00D7;
489    /// Scroll wheel right tilt.
490    pub const SCROLL_RIGHT: u16 = 0x00D0;
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_control_flags() {
499        let flags = ControlFlags::DIVERTABLE | ControlFlags::REPROGRAMMABLE;
500        assert!(flags.contains(ControlFlags::DIVERTABLE));
501        assert!(flags.contains(ControlFlags::REPROGRAMMABLE));
502        assert!(!flags.contains(ControlFlags::MOUSE_BUTTON));
503    }
504
505    #[test]
506    fn test_cid_reporting() {
507        let reporting = CidReporting {
508            cid: 0x00C3,
509            divert: true,
510            persist: false,
511            raw_xy: false,
512        };
513        assert!(reporting.divert);
514        assert!(!reporting.persist);
515    }
516
517    #[test]
518    fn test_button_event() {
519        let press = ButtonEvent::Press(0x00C3);
520        let release = ButtonEvent::Release(0x00C3);
521
522        match press {
523            ButtonEvent::Press(cid) => assert_eq!(cid, 0x00C3),
524            ButtonEvent::Release(_) => panic!("expected press"),
525        }
526
527        match release {
528            ButtonEvent::Release(cid) => assert_eq!(cid, 0x00C3),
529            ButtonEvent::Press(_) => panic!("expected release"),
530        }
531    }
532
533    #[test]
534    fn test_parse_button_event_with_state() {
535        use std::collections::HashSet;
536
537        // Simulate a button press notification
538        // Format: [report_id, device_idx, feature_idx, func_sw_id, cid_hi, cid_lo, 0, 0, ...]
539        let feature_index = 0x05;
540        let notification = [
541            0x11, // Long report
542            0xFF, // Device index
543            feature_index,
544            0x00, // Function 0, SW ID 0
545            0x00,
546            0xC3, // CID 0x00C3 (thumb button)
547            0x00,
548            0x00,
549        ];
550
551        let previous: HashSet<u16> = HashSet::new();
552        let result = ReprogControlsFeature::parse_button_event_with_state(
553            &notification,
554            feature_index,
555            &previous,
556        );
557
558        assert!(result.is_some());
559        let (events, new_pressed) = result.unwrap();
560        assert_eq!(events.len(), 1);
561        assert!(matches!(events[0], ButtonEvent::Press(0x00C3)));
562        assert!(new_pressed.contains(&0x00C3));
563    }
564}