1use crate::{Request, Response, AnalogCalibrationConfig};
7use bincode;
8use serde::{Serialize, de::DeserializeOwned};
9
10use std::io;
11use std::path::Path;
12use std::time::Duration;
13use thiserror::Error;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::UnixStream;
16
17use tokio::time::timeout;
18
19#[derive(Error, Debug)]
21pub enum IpcError {
22 #[error("failed to connect to daemon: {0}")]
23 Connect(std::io::Error),
24 #[error("failed to send request: {0}")]
25 Send(std::io::Error),
26 #[error("failed to receive response: {0}")]
27 Receive(std::io::Error),
28 #[error("serialization error: {0}")]
29 Serialize(bincode::Error),
30 #[error("deserialization error: {0}")]
31 Deserialize(bincode::Error),
32 #[error("request timed out")]
33 Timeout,
34
35 #[error("IO error: {0}")]
36 Io(#[from] io::Error),
37
38 #[error("Serialization error: {0}")]
39 Serialization(String),
40
41 #[error("Connection timeout")]
42 ConnectionTimeout,
43
44 #[error("Operation timeout after {0}ms")]
45 OperationTimeout(u64),
46
47 #[error("Daemon not running at {0}")]
48 DaemonNotRunning(String),
49
50 #[error("Invalid response from daemon")]
51 InvalidResponse,
52
53 #[error("Message too large: {0} bytes exceeds maximum of {1} bytes")]
54 MessageTooLarge(usize, usize),
55
56 #[error("Connection closed unexpectedly")]
57 ConnectionClosed,
58
59 #[error("Other error: {0}")]
60 Other(String),
61}
62
63pub const DEFAULT_SOCKET_PATH: &str = "/run/aethermap/aethermap.sock";
65
66pub const DEFAULT_TIMEOUT_MS: u64 = 5000;
68
69pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
71
72pub const DEFAULT_MAX_RETRIES: u32 = 3;
74
75pub const DEFAULT_RETRY_DELAY_MS: u64 = 1000;
77
78#[derive(Debug)]
80pub struct IpcClient {
81 socket_path: String,
82 timeout: Duration,
83 max_retries: u32,
84 retry_delay: Duration,
85 }
87
88impl IpcClient {
89 pub fn new() -> Self {
91 Self::with_socket_path(DEFAULT_SOCKET_PATH)
92 }
93
94 pub fn with_socket_path<P: AsRef<Path>>(socket_path: P) -> Self {
96 Self {
97 socket_path: socket_path.as_ref().to_string_lossy().to_string(),
98 timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
99 max_retries: DEFAULT_MAX_RETRIES,
100 retry_delay: Duration::from_millis(DEFAULT_RETRY_DELAY_MS),
101 }
102 }
103
104 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
106 self.timeout = Duration::from_millis(timeout_ms);
107 self
108 }
109
110 pub fn with_retry_params(mut self, max_retries: u32, retry_delay_ms: u64) -> Self {
112 self.max_retries = max_retries;
113 self.retry_delay = Duration::from_millis(retry_delay_ms);
114 self
115 }
116
117 pub async fn is_daemon_running(&self) -> bool {
119 match UnixStream::connect(&self.socket_path).await {
120 Ok(_) => true,
121 Err(_) => false,
122 }
123 }
124
125 pub async fn connect(&self) -> Result<UnixStream, IpcError> {
127 let mut attempts = 0;
128
129 loop {
130 match timeout(self.timeout, UnixStream::connect(&self.socket_path)).await {
131 Ok(Ok(stream)) => return Ok(stream),
132 Ok(Err(e)) => {
133 if attempts >= self.max_retries {
134 return Err(IpcError::DaemonNotRunning(self.socket_path.clone()));
135 }
136 tracing::warn!("Connection attempt {} failed: {}, retrying...", attempts + 1, e);
137 tokio::time::sleep(self.retry_delay).await;
138 attempts += 1;
139 }
140 Err(_) => return Err(IpcError::ConnectionTimeout),
141 }
142 }
143 }
144
145 pub async fn send(&self, request: &Request) -> Result<Response, IpcError> {
147 self.send_with_retries(request, self.max_retries).await
148 }
149
150 pub async fn send_with_retries(&self, request: &Request, max_retries: u32) -> Result<Response, IpcError> {
152 let mut attempts = 0;
153 let mut last_error = None;
154
155 while attempts <= max_retries {
156 match self.connect().await {
157 Ok(mut stream) => {
158 match self.send_with_stream(&mut stream, request).await {
159 Ok(response) => return Ok(response),
160 Err(e) => {
161 last_error = Some(e);
162 if attempts < max_retries {
163 tracing::warn!("Request attempt {} failed, retrying...", attempts + 1);
164 tokio::time::sleep(self.retry_delay).await;
165 }
166 }
167 }
168 }
169 Err(e) => {
170 last_error = Some(e);
171 if attempts < max_retries {
172 tracing::warn!("Connection attempt {} failed, retrying...", attempts + 1);
173 tokio::time::sleep(self.retry_delay).await;
174 }
175 }
176 }
177 attempts += 1;
178 }
179
180 Err(last_error.unwrap_or(IpcError::Other("Unknown error".to_string())))
181 }
182
183 pub async fn set_macro_settings(&self, settings: crate::MacroSettings) -> Result<(), IpcError> {
185 let request = Request::SetMacroSettings(settings);
186 match self.send(&request).await? {
187 Response::Ack => Ok(()),
188 Response::Error(msg) => Err(IpcError::Other(msg)),
189 _ => Err(IpcError::InvalidResponse),
190 }
191 }
192
193 pub async fn get_macro_settings(&self) -> Result<crate::MacroSettings, IpcError> {
195 let request = Request::GetMacroSettings;
196 match self.send(&request).await? {
197 Response::MacroSettings(settings) => Ok(settings),
198 Response::Error(msg) => Err(IpcError::Other(msg)),
199 _ => Err(IpcError::InvalidResponse),
200 }
201 }
202
203 async fn send_with_stream(&self, stream: &mut UnixStream, request: &Request) -> Result<Response, IpcError> {
205 let serialized = bincode::serialize(request)
207 .map_err(|e| IpcError::Serialization(e.to_string()))?;
208
209 if serialized.len() > MAX_MESSAGE_SIZE {
211 return Err(IpcError::MessageTooLarge(serialized.len(), MAX_MESSAGE_SIZE));
212 }
213
214 if let Err(_) = timeout(self.timeout, async {
216 let len = serialized.len() as u32;
218 stream.write_all(&len.to_le_bytes()).await?;
219
220 stream.write_all(&serialized).await?;
222 stream.flush().await?;
223
224 Ok::<(), io::Error>(())
225 }).await {
226 return Err(IpcError::OperationTimeout(self.timeout.as_millis() as u64));
227 }
228
229 let response = timeout(self.timeout, async {
231 let mut len_bytes = [0u8; 4];
233 stream.read_exact(&mut len_bytes).await?;
234 let response_len = u32::from_le_bytes(len_bytes) as usize;
235
236 if response_len > MAX_MESSAGE_SIZE {
238 return Err(IpcError::MessageTooLarge(response_len, MAX_MESSAGE_SIZE));
239 }
240
241 let mut buffer = vec![0u8; response_len];
243 stream.read_exact(&mut buffer).await?;
244
245 bincode::deserialize(&buffer)
247 .map_err(|e| IpcError::Serialization(e.to_string()))
248 }).await;
249
250 match response {
251 Ok(Ok(resp)) => Ok(resp),
252 Ok(Err(e)) => Err(e),
253 Err(_) => Err(IpcError::OperationTimeout(self.timeout.as_millis() as u64)),
254 }
255 }
256}
257
258pub async fn send(request: &Request) -> Result<Response, IpcError> {
281 let client = IpcClient::new();
282 client.send(request).await
283}
284
285pub async fn send_request(req: &Request) -> Result<Response, IpcError> {
299 let mut stream = timeout(
301 Duration::from_secs(2),
302 UnixStream::connect(DEFAULT_SOCKET_PATH)
303 )
304 .await
305 .map_err(|_| IpcError::Timeout)?
306 .map_err(IpcError::Connect)?;
307
308 let serialized = bincode::serialize(req).map_err(IpcError::Serialize)?;
310
311 if serialized.len() > MAX_MESSAGE_SIZE {
313 return Err(IpcError::MessageTooLarge(serialized.len(), MAX_MESSAGE_SIZE));
314 }
315
316 let len_prefix = (serialized.len() as u32).to_le_bytes();
318 timeout(
319 Duration::from_secs(2),
320 stream.write_all(&len_prefix)
321 )
322 .await
323 .map_err(|_| IpcError::Timeout)?
324 .map_err(IpcError::Send)?;
325
326 timeout(
327 Duration::from_secs(2),
328 stream.write_all(&serialized)
329 )
330 .await
331 .map_err(|_| IpcError::Timeout)?
332 .map_err(IpcError::Send)?;
333
334 let mut response_len_bytes = [0u8; 4];
336 timeout(
337 Duration::from_secs(2),
338 stream.read_exact(&mut response_len_bytes)
339 )
340 .await
341 .map_err(|_| IpcError::Timeout)?
342 .map_err(IpcError::Receive)?;
343
344 let response_len = u32::from_le_bytes(response_len_bytes) as usize;
345
346 if response_len > MAX_MESSAGE_SIZE {
348 return Err(IpcError::MessageTooLarge(response_len, MAX_MESSAGE_SIZE));
349 }
350
351 let mut response_buffer = vec![0u8; response_len];
353 timeout(
354 Duration::from_secs(2),
355 stream.read_exact(&mut response_buffer)
356 )
357 .await
358 .map_err(|_| IpcError::Timeout)?
359 .map_err(IpcError::Receive)?;
360
361 bincode::deserialize(&response_buffer).map_err(IpcError::Deserialize)
363}
364
365pub async fn send_to_path<P: AsRef<Path>>(request: &Request, socket_path: P) -> Result<Response, IpcError> {
376 let client = IpcClient::with_socket_path(socket_path);
377 client.send(request).await
378}
379
380pub async fn send_with_timeout(request: &Request, timeout_ms: u64) -> Result<Response, IpcError> {
391 let client = IpcClient::new().with_timeout(timeout_ms);
392 client.send(request).await
393}
394
395pub async fn is_daemon_running<P: AsRef<Path>>(socket_path: Option<P>) -> bool {
405 let path = socket_path.map(|p| p.as_ref().to_string_lossy().to_string())
406 .unwrap_or_else(|| DEFAULT_SOCKET_PATH.to_string());
407
408 match UnixStream::connect(path).await {
409 Ok(_) => true,
410 Err(_) => false,
411 }
412}
413
414pub async fn get_analog_calibration(
439 device_id: &str,
440 layer_id: usize,
441) -> Result<AnalogCalibrationConfig, IpcError> {
442 let request = Request::GetAnalogCalibration {
443 device_id: device_id.to_string(),
444 layer_id,
445 };
446
447 match send(&request).await? {
448 Response::AnalogCalibration { calibration: Some(cal), .. } => Ok(cal),
449 Response::AnalogCalibration { calibration: None, .. } => {
450 Ok(AnalogCalibrationConfig::default())
452 }
453 Response::Error(msg) => Err(IpcError::Other(msg)),
454 _ => Err(IpcError::InvalidResponse),
455 }
456}
457
458pub async fn set_analog_calibration(
497 device_id: &str,
498 layer_id: usize,
499 calibration: AnalogCalibrationConfig,
500) -> Result<(), IpcError> {
501 let request = Request::SetAnalogCalibration {
502 device_id: device_id.to_string(),
503 layer_id,
504 calibration,
505 };
506
507 match send(&request).await? {
508 Response::AnalogCalibrationAck => Ok(()),
509 Response::Error(msg) => Err(IpcError::Other(msg)),
510 _ => Err(IpcError::InvalidResponse),
511 }
512}
513
514pub async fn get_auto_switch_rules() -> Result<Vec<crate::AutoSwitchRule>, IpcError> {
536 let request = Request::GetAutoSwitchRules;
537
538 match send(&request).await? {
539 Response::AutoSwitchRules { rules } => Ok(rules),
540 Response::Error(msg) => Err(IpcError::Other(msg)),
541 _ => Err(IpcError::InvalidResponse),
542 }
543}
544
545pub async fn get_macro_settings() -> Result<crate::MacroSettings, IpcError> {
547 IpcClient::new().get_macro_settings().await
548}
549
550pub async fn set_macro_settings(settings: crate::MacroSettings) -> Result<(), IpcError> {
552 IpcClient::new().set_macro_settings(settings).await
553}
554
555pub fn serialize<T: Serialize>(msg: &T) -> Result<Vec<u8>, IpcError> {
557 bincode::serialize(msg)
558 .map_err(|e| IpcError::Serialization(e.to_string()))
559}
560
561pub fn deserialize<T: DeserializeOwned>(bytes: &[u8]) -> Result<T, IpcError> {
563 bincode::deserialize(bytes)
564 .map_err(|e| IpcError::Serialization(e.to_string()))
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::{Request, Response, DeviceInfo, DeviceType, Action, KeyCombo, MacroEntry};
571 use std::path::PathBuf;
572 use tempfile::TempDir;
573 use tokio::net::UnixListener;
574 use tokio::io::{AsyncReadExt, AsyncWriteExt};
575
576 async fn mock_daemon(socket_path: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
578 if Path::new(socket_path).exists() {
580 std::fs::remove_file(socket_path)?;
581 }
582
583 let listener = UnixListener::bind(socket_path)?;
584
585 loop {
586 match listener.accept().await {
587 Ok((mut stream, _)) => {
588 tokio::spawn(async move {
590 let mut len_buf = [0u8; 4];
592 if let Err(_) = stream.read_exact(&mut len_buf).await {
593 return;
594 }
595
596 let msg_len = u32::from_le_bytes(len_buf) as usize;
597 if msg_len > MAX_MESSAGE_SIZE {
598 return;
599 }
600
601 let mut msg_buf = vec![0u8; msg_len];
602 if let Err(_) = stream.read_exact(&mut msg_buf).await {
603 return;
604 }
605
606 let request: Request = match bincode::deserialize(&msg_buf) {
608 Ok(req) => req,
609 Err(_) => return,
610 };
611
612 let response = match request {
614 Request::GetDevices => {
615 let devices = vec![
616 DeviceInfo {
617 name: "Test Device".to_string(),
618 path: PathBuf::from("/dev/input/event0"),
619 vendor_id: 0x1234,
620 product_id: 0x5678,
621 phys: "usb-0000:00:14.0-1/input0".to_string(),
622 device_type: DeviceType::Other,
623 }
624 ];
625 Response::Devices(devices)
626 },
627 Request::ListMacros => {
628 let macros = vec![
629 MacroEntry {
630 name: "Test Macro".to_string(),
631 trigger: KeyCombo {
632 keys: vec![30], modifiers: vec![],
634 },
635 actions: vec![
636 Action::KeyPress(31), Action::Delay(100),
638 Action::KeyRelease(31), ],
640 device_id: None,
641 enabled: true,
642 humanize: false,
643 capture_mouse: false,
644 }
645 ];
646 Response::Macros(macros)
647 },
648 Request::GetStatus => {
649 Response::Status {
650 version: "0.1.0".to_string(),
651 uptime_seconds: 60,
652 devices_count: 1,
653 macros_count: 1,
654 }
655 },
656 _ => Response::Error("Unsupported request in test".to_string()),
657 };
658
659 let response_bytes = bincode::serialize(&response).unwrap();
661 let len = response_bytes.len() as u32;
662
663 if let Err(_) = stream.write_all(&len.to_le_bytes()).await {
664 return;
665 }
666
667 if let Err(_) = stream.write_all(&response_bytes).await {
668 return;
669 }
670
671 let _ = stream.flush().await;
672 });
673 }
674 Err(e) => {
675 tracing::error!("Failed to accept connection: {}", e);
676 break;
677 }
678 }
679 }
680
681 Ok(())
682 }
683
684 #[tokio::test]
685 async fn test_ipc_client_creation() {
686 let client = IpcClient::new();
687 assert_eq!(client.socket_path, DEFAULT_SOCKET_PATH);
688 assert_eq!(client.timeout, Duration::from_millis(DEFAULT_TIMEOUT_MS));
689 assert_eq!(client.max_retries, DEFAULT_MAX_RETRIES);
690 assert_eq!(client.retry_delay, Duration::from_millis(DEFAULT_RETRY_DELAY_MS));
691
692 let custom_path = "/tmp/test.sock";
693 let custom_client = IpcClient::with_socket_path(custom_path)
694 .with_timeout(10000)
695 .with_retry_params(5, 2000);
696
697 assert_eq!(custom_client.socket_path, custom_path);
698 assert_eq!(custom_client.timeout, Duration::from_millis(10000));
699 assert_eq!(custom_client.max_retries, 5);
700 assert_eq!(custom_client.retry_delay, Duration::from_millis(2000));
701 }
702
703 #[tokio::test]
704 async fn test_serialization_deserialization() {
705 let request = Request::GetDevices;
706 let serialized = serialize(&request).unwrap();
707 let deserialized: Request = deserialize(&serialized).unwrap();
708 assert!(matches!(deserialized, Request::GetDevices));
709
710 let macro_entry = MacroEntry {
711 name: "Test Macro".to_string(),
712 trigger: KeyCombo {
713 keys: vec![30, 40], modifiers: vec![29], },
716 actions: vec![
717 Action::KeyPress(30),
718 Action::Delay(100),
719 Action::KeyRelease(30),
720 ],
721 device_id: Some("test_device".to_string()),
722 enabled: true,
723 humanize: false,
724 capture_mouse: false,
725 };
726
727 let serialized = serialize(¯o_entry).unwrap();
728 let deserialized: MacroEntry = deserialize(&serialized).unwrap();
729 assert_eq!(deserialized.name, "Test Macro");
730 assert_eq!(deserialized.trigger.keys, vec![30, 40]);
731 }
732
733 #[tokio::test]
734 async fn test_client_server_communication() {
735 let temp_dir = TempDir::new().unwrap();
737 let socket_path = temp_dir.path().join("test.sock");
738 let socket_path_str = socket_path.to_string_lossy().to_string();
739 let socket_path_clone = socket_path_str.clone();
740
741 tokio::spawn(async move {
743 mock_daemon(&socket_path_clone).await
744 });
745
746 tokio::time::sleep(Duration::from_millis(100)).await;
748
749 let client = IpcClient::with_socket_path(&socket_path_str);
751
752 assert!(client.is_daemon_running().await);
754
755 let response = client.send(&Request::GetDevices).await.unwrap();
757 if let Response::Devices(devices) = response {
758 assert_eq!(devices.len(), 1);
759 assert_eq!(devices[0].name, "Test Device");
760 } else {
761 panic!("Expected Devices response");
762 }
763
764 let response = client.send(&Request::ListMacros).await.unwrap();
766 if let Response::Macros(macros) = response {
767 assert_eq!(macros.len(), 1);
768 assert_eq!(macros[0].name, "Test Macro");
769 } else {
770 panic!("Expected Macros response");
771 }
772
773 let response = client.send(&Request::GetStatus).await.unwrap();
775 if let Response::Status { version, uptime_seconds, devices_count, macros_count } = response {
776 assert_eq!(version, "0.1.0");
777 assert_eq!(uptime_seconds, 60);
778 assert_eq!(devices_count, 1);
779 assert_eq!(macros_count, 1);
780 } else {
781 panic!("Expected Status response");
782 }
783
784 let response = send_to_path(&Request::GetDevices, &socket_path_str).await.unwrap();
786 if let Response::Devices(devices) = response {
787 assert_eq!(devices.len(), 1);
788 } else {
789 panic!("Expected Devices response");
790 }
791 }
792
793 #[tokio::test]
794 async fn test_connection_timeout() {
795 let client = IpcClient::with_socket_path("/tmp/nonexistent.sock")
797 .with_timeout(100) .with_retry_params(1, 100); match client.send(&Request::GetDevices).await {
802 Err(IpcError::DaemonNotRunning(_)) | Err(IpcError::ConnectionTimeout) => {
803 },
805 _ => panic!("Expected DaemonNotRunning or ConnectionTimeout error"),
806 }
807 }
808
809 #[tokio::test]
810 async fn test_is_daemon_running() {
811 assert!(!is_daemon_running(Some("/tmp/nonexistent.sock")).await);
813
814 assert!(!is_daemon_running(None::<&str>).await);
816 }
817
818 #[test]
819 fn test_serialization_roundtrip() {
820 let request = Request::GetDevices;
822 let serialized = bincode::serialize(&request).map_err(IpcError::Serialize).unwrap();
823 let deserialized: Request = bincode::deserialize(&serialized).map_err(IpcError::Deserialize).unwrap();
824 assert!(matches!(deserialized, Request::GetDevices));
825
826 let devices = vec![
828 DeviceInfo {
829 name: "Test Device".to_string(),
830 path: std::path::PathBuf::from("/dev/input/test"),
831 vendor_id: 0x1532,
832 product_id: 0x0221,
833 phys: "usb-0000:00:14.0-1/input0".to_string(),
834 device_type: DeviceType::Other,
835 }
836 ];
837 let response = Response::Devices(devices.clone());
838 let serialized = bincode::serialize(&response).map_err(IpcError::Serialize).unwrap();
839 let deserialized: Response = bincode::deserialize(&serialized).map_err(IpcError::Deserialize).unwrap();
840
841 if let Response::Devices(deserialized_devices) = deserialized {
842 assert_eq!(deserialized_devices.len(), devices.len());
843 assert_eq!(deserialized_devices[0].name, devices[0].name);
844 assert_eq!(deserialized_devices[0].vendor_id, devices[0].vendor_id);
845 } else {
846 panic!("Expected Devices response");
847 }
848 }
849
850 #[test]
851 fn test_send_request_error_handling() {
852 let request = Request::GetDevices;
858 let result = bincode::serialize(&request);
859 assert!(result.is_ok());
860
861 let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF];
863 let result: Result<Request, bincode::Error> = bincode::deserialize(&invalid_data);
864 assert!(result.is_err());
865
866 let _serialized = bincode::serialize(&request).unwrap();
868 let error = bincode::deserialize::<Request>(&invalid_data).map_err(IpcError::Deserialize);
869 assert!(matches!(error, Err(IpcError::Deserialize(_))));
870 }
871}