pueue_lib/network/socket/
unix.rs

1use std::convert::TryFrom;
2
3use async_trait::async_trait;
4use rustls::pki_types::ServerName;
5use tokio::net::{TcpStream, UnixListener, UnixStream};
6
7use super::{ConnectionSettings, GenericStream, Listener, Stream, get_tls_connector};
8use crate::error::Error;
9
10#[async_trait]
11impl Listener for UnixListener {
12    async fn accept<'a>(&'a self) -> Result<GenericStream, Error> {
13        let (stream, _) = self
14            .accept()
15            .await
16            .map_err(|err| Error::IoError("accepting new unix connection.".to_string(), err))?;
17        Ok(Box::new(stream))
18    }
19}
20
21/// A new trait, which can be used to represent Unix- and Tls encrypted TcpStreams. \
22/// This is necessary to write generic functions where both types can be used.
23impl Stream for UnixStream {}
24
25/// Get a new stream for the client. \
26/// This can either be a UnixStream or a Tls encrypted TCPStream, depending on the parameters.
27pub async fn get_client_stream(settings: ConnectionSettings<'_>) -> Result<GenericStream, Error> {
28    match settings {
29        // Create a unix socket
30        ConnectionSettings::UnixSocket { path } => {
31            let stream = UnixStream::connect(&path).await.map_err(|err| {
32                Error::IoPathError(path, "connecting to daemon. Did you start it?", err)
33            })?;
34
35            Ok(Box::new(stream))
36        }
37        // Connect to the daemon via TCP
38        ConnectionSettings::TlsTcpSocket {
39            host,
40            port,
41            certificate,
42        } => {
43            let address = format!("{host}:{port}");
44            let tcp_stream = TcpStream::connect(&address).await.map_err(|_| {
45                Error::Connection(format!(
46                    "Failed to connect to the daemon on {address}. Did you start it?"
47                ))
48            })?;
49
50            // Get the configured rustls TlsConnector
51            let tls_connector = get_tls_connector(certificate).await.map_err(|err| {
52                Error::Connection(format!("Failed to initialize tls connector:\n{err}."))
53            })?;
54
55            // Initialize the TLS layer
56            let stream = tls_connector
57                .connect(ServerName::try_from("pueue.local").unwrap(), tcp_stream)
58                .await
59                .map_err(|err| Error::Connection(format!("Failed to initialize tls:\n{err}.")))?;
60
61            Ok(Box::new(stream))
62        }
63    }
64}