1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#![feature(external_doc)]
#![deny(missing_docs)]
#![doc(include = "../README.md")]

#[macro_use]
extern crate log;

use std::fs;
use std::io;
use std::io::Error as IoError;
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::PathBuf;
use std::sync::{mpsc, Arc, RwLock};
use std::thread;

#[derive(Debug, Clone)]
/// Possible states of the unix domain socket.
///
/// The main loop will continue while the state is [`State::Listening`]
pub enum State {
    /// No [`UnixDomainSocket::bind`] was performed yet.
    Closed = 0x00,
    /// The [`UnixDomainSocket`] is currently listening.
    Listening = 0x01,
    /// The [`UnixDomainSocket`] should interrupt the listening after the next iteration.
    ShouldQuit = 0x02,
}

#[derive(Debug, Clone)]
/// Possible messages that can be sent through the [`mpsc`] channel.
pub enum Message {
    /// Request the [`UnixDomainSocket`] to change its state.
    ChangeState(State),
}

/// Boilerplate for [`UnixListener`].
///
/// Will be converted from an implementation of [`ToString`], or will receive a [`PathBuf`] via the
/// constructor [`UnixDomainSocket::new`].
pub struct UnixDomainSocket {
    /// The OS path for the socket.
    path: PathBuf,
    /// The current state.
    ///
    /// Since its [`Arc`]/[`RwLock`] protected, we don't need a mutable reference to update the
    /// state.
    state: Arc<RwLock<State>>,
    /// [`mpsc`] channel to send messages.
    tx: mpsc::Sender<Message>,
    /// [`mpsc`] channel to receive messages.
    rx: mpsc::Receiver<Message>,
}

impl UnixDomainSocket {
    /// Default constructor. Receives a [`PathBuf`] and never fails, because no
    /// [`UnixListener::bind`] is performed.
    pub fn new(path: PathBuf) -> Self {
        let state = State::Closed;
        let state = RwLock::new(state);
        let state = Arc::new(state);

        let (tx, rx) = mpsc::channel();

        UnixDomainSocket {
            path,
            state,
            tx,
            rx,
        }
    }

    /// Returns the current [`State`] of the socket.
    ///
    /// It's fallible in case the inner [`RwLock`] is poisoned.
    ///
    /// Also, it's blocking, if there is a concurrent read/write for the state.
    pub fn get_state(&self) -> Result<State, IoError> {
        self.state
            .read()
            .map_err(|e| IoError::new(io::ErrorKind::Other, e.to_string()))
            .and_then(|state| Ok(state.clone()))
    }

    /// Set the current [`State`] of the socket.
    ///
    /// Same constraints as [`UnixDomainSocket::get_state`]
    pub fn set_state(&self, state: State) -> Result<(), IoError> {
        self.state
            .write()
            .map_err(|e| IoError::new(io::ErrorKind::Other, e.to_string()))
            .and_then(|mut self_state| Ok(*self_state = state))
    }

    /// Check if the current state is [`State::Listening`].
    ///
    /// Same constraints as [`UnixDomainSocket::get_state`]
    pub fn is_listening(&self) -> Result<bool, IoError> {
        self.get_state().and_then(|state| Ok(state as u8 == 0x01))
    }

    /// Will be called by the [`mpsc`] channel to process any incoming [`Message`].
    pub fn receive_message(&self, message: Message) -> Result<(), IoError> {
        match message {
            Message::ChangeState(state) => self.set_state(state),
        }
    }

    /// Will remove the [`UnixDomainSocket::path`], if it exists, so it cant bind properly to that
    /// location.
    ///
    /// While the state is [`State::Listening`], listen for incoming data.
    ///
    /// Should it receive a [`Message`] with [`State::ShouldQuit`], it will interrupt the loop only after the next
    /// iteration, because the message is handled asynchronously.
    ///
    /// If the socket fails, the loop wont be interrupted, and the error will be reported to the
    /// log facade.
    ///
    /// Minimal example:
    /// ```
    /// use dusk_uds::UnixDomainSocket;
    ///
    /// UnixDomainSocket::from("/dev/null").bind(move |_stream, _sender| {
    ///     // Code goes here
    /// }).unwrap_or(());
    /// ```
    pub fn bind<F>(&self, handler: F) -> Result<(), IoError>
    where
        F: FnOnce(UnixStream, mpsc::Sender<Message>),
        F: Send + 'static,
        F: Copy + 'static,
    {
        if self.path.as_path().exists() {
            fs::remove_file(self.path.as_path())?;
        }

        let path = self
            .path
            .to_str()
            .map(|p| Ok(p))
            .unwrap_or(Err(IoError::new(
                io::ErrorKind::Other,
                "Invalid path returned by the buffer",
            )))?;

        UnixListener::bind(path).and_then(|listener| {
            info!("UnixDomainSocket bound on {}", path);
            self.set_state(State::Listening)?;

            for socket in listener.incoming() {
                let tx = self.tx.clone();

                socket
                    .and_then(|s| Ok(thread::spawn(move || handler(s, tx))))
                    .map(|_| ())
                    .map_err(|e| error!("{}", e))
                    .unwrap_or(());

                for msg in self.rx.try_iter() {
                    self.receive_message(msg)?;
                }

                if !self.is_listening()? {
                    break;
                }

                debug!("Thread spawned");
            }

            Ok(info!("UnixDomainSocket closed"))
        })
    }
}

impl<T: ToString> From<T> for UnixDomainSocket {
    fn from(path: T) -> Self {
        let path = PathBuf::from(path.to_string());
        UnixDomainSocket::new(path)
    }
}