1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
use aes_gcm::{
    aead::{Aead, OsRng},
    AeadCore, Aes256Gcm, Key, KeyInit,
};
use base64ct::{Base64, Base64Url, Encoding};

use crate::{
    filter::Filter,
    ordering::Ordering,
    query::{
        error::{QueryError, QueryResult},
        page_token::utility::get_page_filter,
    },
    schema::SchemaMapped,
};

use super::{utility::make_page_key, FilterPageToken, PageTokenBuilder};

const NONCE_LENGTH: usize = 12;

/// AES-256-GCM page token builder.
/// The page token is encrypted using the query parameters as the key.
/// This is useful for ensuring that the page token was generated for the same paging rules.
pub struct Aes256PageTokenBuilder {
    url_safe: bool,
}

impl Aes256PageTokenBuilder {
    pub fn new(url_safe: bool) -> Self {
        Self { url_safe }
    }
}

impl PageTokenBuilder for Aes256PageTokenBuilder {
    type PageToken = FilterPageToken;

    fn parse(
        &self,
        filter: &Filter,
        ordering: &Ordering,
        page_token: &str,
    ) -> QueryResult<Self::PageToken> {
        let decoded = if self.url_safe {
            Base64Url::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)?
        } else {
            Base64::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)?
        };

        let key = make_page_key::<32>(filter, ordering);
        let key: &Key<Aes256Gcm> = (&key).into();

        let cipher = Aes256Gcm::new(key);
        if decoded.len() <= NONCE_LENGTH {
            return Err(QueryError::InvalidPageToken);
        }
        let (nonce_buf, encrypted) = decoded.split_at(NONCE_LENGTH);
        let nonce = nonce_buf
            .try_into()
            .map_err(|_| QueryError::InvalidPageToken)?;

        let plaintext = cipher
            .decrypt(nonce, encrypted)
            .map_err(|_| QueryError::InvalidPageToken)?;

        let page_filter =
            Filter::parse(&String::from_utf8(plaintext).map_err(|_| QueryError::InvalidPageToken)?)
                .map_err(|_| QueryError::InvalidPageToken)?;

        Ok(Self::PageToken {
            filter: page_filter,
        })
    }

    fn build_next<T: SchemaMapped>(
        &self,
        filter: &Filter,
        ordering: &Ordering,
        next_item: &T,
    ) -> QueryResult<String> {
        let page_filter = get_page_filter(ordering, next_item);
        if page_filter.is_empty() {
            return Err(QueryError::PageTokenFailure);
        }
        let plaintext = page_filter.to_string();

        let key = make_page_key::<32>(filter, ordering);
        let key: &Key<Aes256Gcm> = (&key).into();

        let cipher = Aes256Gcm::new(key);
        // 96-bits; unique per message
        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
        let mut encrypted = cipher.encrypt(&nonce, plaintext.as_bytes()).unwrap();
        // Prepend nonce to encrypted buffer
        encrypted.splice(0..0, nonce);

        if self.url_safe {
            Ok(Base64Url::encode_string(&encrypted))
        } else {
            Ok(Base64::encode_string(&encrypted))
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::testing::schema::UserItem;

    use super::*;

    #[test]
    fn it_works() {
        let b = Aes256PageTokenBuilder::new(true);
        let filter = Filter::parse(r#"displayName = "John""#).unwrap();
        let ordering = Ordering::parse("id desc, age desc").unwrap();
        let page_token = b
            .build_next(
                &filter,
                &ordering,
                &UserItem {
                    id: "1337".into(),
                    display_name: "John".into(),
                    age: 14000,
                },
            )
            .unwrap();
        assert!(page_token.trim().len() > NONCE_LENGTH);
        assert!(!page_token.contains(r#"id <= "1337""#));
        let parsed = b.parse(&filter, &ordering, &page_token).unwrap();
        assert_eq!(
            parsed.filter.to_string(),
            r#"id <= "1337" AND age <= 14000"#
        );
    }

    #[test]
    fn errors() {
        let b = Aes256PageTokenBuilder::new(true);

        // Generate key for different parameters
        let filter = Filter::parse("id=1").unwrap();
        let ordering = Ordering::parse("age desc").unwrap();
        let page_token = b
            .build_next(
                &filter,
                &ordering,
                &UserItem {
                    id: "1337".into(),
                    display_name: "John".into(),
                    age: 14000,
                },
            )
            .unwrap();
        let parsed = b.parse(&filter, &ordering, &page_token).unwrap();
        assert_eq!(parsed.filter.to_string(), "age <= 14000");
        assert_eq!(
            b.parse(
                &Filter::parse("id=2").unwrap(),
                &Ordering::parse("age desc").unwrap(),
                &page_token
            )
            .unwrap_err(),
            QueryError::InvalidPageToken
        );
        assert_eq!(
            b.parse(
                &Filter::parse("id=1").unwrap(),
                &Ordering::parse("age asc").unwrap(),
                &page_token
            )
            .unwrap_err(),
            QueryError::InvalidPageToken
        );
    }
}