rama_tls_rustls/
type_conversion.rs1use crate::{RamaFrom, RamaTryFrom};
2use rama_core::error::{ErrorContext, OpaqueError};
3use rama_net::{
4 address::{Domain, Host},
5 tls::{
6 ApplicationProtocol, CipherSuite, DataEncoding, ProtocolVersion, SignatureScheme,
7 client::{ClientHello, ClientHelloExtension},
8 },
9};
10use rustls::pki_types;
11use std::net::IpAddr;
12
13macro_rules! enum_from_rustls {
14 ($t:ty => $($name:ident),+$(,)?) => {
15 $(
16 impl RamaFrom<rustls::$name> for rama_net::tls::$name {
17 fn rama_from(value: ::rustls::$name) -> Self {
18 let n: $t = value.into();
19 n.into()
20 }
21 }
22
23 impl RamaFrom<rama_net::tls::$name> for rustls::$name {
24 fn rama_from(value: rama_net::tls::$name) -> Self {
25 let n: $t = value.into();
26 n.into()
27 }
28 }
29 )+
30 };
31}
32
33enum_from_rustls!(u16 => ProtocolVersion, CipherSuite, SignatureScheme);
34
35impl RamaTryFrom<ProtocolVersion> for &rustls::SupportedProtocolVersion {
36 type Error = ProtocolVersion;
37
38 fn rama_try_from(value: ProtocolVersion) -> Result<Self, Self::Error> {
39 match value {
40 ProtocolVersion::TLSv1_2 => Ok(&rustls::version::TLS12),
41 ProtocolVersion::TLSv1_3 => Ok(&rustls::version::TLS13),
42 other => Err(other),
43 }
44 }
45}
46
47impl<'a> RamaTryFrom<rustls::pki_types::ServerName<'a>> for Host {
48 type Error = OpaqueError;
49
50 fn rama_try_from(value: rustls::pki_types::ServerName<'a>) -> Result<Self, Self::Error> {
51 match value {
52 rustls::pki_types::ServerName::DnsName(name) => {
53 Ok(Domain::try_from(name.as_ref().to_owned())?.into())
54 }
55 rustls::pki_types::ServerName::IpAddress(ip) => Ok(Host::from(IpAddr::from(ip))),
56 _ => Err(OpaqueError::from_display(format!(
57 "urecognised rustls (PKI) server name: {value:?}",
58 ))),
59 }
60 }
61}
62impl RamaTryFrom<Host> for rustls::pki_types::ServerName<'_> {
63 type Error = OpaqueError;
64
65 fn rama_try_from(value: Host) -> Result<Self, Self::Error> {
66 match value {
67 Host::Name(name) => Ok(rustls::pki_types::ServerName::DnsName(
68 rustls::pki_types::DnsName::try_from(name.as_str().to_owned())
69 .context("convert domain to rustls (PKI) ServerName")?,
70 )),
71 Host::Address(ip) => Ok(rustls::pki_types::ServerName::IpAddress(ip.into())),
72 }
73 }
74}
75
76impl<'a> RamaTryFrom<&rustls::pki_types::ServerName<'a>> for Host {
77 type Error = OpaqueError;
78
79 fn rama_try_from(value: &rustls::pki_types::ServerName<'a>) -> Result<Self, Self::Error> {
80 match value {
81 rustls::pki_types::ServerName::DnsName(name) => {
82 Ok(Domain::try_from(name.as_ref().to_owned())?.into())
83 }
84 rustls::pki_types::ServerName::IpAddress(ip) => Ok(Host::from(IpAddr::from(*ip))),
85 _ => Err(OpaqueError::from_display(format!(
86 "urecognised rustls (PKI) server name: {value:?}",
87 ))),
88 }
89 }
90}
91
92impl<'a> RamaTryFrom<&'a Host> for rustls::pki_types::ServerName<'a> {
93 type Error = OpaqueError;
94
95 fn rama_try_from(value: &'a Host) -> Result<Self, Self::Error> {
96 match value {
97 Host::Name(name) => Ok(rustls::pki_types::ServerName::DnsName(
98 rustls::pki_types::DnsName::try_from(name.as_str())
99 .context("convert domain to rustls (PKI) ServerName")?,
100 )),
101 Host::Address(ip) => Ok(rustls::pki_types::ServerName::IpAddress((*ip).into())),
102 }
103 }
104}
105
106impl RamaFrom<&pki_types::CertificateDer<'static>> for DataEncoding {
107 fn rama_from(value: &pki_types::CertificateDer<'static>) -> Self {
108 DataEncoding::Der(value.as_ref().into())
109 }
110}
111
112impl RamaFrom<&[pki_types::CertificateDer<'static>]> for DataEncoding {
113 fn rama_from(value: &[pki_types::CertificateDer<'static>]) -> Self {
114 DataEncoding::DerStack(
115 value
116 .iter()
117 .map(|cert| Into::<Vec<u8>>::into(cert.as_ref()))
118 .collect(),
119 )
120 }
121}
122
123impl<'a> RamaFrom<rustls::server::ClientHello<'a>> for ClientHello {
124 fn rama_from(value: rustls::server::ClientHello<'a>) -> Self {
125 let cipher_suites = value
126 .cipher_suites()
127 .iter()
128 .map(|cs| CipherSuite::rama_from(*cs))
129 .collect();
130
131 let mut extensions = Vec::with_capacity(3);
132
133 extensions.push(ClientHelloExtension::SignatureAlgorithms(
134 value
135 .signature_schemes()
136 .iter()
137 .map(|sc| SignatureScheme::rama_from(*sc))
138 .collect(),
139 ));
140
141 if let Some(domain) = value.server_name().and_then(|d| d.parse().ok()) {
142 extensions.push(ClientHelloExtension::ServerName(Some(domain)));
143 }
144
145 if let Some(alpn) = value.alpn() {
146 extensions.push(ClientHelloExtension::ApplicationLayerProtocolNegotiation(
147 alpn.map(ApplicationProtocol::from).collect(),
148 ));
149 }
150
151 Self::new(
152 ProtocolVersion::Unknown(0),
154 cipher_suites,
155 vec![],
156 extensions,
157 )
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_rustls_to_common_to_rustls() {
167 let p = rustls::ProtocolVersion::TLSv1_3;
168 let p = ProtocolVersion::rama_from(p);
169 assert_eq!(p, ProtocolVersion::TLSv1_3);
170 let p = rustls::ProtocolVersion::rama_from(p);
171 assert_eq!(p, rustls::ProtocolVersion::TLSv1_3);
172 }
173}