clia_rustls_mod/server/
handy.rs

1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::fmt::Debug;
4
5use crate::msgs::handshake::CertificateChain;
6use crate::server::ClientHello;
7use crate::{server, sign};
8
9/// Something which never stores sessions.
10#[derive(Debug)]
11pub struct NoServerSessionStorage {}
12
13impl server::StoresServerSessions for NoServerSessionStorage {
14    fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
15        false
16    }
17    fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
18        None
19    }
20    fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
21        None
22    }
23    fn can_cache(&self) -> bool {
24        false
25    }
26}
27
28#[cfg(feature = "std")]
29mod cache {
30    use alloc::sync::Arc;
31    use alloc::vec::Vec;
32    use core::fmt::{Debug, Formatter};
33    use std::sync::Mutex;
34
35    use crate::{limited_cache, server};
36
37    /// An implementer of `StoresServerSessions` that stores everything
38    /// in memory.  If enforces a limit on the number of stored sessions
39    /// to bound memory usage.
40    pub struct ServerSessionMemoryCache {
41        cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
42    }
43
44    impl ServerSessionMemoryCache {
45        /// Make a new ServerSessionMemoryCache.  `size` is the maximum
46        /// number of stored sessions, and may be rounded-up for
47        /// efficiency.
48        pub fn new(size: usize) -> Arc<Self> {
49            Arc::new(Self {
50                cache: Mutex::new(limited_cache::LimitedCache::new(size)),
51            })
52        }
53    }
54
55    impl server::StoresServerSessions for ServerSessionMemoryCache {
56        fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
57            self.cache
58                .lock()
59                .unwrap()
60                .insert(key, value);
61            true
62        }
63
64        fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
65            self.cache
66                .lock()
67                .unwrap()
68                .get(key)
69                .cloned()
70        }
71
72        fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
73            self.cache.lock().unwrap().remove(key)
74        }
75
76        fn can_cache(&self) -> bool {
77            true
78        }
79    }
80
81    impl Debug for ServerSessionMemoryCache {
82        fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
83            f.debug_struct("ServerSessionMemoryCache")
84                .finish()
85        }
86    }
87
88    #[cfg(test)]
89    mod tests {
90        use std::vec;
91
92        use super::*;
93        use crate::server::StoresServerSessions;
94
95        #[test]
96        fn test_serversessionmemorycache_accepts_put() {
97            let c = ServerSessionMemoryCache::new(4);
98            assert!(c.put(vec![0x01], vec![0x02]));
99        }
100
101        #[test]
102        fn test_serversessionmemorycache_persists_put() {
103            let c = ServerSessionMemoryCache::new(4);
104            assert!(c.put(vec![0x01], vec![0x02]));
105            assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
106            assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
107        }
108
109        #[test]
110        fn test_serversessionmemorycache_overwrites_put() {
111            let c = ServerSessionMemoryCache::new(4);
112            assert!(c.put(vec![0x01], vec![0x02]));
113            assert!(c.put(vec![0x01], vec![0x04]));
114            assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
115        }
116
117        #[test]
118        fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
119            let c = ServerSessionMemoryCache::new(2);
120            assert!(c.put(vec![0x01], vec![0x02]));
121            assert!(c.put(vec![0x03], vec![0x04]));
122            assert!(c.put(vec![0x05], vec![0x06]));
123            assert!(c.put(vec![0x07], vec![0x08]));
124            assert!(c.put(vec![0x09], vec![0x0a]));
125
126            let count = c.get(&[0x01]).iter().count()
127                + c.get(&[0x03]).iter().count()
128                + c.get(&[0x05]).iter().count()
129                + c.get(&[0x07]).iter().count()
130                + c.get(&[0x09]).iter().count();
131
132            assert!(count < 5);
133        }
134    }
135}
136#[cfg(feature = "std")]
137pub use cache::ServerSessionMemoryCache;
138
139/// Something which never produces tickets.
140#[derive(Debug)]
141pub(super) struct NeverProducesTickets {}
142
143impl server::ProducesTickets for NeverProducesTickets {
144    fn enabled(&self) -> bool {
145        false
146    }
147    fn lifetime(&self) -> u32 {
148        0
149    }
150    fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
151        None
152    }
153    fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
154        None
155    }
156}
157
158/// Something which always resolves to the same cert chain.
159#[derive(Debug)]
160pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
161
162impl AlwaysResolvesChain {
163    /// Creates an `AlwaysResolvesChain`, using the supplied key and certificate chain.
164    pub(super) fn new(
165        private_key: Arc<dyn sign::SigningKey>,
166        chain: CertificateChain<'static>,
167    ) -> Self {
168        Self(Arc::new(sign::CertifiedKey::new(chain.0, private_key)))
169    }
170
171    /// Creates an `AlwaysResolvesChain`, using the supplied key, certificate chain and OCSP response.
172    ///
173    /// If non-empty, the given OCSP response is attached.
174    pub(super) fn new_with_extras(
175        private_key: Arc<dyn sign::SigningKey>,
176        chain: CertificateChain<'static>,
177        ocsp: Vec<u8>,
178    ) -> Self {
179        let mut r = Self::new(private_key, chain);
180
181        {
182            let cert = Arc::make_mut(&mut r.0);
183            if !ocsp.is_empty() {
184                cert.ocsp = Some(ocsp);
185            }
186        }
187
188        r
189    }
190}
191
192impl server::ResolvesServerCert for AlwaysResolvesChain {
193    fn resolve(&self, _client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
194        Some(Arc::clone(&self.0))
195    }
196}
197
198#[cfg(feature = "std")]
199mod sni_resolver {
200    use alloc::string::{String, ToString};
201    use alloc::sync::Arc;
202    use core::fmt::Debug;
203    use std::collections::HashMap;
204
205    use pki_types::{DnsName, ServerName};
206
207    use crate::error::Error;
208    use crate::server::ClientHello;
209    use crate::webpki::{verify_server_name, ParsedCertificate};
210    use crate::{server, sign};
211
212    /// Something that resolves do different cert chains/keys based
213    /// on client-supplied server name (via SNI).
214    #[derive(Debug)]
215    pub struct ResolvesServerCertUsingSni {
216        by_name: HashMap<String, Arc<sign::CertifiedKey>>,
217    }
218
219    impl ResolvesServerCertUsingSni {
220        /// Create a new and empty (i.e., knows no certificates) resolver.
221        pub fn new() -> Self {
222            Self {
223                by_name: HashMap::new(),
224            }
225        }
226
227        /// Add a new `sign::CertifiedKey` to be used for the given SNI `name`.
228        ///
229        /// This function fails if `name` is not a valid DNS name, or if
230        /// it's not valid for the supplied certificate, or if the certificate
231        /// chain is syntactically faulty.
232        pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
233            let server_name = {
234                let checked_name = DnsName::try_from(name)
235                    .map_err(|_| Error::General("Bad DNS name".into()))
236                    .map(|name| name.to_lowercase_owned())?;
237                ServerName::DnsName(checked_name)
238            };
239
240            // Check the certificate chain for validity:
241            // - it should be non-empty list
242            // - the first certificate should be parsable as a x509v3,
243            // - the first certificate should quote the given server name
244            //   (if provided)
245            //
246            // These checks are not security-sensitive.  They are the
247            // *server* attempting to detect accidental misconfiguration.
248
249            ck.end_entity_cert()
250                .and_then(ParsedCertificate::try_from)
251                .and_then(|cert| verify_server_name(&cert, &server_name))?;
252
253            if let ServerName::DnsName(name) = server_name {
254                self.by_name
255                    .insert(name.as_ref().to_string(), Arc::new(ck));
256            }
257            Ok(())
258        }
259    }
260
261    impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
262        fn resolve(&self, client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
263            if let Some(name) = client_hello.server_name() {
264                self.by_name.get(name).cloned()
265            } else {
266                // This kind of resolver requires SNI
267                None
268            }
269        }
270    }
271
272    #[cfg(test)]
273    mod tests {
274        use super::*;
275        use crate::server::ResolvesServerCert;
276
277        #[test]
278        fn test_resolvesservercertusingsni_requires_sni() {
279            let rscsni = ResolvesServerCertUsingSni::new();
280            assert!(rscsni
281                .resolve(ClientHello::new(&None, &[], None, &[]))
282                .is_none());
283        }
284
285        #[test]
286        fn test_resolvesservercertusingsni_handles_unknown_name() {
287            let rscsni = ResolvesServerCertUsingSni::new();
288            let name = DnsName::try_from("hello.com")
289                .unwrap()
290                .to_owned();
291            assert!(rscsni
292                .resolve(ClientHello::new(&Some(name), &[], None, &[]))
293                .is_none());
294        }
295    }
296}
297
298#[cfg(feature = "std")]
299pub use sni_resolver::ResolvesServerCertUsingSni;
300
301#[cfg(test)]
302mod tests {
303    use std::vec;
304
305    use super::*;
306    use crate::server::{ProducesTickets, StoresServerSessions};
307
308    #[test]
309    fn test_noserversessionstorage_drops_put() {
310        let c = NoServerSessionStorage {};
311        assert!(!c.put(vec![0x01], vec![0x02]));
312    }
313
314    #[test]
315    fn test_noserversessionstorage_denies_gets() {
316        let c = NoServerSessionStorage {};
317        c.put(vec![0x01], vec![0x02]);
318        assert_eq!(c.get(&[]), None);
319        assert_eq!(c.get(&[0x01]), None);
320        assert_eq!(c.get(&[0x02]), None);
321    }
322
323    #[test]
324    fn test_noserversessionstorage_denies_takes() {
325        let c = NoServerSessionStorage {};
326        assert_eq!(c.take(&[]), None);
327        assert_eq!(c.take(&[0x01]), None);
328        assert_eq!(c.take(&[0x02]), None);
329    }
330
331    #[test]
332    fn test_neverproducestickets_does_nothing() {
333        let npt = NeverProducesTickets {};
334        assert!(!npt.enabled());
335        assert_eq!(0, npt.lifetime());
336        assert_eq!(None, npt.encrypt(&[]));
337        assert_eq!(None, npt.decrypt(&[]));
338    }
339}