Skip to main content

ic_bn_lib/tls/
sessions.rs

1use std::time::Duration;
2
3use ahash::RandomState;
4use moka::sync::Cache;
5use prometheus::{
6    IntCounterVec, IntGauge, Registry, register_int_counter_vec_with_registry,
7    register_int_gauge_with_registry,
8};
9use rustls::server::StoresServerSessions;
10use tokio::time::interval;
11use zeroize::ZeroizeOnDrop;
12
13type Key = Vec<u8>;
14
15/// Sessions are considered highly sensitive data, so wipe the memory when
16/// they're removed from storage. We can't do anything with the returned Vec<u8>,
17/// but it's better than nothing.
18#[derive(Debug, PartialEq, Eq, Hash, Clone, ZeroizeOnDrop)]
19struct Val(Vec<u8>);
20
21const fn weigher(k: &Key, v: &Val) -> u32 {
22    (k.len() + v.0.len()) as u32
23}
24
25/// Stores TLS sessions for TLSv1.2 only.
26/// `SipHash` is replaced with ~10x faster aHash.
27/// see <https://github.com/tkaitchuck/aHash/blob/master/compare/readme.md>
28#[derive(Debug)]
29pub struct Storage {
30    cache: Cache<Key, Val, RandomState>,
31    metrics: Metrics,
32}
33
34impl Storage {
35    pub fn new(capacity: u64, tti: Duration, registry: &Registry) -> Self {
36        let cache = Cache::builder()
37            .max_capacity(capacity)
38            .time_to_idle(tti)
39            .weigher(weigher)
40            .build_with_hasher(RandomState::default());
41
42        let metrics = Metrics::new(registry);
43        Self { cache, metrics }
44    }
45
46    pub async fn metrics_runner(&self) {
47        let mut interval = interval(Duration::from_secs(1));
48        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
49
50        loop {
51            interval.tick().await;
52            self.metrics.size.set(self.cache.weighted_size() as i64);
53            self.metrics.count.set(self.cache.entry_count() as i64);
54        }
55    }
56}
57
58impl StoresServerSessions for Storage {
59    fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
60        let v = self.cache.get(key).map(|x| x.0.clone());
61        self.metrics.record("get", v.is_some());
62        v
63    }
64
65    fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
66        self.cache.insert(key, Val(value));
67        self.metrics.record("put", true);
68        true
69    }
70
71    fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
72        let v = self.cache.remove(key).map(|x| x.0.clone());
73        self.metrics.record("take", v.is_some());
74        v
75    }
76
77    fn can_cache(&self) -> bool {
78        true
79    }
80}
81
82#[derive(Debug)]
83pub struct Metrics {
84    count: IntGauge,
85    size: IntGauge,
86    processed: IntCounterVec,
87}
88
89impl Metrics {
90    pub fn new(registry: &Registry) -> Self {
91        Self {
92            count: register_int_gauge_with_registry!(
93                format!("tls_session_cache_count"),
94                format!("Number of TLS sessions in the cache"),
95                registry
96            )
97            .unwrap(),
98
99            size: register_int_gauge_with_registry!(
100                format!("tls_session_cache_size"),
101                format!("Size of TLS sessions in the cache"),
102                registry
103            )
104            .unwrap(),
105
106            processed: register_int_counter_vec_with_registry!(
107                format!("tls_sessions"),
108                format!("Number of TLS sessions that were processed"),
109                &["action", "found"],
110                registry
111            )
112            .unwrap(),
113        }
114    }
115
116    fn record(&self, action: &str, ok: bool) {
117        self.processed
118            .with_label_values(&[action, if ok { "yes" } else { "no" }])
119            .inc();
120    }
121}
122
123#[cfg(test)]
124mod test {
125    use super::*;
126
127    #[test]
128    fn test_storage() {
129        let c = Storage::new(10000, Duration::from_secs(3600), &Registry::new());
130
131        let key1 = "a".repeat(2500).as_bytes().to_vec();
132        let key2 = "b".repeat(2500).as_bytes().to_vec();
133        let key3 = b"b".to_vec();
134
135        // Check that two entries fit
136        c.put(key1.clone(), key1.clone());
137        c.cache.run_pending_tasks();
138        assert_eq!(c.cache.entry_count(), 1);
139        assert_eq!(c.cache.weighted_size(), 5000);
140        c.put(key2.clone(), key2.clone());
141        c.cache.run_pending_tasks();
142        assert_eq!(c.cache.entry_count(), 2);
143        assert_eq!(c.cache.weighted_size(), 10000);
144
145        // Check that 3rd entry won't fit
146        c.put(key3.clone(), key3.clone());
147        c.cache.run_pending_tasks();
148        assert_eq!(c.cache.entry_count(), 2);
149        assert_eq!(c.cache.weighted_size(), 10000);
150        assert!(c.get(&key3).is_none());
151
152        // Check that keys are taken and not left
153        assert!(c.take(&key1).is_some());
154        assert!(c.get(&key1).is_none());
155        assert!(c.take(&key2).is_some());
156        assert!(c.get(&key2).is_none());
157
158        // Check that nothing left
159        c.cache.run_pending_tasks();
160        assert_eq!(c.cache.entry_count(), 0);
161        assert_eq!(c.cache.weighted_size(), 0);
162    }
163}