clia_rustls_mod/server/
handy.rs1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::fmt::Debug;
4
5use crate::msgs::handshake::CertificateChain;
6use crate::server::ClientHello;
7use crate::{server, sign};
8
9#[derive(Debug)]
11pub struct NoServerSessionStorage {}
12
13impl server::StoresServerSessions for NoServerSessionStorage {
14 fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
15 false
16 }
17 fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
18 None
19 }
20 fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
21 None
22 }
23 fn can_cache(&self) -> bool {
24 false
25 }
26}
27
28#[cfg(feature = "std")]
29mod cache {
30 use alloc::sync::Arc;
31 use alloc::vec::Vec;
32 use core::fmt::{Debug, Formatter};
33 use std::sync::Mutex;
34
35 use crate::{limited_cache, server};
36
37 pub struct ServerSessionMemoryCache {
41 cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
42 }
43
44 impl ServerSessionMemoryCache {
45 pub fn new(size: usize) -> Arc<Self> {
49 Arc::new(Self {
50 cache: Mutex::new(limited_cache::LimitedCache::new(size)),
51 })
52 }
53 }
54
55 impl server::StoresServerSessions for ServerSessionMemoryCache {
56 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
57 self.cache
58 .lock()
59 .unwrap()
60 .insert(key, value);
61 true
62 }
63
64 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
65 self.cache
66 .lock()
67 .unwrap()
68 .get(key)
69 .cloned()
70 }
71
72 fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
73 self.cache.lock().unwrap().remove(key)
74 }
75
76 fn can_cache(&self) -> bool {
77 true
78 }
79 }
80
81 impl Debug for ServerSessionMemoryCache {
82 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
83 f.debug_struct("ServerSessionMemoryCache")
84 .finish()
85 }
86 }
87
88 #[cfg(test)]
89 mod tests {
90 use std::vec;
91
92 use super::*;
93 use crate::server::StoresServerSessions;
94
95 #[test]
96 fn test_serversessionmemorycache_accepts_put() {
97 let c = ServerSessionMemoryCache::new(4);
98 assert!(c.put(vec![0x01], vec![0x02]));
99 }
100
101 #[test]
102 fn test_serversessionmemorycache_persists_put() {
103 let c = ServerSessionMemoryCache::new(4);
104 assert!(c.put(vec![0x01], vec![0x02]));
105 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
106 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
107 }
108
109 #[test]
110 fn test_serversessionmemorycache_overwrites_put() {
111 let c = ServerSessionMemoryCache::new(4);
112 assert!(c.put(vec![0x01], vec![0x02]));
113 assert!(c.put(vec![0x01], vec![0x04]));
114 assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
115 }
116
117 #[test]
118 fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
119 let c = ServerSessionMemoryCache::new(2);
120 assert!(c.put(vec![0x01], vec![0x02]));
121 assert!(c.put(vec![0x03], vec![0x04]));
122 assert!(c.put(vec![0x05], vec![0x06]));
123 assert!(c.put(vec![0x07], vec![0x08]));
124 assert!(c.put(vec![0x09], vec![0x0a]));
125
126 let count = c.get(&[0x01]).iter().count()
127 + c.get(&[0x03]).iter().count()
128 + c.get(&[0x05]).iter().count()
129 + c.get(&[0x07]).iter().count()
130 + c.get(&[0x09]).iter().count();
131
132 assert!(count < 5);
133 }
134 }
135}
136#[cfg(feature = "std")]
137pub use cache::ServerSessionMemoryCache;
138
139#[derive(Debug)]
141pub(super) struct NeverProducesTickets {}
142
143impl server::ProducesTickets for NeverProducesTickets {
144 fn enabled(&self) -> bool {
145 false
146 }
147 fn lifetime(&self) -> u32 {
148 0
149 }
150 fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
151 None
152 }
153 fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
154 None
155 }
156}
157
158#[derive(Debug)]
160pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
161
162impl AlwaysResolvesChain {
163 pub(super) fn new(
165 private_key: Arc<dyn sign::SigningKey>,
166 chain: CertificateChain<'static>,
167 ) -> Self {
168 Self(Arc::new(sign::CertifiedKey::new(chain.0, private_key)))
169 }
170
171 pub(super) fn new_with_extras(
175 private_key: Arc<dyn sign::SigningKey>,
176 chain: CertificateChain<'static>,
177 ocsp: Vec<u8>,
178 ) -> Self {
179 let mut r = Self::new(private_key, chain);
180
181 {
182 let cert = Arc::make_mut(&mut r.0);
183 if !ocsp.is_empty() {
184 cert.ocsp = Some(ocsp);
185 }
186 }
187
188 r
189 }
190}
191
192impl server::ResolvesServerCert for AlwaysResolvesChain {
193 fn resolve(&self, _client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
194 Some(Arc::clone(&self.0))
195 }
196}
197
198#[cfg(feature = "std")]
199mod sni_resolver {
200 use alloc::string::{String, ToString};
201 use alloc::sync::Arc;
202 use core::fmt::Debug;
203 use std::collections::HashMap;
204
205 use pki_types::{DnsName, ServerName};
206
207 use crate::error::Error;
208 use crate::server::ClientHello;
209 use crate::webpki::{verify_server_name, ParsedCertificate};
210 use crate::{server, sign};
211
212 #[derive(Debug)]
215 pub struct ResolvesServerCertUsingSni {
216 by_name: HashMap<String, Arc<sign::CertifiedKey>>,
217 }
218
219 impl ResolvesServerCertUsingSni {
220 pub fn new() -> Self {
222 Self {
223 by_name: HashMap::new(),
224 }
225 }
226
227 pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
233 let server_name = {
234 let checked_name = DnsName::try_from(name)
235 .map_err(|_| Error::General("Bad DNS name".into()))
236 .map(|name| name.to_lowercase_owned())?;
237 ServerName::DnsName(checked_name)
238 };
239
240 ck.end_entity_cert()
250 .and_then(ParsedCertificate::try_from)
251 .and_then(|cert| verify_server_name(&cert, &server_name))?;
252
253 if let ServerName::DnsName(name) = server_name {
254 self.by_name
255 .insert(name.as_ref().to_string(), Arc::new(ck));
256 }
257 Ok(())
258 }
259 }
260
261 impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
262 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
263 if let Some(name) = client_hello.server_name() {
264 self.by_name.get(name).cloned()
265 } else {
266 None
268 }
269 }
270 }
271
272 #[cfg(test)]
273 mod tests {
274 use super::*;
275 use crate::server::ResolvesServerCert;
276
277 #[test]
278 fn test_resolvesservercertusingsni_requires_sni() {
279 let rscsni = ResolvesServerCertUsingSni::new();
280 assert!(rscsni
281 .resolve(ClientHello::new(&None, &[], None, &[]))
282 .is_none());
283 }
284
285 #[test]
286 fn test_resolvesservercertusingsni_handles_unknown_name() {
287 let rscsni = ResolvesServerCertUsingSni::new();
288 let name = DnsName::try_from("hello.com")
289 .unwrap()
290 .to_owned();
291 assert!(rscsni
292 .resolve(ClientHello::new(&Some(name), &[], None, &[]))
293 .is_none());
294 }
295 }
296}
297
298#[cfg(feature = "std")]
299pub use sni_resolver::ResolvesServerCertUsingSni;
300
301#[cfg(test)]
302mod tests {
303 use std::vec;
304
305 use super::*;
306 use crate::server::{ProducesTickets, StoresServerSessions};
307
308 #[test]
309 fn test_noserversessionstorage_drops_put() {
310 let c = NoServerSessionStorage {};
311 assert!(!c.put(vec![0x01], vec![0x02]));
312 }
313
314 #[test]
315 fn test_noserversessionstorage_denies_gets() {
316 let c = NoServerSessionStorage {};
317 c.put(vec![0x01], vec![0x02]);
318 assert_eq!(c.get(&[]), None);
319 assert_eq!(c.get(&[0x01]), None);
320 assert_eq!(c.get(&[0x02]), None);
321 }
322
323 #[test]
324 fn test_noserversessionstorage_denies_takes() {
325 let c = NoServerSessionStorage {};
326 assert_eq!(c.take(&[]), None);
327 assert_eq!(c.take(&[0x01]), None);
328 assert_eq!(c.take(&[0x02]), None);
329 }
330
331 #[test]
332 fn test_neverproducestickets_does_nothing() {
333 let npt = NeverProducesTickets {};
334 assert!(!npt.enabled());
335 assert_eq!(0, npt.lifetime());
336 assert_eq!(None, npt.encrypt(&[]));
337 assert_eq!(None, npt.decrypt(&[]));
338 }
339}