Skip to main content

rustls_graviola/
ticketer.rs

1use core::sync::atomic::{AtomicUsize, Ordering};
2use std::fmt;
3use std::sync::Arc;
4
5use graviola::{aead, random};
6use rustls::crypto::GetRandomFailed;
7use rustls::server::ProducesTickets;
8use rustls::{Error, TicketRotator};
9
10/// The default ticketer.
11pub struct Ticketer;
12
13impl Ticketer {
14    /// Make a new ticketer.
15    ///
16    /// Tickets are encrypted with XChaCha20Poly1305.
17    /// Ticket keys are rotated every 6 hours.
18    #[allow(clippy::new_ret_no_self)]
19    pub fn new() -> Result<Arc<dyn ProducesTickets>, Error> {
20        Ok(Arc::new(TicketRotator::new(
21            ONE_TICKET_LIFETIME_SECS,
22            make_ticket_generator,
23        )?))
24    }
25}
26
27fn make_ticket_generator() -> Result<Box<dyn ProducesTickets>, GetRandomFailed> {
28    Ok(Box::new(XChaCha20Ticketer::new()?))
29}
30
31struct XChaCha20Ticketer {
32    key: aead::XChaCha20Poly1305,
33    key_name: [u8; 16],
34    lifetime: u32,
35    maximum_ciphertext_len: AtomicUsize,
36}
37
38impl XChaCha20Ticketer {
39    fn new() -> Result<Self, GetRandomFailed> {
40        let mut key = [0u8; 32];
41        let mut key_name = [0u8; 16];
42
43        random::fill(&mut key).map_err(|_| GetRandomFailed)?;
44        random::fill(&mut key_name).map_err(|_| GetRandomFailed)?;
45
46        let key = aead::XChaCha20Poly1305::new(key);
47
48        Ok(Self {
49            key,
50            key_name,
51            lifetime: ONE_TICKET_LIFETIME_SECS,
52            maximum_ciphertext_len: AtomicUsize::new(0),
53        })
54    }
55}
56
57impl ProducesTickets for XChaCha20Ticketer {
58    fn enabled(&self) -> bool {
59        true
60    }
61
62    fn lifetime(&self) -> u32 {
63        self.lifetime
64    }
65
66    fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
67        let mut nonce = [0u8; 24];
68        random::fill(&mut nonce).ok()?;
69
70        // wire format is:
71        // - key_name [u8; 16]
72        // - nonce [u8; 24]
73        // - ciphertext [u8; n]
74        // - tag [u8; 16]
75        //
76        // aad is key_name
77
78        let mut tag = [0u8; 16];
79        let mut res =
80            Vec::with_capacity(self.key_name.len() + nonce.len() + message.len() + tag.len());
81        res.extend(&self.key_name);
82        res.extend(&nonce);
83        res.extend(message);
84
85        self.key.encrypt(
86            &nonce,
87            &self.key_name,
88            &mut res[self.key_name.len() + nonce.len()..],
89            &mut tag,
90        );
91        res.extend(tag);
92
93        self.maximum_ciphertext_len
94            .fetch_max(res.len(), Ordering::SeqCst);
95
96        Some(res)
97    }
98
99    fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
100        if ciphertext.len() > self.maximum_ciphertext_len.load(Ordering::SeqCst) {
101            return None;
102        }
103
104        let plain_len = ciphertext
105            .len()
106            .saturating_sub(self.key_name.len() + 24 + 16);
107
108        if plain_len == 0 {
109            return None;
110        }
111
112        let (alleged_key_name, rest) = ciphertext.split_at(self.key_name.len());
113
114        // nb. key_name is public data
115        if alleged_key_name != self.key_name {
116            return None;
117        }
118
119        let (nonce, rest) = rest.split_at(24);
120        let nonce = nonce.try_into().unwrap();
121        let (plain, alleged_tag) = rest.split_at(plain_len);
122        let mut plain = plain.to_vec();
123
124        self.key
125            .decrypt(&nonce, alleged_key_name, &mut plain, alleged_tag)
126            .ok()?;
127        Some(plain)
128    }
129}
130
131impl fmt::Debug for XChaCha20Ticketer {
132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        f.debug_struct("XChaCha20Ticketer")
134            .field("lifetime", &self.lifetime)
135            .finish_non_exhaustive()
136    }
137}
138
139const ONE_TICKET_LIFETIME_SECS: u32 = 6 * 60 * 60;
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn roundtrip() {
147        let t = Ticketer::new().unwrap();
148        let ehello = t.encrypt(b"hello").unwrap();
149        assert_eq!(t.decrypt(&ehello).unwrap(), b"hello");
150
151        assert!(t.enabled());
152        assert_eq!(t.lifetime(), ONE_TICKET_LIFETIME_SECS * 2);
153        println!("{t:?}");
154    }
155
156    #[test]
157    fn make_generator() {
158        let g = make_ticket_generator().unwrap();
159        assert!(g.enabled());
160        assert_eq!(g.lifetime(), ONE_TICKET_LIFETIME_SECS);
161        println!("{g:?}");
162    }
163
164    #[test]
165    fn length_checks() {
166        let t = Ticketer::new().unwrap();
167        assert_eq!(t.decrypt(b""), None);
168        assert_eq!(t.decrypt(b"a"), None);
169
170        let e = t.encrypt(b"a").unwrap();
171        assert_eq!(t.decrypt(&e).unwrap(), b"a");
172        assert_eq!(t.decrypt(&e[..e.len() - 1]), None);
173    }
174
175    #[test]
176    fn non_malleable() {
177        let t = Ticketer::new().unwrap();
178        let ehello = t.encrypt(b"hello").unwrap();
179
180        for i in 0..ehello.len() {
181            let mut ehello_tmp = ehello.clone();
182            ehello_tmp[i] ^= 1;
183            assert_eq!(None, t.decrypt(&ehello_tmp));
184        }
185    }
186}