libp2prs_websocket/
tls.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2// Copyright 2020 Netwarps Ltd.
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22use async_tls::{TlsAcceptor, TlsConnector};
23use std::{fmt, io, sync::Arc};
24
25/// TLS configuration.
26#[derive(Clone)]
27pub struct Config {
28    pub(crate) client: TlsConnector,
29    pub(crate) server: Option<TlsAcceptor>,
30}
31
32impl fmt::Debug for Config {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        f.write_str("Config")
35    }
36}
37
38/// Private key, DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
39#[derive(Clone)]
40pub struct PrivateKey(rustls::PrivateKey);
41
42impl PrivateKey {
43    /// Assert the given bytes are DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
44    pub fn new(bytes: Vec<u8>) -> Self {
45        PrivateKey(rustls::PrivateKey(bytes))
46    }
47}
48
49/// Certificate, DER-encoded X.509 format.
50#[derive(Debug, Clone)]
51pub struct Certificate(rustls::Certificate);
52
53impl Certificate {
54    /// Assert the given bytes are in DER-encoded X.509 format.
55    pub fn new(bytes: Vec<u8>) -> Self {
56        Certificate(rustls::Certificate(bytes))
57    }
58}
59
60impl Config {
61    /// Create a new TLS configuration with the given server key and certificate chain.
62    pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
63    where
64        I: IntoIterator<Item = Certificate>,
65    {
66        let mut builder = Config::builder();
67        builder.server(key, certs)?;
68        Ok(builder.finish())
69    }
70
71    /// Create a client-only configuration.
72    pub fn client() -> Self {
73        Config {
74            client: Arc::new(client_config()).into(),
75            server: None,
76        }
77    }
78
79    /// Create a new TLS configuration builder.
80    pub fn builder() -> Builder {
81        Builder {
82            client: client_config(),
83            server: None,
84        }
85    }
86}
87
88/// Setup the rustls client configuration.
89fn client_config() -> rustls::ClientConfig {
90    let mut client = rustls::ClientConfig::new();
91    client.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
92    client
93}
94
95/// TLS configuration builder.
96#[derive(Clone)]
97pub struct Builder {
98    client: rustls::ClientConfig,
99    server: Option<rustls::ServerConfig>,
100}
101
102impl Builder {
103    /// Set server key and certificate chain.
104    pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
105    where
106        I: IntoIterator<Item = Certificate>,
107    {
108        let mut server = rustls::ServerConfig::new(rustls::NoClientAuth::new());
109        let certs = certs.into_iter().map(|c| c.0).collect();
110        server.set_single_cert(certs, key.0).map_err(|e| Error::Tls(Box::new(e)))?;
111        self.server = Some(server);
112        Ok(self)
113    }
114
115    /// Add an additional trust anchor.
116    pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
117        self.client.root_store.add(&cert.0).map_err(|e| Error::Tls(Box::new(e)))?;
118        Ok(self)
119    }
120
121    /// Finish configuration.
122    pub fn finish(self) -> Config {
123        Config {
124            client: Arc::new(self.client).into(),
125            server: self.server.map(|s| Arc::new(s).into()),
126        }
127    }
128}
129
130pub(crate) fn dns_name_ref(name: &str) -> Result<webpki::DNSNameRef<'_>, Error> {
131    webpki::DNSNameRef::try_from_ascii_str(name).map_err(|_| Error::InvalidDnsName(name.into()))
132}
133
134// Error //////////////////////////////////////////////////////////////////////////////////////////
135
136/// TLS related errors.
137#[derive(Debug)]
138#[non_exhaustive]
139pub enum Error {
140    /// An underlying I/O error.
141    Io(io::Error),
142    /// Actual TLS error.
143    Tls(Box<dyn std::error::Error + Send + Sync>),
144    /// The DNS name was invalid.
145    InvalidDnsName(String),
146}
147
148impl fmt::Display for Error {
149    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
150        match self {
151            Error::Io(e) => write!(f, "i/o error: {}", e),
152            Error::Tls(e) => write!(f, "tls error: {}", e),
153            Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {}", n),
154        }
155    }
156}
157
158impl std::error::Error for Error {
159    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
160        match self {
161            Error::Io(e) => Some(e),
162            Error::Tls(e) => Some(&**e),
163            Error::InvalidDnsName(_) => None,
164        }
165    }
166}
167
168impl From<io::Error> for Error {
169    fn from(e: io::Error) -> Self {
170        Error::Io(e)
171    }
172}