gistit_ipc/
lib.rs

1//
2//   ________.__          __  .__  __
3//  /  _____/|__| _______/  |_|__|/  |_
4// /   \  ___|  |/  ___/\   __\  \   __\
5// \    \_\  \  |\___ \  |  | |  ||  |
6//  \______  /__/____  > |__| |__||__|
7//         \/        \/
8//
9#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
10#![allow(clippy::module_name_repetitions)]
11#![cfg_attr(
12    test,
13    allow(
14        unused,
15        clippy::all,
16        clippy::pedantic,
17        clippy::nursery,
18        clippy::dbg_macro,
19        clippy::unwrap_used,
20        clippy::missing_docs_in_private_items,
21    )
22)]
23//! This is a simple crate to handle the inter process comms for gistit-daemon and gistit-cli
24//! TODO: Missing TCP socket implementation
25
26use std::fs::{metadata, remove_file};
27use std::marker::PhantomData;
28use std::path::{Path, PathBuf};
29use std::time::Instant;
30use tokio::net::UnixDatagram;
31
32use gistit_proto::bytes::BytesMut;
33use gistit_proto::prost::{self, Message};
34use gistit_proto::Instruction;
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38const NAMED_SOCKET_0: &str = "gistit-0";
39const NAMED_SOCKET_1: &str = "gistit-1";
40
41const READBUF_SIZE: usize = 60_000; // A bit bigger than 50kb because encoding
42const CONNECT_TIMEOUT_SECS: u64 = 3;
43
44pub trait SockEnd {}
45
46#[derive(Debug)]
47pub struct Server;
48impl SockEnd for Server {}
49
50#[derive(Debug)]
51pub struct Client;
52impl SockEnd for Client {}
53
54#[derive(Debug)]
55pub struct Bridge<T: SockEnd> {
56    pub sock_0: UnixDatagram,
57    pub sock_1: UnixDatagram,
58    base: PathBuf,
59    __marker_t: PhantomData<T>,
60}
61
62/// Recv from [`NAMED_SOCKET_0`] and send to [`NAMED_SOCKET_1`]
63/// The owner of `sock_0`
64///
65/// # Errors
66///
67/// Fails if can't spawn a named socket
68pub fn server(base: &Path) -> Result<Bridge<Server>> {
69    let sockpath_0 = &base.join(NAMED_SOCKET_0);
70
71    if metadata(sockpath_0).is_ok() {
72        remove_file(sockpath_0)?;
73    }
74
75    log::trace!("Bind sock_0 (server) at {:?}", sockpath_0);
76    let sock_0 = UnixDatagram::bind(sockpath_0)?;
77
78    Ok(Bridge {
79        sock_0,
80        sock_1: UnixDatagram::unbound()?,
81        base: base.to_path_buf(),
82        __marker_t: PhantomData,
83    })
84}
85
86/// Recv from [`NAMED_SOCKET_1`] and send to [`NAMED_SOCKET_0`]
87/// The owner of `sock_1`
88///
89/// # Errors
90///
91/// Fails if can't spawn a named socket
92pub fn client(base: &Path) -> Result<Bridge<Client>> {
93    let sockpath_1 = &base.join(NAMED_SOCKET_1);
94
95    if metadata(sockpath_1).is_ok() {
96        remove_file(sockpath_1)?;
97    }
98
99    log::trace!("Bind sock_1 (client) at {:?}", sockpath_1);
100    let sock_1 = UnixDatagram::bind(sockpath_1)?;
101
102    Ok(Bridge {
103        sock_0: UnixDatagram::unbound()?,
104        sock_1,
105        base: base.to_path_buf(),
106        __marker_t: PhantomData,
107    })
108}
109
110fn __alive(base: &Path, dgram: &UnixDatagram, sock_name: &str) -> bool {
111    !matches!(dgram.connect(base.join(sock_name)), Err(_))
112}
113
114fn __connect_blocking(base: &Path, dgram: &UnixDatagram, sock_name: &str) -> Result<()> {
115    let earlier = Instant::now();
116    while let Err(err) = dgram.connect(base.join(sock_name)) {
117        if Instant::now().duration_since(earlier).as_secs() > CONNECT_TIMEOUT_SECS {
118            return Err(err.into());
119        }
120    }
121
122    log::trace!("Connecting to {:?}", sock_name);
123    Ok(())
124}
125
126impl Bridge<Server> {
127    pub fn alive(&self) -> bool {
128        __alive(&self.base, &self.sock_1, NAMED_SOCKET_1)
129    }
130
131    /// Connect to the other end
132    ///
133    /// # Errors
134    ///
135    /// Inherits errors of [`__connect_blocking`]
136    pub fn connect_blocking(&mut self) -> Result<()> {
137        __connect_blocking(&self.base, &self.sock_1, NAMED_SOCKET_1)
138    }
139
140    /// Send bincode serialized data through the pipe
141    ///
142    /// # Errors
143    ///
144    /// Fails if the socket is not alive
145    pub async fn send(&self, instruction: Instruction) -> Result<()> {
146        let mut buf = BytesMut::with_capacity(READBUF_SIZE);
147        instruction.encode(&mut buf)?;
148        log::trace!("Sending to client {} bytes", buf.len());
149        self.sock_1.send(&buf).await?;
150        Ok(())
151    }
152
153    /// Attempts to receive serialized data from the pipe
154    ///
155    /// # Errors
156    ///
157    /// Fails if the socket is not alive
158    pub async fn recv(&self) -> Result<Instruction> {
159        let mut buf = vec![0u8; READBUF_SIZE];
160        let read = self.sock_0.recv(&mut buf).await?;
161        buf.truncate(read);
162        let target = Instruction::decode(&*buf)?;
163        Ok(target)
164    }
165}
166
167impl Bridge<Client> {
168    pub fn alive(&self) -> bool {
169        __alive(&self.base, &self.sock_0, NAMED_SOCKET_0)
170    }
171
172    /// Connect to the other end
173    ///
174    /// # Errors
175    ///
176    /// Inherits errors of [`__connect_blocking`]
177    pub fn connect_blocking(&mut self) -> Result<()> {
178        __connect_blocking(&self.base, &self.sock_0, NAMED_SOCKET_0)
179    }
180
181    /// Send bincode serialized data through the pipe
182    ///
183    /// # Errors
184    ///
185    /// Fails if the socket is not alive
186    pub async fn send(&self, instruction: Instruction) -> Result<()> {
187        let mut buf = BytesMut::with_capacity(READBUF_SIZE);
188        instruction.encode(&mut buf)?;
189        log::trace!("Sending to server {} bytes", buf.len());
190        self.sock_0.send(&*buf).await?;
191        Ok(())
192    }
193
194    /// Attempts to receive serialized data from the pipe
195    ///
196    /// # Errors
197    ///
198    /// Fails if the socket is not alive
199    pub async fn recv(&self) -> Result<Instruction> {
200        let mut buf = vec![0u8; READBUF_SIZE];
201        let read = self.sock_1.recv(&mut buf).await?;
202        buf.truncate(read);
203        let target = Instruction::decode(&*buf)?;
204        Ok(target)
205    }
206}
207
208#[derive(thiserror::Error, Debug)]
209pub enum Error {
210    #[error("io error {0}")]
211    IO(#[from] std::io::Error),
212
213    #[error("decode error {0}")]
214    Decode(#[from] prost::DecodeError),
215
216    #[error("encode error {0}")]
217    Encode(#[from] prost::EncodeError),
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use assert_fs::prelude::*;
224    use std::sync::Arc;
225
226    pub fn test_instruction_1() -> Instruction {
227        Instruction::request_status()
228    }
229
230    pub fn test_instruction_2() -> Instruction {
231        Instruction::request_shutdown()
232    }
233
234    #[tokio::test]
235    async fn ipc_named_socket_spawn() {
236        let tmp = assert_fs::TempDir::new().unwrap();
237        let _ = server(&tmp).unwrap();
238        let _ = client(&tmp).unwrap();
239
240        assert!(tmp.child("gistit-0").exists());
241        assert!(tmp.child("gistit-1").exists());
242    }
243
244    #[tokio::test]
245    async fn ipc_socket_spawn_is_alive() {
246        let tmp = assert_fs::TempDir::new().unwrap();
247        let server = server(&tmp).unwrap();
248        let client = client(&tmp).unwrap();
249
250        assert!(server.alive());
251        assert!(client.alive());
252    }
253
254    #[tokio::test]
255    async fn ipc_socket_server_recv_traffic() {
256        let tmp = assert_fs::TempDir::new().unwrap();
257        let server = server(&tmp).unwrap();
258        let mut client = client(&tmp).unwrap();
259
260        client.connect_blocking().unwrap();
261
262        client.send(test_instruction_1()).await.unwrap();
263        client.send(test_instruction_2()).await.unwrap();
264
265        assert_eq!(server.recv().await.unwrap(), test_instruction_1());
266        assert_eq!(server.recv().await.unwrap(), test_instruction_2());
267    }
268
269    #[tokio::test]
270    async fn ipc_socket_client_recv_traffic() {
271        let tmp = assert_fs::TempDir::new().unwrap();
272        let mut server = server(&tmp).unwrap();
273        let client = client(&tmp).unwrap();
274
275        server.connect_blocking().unwrap();
276
277        server.send(test_instruction_1()).await.unwrap();
278        server.send(test_instruction_2()).await.unwrap();
279
280        assert_eq!(client.recv().await.unwrap(), test_instruction_1());
281        assert_eq!(client.recv().await.unwrap(), test_instruction_2());
282    }
283
284    #[tokio::test]
285    async fn ipc_socket_alternate_traffic() {
286        let tmp = assert_fs::TempDir::new().unwrap();
287        let mut server = server(&tmp).unwrap();
288        let mut client = client(&tmp).unwrap();
289
290        client.connect_blocking().unwrap();
291        server.connect_blocking().unwrap();
292
293        client.send(test_instruction_1()).await.unwrap();
294        client.send(test_instruction_2()).await.unwrap();
295
296        server.send(test_instruction_1()).await.unwrap();
297        server.send(test_instruction_2()).await.unwrap();
298
299        assert_eq!(client.recv().await.unwrap(), test_instruction_1());
300        assert_eq!(server.recv().await.unwrap(), test_instruction_1());
301        assert_eq!(client.recv().await.unwrap(), test_instruction_2());
302        assert_eq!(server.recv().await.unwrap(), test_instruction_2());
303    }
304
305    #[tokio::test]
306    async fn ipc_socket_alternate_traffic_rerun() {
307        let tmp = assert_fs::TempDir::new().unwrap();
308        let mut server = server(&tmp).unwrap();
309        let mut client = client(&tmp).unwrap();
310
311        client.connect_blocking().unwrap();
312        server.connect_blocking().unwrap();
313
314        client.send(test_instruction_1()).await.unwrap();
315        client.send(test_instruction_2()).await.unwrap();
316
317        server.send(test_instruction_1()).await.unwrap();
318        server.send(test_instruction_2()).await.unwrap();
319
320        assert_eq!(client.recv().await.unwrap(), test_instruction_1());
321        assert_eq!(server.recv().await.unwrap(), test_instruction_1());
322        assert_eq!(client.recv().await.unwrap(), test_instruction_2());
323        assert_eq!(server.recv().await.unwrap(), test_instruction_2());
324
325        client.send(test_instruction_1()).await.unwrap();
326        client.send(test_instruction_2()).await.unwrap();
327
328        server.send(test_instruction_1()).await.unwrap();
329        server.send(test_instruction_2()).await.unwrap();
330
331        assert_eq!(client.recv().await.unwrap(), test_instruction_1());
332        assert_eq!(server.recv().await.unwrap(), test_instruction_1());
333        assert_eq!(client.recv().await.unwrap(), test_instruction_2());
334        assert_eq!(server.recv().await.unwrap(), test_instruction_2());
335    }
336
337    #[tokio::test]
338    async fn ipc_socket_traffic_under_load() {
339        let tmp = assert_fs::TempDir::new().unwrap();
340        let mut server = server(&tmp).unwrap();
341        let mut client = client(&tmp).unwrap();
342
343        client.connect_blocking().unwrap();
344        server.connect_blocking().unwrap();
345
346        let server = Arc::new(server);
347        let client = Arc::new(client);
348
349        for _ in 0..8 {
350            let s = server.clone();
351            let c = client.clone();
352
353            tokio::spawn(async move {
354                loop {
355                    c.send(test_instruction_1()).await.unwrap();
356                    c.send(test_instruction_2()).await.unwrap();
357
358                    s.send(test_instruction_1()).await.unwrap();
359                    s.send(test_instruction_2()).await.unwrap();
360                }
361            });
362
363            assert_eq!(client.recv().await.unwrap(), test_instruction_1());
364            assert_eq!(server.recv().await.unwrap(), test_instruction_1());
365            assert_eq!(client.recv().await.unwrap(), test_instruction_2());
366            assert_eq!(server.recv().await.unwrap(), test_instruction_2());
367        }
368    }
369}