clia_rustls_mod/client/
handy.rs1use alloc::sync::Arc;
2
3use pki_types::ServerName;
4
5use crate::enums::SignatureScheme;
6use crate::error::Error;
7use crate::msgs::handshake::CertificateChain;
8use crate::msgs::persist;
9use crate::{client, sign, NamedGroup};
10
11#[derive(Debug)]
13pub(super) struct NoClientSessionStorage;
14
15impl client::ClientSessionStore for NoClientSessionStorage {
16 fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
17
18 fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
19 None
20 }
21
22 fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
23
24 fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
25 None
26 }
27
28 fn remove_tls12_session(&self, _: &ServerName<'_>) {}
29
30 fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
31
32 fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
33 None
34 }
35}
36
37#[cfg(feature = "std")]
38mod cache {
39 use alloc::collections::VecDeque;
40 use core::fmt;
41 use std::sync::Mutex;
42
43 use pki_types::ServerName;
44
45 use crate::msgs::persist;
46 use crate::{limited_cache, NamedGroup};
47
48 const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
49
50 struct ServerData {
51 kx_hint: Option<NamedGroup>,
52
53 #[cfg(feature = "tls12")]
55 tls12: Option<persist::Tls12ClientSessionValue>,
56
57 tls13: VecDeque<persist::Tls13ClientSessionValue>,
59 }
60
61 impl Default for ServerData {
62 fn default() -> Self {
63 Self {
64 kx_hint: None,
65 #[cfg(feature = "tls12")]
66 tls12: None,
67 tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
68 }
69 }
70 }
71
72 pub struct ClientSessionMemoryCache {
77 servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
78 }
79
80 impl ClientSessionMemoryCache {
81 pub fn new(size: usize) -> Self {
84 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
85 / MAX_TLS13_TICKETS_PER_SERVER;
86 Self {
87 servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
88 }
89 }
90 }
91
92 impl super::client::ClientSessionStore for ClientSessionMemoryCache {
93 fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
94 self.servers
95 .lock()
96 .unwrap()
97 .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
98 }
99
100 fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
101 self.servers
102 .lock()
103 .unwrap()
104 .get(server_name)
105 .and_then(|sd| sd.kx_hint)
106 }
107
108 fn set_tls12_session(
109 &self,
110 _server_name: ServerName<'static>,
111 _value: persist::Tls12ClientSessionValue,
112 ) {
113 #[cfg(feature = "tls12")]
114 self.servers
115 .lock()
116 .unwrap()
117 .get_or_insert_default_and_edit(_server_name.clone(), |data| {
118 data.tls12 = Some(_value)
119 });
120 }
121
122 fn tls12_session(
123 &self,
124 _server_name: &ServerName<'_>,
125 ) -> Option<persist::Tls12ClientSessionValue> {
126 #[cfg(not(feature = "tls12"))]
127 return None;
128
129 #[cfg(feature = "tls12")]
130 self.servers
131 .lock()
132 .unwrap()
133 .get(_server_name)
134 .and_then(|sd| sd.tls12.as_ref().cloned())
135 }
136
137 fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
138 #[cfg(feature = "tls12")]
139 self.servers
140 .lock()
141 .unwrap()
142 .get_mut(_server_name)
143 .and_then(|data| data.tls12.take());
144 }
145
146 fn insert_tls13_ticket(
147 &self,
148 server_name: ServerName<'static>,
149 value: persist::Tls13ClientSessionValue,
150 ) {
151 self.servers
152 .lock()
153 .unwrap()
154 .get_or_insert_default_and_edit(server_name.clone(), |data| {
155 if data.tls13.len() == data.tls13.capacity() {
156 data.tls13.pop_front();
157 }
158 data.tls13.push_back(value);
159 });
160 }
161
162 fn take_tls13_ticket(
163 &self,
164 server_name: &ServerName<'static>,
165 ) -> Option<persist::Tls13ClientSessionValue> {
166 self.servers
167 .lock()
168 .unwrap()
169 .get_mut(server_name)
170 .and_then(|data| data.tls13.pop_back())
171 }
172 }
173
174 impl fmt::Debug for ClientSessionMemoryCache {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_struct("ClientSessionMemoryCache")
178 .finish()
179 }
180 }
181}
182
183#[cfg(feature = "std")]
184pub use cache::ClientSessionMemoryCache;
185
186#[derive(Debug)]
187pub(super) struct FailResolveClientCert {}
188
189impl client::ResolvesClientCert for FailResolveClientCert {
190 fn resolve(
191 &self,
192 _root_hint_subjects: &[&[u8]],
193 _sigschemes: &[SignatureScheme],
194 ) -> Option<Arc<sign::CertifiedKey>> {
195 None
196 }
197
198 fn has_certs(&self) -> bool {
199 false
200 }
201}
202
203#[derive(Debug)]
204pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
205
206impl AlwaysResolvesClientCert {
207 pub(super) fn new(
208 private_key: Arc<dyn sign::SigningKey>,
209 chain: CertificateChain<'static>,
210 ) -> Result<Self, Error> {
211 Ok(Self(Arc::new(sign::CertifiedKey::new(
212 chain.0,
213 private_key,
214 ))))
215 }
216}
217
218impl client::ResolvesClientCert for AlwaysResolvesClientCert {
219 fn resolve(
220 &self,
221 _root_hint_subjects: &[&[u8]],
222 _sigschemes: &[SignatureScheme],
223 ) -> Option<Arc<sign::CertifiedKey>> {
224 Some(Arc::clone(&self.0))
225 }
226
227 fn has_certs(&self) -> bool {
228 true
229 }
230}
231
232test_for_each_provider! {
233 use std::prelude::v1::*;
234 use super::NoClientSessionStorage;
235 use crate::client::ClientSessionStore;
236 use crate::msgs::enums::NamedGroup;
237 use crate::msgs::handshake::CertificateChain;
238 #[cfg(feature = "tls12")]
239 use crate::msgs::handshake::SessionId;
240 use crate::msgs::persist::Tls13ClientSessionValue;
241 use crate::suites::SupportedCipherSuite;
242 use provider::cipher_suite;
243
244 use pki_types::{ServerName, UnixTime};
245
246 #[test]
247 fn test_noclientsessionstorage_does_nothing() {
248 let c = NoClientSessionStorage {};
249 let name = ServerName::try_from("example.com").unwrap();
250 let now = UnixTime::now();
251
252 c.set_kx_hint(name.clone(), NamedGroup::X25519);
253 assert_eq!(None, c.kx_hint(&name));
254
255 #[cfg(feature = "tls12")]
256 {
257 use crate::msgs::persist::Tls12ClientSessionValue;
258 let SupportedCipherSuite::Tls12(tls12_suite) =
259 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
260 else {
261 unreachable!()
262 };
263
264 c.set_tls12_session(
265 name.clone(),
266 Tls12ClientSessionValue::new(
267 tls12_suite,
268 SessionId::empty(),
269 Vec::new(),
270 &[],
271 CertificateChain::default(),
272 now,
273 0,
274 true,
275 ),
276 );
277 assert!(c.tls12_session(&name).is_none());
278 c.remove_tls12_session(&name);
279 }
280
281 #[cfg_attr(not(feature = "tls12"), allow(clippy::infallible_destructuring_match))]
282 let tls13_suite = match cipher_suite::TLS13_AES_256_GCM_SHA384 {
283 SupportedCipherSuite::Tls13(inner) => inner,
284 #[cfg(feature = "tls12")]
285 _ => unreachable!(),
286 };
287 c.insert_tls13_ticket(
288 name.clone(),
289 Tls13ClientSessionValue::new(
290 tls13_suite,
291 Vec::new(),
292 &[],
293 CertificateChain::default(),
294 now,
295 0,
296 0,
297 0,
298 ),
299 );
300 assert!(c.take_tls13_ticket(&name).is_none());
301 }
302}