rama_tls_rustls/
type_conversion.rs

1use crate::{RamaFrom, RamaTryFrom};
2use rama_core::error::{ErrorContext, OpaqueError};
3use rama_net::{
4    address::{Domain, Host},
5    tls::{
6        ApplicationProtocol, CipherSuite, DataEncoding, ProtocolVersion, SignatureScheme,
7        client::{ClientHello, ClientHelloExtension},
8    },
9};
10use rustls::pki_types;
11use std::net::IpAddr;
12
13macro_rules! enum_from_rustls {
14    ($t:ty => $($name:ident),+$(,)?) => {
15        $(
16            impl RamaFrom<rustls::$name> for rama_net::tls::$name {
17                fn rama_from(value: ::rustls::$name) -> Self {
18                    let n: $t = value.into();
19                    n.into()
20                }
21            }
22
23            impl RamaFrom<rama_net::tls::$name> for rustls::$name {
24                fn rama_from(value: rama_net::tls::$name) -> Self {
25                    let n: $t = value.into();
26                    n.into()
27                }
28            }
29        )+
30    };
31}
32
33enum_from_rustls!(u16 => ProtocolVersion, CipherSuite, SignatureScheme);
34
35impl RamaTryFrom<ProtocolVersion> for &rustls::SupportedProtocolVersion {
36    type Error = ProtocolVersion;
37
38    fn rama_try_from(value: ProtocolVersion) -> Result<Self, Self::Error> {
39        match value {
40            ProtocolVersion::TLSv1_2 => Ok(&rustls::version::TLS12),
41            ProtocolVersion::TLSv1_3 => Ok(&rustls::version::TLS13),
42            other => Err(other),
43        }
44    }
45}
46
47impl<'a> RamaTryFrom<rustls::pki_types::ServerName<'a>> for Host {
48    type Error = OpaqueError;
49
50    fn rama_try_from(value: rustls::pki_types::ServerName<'a>) -> Result<Self, Self::Error> {
51        match value {
52            rustls::pki_types::ServerName::DnsName(name) => {
53                Ok(Domain::try_from(name.as_ref().to_owned())?.into())
54            }
55            rustls::pki_types::ServerName::IpAddress(ip) => Ok(Host::from(IpAddr::from(ip))),
56            _ => Err(OpaqueError::from_display(format!(
57                "urecognised rustls (PKI) server name: {value:?}",
58            ))),
59        }
60    }
61}
62impl RamaTryFrom<Host> for rustls::pki_types::ServerName<'_> {
63    type Error = OpaqueError;
64
65    fn rama_try_from(value: Host) -> Result<Self, Self::Error> {
66        match value {
67            Host::Name(name) => Ok(rustls::pki_types::ServerName::DnsName(
68                rustls::pki_types::DnsName::try_from(name.as_str().to_owned())
69                    .context("convert domain to rustls (PKI) ServerName")?,
70            )),
71            Host::Address(ip) => Ok(rustls::pki_types::ServerName::IpAddress(ip.into())),
72        }
73    }
74}
75
76impl<'a> RamaTryFrom<&rustls::pki_types::ServerName<'a>> for Host {
77    type Error = OpaqueError;
78
79    fn rama_try_from(value: &rustls::pki_types::ServerName<'a>) -> Result<Self, Self::Error> {
80        match value {
81            rustls::pki_types::ServerName::DnsName(name) => {
82                Ok(Domain::try_from(name.as_ref().to_owned())?.into())
83            }
84            rustls::pki_types::ServerName::IpAddress(ip) => Ok(Host::from(IpAddr::from(*ip))),
85            _ => Err(OpaqueError::from_display(format!(
86                "urecognised rustls (PKI) server name: {value:?}",
87            ))),
88        }
89    }
90}
91
92impl<'a> RamaTryFrom<&'a Host> for rustls::pki_types::ServerName<'a> {
93    type Error = OpaqueError;
94
95    fn rama_try_from(value: &'a Host) -> Result<Self, Self::Error> {
96        match value {
97            Host::Name(name) => Ok(rustls::pki_types::ServerName::DnsName(
98                rustls::pki_types::DnsName::try_from(name.as_str())
99                    .context("convert domain to rustls (PKI) ServerName")?,
100            )),
101            Host::Address(ip) => Ok(rustls::pki_types::ServerName::IpAddress((*ip).into())),
102        }
103    }
104}
105
106impl RamaFrom<&pki_types::CertificateDer<'static>> for DataEncoding {
107    fn rama_from(value: &pki_types::CertificateDer<'static>) -> Self {
108        DataEncoding::Der(value.as_ref().into())
109    }
110}
111
112impl RamaFrom<&[pki_types::CertificateDer<'static>]> for DataEncoding {
113    fn rama_from(value: &[pki_types::CertificateDer<'static>]) -> Self {
114        DataEncoding::DerStack(
115            value
116                .iter()
117                .map(|cert| Into::<Vec<u8>>::into(cert.as_ref()))
118                .collect(),
119        )
120    }
121}
122
123impl<'a> RamaFrom<rustls::server::ClientHello<'a>> for ClientHello {
124    fn rama_from(value: rustls::server::ClientHello<'a>) -> Self {
125        let cipher_suites = value
126            .cipher_suites()
127            .iter()
128            .map(|cs| CipherSuite::rama_from(*cs))
129            .collect();
130
131        let mut extensions = Vec::with_capacity(3);
132
133        extensions.push(ClientHelloExtension::SignatureAlgorithms(
134            value
135                .signature_schemes()
136                .iter()
137                .map(|sc| SignatureScheme::rama_from(*sc))
138                .collect(),
139        ));
140
141        if let Some(domain) = value.server_name().and_then(|d| d.parse().ok()) {
142            extensions.push(ClientHelloExtension::ServerName(Some(domain)));
143        }
144
145        if let Some(alpn) = value.alpn() {
146            extensions.push(ClientHelloExtension::ApplicationLayerProtocolNegotiation(
147                alpn.map(ApplicationProtocol::from).collect(),
148            ));
149        }
150
151        Self::new(
152            // TODO: support if rustls can handle this
153            ProtocolVersion::Unknown(0),
154            cipher_suites,
155            vec![],
156            extensions,
157        )
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_rustls_to_common_to_rustls() {
167        let p = rustls::ProtocolVersion::TLSv1_3;
168        let p = ProtocolVersion::rama_from(p);
169        assert_eq!(p, ProtocolVersion::TLSv1_3);
170        let p = rustls::ProtocolVersion::rama_from(p);
171        assert_eq!(p, rustls::ProtocolVersion::TLSv1_3);
172    }
173}