1use crate::error::{IpcError, Result};
37use crate::graceful::{GracefulChannel, ShutdownState};
38use crate::local_socket::{LocalSocketListener, LocalSocketStream};
39use parking_lot::RwLock;
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use std::io::{Read, Write};
43use std::sync::atomic::{AtomicU64, Ordering};
44use std::sync::Arc;
45use std::thread::JoinHandle;
46use std::time::{Duration, SystemTime};
47
48pub type ConnectionId = u64;
50
51#[derive(Debug, Clone)]
53pub struct SocketServerConfig {
54 pub path: String,
56 pub max_connections: usize,
58 pub connection_timeout: Duration,
60 pub cleanup_on_start: bool,
62 pub buffer_size: usize,
64}
65
66impl Default for SocketServerConfig {
67 fn default() -> Self {
68 Self {
69 path: default_socket_path(),
70 max_connections: 100,
71 connection_timeout: Duration::from_secs(30),
72 cleanup_on_start: true,
73 buffer_size: 8192,
74 }
75 }
76}
77
78impl SocketServerConfig {
79 pub fn with_path(path: &str) -> Self {
81 Self {
82 path: path.to_string(),
83 ..Default::default()
84 }
85 }
86}
87
88pub fn default_socket_path() -> String {
90 #[cfg(unix)]
91 {
92 let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".to_string());
93 format!("{}/ipckit.sock", runtime_dir)
94 }
95 #[cfg(windows)]
96 {
97 r"\\.\pipe\ipckit".to_string()
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ConnectionMetadata {
104 #[serde(with = "system_time_serde")]
106 pub connected_at: SystemTime,
107 pub client_pid: Option<u32>,
109 pub client_info: Option<String>,
111}
112
113mod system_time_serde {
114 use serde::{Deserialize, Deserializer, Serialize, Serializer};
115 use std::time::{Duration, SystemTime, UNIX_EPOCH};
116
117 pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
118 where
119 S: Serializer,
120 {
121 let duration = time.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
122 duration.as_secs_f64().serialize(serializer)
123 }
124
125 pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
126 where
127 D: Deserializer<'de>,
128 {
129 let secs = f64::deserialize(deserializer)?;
130 Ok(UNIX_EPOCH + Duration::from_secs_f64(secs))
131 }
132}
133
134impl Default for ConnectionMetadata {
135 fn default() -> Self {
136 Self {
137 connected_at: SystemTime::now(),
138 client_pid: None,
139 client_info: None,
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct Message {
147 pub msg_type: MessageType,
149 pub payload: serde_json::Value,
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
155#[serde(rename_all = "lowercase")]
156pub enum MessageType {
157 Text,
159 Binary,
161 Request,
163 Response,
165 Error,
167 Ping,
169 Pong,
171}
172
173impl Message {
174 pub fn text(content: &str) -> Self {
176 Self {
177 msg_type: MessageType::Text,
178 payload: serde_json::json!({ "content": content }),
179 }
180 }
181
182 pub fn request(method: &str, params: serde_json::Value) -> Self {
184 Self {
185 msg_type: MessageType::Request,
186 payload: serde_json::json!({
187 "method": method,
188 "params": params
189 }),
190 }
191 }
192
193 pub fn response(result: serde_json::Value) -> Self {
195 Self {
196 msg_type: MessageType::Response,
197 payload: serde_json::json!({ "result": result }),
198 }
199 }
200
201 pub fn error(code: i32, message: &str) -> Self {
203 Self {
204 msg_type: MessageType::Error,
205 payload: serde_json::json!({
206 "code": code,
207 "message": message
208 }),
209 }
210 }
211
212 pub fn ping() -> Self {
214 Self {
215 msg_type: MessageType::Ping,
216 payload: serde_json::json!({}),
217 }
218 }
219
220 pub fn pong() -> Self {
222 Self {
223 msg_type: MessageType::Pong,
224 payload: serde_json::json!({}),
225 }
226 }
227
228 pub fn json(value: serde_json::Value) -> Self {
230 Self {
231 msg_type: MessageType::Text,
232 payload: value,
233 }
234 }
235
236 pub fn binary(data: Vec<u8>) -> Self {
238 Self {
239 msg_type: MessageType::Binary,
240 payload: serde_json::json!({
241 "data": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &data)
242 }),
243 }
244 }
245
246 pub fn as_binary(&self) -> Option<Vec<u8>> {
248 self.payload
249 .get("data")
250 .and_then(|v| v.as_str())
251 .and_then(|s| {
252 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s).ok()
253 })
254 }
255
256 pub fn as_text(&self) -> Option<&str> {
258 self.payload.get("content").and_then(|v| v.as_str())
259 }
260
261 pub fn method(&self) -> Option<&str> {
263 self.payload.get("method").and_then(|v| v.as_str())
264 }
265
266 pub fn params(&self) -> Option<&serde_json::Value> {
268 self.payload.get("params")
269 }
270
271 pub fn result(&self) -> Option<&serde_json::Value> {
273 self.payload.get("result")
274 }
275}
276
277pub struct Connection {
279 id: ConnectionId,
280 stream: LocalSocketStream,
281 metadata: ConnectionMetadata,
282 buffer: Vec<u8>,
283}
284
285impl Connection {
286 fn new(id: ConnectionId, stream: LocalSocketStream) -> Self {
288 Self {
289 id,
290 stream,
291 metadata: ConnectionMetadata::default(),
292 buffer: Vec::with_capacity(8192),
293 }
294 }
295
296 pub fn id(&self) -> ConnectionId {
298 self.id
299 }
300
301 pub fn metadata(&self) -> &ConnectionMetadata {
303 &self.metadata
304 }
305
306 pub fn set_client_info(&mut self, info: &str) {
308 self.metadata.client_info = Some(info.to_string());
309 }
310
311 pub fn send(&mut self, msg: &Message) -> Result<()> {
313 let data = serde_json::to_vec(msg).map_err(|e| IpcError::serialization(e.to_string()))?;
314
315 let len = data.len() as u32;
317 self.stream.write_all(&len.to_le_bytes())?;
318
319 self.stream.write_all(&data)?;
321 self.stream.flush()?;
322
323 Ok(())
324 }
325
326 pub fn recv(&mut self) -> Result<Message> {
328 let mut len_buf = [0u8; 4];
330 self.stream.read_exact(&mut len_buf)?;
331 let len = u32::from_le_bytes(len_buf) as usize;
332
333 if len > 16 * 1024 * 1024 {
335 return Err(IpcError::BufferTooSmall {
336 needed: len,
337 got: 16 * 1024 * 1024,
338 });
339 }
340
341 self.buffer.resize(len, 0);
343 self.stream.read_exact(&mut self.buffer)?;
344
345 serde_json::from_slice(&self.buffer).map_err(|e| IpcError::deserialization(e.to_string()))
347 }
348
349 pub fn try_recv(&mut self) -> Result<Option<Message>> {
354 Err(IpcError::WouldBlock)
358 }
359
360 pub fn request(
362 &mut self,
363 method: &str,
364 params: serde_json::Value,
365 ) -> Result<serde_json::Value> {
366 self.send(&Message::request(method, params))?;
367 let response = self.recv()?;
368
369 match response.msg_type {
370 MessageType::Response => response
371 .result()
372 .cloned()
373 .ok_or_else(|| IpcError::deserialization("Missing result in response".to_string())),
374 MessageType::Error => {
375 let msg = response
376 .payload
377 .get("message")
378 .and_then(|v| v.as_str())
379 .unwrap_or("Unknown error");
380 Err(IpcError::Other(msg.to_string()))
381 }
382 _ => Err(IpcError::deserialization(
383 "Unexpected message type".to_string(),
384 )),
385 }
386 }
387}
388
389pub trait ConnectionHandler: Clone + Send + 'static {
391 fn on_connect(&self, conn: &mut Connection) -> Result<()> {
393 let _ = conn;
394 Ok(())
395 }
396
397 fn on_message(&self, conn: &mut Connection, msg: Message) -> Result<Option<Message>>;
399
400 fn on_disconnect(&self, conn_id: ConnectionId) {
402 let _ = conn_id;
403 }
404}
405
406#[derive(Clone)]
408pub struct FnHandler<F>
409where
410 F: Fn(&mut Connection, Message) -> Result<Option<Message>> + Clone + Send + 'static,
411{
412 handler: F,
413}
414
415impl<F> FnHandler<F>
416where
417 F: Fn(&mut Connection, Message) -> Result<Option<Message>> + Clone + Send + 'static,
418{
419 pub fn new(handler: F) -> Self {
421 Self { handler }
422 }
423}
424
425impl<F> ConnectionHandler for FnHandler<F>
426where
427 F: Fn(&mut Connection, Message) -> Result<Option<Message>> + Clone + Send + 'static,
428{
429 fn on_message(&self, conn: &mut Connection, msg: Message) -> Result<Option<Message>> {
430 (self.handler)(conn, msg)
431 }
432}
433
434pub struct SocketServer {
436 config: SocketServerConfig,
437 listener: LocalSocketListener,
438 connections: Arc<RwLock<HashMap<ConnectionId, Arc<RwLock<Connection>>>>>,
439 shutdown: Arc<ShutdownState>,
440 next_id: AtomicU64,
441}
442
443impl SocketServer {
444 pub fn new(config: SocketServerConfig) -> Result<Self> {
446 #[cfg(unix)]
448 if config.cleanup_on_start && !config.path.starts_with(r"\\.\pipe\") {
449 let _ = std::fs::remove_file(&config.path);
450 }
451
452 let listener = LocalSocketListener::bind(&config.path)?;
453
454 Ok(Self {
455 config,
456 listener,
457 connections: Arc::new(RwLock::new(HashMap::new())),
458 shutdown: Arc::new(ShutdownState::new()),
459 next_id: AtomicU64::new(1),
460 })
461 }
462
463 pub fn with_defaults() -> Result<Self> {
465 Self::new(SocketServerConfig::default())
466 }
467
468 pub fn at(path: &str) -> Result<Self> {
470 Self::new(SocketServerConfig::with_path(path))
471 }
472
473 pub fn socket_path(&self) -> &str {
475 &self.config.path
476 }
477
478 pub fn connection_count(&self) -> usize {
480 self.connections.read().len()
481 }
482
483 pub fn accept(&self) -> Result<Connection> {
485 if self.shutdown.is_shutdown() {
486 return Err(IpcError::Closed);
487 }
488
489 let stream = self.listener.accept()?;
490 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
491 let conn = Connection::new(id, stream);
492
493 self.connections
494 .write()
495 .insert(id, Arc::new(RwLock::new(conn)));
496
497 let stream = self.listener.accept()?;
499 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
500
501 Ok(Connection::new(id, stream))
502 }
503
504 pub fn incoming(&self) -> impl Iterator<Item = Result<Connection>> + '_ {
506 std::iter::from_fn(move || {
507 if self.shutdown.is_shutdown() {
508 return None;
509 }
510
511 match self.listener.accept() {
512 Ok(stream) => {
513 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
514 Some(Ok(Connection::new(id, stream)))
515 }
516 Err(e) => Some(Err(e)),
517 }
518 })
519 }
520
521 pub fn run<H: ConnectionHandler>(&self, handler: H) -> Result<()> {
523 for conn_result in self.incoming() {
524 if self.shutdown.is_shutdown() {
525 break;
526 }
527
528 match conn_result {
529 Ok(mut conn) => {
530 let handler = handler.clone();
531 let shutdown = Arc::clone(&self.shutdown);
532
533 std::thread::spawn(move || {
534 if let Err(e) = handler.on_connect(&mut conn) {
535 tracing::error!("Connection error: {}", e);
536 return;
537 }
538
539 loop {
540 if shutdown.is_shutdown() {
541 break;
542 }
543
544 match conn.recv() {
545 Ok(msg) => match handler.on_message(&mut conn, msg) {
546 Ok(Some(response)) => {
547 if let Err(e) = conn.send(&response) {
548 tracing::error!("Send error: {}", e);
549 break;
550 }
551 }
552 Ok(None) => {}
553 Err(e) => {
554 tracing::error!("Handler error: {}", e);
555 let _ = conn.send(&Message::error(-1, &e.to_string()));
556 }
557 },
558 Err(IpcError::Io(ref e))
559 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
560 {
561 break;
562 }
563 Err(e) => {
564 tracing::error!("Receive error: {}", e);
565 break;
566 }
567 }
568 }
569
570 handler.on_disconnect(conn.id());
571 });
572 }
573 Err(e) => {
574 tracing::error!("Accept error: {}", e);
575 }
576 }
577 }
578
579 Ok(())
580 }
581
582 pub fn spawn<H: ConnectionHandler>(self, handler: H) -> JoinHandle<Result<()>> {
584 std::thread::spawn(move || self.run(handler))
585 }
586
587 pub fn shutdown(&self) {
589 self.shutdown.shutdown();
590 }
591
592 pub fn is_shutdown(&self) -> bool {
594 self.shutdown.is_shutdown()
595 }
596}
597
598impl GracefulChannel for SocketServer {
599 fn shutdown(&self) {
600 self.shutdown.shutdown();
601 }
602
603 fn is_shutdown(&self) -> bool {
604 self.shutdown.is_shutdown()
605 }
606
607 fn drain(&self) -> Result<()> {
608 self.shutdown.wait_for_drain(None)
609 }
610
611 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
612 self.shutdown();
613 self.shutdown.wait_for_drain(Some(timeout))
614 }
615}
616
617pub struct SocketClient {
619 connection: Connection,
620}
621
622impl SocketClient {
623 pub fn connect(path: &str) -> Result<Self> {
625 let stream = LocalSocketStream::connect(path)?;
626 let connection = Connection::new(0, stream);
627
628 Ok(Self { connection })
629 }
630
631 pub fn connect_timeout(path: &str, timeout: Duration) -> Result<Self> {
637 use std::sync::mpsc;
638 use std::thread;
639
640 let path_owned = path.to_string();
641 let (tx, rx) = mpsc::channel();
642
643 thread::spawn(move || {
645 let result = LocalSocketStream::connect(&path_owned);
646 let _ = tx.send(result);
647 });
648
649 match rx.recv_timeout(timeout) {
651 Ok(Ok(stream)) => {
652 let connection = Connection::new(0, stream);
653 Ok(Self { connection })
654 }
655 Ok(Err(e)) => Err(e),
656 Err(_) => Err(IpcError::Timeout),
657 }
658 }
659
660 pub fn connect_default() -> Result<Self> {
662 Self::connect(&default_socket_path())
663 }
664
665 pub fn connect_default_timeout(timeout: Duration) -> Result<Self> {
667 Self::connect_timeout(&default_socket_path(), timeout)
668 }
669
670 pub fn send(&mut self, msg: &Message) -> Result<()> {
672 self.connection.send(msg)
673 }
674
675 pub fn recv(&mut self) -> Result<Message> {
677 self.connection.recv()
678 }
679
680 pub fn request(
682 &mut self,
683 method: &str,
684 params: serde_json::Value,
685 ) -> Result<serde_json::Value> {
686 self.connection.request(method, params)
687 }
688
689 pub fn connection(&mut self) -> &mut Connection {
691 &mut self.connection
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use std::thread;
699 use std::time::Duration;
700
701 #[test]
702 fn test_message_creation() {
703 let text = Message::text("Hello");
704 assert_eq!(text.msg_type, MessageType::Text);
705 assert_eq!(text.as_text(), Some("Hello"));
706
707 let request = Message::request("ping", serde_json::json!({}));
708 assert_eq!(request.msg_type, MessageType::Request);
709 assert_eq!(request.method(), Some("ping"));
710
711 let response = Message::response(serde_json::json!({"pong": true}));
712 assert_eq!(response.msg_type, MessageType::Response);
713 assert!(response.result().is_some());
714
715 let error = Message::error(404, "Not found");
716 assert_eq!(error.msg_type, MessageType::Error);
717 }
718
719 #[test]
720 fn test_message_serialization() {
721 let msg = Message::request("test", serde_json::json!({"key": "value"}));
722 let json = serde_json::to_string(&msg).unwrap();
723 let deserialized: Message = serde_json::from_str(&json).unwrap();
724
725 assert_eq!(deserialized.msg_type, msg.msg_type);
726 assert_eq!(deserialized.method(), msg.method());
727 }
728
729 #[test]
730 fn test_socket_server_config() {
731 let config = SocketServerConfig::default();
732 assert_eq!(config.max_connections, 100);
733 assert!(config.cleanup_on_start);
734
735 let custom = SocketServerConfig::with_path("/tmp/test.sock");
736 assert_eq!(custom.path, "/tmp/test.sock");
737 }
738
739 #[test]
740 fn test_connection_metadata() {
741 let metadata = ConnectionMetadata::default();
742 assert!(metadata.client_pid.is_none());
743 assert!(metadata.client_info.is_none());
744 }
745
746 #[test]
747 fn test_fn_handler() {
748 let handler = FnHandler::new(|_conn, msg| {
749 if msg.method() == Some("ping") {
750 Ok(Some(Message::response(serde_json::json!({"pong": true}))))
751 } else {
752 Ok(None)
753 }
754 });
755
756 let _handler2 = handler.clone();
758 }
759
760 #[test]
761 #[ignore] fn test_socket_client_server() {
763 use std::sync::atomic::{AtomicBool, Ordering};
764 use std::sync::Arc;
765
766 let socket_name = format!("test_socket_server_{}", std::process::id());
767 let server_ready = Arc::new(AtomicBool::new(false));
768 let server_ready_clone = server_ready.clone();
769
770 let socket_name_clone = socket_name.clone();
772 let server_handle = thread::spawn(move || {
773 let config = SocketServerConfig::with_path(&socket_name_clone);
774 let server = match SocketServer::new(config) {
775 Ok(s) => s,
776 Err(e) => {
777 eprintln!("Failed to create server: {}", e);
778 return;
779 }
780 };
781
782 server_ready_clone.store(true, Ordering::SeqCst);
784
785 if let Ok(mut conn) = server.accept() {
787 if let Ok(msg) = conn.recv() {
788 if msg.method() == Some("ping") {
789 conn.send(&Message::response(serde_json::json!({"pong": true})))
790 .ok();
791 }
792 }
793 }
794 });
795
796 let start = std::time::Instant::now();
798 while !server_ready.load(Ordering::SeqCst) {
799 if start.elapsed() > Duration::from_secs(5) {
800 panic!("Server failed to start within timeout");
801 }
802 thread::sleep(Duration::from_millis(10));
803 }
804
805 thread::sleep(Duration::from_millis(100));
807
808 let mut client = None;
810 for _ in 0..10 {
811 match SocketClient::connect(&socket_name) {
812 Ok(c) => {
813 client = Some(c);
814 break;
815 }
816 Err(_) => {
817 thread::sleep(Duration::from_millis(50));
818 }
819 }
820 }
821
822 let mut client = client.expect("Failed to connect to server");
823 let result = client.request("ping", serde_json::json!({})).unwrap();
824
825 assert_eq!(result["pong"], true);
826
827 server_handle.join().unwrap();
828 }
829}