1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use crate::ssl::{SslConfig, Stream};
use crate::Error;
use rustls::{
    Certificate, Error as TLSError, PrivateKey, ServerConfig, ServerConnection, StreamOwned,
};
use std::fs;
use std::io::BufReader;
use std::net::TcpStream;
use std::sync::Arc;

// Rustls wrapper
#[derive(Clone)]
pub struct SslImpl {
    tls_config: Arc<ServerConfig>,
}

impl Stream for StreamOwned<ServerConnection, TcpStream> {}

impl From<TLSError> for Error {
    fn from(error: TLSError) -> Self {
        let msg = error.to_string();
        Error::with_source(msg, error)
    }
}

impl SslImpl {
    pub fn setup(ssl_config: SslConfig) -> Result<Option<Self>, Error> {
        let config = match ssl_config {
            SslConfig::Trusted {
                cert_path,
                key_path,
                chain_path,
            } => {
                let mut certs = load_certs(&cert_path)?;
                let mut chain = load_certs(&chain_path)?;
                certs.append(&mut chain);
                let key = load_key(&key_path)?;
                let config = ServerConfig::builder()
                    .with_safe_defaults()
                    .with_no_client_auth()
                    .with_single_cert(certs, key)?;
                Some(config)
            }
            SslConfig::SelfSigned {
                cert_path,
                key_path,
            } => {
                let certs = load_certs(&cert_path)?;
                let key = load_key(&key_path)?;
                let config = ServerConfig::builder()
                    .with_safe_defaults()
                    .with_no_client_auth()
                    .with_single_cert(certs, key)?;
                Some(config)
            }
            _ => None,
        };
        let ret = config.map(|c| SslImpl {
            tls_config: Arc::new(c),
        });
        Ok(ret)
    }

    pub fn accept(&self, stream: TcpStream) -> Result<impl Stream, Error> {
        let session = ServerConnection::new(self.tls_config.clone())?;
        let tls_stream = StreamOwned::new(session, stream);
        Ok(tls_stream)
    }
}

fn load_certs(filename: &str) -> Result<Vec<Certificate>, Error> {
    let certfile = fs::File::open(filename)?;
    let mut reader = BufReader::new(certfile);
    let ret: Vec<Certificate> = rustls_pemfile::certs(&mut reader)
        .map_err(|_| Error::new("Unparseable certificates"))?
        .into_iter()
        .map(Certificate)
        .collect();
    Ok(ret)
}

fn load_key(filename: &str) -> Result<PrivateKey, Error> {
    let keyfile = fs::File::open(filename)?;
    let mut reader = BufReader::new(keyfile);
    let rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader)
        .map_err(|_| Error::new("Unparseable RSA key"))?;
    let keyfile = fs::File::open(filename)?;
    let mut reader = BufReader::new(keyfile);
    let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
        .map_err(|_| Error::new("Unparseable PKCS8 key"))?;

    // Prefer to load pkcs8 keys
    pkcs8_keys
        .first()
        .or_else(|| rsa_keys.first())
        .cloned()
        .map(PrivateKey)
        .ok_or_else(|| Error::new("No RSA or PKCS8 keys found"))
}