1use mio::net::{UnixListener, UnixStream};
2use mio::{Events, Interest, Poll, Token};
3use serde::{Deserialize, Serialize};
4use std::fs::{remove_file, set_permissions, Permissions};
5use std::io::{self, Read};
6use std::marker::PhantomData;
7use std::os::unix::fs::PermissionsExt;
8use std::path::Path;
9use std::time::Duration;
10
11pub trait IpcServerCommand: Serialize + for<'a> Deserialize<'a> + std::fmt::Debug {
12 type Response: Serialize + for<'a> Deserialize<'a> + std::fmt::Debug;
13 type Context<'a>;
14
15 fn process<'a, 'b>(self, context: &'b mut Self::Context<'a>) -> Self::Response;
16}
17
18pub struct IpcServer<C: IpcServerCommand> {
19 listener: UnixListener,
20 poll: Poll,
21 events: Events,
22 _command: PhantomData<C>,
23}
24
25impl<C: IpcServerCommand> IpcServer<C> {
26 pub fn new(socket_path: &str) -> io::Result<IpcServer<C>> {
30 if Path::new(socket_path).exists() {
31 remove_file(socket_path)?;
32 }
33
34 let mut listener = UnixListener::bind(socket_path)?;
35 set_permissions(socket_path, Permissions::from_mode(0o600))?;
37
38 let poll = Poll::new()?;
39 let events = Events::with_capacity(128);
40
41 poll.registry()
42 .register(&mut listener, Token(0), Interest::READABLE)?;
43
44 Ok(IpcServer::<C> {
45 listener,
46 poll,
47 events,
48 _command: Default::default(),
49 })
50 }
51
52 pub fn handle_new_messages<'a>(&mut self, mut context: C::Context<'a>) -> io::Result<()> {
54 self.poll
55 .poll(&mut self.events, Some(Duration::from_nanos(10)))?;
56
57 for event in self.events.iter() {
58 match event.token() {
59 Token(0) => loop {
60 match self.listener.accept() {
61 Ok((mut stream, _)) => {
62 let mut buffer = [0; 1024];
63 match stream.read(&mut buffer) {
64 Ok(bytes_read) => {
65 let command = bincode::deserialize::<C>(&buffer[..bytes_read])
66 .map_err(|e| {
67 io::Error::new(io::ErrorKind::InvalidData, e)
68 })?;
69 self.process_command(command, &mut context, &mut stream)?;
70 }
71 Err(err) => {
72 eprintln!("Failed to read from connection: {}", err);
73 break;
74 }
75 }
76 }
77 Err(ref err) if would_block(err) => break,
78 Err(err) => {
79 eprintln!("Failed to accept connection: {}", err);
80 break;
81 }
82 }
83 },
84 _ => unreachable!(),
85 }
86 }
87
88 Ok(())
89 }
90
91 #[inline(always)]
92 fn process_command<'a, 'b>(
93 &self,
94 command: C,
95 context: &'b mut C::Context<'a>,
96 stream: &mut UnixStream,
97 ) -> io::Result<()> {
98 let response = command.process(context);
99 loop {
100 match bincode::serialize_into(&mut *stream, &response)
101 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
102 {
103 Ok(()) => return Ok(()),
104 Err(ref err) if would_block(err) => {
105 std::hint::spin_loop();
108 continue;
109 }
110 e => return e,
111 }
112 }
113 }
114}
115
116fn would_block(err: &std::io::Error) -> bool {
117 err.kind() == std::io::ErrorKind::WouldBlock
118}
119
120pub fn client_send<C: IpcServerCommand>(command: &C, socket_path: &str) {
124 let mut stream = UnixStream::connect(socket_path).unwrap();
125 bincode::serialize_into(&mut stream, command).unwrap();
126 println!("sent command: {:?}", command);
127
128 loop {
129 let mut buffer = [0; 1024];
130 match stream.read(&mut buffer) {
131 Ok(bytes_read) => {
132 if let Ok(response) = bincode::deserialize::<C::Response>(&buffer[..bytes_read]) {
133 println!("received response: {:?}", response);
134 } else {
135 eprintln!("failed to parse response: {:?}", &buffer[..bytes_read]);
136 }
137 return;
138 }
139 Err(ref err) if would_block(&err) => {
140 #[allow(deprecated)]
141 std::thread::sleep_ms(1);
142 continue;
143 }
144 Err(err) => {
145 eprintln!("failed to read response: {} {}", err, err.kind());
146 return;
147 }
148 }
149 }
150}