1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use mio::net::{UnixListener, UnixStream};
use mio::{Events, Interest, Poll, Token};
use serde::{Deserialize, Serialize};
use std::fs::{remove_file, set_permissions, Permissions};
use std::io::{self, Read, Write};
use std::marker::PhantomData;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::time::Duration;

pub trait IpcServerCommand: Serialize + for<'a> Deserialize<'a> + std::fmt::Debug {
    type Response: Serialize + for<'a> Deserialize<'a> + std::fmt::Debug;
    type Context<'a>;

    fn process<'a, 'b>(self, context: &'b mut Self::Context<'a>) -> Self::Response;
}

pub struct IpcServer<C: IpcServerCommand> {
    listener: UnixListener,
    poll: Poll,
    events: Events,
    _command: PhantomData<C>,
}

impl<C: IpcServerCommand> IpcServer<C> {
    /// Initialize a new IpcServer. Recall that there is no dedicated server
    /// thread. You must call `handle_new_messages` to poll for and process
    /// new messages
    pub fn new(socket_path: &str) -> io::Result<IpcServer<C>> {
        if Path::new(socket_path).exists() {
            remove_file(socket_path)?;
        }

        let mut listener = UnixListener::bind(socket_path)?;
        // Restrict permissions to owner read/write only
        set_permissions(socket_path, Permissions::from_mode(0o600))?;

        let poll = Poll::new()?;
        let events = Events::with_capacity(128);

        poll.registry()
            .register(&mut listener, Token(0), Interest::READABLE)?;

        Ok(IpcServer::<C> {
            listener,
            poll,
            events,
            _command: Default::default(),
        })
    }

    /// Polls for new messages from any clients, and processes and responds.
    pub fn handle_new_messages<'a>(&mut self, mut context: C::Context<'a>) -> io::Result<()> {
        self.poll
            .poll(&mut self.events, Some(Duration::from_millis(1)))?;

        for event in self.events.iter() {
            match event.token() {
                Token(0) => loop {
                    match self.listener.accept() {
                        Ok((mut stream, _)) => {
                            let mut buffer = [0; 1024];
                            match stream.read(&mut buffer) {
                                Ok(bytes_read) => {
                                    let payload = String::from_utf8_lossy(&buffer[..bytes_read]);
                                    let command =
                                        serde_json::from_str::<C>(&payload).map_err(|e| {
                                            io::Error::new(io::ErrorKind::InvalidData, e)
                                        })?;
                                    self.process_command(command, &mut context, &mut stream)?;
                                }
                                Err(err) => {
                                    eprintln!("Failed to read from connection: {}", err);
                                    break;
                                }
                            }
                        }
                        Err(ref err) if would_block(err) => break,
                        Err(err) => {
                            eprintln!("Failed to accept connection: {}", err);
                            break;
                        }
                    }
                },
                _ => unreachable!(),
            }
        }

        Ok(())
    }

    #[inline(always)]
    fn process_command<'a, 'b>(
        &self,
        command: C,
        context: &'b mut C::Context<'a>,
        stream: &mut UnixStream,
    ) -> io::Result<()> {
        let response = command.process(context);
        loop {
            match serde_json::to_writer(&mut *stream, &response)
                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
            {
                Ok(()) => return Ok(()),
                Err(ref err) if would_block(err) => {
                    // Spin loop is okay here.
                    // IPC server is not intended for large payloads or high volumes.
                    std::hint::spin_loop();
                    continue;
                }
                e => return e,
            }
        }
    }
}

fn would_block(err: &std::io::Error) -> bool {
    err.kind() == std::io::ErrorKind::WouldBlock
}

/// Serialize and write the `command` provided to the `UnixStream` at the
/// `socket_path` provided. If there is an active `IpcServer`, it will receive
/// and process this command upon polling.
pub fn client_send<C: IpcServerCommand>(command: &C, socket_path: &str) {
    let mut stream = UnixStream::connect(socket_path).unwrap();
    let payload = serde_json::to_string(command).unwrap();
    stream.write_all(payload.as_bytes()).unwrap();
    println!("sent command: {:?}", command);

    loop {
        let mut buffer = [0; 1024];
        match stream.read(&mut buffer) {
            Ok(bytes_read) => {
                let response_str = String::from_utf8_lossy(&buffer[..bytes_read]);
                if let Ok(response) = serde_json::from_str::<C::Response>(&response_str) {
                    println!("received response: {:?}", response);
                } else {
                    eprintln!("failed to parse response: {}", response_str);
                }
                return;
            }
            Err(ref err) if would_block(&err) => {
                #[allow(deprecated)]
                std::thread::sleep_ms(1);
                continue;
            }
            Err(err) => {
                eprintln!("failed to read response: {} {}", err, err.kind());
                return;
            }
        }
    }
}