1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! Implementation of Unix domain socket transport (Unix only)

use crate::transport::{Request, Subscribe, TransportError};
use log::{debug, error, trace};
use std::io::{BufRead, BufReader, Write};
use std::net::Shutdown;
use std::os::unix::net::UnixStream;
use std::str;
use thiserror::Error;

/// An interprocess connection using a unix domain socket (Unix only)
pub struct Uds {
    pub path: String,
    read_stream: BufReader<UnixStream>,
    write_stream: UnixStream,
}

impl Uds {
    pub(crate) fn new(path: String) -> Result<Self, UdsError> {
        debug!("Opening connection to unix domain socket: {}", &path);
        let write_stream = UnixStream::connect(path.clone()).map_err(UdsError::UdsInit)?;
        let read_stream = write_stream.try_clone().map_err(UdsError::UdsInit)?;
        Ok(Self {
            path,
            read_stream: BufReader::new(read_stream),
            write_stream,
        })
    }

    fn read_json(&mut self) -> Result<String, UdsError> {
        let mut buffer = Vec::<u8>::new();
        loop {
            let _read_bytes = self
                .read_stream
                .read_until(b'}', &mut buffer)
                .map_err(UdsError::Read)?;
            let utf8_slice = str::from_utf8(&buffer).map_err(UdsError::Utf8)?;
            if utf8_slice.matches('{').count() == utf8_slice.matches('}').count() {
                trace!("Reading from Unix domain socket: {}", utf8_slice);
                break Ok(utf8_slice.to_string());
            }
        }
    }

    fn write(&mut self, message: String) -> Result<(), UdsError> {
        trace!("Writing to Unix domain socket: {}", &message);
        let _write = self
            .write_stream
            .write_all(message.as_bytes())
            .map_err(UdsError::Write)?;
        let _flush = self.write_stream.flush().map_err(UdsError::Write)?;
        Ok(())
    }
}

impl Request for Uds {
    fn request(&mut self, cmd: String) -> Result<String, TransportError> {
        let _write = self.write(cmd)?;
        self.read_json().map_err(TransportError::UdsError)
    }
}

impl Subscribe for Uds {
    fn read_next(&mut self) -> Result<String, TransportError> {
        self.read_json().map_err(TransportError::UdsError)
    }

    fn fork(&self) -> Result<Self, TransportError>
    where
        Self: Sized,
    {
        Self::new(self.path.clone()).map_err(TransportError::from)
    }
}

impl Drop for Uds {
    fn drop(&mut self) {
        debug!("Closing unix domain socket connection");
        let close = self.write_stream.shutdown(Shutdown::Both);
        if let Err(err) = close {
            error!("{}", err);
        }
    }
}

/// An error type collecting what can go wrong with unix domain sockets
#[derive(Debug, Error)]
pub enum UdsError {
    #[error("Unix Domain Socket Init Error: {0}")]
    UdsInit(std::io::Error),
    #[error("Unix Domain Socket Close Error: {0}")]
    UdsClose(std::io::Error),
    #[error("Unix Domain Socket Read Error: {0}")]
    Read(std::io::Error),
    #[error("Unix Domain Socket Utf8 Error: {0}")]
    Utf8(std::str::Utf8Error),
    #[error("Unix Domain Socket Write Error: {0}")]
    Write(std::io::Error),
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::os::unix::net::UnixListener;

    const TEST_IPC: &str = "/tmp/ethane_test.ipc";

    fn spawn_test_uds_server() {
        let unix_listener = UnixListener::bind(TEST_IPC).unwrap();
        std::thread::spawn(move || {
            for incoming in unix_listener.incoming() {
                match incoming {
                    Ok(mut stream) => {
                        let mut buffer = Vec::<u8>::new();
                        let mut reader = BufReader::new(&mut stream);

                        let _read = reader.read_until(b'}', &mut buffer).unwrap();
                        let _write = (&mut stream).write_all(buffer.as_slice()).unwrap();
                        let _flush = (&mut stream).flush().unwrap();
                    }
                    Err(err) => panic!(err),
                }
            }
        });
    }

    #[test]
    fn test_uds() {
        spawn_test_uds_server();
        let message = "{\"test\": true}";
        let mut uds = Uds::new(TEST_IPC.to_string()).unwrap();
        let _write = uds.write(String::from(message)).unwrap();

        let _delete_socket = std::fs::remove_file(TEST_IPC).unwrap();
        assert_eq!(uds.read_json().unwrap(), message);
    }
}