1use crate::{AnalogCalibrationConfig, Request, Response};
7use bincode;
8use serde::{de::DeserializeOwned, Serialize};
9
10use std::io;
11use std::path::Path;
12use std::time::Duration;
13use thiserror::Error;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::UnixStream;
16
17use tokio::time::timeout;
18
19#[derive(Error, Debug)]
21pub enum IpcError {
22 #[error("failed to connect to daemon: {0}")]
23 Connect(std::io::Error),
24 #[error("failed to send request: {0}")]
25 Send(std::io::Error),
26 #[error("failed to receive response: {0}")]
27 Receive(std::io::Error),
28 #[error("serialization error: {0}")]
29 Serialize(bincode::Error),
30 #[error("deserialization error: {0}")]
31 Deserialize(bincode::Error),
32 #[error("request timed out")]
33 Timeout,
34
35 #[error("IO error: {0}")]
36 Io(#[from] io::Error),
37
38 #[error("Serialization error: {0}")]
39 Serialization(String),
40
41 #[error("Connection timeout")]
42 ConnectionTimeout,
43
44 #[error("Operation timeout after {0}ms")]
45 OperationTimeout(u64),
46
47 #[error("Daemon not running at {0}")]
48 DaemonNotRunning(String),
49
50 #[error("Invalid response from daemon")]
51 InvalidResponse,
52
53 #[error("Message too large: {0} bytes exceeds maximum of {1} bytes")]
54 MessageTooLarge(usize, usize),
55
56 #[error("Connection closed unexpectedly")]
57 ConnectionClosed,
58
59 #[error("Other error: {0}")]
60 Other(String),
61}
62
63pub const DEFAULT_SOCKET_PATH: &str = "/run/aethermap/aethermap.sock";
65
66pub const DEFAULT_TIMEOUT_MS: u64 = 5000;
68
69pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
71
72pub const DEFAULT_MAX_RETRIES: u32 = 3;
74
75pub const DEFAULT_RETRY_DELAY_MS: u64 = 1000;
77
78#[derive(Debug)]
80pub struct IpcClient {
81 socket_path: String,
82 timeout: Duration,
83 max_retries: u32,
84 retry_delay: Duration,
85 }
87
88impl Default for IpcClient {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl IpcClient {
95 pub fn new() -> Self {
97 Self::with_socket_path(DEFAULT_SOCKET_PATH)
98 }
99
100 pub fn with_socket_path<P: AsRef<Path>>(socket_path: P) -> Self {
102 Self {
103 socket_path: socket_path.as_ref().to_string_lossy().to_string(),
104 timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
105 max_retries: DEFAULT_MAX_RETRIES,
106 retry_delay: Duration::from_millis(DEFAULT_RETRY_DELAY_MS),
107 }
108 }
109
110 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
112 self.timeout = Duration::from_millis(timeout_ms);
113 self
114 }
115
116 pub fn with_retry_params(mut self, max_retries: u32, retry_delay_ms: u64) -> Self {
118 self.max_retries = max_retries;
119 self.retry_delay = Duration::from_millis(retry_delay_ms);
120 self
121 }
122
123 pub async fn is_daemon_running(&self) -> bool {
125 UnixStream::connect(&self.socket_path).await.is_ok()
126 }
127
128 pub async fn connect(&self) -> Result<UnixStream, IpcError> {
130 let mut attempts = 0;
131
132 loop {
133 match timeout(self.timeout, UnixStream::connect(&self.socket_path)).await {
134 Ok(Ok(stream)) => return Ok(stream),
135 Ok(Err(e)) => {
136 if attempts >= self.max_retries {
137 return Err(IpcError::DaemonNotRunning(self.socket_path.clone()));
138 }
139 tracing::warn!(
140 "Connection attempt {} failed: {}, retrying...",
141 attempts + 1,
142 e
143 );
144 tokio::time::sleep(self.retry_delay).await;
145 attempts += 1;
146 }
147 Err(_) => return Err(IpcError::ConnectionTimeout),
148 }
149 }
150 }
151
152 pub async fn send(&self, request: &Request) -> Result<Response, IpcError> {
154 self.send_with_retries(request, self.max_retries).await
155 }
156
157 pub async fn send_with_retries(
159 &self,
160 request: &Request,
161 max_retries: u32,
162 ) -> Result<Response, IpcError> {
163 let mut attempts = 0;
164 let mut last_error = None;
165
166 while attempts <= max_retries {
167 match self.connect().await {
168 Ok(mut stream) => match self.send_with_stream(&mut stream, request).await {
169 Ok(response) => return Ok(response),
170 Err(e) => {
171 last_error = Some(e);
172 if attempts < max_retries {
173 tracing::warn!("Request attempt {} failed, retrying...", attempts + 1);
174 tokio::time::sleep(self.retry_delay).await;
175 }
176 }
177 },
178 Err(e) => {
179 last_error = Some(e);
180 if attempts < max_retries {
181 tracing::warn!("Connection attempt {} failed, retrying...", attempts + 1);
182 tokio::time::sleep(self.retry_delay).await;
183 }
184 }
185 }
186 attempts += 1;
187 }
188
189 Err(last_error.unwrap_or(IpcError::Other("Unknown error".to_string())))
190 }
191
192 pub async fn set_macro_settings(&self, settings: crate::MacroSettings) -> Result<(), IpcError> {
194 let request = Request::SetMacroSettings(settings);
195 match self.send(&request).await? {
196 Response::Ack => Ok(()),
197 Response::Error(msg) => Err(IpcError::Other(msg)),
198 _ => Err(IpcError::InvalidResponse),
199 }
200 }
201
202 pub async fn get_macro_settings(&self) -> Result<crate::MacroSettings, IpcError> {
204 let request = Request::GetMacroSettings;
205 match self.send(&request).await? {
206 Response::MacroSettings(settings) => Ok(settings),
207 Response::Error(msg) => Err(IpcError::Other(msg)),
208 _ => Err(IpcError::InvalidResponse),
209 }
210 }
211
212 async fn send_with_stream(
214 &self,
215 stream: &mut UnixStream,
216 request: &Request,
217 ) -> Result<Response, IpcError> {
218 let serialized =
220 bincode::serialize(request).map_err(|e| IpcError::Serialization(e.to_string()))?;
221
222 if serialized.len() > MAX_MESSAGE_SIZE {
224 return Err(IpcError::MessageTooLarge(
225 serialized.len(),
226 MAX_MESSAGE_SIZE,
227 ));
228 }
229
230 if timeout(self.timeout, async {
232 let len = serialized.len() as u32;
234 stream.write_all(&len.to_le_bytes()).await?;
235
236 stream.write_all(&serialized).await?;
238 stream.flush().await?;
239
240 Ok::<(), io::Error>(())
241 })
242 .await
243 .is_err()
244 {
245 return Err(IpcError::OperationTimeout(self.timeout.as_millis() as u64));
246 }
247
248 let response = timeout(self.timeout, async {
250 let mut len_bytes = [0u8; 4];
252 stream.read_exact(&mut len_bytes).await?;
253 let response_len = u32::from_le_bytes(len_bytes) as usize;
254
255 if response_len > MAX_MESSAGE_SIZE {
257 return Err(IpcError::MessageTooLarge(response_len, MAX_MESSAGE_SIZE));
258 }
259
260 let mut buffer = vec![0u8; response_len];
262 stream.read_exact(&mut buffer).await?;
263
264 bincode::deserialize(&buffer).map_err(|e| IpcError::Serialization(e.to_string()))
266 })
267 .await;
268
269 match response {
270 Ok(Ok(resp)) => Ok(resp),
271 Ok(Err(e)) => Err(e),
272 Err(_) => Err(IpcError::OperationTimeout(self.timeout.as_millis() as u64)),
273 }
274 }
275}
276
277pub async fn send(request: &Request) -> Result<Response, IpcError> {
300 let client = IpcClient::new();
301 client.send(request).await
302}
303
304pub async fn send_request(req: &Request) -> Result<Response, IpcError> {
318 let mut stream = timeout(
320 Duration::from_secs(2),
321 UnixStream::connect(DEFAULT_SOCKET_PATH),
322 )
323 .await
324 .map_err(|_| IpcError::Timeout)?
325 .map_err(IpcError::Connect)?;
326
327 let serialized = bincode::serialize(req).map_err(IpcError::Serialize)?;
329
330 if serialized.len() > MAX_MESSAGE_SIZE {
332 return Err(IpcError::MessageTooLarge(
333 serialized.len(),
334 MAX_MESSAGE_SIZE,
335 ));
336 }
337
338 let len_prefix = (serialized.len() as u32).to_le_bytes();
340 timeout(Duration::from_secs(2), stream.write_all(&len_prefix))
341 .await
342 .map_err(|_| IpcError::Timeout)?
343 .map_err(IpcError::Send)?;
344
345 timeout(Duration::from_secs(2), stream.write_all(&serialized))
346 .await
347 .map_err(|_| IpcError::Timeout)?
348 .map_err(IpcError::Send)?;
349
350 let mut response_len_bytes = [0u8; 4];
352 timeout(
353 Duration::from_secs(2),
354 stream.read_exact(&mut response_len_bytes),
355 )
356 .await
357 .map_err(|_| IpcError::Timeout)?
358 .map_err(IpcError::Receive)?;
359
360 let response_len = u32::from_le_bytes(response_len_bytes) as usize;
361
362 if response_len > MAX_MESSAGE_SIZE {
364 return Err(IpcError::MessageTooLarge(response_len, MAX_MESSAGE_SIZE));
365 }
366
367 let mut response_buffer = vec![0u8; response_len];
369 timeout(
370 Duration::from_secs(2),
371 stream.read_exact(&mut response_buffer),
372 )
373 .await
374 .map_err(|_| IpcError::Timeout)?
375 .map_err(IpcError::Receive)?;
376
377 bincode::deserialize(&response_buffer).map_err(IpcError::Deserialize)
379}
380
381pub async fn send_to_path<P: AsRef<Path>>(
392 request: &Request,
393 socket_path: P,
394) -> Result<Response, IpcError> {
395 let client = IpcClient::with_socket_path(socket_path);
396 client.send(request).await
397}
398
399pub async fn send_with_timeout(request: &Request, timeout_ms: u64) -> Result<Response, IpcError> {
410 let client = IpcClient::new().with_timeout(timeout_ms);
411 client.send(request).await
412}
413
414pub async fn is_daemon_running<P: AsRef<Path>>(socket_path: Option<P>) -> bool {
424 let path = socket_path
425 .map(|p| p.as_ref().to_string_lossy().to_string())
426 .unwrap_or_else(|| DEFAULT_SOCKET_PATH.to_string());
427
428 UnixStream::connect(path).await.is_ok()
429}
430
431pub async fn get_analog_calibration(
456 device_id: &str,
457 layer_id: usize,
458) -> Result<AnalogCalibrationConfig, IpcError> {
459 let request = Request::GetAnalogCalibration {
460 device_id: device_id.to_string(),
461 layer_id,
462 };
463
464 match send(&request).await? {
465 Response::AnalogCalibration {
466 calibration: Some(cal),
467 ..
468 } => Ok(cal),
469 Response::AnalogCalibration {
470 calibration: None, ..
471 } => {
472 Ok(AnalogCalibrationConfig::default())
474 }
475 Response::Error(msg) => Err(IpcError::Other(msg)),
476 _ => Err(IpcError::InvalidResponse),
477 }
478}
479
480pub async fn set_analog_calibration(
519 device_id: &str,
520 layer_id: usize,
521 calibration: AnalogCalibrationConfig,
522) -> Result<(), IpcError> {
523 let request = Request::SetAnalogCalibration {
524 device_id: device_id.to_string(),
525 layer_id,
526 calibration,
527 };
528
529 match send(&request).await? {
530 Response::AnalogCalibrationAck => Ok(()),
531 Response::Error(msg) => Err(IpcError::Other(msg)),
532 _ => Err(IpcError::InvalidResponse),
533 }
534}
535
536pub async fn get_auto_switch_rules() -> Result<Vec<crate::AutoSwitchRule>, IpcError> {
558 let request = Request::GetAutoSwitchRules;
559
560 match send(&request).await? {
561 Response::AutoSwitchRules { rules } => Ok(rules),
562 Response::Error(msg) => Err(IpcError::Other(msg)),
563 _ => Err(IpcError::InvalidResponse),
564 }
565}
566
567pub async fn get_macro_settings() -> Result<crate::MacroSettings, IpcError> {
569 IpcClient::new().get_macro_settings().await
570}
571
572pub async fn set_macro_settings(settings: crate::MacroSettings) -> Result<(), IpcError> {
574 IpcClient::new().set_macro_settings(settings).await
575}
576
577pub fn serialize<T: Serialize>(msg: &T) -> Result<Vec<u8>, IpcError> {
579 bincode::serialize(msg).map_err(|e| IpcError::Serialization(e.to_string()))
580}
581
582pub fn deserialize<T: DeserializeOwned>(bytes: &[u8]) -> Result<T, IpcError> {
584 bincode::deserialize(bytes).map_err(|e| IpcError::Serialization(e.to_string()))
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use crate::{Action, DeviceInfo, DeviceType, KeyCombo, MacroEntry, Request, Response};
591 use std::path::PathBuf;
592 use tempfile::TempDir;
593 use tokio::io::{AsyncReadExt, AsyncWriteExt};
594 use tokio::net::UnixListener;
595
596 async fn mock_daemon(
598 socket_path: &str,
599 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
600 if Path::new(socket_path).exists() {
602 std::fs::remove_file(socket_path)?;
603 }
604
605 let listener = UnixListener::bind(socket_path)?;
606
607 loop {
608 match listener.accept().await {
609 Ok((mut stream, _)) => {
610 tokio::spawn(async move {
612 let mut len_buf = [0u8; 4];
614 if stream.read_exact(&mut len_buf).await.is_err() {
615 return;
616 }
617
618 let msg_len = u32::from_le_bytes(len_buf) as usize;
619 if msg_len > MAX_MESSAGE_SIZE {
620 return;
621 }
622
623 let mut msg_buf = vec![0u8; msg_len];
624 if stream.read_exact(&mut msg_buf).await.is_err() {
625 return;
626 }
627
628 let request: Request = match bincode::deserialize(&msg_buf) {
630 Ok(req) => req,
631 Err(_) => return,
632 };
633
634 let response = match request {
636 Request::GetDevices => {
637 let devices = vec![DeviceInfo {
638 name: "Test Device".to_string(),
639 path: PathBuf::from("/dev/input/event0"),
640 vendor_id: 0x1234,
641 product_id: 0x5678,
642 phys: "usb-0000:00:14.0-1/input0".to_string(),
643 device_type: DeviceType::Other,
644 }];
645 Response::Devices(devices)
646 }
647 Request::ListMacros => {
648 let macros = vec![MacroEntry {
649 name: "Test Macro".to_string(),
650 trigger: KeyCombo {
651 keys: vec![30], modifiers: vec![],
653 },
654 actions: vec![
655 Action::KeyPress(31), Action::Delay(100),
657 Action::KeyRelease(31), ],
659 device_id: None,
660 enabled: true,
661 humanize: false,
662 capture_mouse: false,
663 }];
664 Response::Macros(macros)
665 }
666 Request::GetStatus => Response::Status {
667 version: "0.1.0".to_string(),
668 uptime_seconds: 60,
669 devices_count: 1,
670 macros_count: 1,
671 },
672 _ => Response::Error("Unsupported request in test".to_string()),
673 };
674
675 let response_bytes = bincode::serialize(&response).unwrap();
677 let len = response_bytes.len() as u32;
678
679 if stream.write_all(&len.to_le_bytes()).await.is_err() {
680 return;
681 }
682
683 if stream.write_all(&response_bytes).await.is_err() {
684 return;
685 }
686
687 let _ = stream.flush().await;
688 });
689 }
690 Err(e) => {
691 tracing::error!("Failed to accept connection: {}", e);
692 break;
693 }
694 }
695 }
696
697 Ok(())
698 }
699
700 #[tokio::test]
701 async fn test_ipc_client_creation() {
702 let client = IpcClient::new();
703 assert_eq!(client.socket_path, DEFAULT_SOCKET_PATH);
704 assert_eq!(client.timeout, Duration::from_millis(DEFAULT_TIMEOUT_MS));
705 assert_eq!(client.max_retries, DEFAULT_MAX_RETRIES);
706 assert_eq!(
707 client.retry_delay,
708 Duration::from_millis(DEFAULT_RETRY_DELAY_MS)
709 );
710
711 let custom_path = "/tmp/test.sock";
712 let custom_client = IpcClient::with_socket_path(custom_path)
713 .with_timeout(10000)
714 .with_retry_params(5, 2000);
715
716 assert_eq!(custom_client.socket_path, custom_path);
717 assert_eq!(custom_client.timeout, Duration::from_millis(10000));
718 assert_eq!(custom_client.max_retries, 5);
719 assert_eq!(custom_client.retry_delay, Duration::from_millis(2000));
720 }
721
722 #[tokio::test]
723 async fn test_serialization_deserialization() {
724 let request = Request::GetDevices;
725 let serialized = serialize(&request).unwrap();
726 let deserialized: Request = deserialize(&serialized).unwrap();
727 assert!(matches!(deserialized, Request::GetDevices));
728
729 let macro_entry = MacroEntry {
730 name: "Test Macro".to_string(),
731 trigger: KeyCombo {
732 keys: vec![30, 40], modifiers: vec![29], },
735 actions: vec![
736 Action::KeyPress(30),
737 Action::Delay(100),
738 Action::KeyRelease(30),
739 ],
740 device_id: Some("test_device".to_string()),
741 enabled: true,
742 humanize: false,
743 capture_mouse: false,
744 };
745
746 let serialized = serialize(¯o_entry).unwrap();
747 let deserialized: MacroEntry = deserialize(&serialized).unwrap();
748 assert_eq!(deserialized.name, "Test Macro");
749 assert_eq!(deserialized.trigger.keys, vec![30, 40]);
750 }
751
752 #[tokio::test]
753 async fn test_client_server_communication() {
754 let temp_dir = TempDir::new().unwrap();
756 let socket_path = temp_dir.path().join("test.sock");
757 let socket_path_str = socket_path.to_string_lossy().to_string();
758 let socket_path_clone = socket_path_str.clone();
759
760 tokio::spawn(async move { mock_daemon(&socket_path_clone).await });
762
763 tokio::time::sleep(Duration::from_millis(100)).await;
765
766 let client = IpcClient::with_socket_path(&socket_path_str);
768
769 assert!(client.is_daemon_running().await);
771
772 let response = client.send(&Request::GetDevices).await.unwrap();
774 if let Response::Devices(devices) = response {
775 assert_eq!(devices.len(), 1);
776 assert_eq!(devices[0].name, "Test Device");
777 } else {
778 panic!("Expected Devices response");
779 }
780
781 let response = client.send(&Request::ListMacros).await.unwrap();
783 if let Response::Macros(macros) = response {
784 assert_eq!(macros.len(), 1);
785 assert_eq!(macros[0].name, "Test Macro");
786 } else {
787 panic!("Expected Macros response");
788 }
789
790 let response = client.send(&Request::GetStatus).await.unwrap();
792 if let Response::Status {
793 version,
794 uptime_seconds,
795 devices_count,
796 macros_count,
797 } = response
798 {
799 assert_eq!(version, "0.1.0");
800 assert_eq!(uptime_seconds, 60);
801 assert_eq!(devices_count, 1);
802 assert_eq!(macros_count, 1);
803 } else {
804 panic!("Expected Status response");
805 }
806
807 let response = send_to_path(&Request::GetDevices, &socket_path_str)
809 .await
810 .unwrap();
811 if let Response::Devices(devices) = response {
812 assert_eq!(devices.len(), 1);
813 } else {
814 panic!("Expected Devices response");
815 }
816 }
817
818 #[tokio::test]
819 async fn test_connection_timeout() {
820 let client = IpcClient::with_socket_path("/tmp/nonexistent.sock")
822 .with_timeout(100) .with_retry_params(1, 100); match client.send(&Request::GetDevices).await {
827 Err(IpcError::DaemonNotRunning(_)) | Err(IpcError::ConnectionTimeout) => {
828 }
830 _ => panic!("Expected DaemonNotRunning or ConnectionTimeout error"),
831 }
832 }
833
834 #[tokio::test]
835 async fn test_is_daemon_running() {
836 assert!(!is_daemon_running(Some("/tmp/nonexistent.sock")).await);
838
839 let _running = is_daemon_running(None::<&str>).await;
843 }
844
845 #[test]
846 fn test_serialization_roundtrip() {
847 let request = Request::GetDevices;
849 let serialized = bincode::serialize(&request)
850 .map_err(IpcError::Serialize)
851 .unwrap();
852 let deserialized: Request = bincode::deserialize(&serialized)
853 .map_err(IpcError::Deserialize)
854 .unwrap();
855 assert!(matches!(deserialized, Request::GetDevices));
856
857 let devices = vec![DeviceInfo {
859 name: "Test Device".to_string(),
860 path: std::path::PathBuf::from("/dev/input/test"),
861 vendor_id: 0x1532,
862 product_id: 0x0221,
863 phys: "usb-0000:00:14.0-1/input0".to_string(),
864 device_type: DeviceType::Other,
865 }];
866 let response = Response::Devices(devices.clone());
867 let serialized = bincode::serialize(&response)
868 .map_err(IpcError::Serialize)
869 .unwrap();
870 let deserialized: Response = bincode::deserialize(&serialized)
871 .map_err(IpcError::Deserialize)
872 .unwrap();
873
874 if let Response::Devices(deserialized_devices) = deserialized {
875 assert_eq!(deserialized_devices.len(), devices.len());
876 assert_eq!(deserialized_devices[0].name, devices[0].name);
877 assert_eq!(deserialized_devices[0].vendor_id, devices[0].vendor_id);
878 } else {
879 panic!("Expected Devices response");
880 }
881 }
882
883 #[test]
884 fn test_send_request_error_handling() {
885 let request = Request::GetDevices;
891 let result = bincode::serialize(&request);
892 assert!(result.is_ok());
893
894 let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF];
896 let result: Result<Request, bincode::Error> = bincode::deserialize(&invalid_data);
897 assert!(result.is_err());
898
899 let _serialized = bincode::serialize(&request).unwrap();
901 let error = bincode::deserialize::<Request>(&invalid_data).map_err(IpcError::Deserialize);
902 assert!(matches!(error, Err(IpcError::Deserialize(_))));
903 }
904}