pueue_lib/network/socket/
mod.rs

1//! Socket handling is platform specific code.
2//!
3//! The submodules of this module represent the different implementations for
4//! each supported platform.
5//! Depending on the target, the respective platform is read and loaded into this scope.
6
7#[cfg(not(target_os = "windows"))]
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use rustls::{ClientConfig, RootCertStore, pki_types::CertificateDer};
13use tokio::{
14    io::{AsyncRead, AsyncWrite},
15    net::TcpStream,
16};
17use tokio_rustls::TlsConnector;
18
19use crate::error::Error;
20#[cfg(feature = "settings")]
21use crate::{settings::Shared, tls::load_ca};
22
23/// Shared socket logic
24#[cfg_attr(not(target_os = "windows"), path = "unix.rs")]
25#[cfg_attr(target_os = "windows", path = "windows.rs")]
26mod platform;
27pub use platform::*;
28
29/// A new trait, which can be used to represent Unix- and TcpListeners. \
30/// This is necessary to easily write generic functions where both types can be used.
31#[async_trait]
32pub trait Listener: Sync + Send {
33    async fn accept<'a>(&'a self) -> Result<GenericStream, Error>;
34}
35
36/// Convenience type, so we don't have type write `Box<dyn Listener>` all the time.
37pub type GenericListener = Box<dyn Listener>;
38/// Convenience type, so we don't have type write `Box<dyn Stream>` all the time. \
39/// This also prevents name collisions, since `Stream` is imported in many preludes.
40pub type GenericStream = Box<dyn Stream>;
41
42/// Describe how a client should connect to the daemon.
43pub enum ConnectionSettings<'a> {
44    #[cfg(not(target_os = "windows"))]
45    UnixSocket { path: PathBuf },
46    TlsTcpSocket {
47        host: String,
48        port: String,
49        certificate: CertificateDer<'a>,
50    },
51}
52
53/// Convenience conversion from [Shared] to [ConnectionSettings].
54#[cfg(feature = "settings")]
55impl TryFrom<Shared> for ConnectionSettings<'_> {
56    type Error = crate::error::Error;
57
58    fn try_from(value: Shared) -> Result<Self, Self::Error> {
59        // Unix socket handling
60        #[cfg(not(target_os = "windows"))]
61        {
62            if value.use_unix_socket {
63                return Ok(ConnectionSettings::UnixSocket {
64                    path: value.unix_socket_path(),
65                });
66            }
67        }
68
69        let cert = load_ca(&value.daemon_cert())?;
70        Ok(ConnectionSettings::TlsTcpSocket {
71            host: value.host,
72            port: value.port,
73            certificate: cert,
74        })
75    }
76}
77
78pub trait Stream: AsyncRead + AsyncWrite + Unpin + Send {}
79impl Stream for tokio_rustls::server::TlsStream<TcpStream> {}
80impl Stream for tokio_rustls::client::TlsStream<TcpStream> {}
81
82/// Initialize our client [TlsConnector]. \
83/// 1. Trust our own CA. ONLY our own CA.
84/// 2. Set the client certificate and key
85pub async fn get_tls_connector(cert: CertificateDer<'_>) -> Result<TlsConnector, Error> {
86    // Only trust server-certificates signed with our own CA.
87    let mut cert_store = RootCertStore::empty();
88    cert_store.add(cert).map_err(|err| {
89        Error::CertificateFailure(format!("Failed to build RootCertStore: {err}"))
90    })?;
91
92    let config: ClientConfig = ClientConfig::builder()
93        .with_root_certificates(cert_store)
94        .with_no_client_auth();
95
96    Ok(TlsConnector::from(Arc::new(config)))
97}