ic_bn_lib/tls/
sessions.rs1use std::time::Duration;
2
3use ahash::RandomState;
4use moka::sync::Cache;
5use prometheus::{
6 IntCounterVec, IntGauge, Registry, register_int_counter_vec_with_registry,
7 register_int_gauge_with_registry,
8};
9use rustls::server::StoresServerSessions;
10use tokio::time::interval;
11use zeroize::ZeroizeOnDrop;
12
13type Key = Vec<u8>;
14
15#[derive(Debug, PartialEq, Eq, Hash, Clone, ZeroizeOnDrop)]
19struct Val(Vec<u8>);
20
21const fn weigher(k: &Key, v: &Val) -> u32 {
22 (k.len() + v.0.len()) as u32
23}
24
25#[derive(Debug)]
29pub struct Storage {
30 cache: Cache<Key, Val, RandomState>,
31 metrics: Metrics,
32}
33
34impl Storage {
35 pub fn new(capacity: u64, tti: Duration, registry: &Registry) -> Self {
36 let cache = Cache::builder()
37 .max_capacity(capacity)
38 .time_to_idle(tti)
39 .weigher(weigher)
40 .build_with_hasher(RandomState::default());
41
42 let metrics = Metrics::new(registry);
43 Self { cache, metrics }
44 }
45
46 pub async fn metrics_runner(&self) {
47 let mut interval = interval(Duration::from_secs(1));
48 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
49
50 loop {
51 interval.tick().await;
52 self.metrics.size.set(self.cache.weighted_size() as i64);
53 self.metrics.count.set(self.cache.entry_count() as i64);
54 }
55 }
56}
57
58impl StoresServerSessions for Storage {
59 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
60 let v = self.cache.get(key).map(|x| x.0.clone());
61 self.metrics.record("get", v.is_some());
62 v
63 }
64
65 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
66 self.cache.insert(key, Val(value));
67 self.metrics.record("put", true);
68 true
69 }
70
71 fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
72 let v = self.cache.remove(key).map(|x| x.0.clone());
73 self.metrics.record("take", v.is_some());
74 v
75 }
76
77 fn can_cache(&self) -> bool {
78 true
79 }
80}
81
82#[derive(Debug)]
83pub struct Metrics {
84 count: IntGauge,
85 size: IntGauge,
86 processed: IntCounterVec,
87}
88
89impl Metrics {
90 pub fn new(registry: &Registry) -> Self {
91 Self {
92 count: register_int_gauge_with_registry!(
93 format!("tls_session_cache_count"),
94 format!("Number of TLS sessions in the cache"),
95 registry
96 )
97 .unwrap(),
98
99 size: register_int_gauge_with_registry!(
100 format!("tls_session_cache_size"),
101 format!("Size of TLS sessions in the cache"),
102 registry
103 )
104 .unwrap(),
105
106 processed: register_int_counter_vec_with_registry!(
107 format!("tls_sessions"),
108 format!("Number of TLS sessions that were processed"),
109 &["action", "found"],
110 registry
111 )
112 .unwrap(),
113 }
114 }
115
116 fn record(&self, action: &str, ok: bool) {
117 self.processed
118 .with_label_values(&[action, if ok { "yes" } else { "no" }])
119 .inc();
120 }
121}
122
123#[cfg(test)]
124mod test {
125 use super::*;
126
127 #[test]
128 fn test_storage() {
129 let c = Storage::new(10000, Duration::from_secs(3600), &Registry::new());
130
131 let key1 = "a".repeat(2500).as_bytes().to_vec();
132 let key2 = "b".repeat(2500).as_bytes().to_vec();
133 let key3 = b"b".to_vec();
134
135 c.put(key1.clone(), key1.clone());
137 c.cache.run_pending_tasks();
138 assert_eq!(c.cache.entry_count(), 1);
139 assert_eq!(c.cache.weighted_size(), 5000);
140 c.put(key2.clone(), key2.clone());
141 c.cache.run_pending_tasks();
142 assert_eq!(c.cache.entry_count(), 2);
143 assert_eq!(c.cache.weighted_size(), 10000);
144
145 c.put(key3.clone(), key3.clone());
147 c.cache.run_pending_tasks();
148 assert_eq!(c.cache.entry_count(), 2);
149 assert_eq!(c.cache.weighted_size(), 10000);
150 assert!(c.get(&key3).is_none());
151
152 assert!(c.take(&key1).is_some());
154 assert!(c.get(&key1).is_none());
155 assert!(c.take(&key2).is_some());
156 assert!(c.get(&key2).is_none());
157
158 c.cache.run_pending_tasks();
160 assert_eq!(c.cache.entry_count(), 0);
161 assert_eq!(c.cache.weighted_size(), 0);
162 }
163}