clia_rustls_mod/client/
handy.rs

1use alloc::sync::Arc;
2
3use pki_types::ServerName;
4
5use crate::enums::SignatureScheme;
6use crate::error::Error;
7use crate::msgs::handshake::CertificateChain;
8use crate::msgs::persist;
9use crate::{client, sign, NamedGroup};
10
11/// An implementer of `ClientSessionStore` which does nothing.
12#[derive(Debug)]
13pub(super) struct NoClientSessionStorage;
14
15impl client::ClientSessionStore for NoClientSessionStorage {
16    fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
17
18    fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
19        None
20    }
21
22    fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
23
24    fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
25        None
26    }
27
28    fn remove_tls12_session(&self, _: &ServerName<'_>) {}
29
30    fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
31
32    fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
33        None
34    }
35}
36
37#[cfg(feature = "std")]
38mod cache {
39    use alloc::collections::VecDeque;
40    use core::fmt;
41    use std::sync::Mutex;
42
43    use pki_types::ServerName;
44
45    use crate::msgs::persist;
46    use crate::{limited_cache, NamedGroup};
47
48    const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
49
50    struct ServerData {
51        kx_hint: Option<NamedGroup>,
52
53        // Zero or one TLS1.2 sessions.
54        #[cfg(feature = "tls12")]
55        tls12: Option<persist::Tls12ClientSessionValue>,
56
57        // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first.
58        tls13: VecDeque<persist::Tls13ClientSessionValue>,
59    }
60
61    impl Default for ServerData {
62        fn default() -> Self {
63            Self {
64                kx_hint: None,
65                #[cfg(feature = "tls12")]
66                tls12: None,
67                tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
68            }
69        }
70    }
71
72    /// An implementer of `ClientSessionStore` that stores everything
73    /// in memory.
74    ///
75    /// It enforces a limit on the number of entries to bound memory usage.
76    pub struct ClientSessionMemoryCache {
77        servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
78    }
79
80    impl ClientSessionMemoryCache {
81        /// Make a new ClientSessionMemoryCache.  `size` is the
82        /// maximum number of stored sessions.
83        pub fn new(size: usize) -> Self {
84            let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
85                / MAX_TLS13_TICKETS_PER_SERVER;
86            Self {
87                servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
88            }
89        }
90    }
91
92    impl super::client::ClientSessionStore for ClientSessionMemoryCache {
93        fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
94            self.servers
95                .lock()
96                .unwrap()
97                .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
98        }
99
100        fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
101            self.servers
102                .lock()
103                .unwrap()
104                .get(server_name)
105                .and_then(|sd| sd.kx_hint)
106        }
107
108        fn set_tls12_session(
109            &self,
110            _server_name: ServerName<'static>,
111            _value: persist::Tls12ClientSessionValue,
112        ) {
113            #[cfg(feature = "tls12")]
114            self.servers
115                .lock()
116                .unwrap()
117                .get_or_insert_default_and_edit(_server_name.clone(), |data| {
118                    data.tls12 = Some(_value)
119                });
120        }
121
122        fn tls12_session(
123            &self,
124            _server_name: &ServerName<'_>,
125        ) -> Option<persist::Tls12ClientSessionValue> {
126            #[cfg(not(feature = "tls12"))]
127            return None;
128
129            #[cfg(feature = "tls12")]
130            self.servers
131                .lock()
132                .unwrap()
133                .get(_server_name)
134                .and_then(|sd| sd.tls12.as_ref().cloned())
135        }
136
137        fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
138            #[cfg(feature = "tls12")]
139            self.servers
140                .lock()
141                .unwrap()
142                .get_mut(_server_name)
143                .and_then(|data| data.tls12.take());
144        }
145
146        fn insert_tls13_ticket(
147            &self,
148            server_name: ServerName<'static>,
149            value: persist::Tls13ClientSessionValue,
150        ) {
151            self.servers
152                .lock()
153                .unwrap()
154                .get_or_insert_default_and_edit(server_name.clone(), |data| {
155                    if data.tls13.len() == data.tls13.capacity() {
156                        data.tls13.pop_front();
157                    }
158                    data.tls13.push_back(value);
159                });
160        }
161
162        fn take_tls13_ticket(
163            &self,
164            server_name: &ServerName<'static>,
165        ) -> Option<persist::Tls13ClientSessionValue> {
166            self.servers
167                .lock()
168                .unwrap()
169                .get_mut(server_name)
170                .and_then(|data| data.tls13.pop_back())
171        }
172    }
173
174    impl fmt::Debug for ClientSessionMemoryCache {
175        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176            // Note: we omit self.servers as it may contain sensitive data.
177            f.debug_struct("ClientSessionMemoryCache")
178                .finish()
179        }
180    }
181}
182
183#[cfg(feature = "std")]
184pub use cache::ClientSessionMemoryCache;
185
186#[derive(Debug)]
187pub(super) struct FailResolveClientCert {}
188
189impl client::ResolvesClientCert for FailResolveClientCert {
190    fn resolve(
191        &self,
192        _root_hint_subjects: &[&[u8]],
193        _sigschemes: &[SignatureScheme],
194    ) -> Option<Arc<sign::CertifiedKey>> {
195        None
196    }
197
198    fn has_certs(&self) -> bool {
199        false
200    }
201}
202
203#[derive(Debug)]
204pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
205
206impl AlwaysResolvesClientCert {
207    pub(super) fn new(
208        private_key: Arc<dyn sign::SigningKey>,
209        chain: CertificateChain<'static>,
210    ) -> Result<Self, Error> {
211        Ok(Self(Arc::new(sign::CertifiedKey::new(
212            chain.0,
213            private_key,
214        ))))
215    }
216}
217
218impl client::ResolvesClientCert for AlwaysResolvesClientCert {
219    fn resolve(
220        &self,
221        _root_hint_subjects: &[&[u8]],
222        _sigschemes: &[SignatureScheme],
223    ) -> Option<Arc<sign::CertifiedKey>> {
224        Some(Arc::clone(&self.0))
225    }
226
227    fn has_certs(&self) -> bool {
228        true
229    }
230}
231
232test_for_each_provider! {
233    use std::prelude::v1::*;
234    use super::NoClientSessionStorage;
235    use crate::client::ClientSessionStore;
236    use crate::msgs::enums::NamedGroup;
237    use crate::msgs::handshake::CertificateChain;
238    #[cfg(feature = "tls12")]
239    use crate::msgs::handshake::SessionId;
240    use crate::msgs::persist::Tls13ClientSessionValue;
241    use crate::suites::SupportedCipherSuite;
242    use provider::cipher_suite;
243
244    use pki_types::{ServerName, UnixTime};
245
246    #[test]
247    fn test_noclientsessionstorage_does_nothing() {
248        let c = NoClientSessionStorage {};
249        let name = ServerName::try_from("example.com").unwrap();
250        let now = UnixTime::now();
251
252        c.set_kx_hint(name.clone(), NamedGroup::X25519);
253        assert_eq!(None, c.kx_hint(&name));
254
255        #[cfg(feature = "tls12")]
256        {
257            use crate::msgs::persist::Tls12ClientSessionValue;
258            let SupportedCipherSuite::Tls12(tls12_suite) =
259                cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
260            else {
261                unreachable!()
262            };
263
264            c.set_tls12_session(
265                name.clone(),
266                Tls12ClientSessionValue::new(
267                    tls12_suite,
268                    SessionId::empty(),
269                    Vec::new(),
270                    &[],
271                    CertificateChain::default(),
272                    now,
273                    0,
274                    true,
275                ),
276            );
277            assert!(c.tls12_session(&name).is_none());
278            c.remove_tls12_session(&name);
279        }
280
281        #[cfg_attr(not(feature = "tls12"), allow(clippy::infallible_destructuring_match))]
282        let tls13_suite = match cipher_suite::TLS13_AES_256_GCM_SHA384 {
283            SupportedCipherSuite::Tls13(inner) => inner,
284            #[cfg(feature = "tls12")]
285            _ => unreachable!(),
286        };
287        c.insert_tls13_ticket(
288            name.clone(),
289            Tls13ClientSessionValue::new(
290                tls13_suite,
291                Vec::new(),
292                &[],
293                CertificateChain::default(),
294                now,
295                0,
296                0,
297                0,
298            ),
299        );
300        assert!(c.take_tls13_ticket(&name).is_none());
301    }
302}