1use super::grant::{Value, Extensions, Grant};
14use super::{Url, Time};
15use super::scope::Scope;
16
17use std::collections::HashMap;
18use std::rc::Rc;
19use std::sync::Arc;
20
21use base64::Engine;
22use base64::engine::general_purpose::STANDARD;
23use hmac::{digest::CtOutput, Mac, Hmac};
24use rand::{rngs::OsRng, RngCore, thread_rng};
25use serde::{Deserialize, Serialize};
26use rmp_serde;
27
28pub trait TagGrant {
45 fn tag(&mut self, usage: u64, grant: &Grant) -> Result<String, ()>;
47}
48
49pub struct RandomGenerator {
54 random: OsRng,
55 len: usize,
56}
57
58impl RandomGenerator {
59 pub fn new(length: usize) -> RandomGenerator {
61 RandomGenerator {
62 random: OsRng {},
63 len: length,
64 }
65 }
66
67 fn generate(&self) -> String {
68 let mut result = vec![0; self.len];
69 let mut rnd = self.random;
70 rnd.try_fill_bytes(result.as_mut_slice())
71 .expect("Failed to generate random token");
72
73 STANDARD.encode(result)
74 }
75}
76
77pub struct Assertion {
87 hasher: Hmac<sha2::Sha256>,
88}
89
90#[non_exhaustive]
92pub enum AssertionKind {
93 HmacSha256,
98}
99
100#[derive(Serialize, Deserialize)]
101struct SerdeAssertionGrant {
102 owner_id: String,
104
105 client_id: String,
107
108 #[serde(with = "scope_serde")]
110 scope: Scope,
111
112 #[serde(with = "url_serde")]
115 redirect_uri: Url,
116
117 #[serde(with = "time_serde")]
119 until: Time,
120
121 public_extensions: HashMap<String, Option<String>>,
123}
124
125#[derive(Serialize, Deserialize)]
126struct AssertGrant(Vec<u8>, Vec<u8>);
127
128pub struct TaggedAssertion<'a>(&'a Assertion, &'a str);
130
131impl Assertion {
132 pub fn new(kind: AssertionKind, key: &[u8]) -> Self {
144 match kind {
145 AssertionKind::HmacSha256 => Assertion {
146 hasher: Hmac::<sha2::Sha256>::new_from_slice(key).unwrap(),
147 },
148 }
149 }
150
151 pub fn ephemeral() -> Self {
153 let mut rand_bytes: [u8; 32] = [0; 32];
155 thread_rng().fill_bytes(&mut rand_bytes);
156 Assertion {
157 hasher: Hmac::<sha2::Sha256>::new_from_slice(&rand_bytes).unwrap(),
158 }
159 }
160
161 pub fn tag<'a>(&'a self, tag: &'a str) -> TaggedAssertion<'a> {
163 TaggedAssertion(self, tag)
164 }
165
166 fn extract<'a>(&self, token: &'a str) -> Result<(Grant, String), ()> {
167 let decoded = STANDARD.decode(token).map_err(|_| ())?;
168 let assertion: AssertGrant = rmp_serde::from_slice(&decoded).map_err(|_| ())?;
169
170 let mut hasher = self.hasher.clone();
171 hasher.update(&assertion.0);
172 hasher.verify_slice(assertion.1.as_slice()).map_err(|_| ())?;
173
174 let (_, serde_grant, tag): (u64, SerdeAssertionGrant, String) =
175 rmp_serde::from_slice(&assertion.0).map_err(|_| ())?;
176
177 Ok((serde_grant.grant(), tag))
178 }
179
180 fn signature(&self, data: &[u8]) -> CtOutput<hmac::Hmac<sha2::Sha256>> {
181 let mut hasher = self.hasher.clone();
182 hasher.update(data);
183 hasher.finalize()
184 }
185
186 fn counted_signature(&self, counter: u64, grant: &Grant) -> Result<String, ()> {
187 let serde_grant = SerdeAssertionGrant::try_from(grant)?;
188 let tosign = rmp_serde::to_vec(&(serde_grant, counter)).unwrap();
189 let signature = self.signature(&tosign);
190 Ok(STANDARD.encode(signature.into_bytes()))
191 }
192
193 fn generate_tagged(&self, counter: u64, grant: &Grant, tag: &str) -> Result<String, ()> {
194 let serde_grant = SerdeAssertionGrant::try_from(grant)?;
195 let tosign = rmp_serde::to_vec(&(counter, serde_grant, tag)).unwrap();
196 let signature = self.signature(&tosign);
197 let assert = AssertGrant(tosign, signature.into_bytes().to_vec());
198
199 Ok(STANDARD.encode(rmp_serde::to_vec(&assert).unwrap()))
200 }
201}
202
203impl<'a> TaggedAssertion<'a> {
204 pub fn sign(&self, counter: u64, grant: &Grant) -> Result<String, ()> {
212 self.0.generate_tagged(counter, grant, self.1)
213 }
214
215 pub fn extract<'b>(&self, token: &'b str) -> Result<Grant, ()> {
220 self.0
221 .extract(token)
222 .and_then(|(token, tag)| if tag == self.1 { Ok(token) } else { Err(()) })
223 }
224}
225
226impl<'a, T: TagGrant + ?Sized + 'a> TagGrant for Box<T> {
227 fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
228 (&mut **self).tag(counter, grant)
229 }
230}
231
232impl<'a, T: TagGrant + ?Sized + 'a> TagGrant for &'a mut T {
233 fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
234 (&mut **self).tag(counter, grant)
235 }
236}
237
238impl TagGrant for RandomGenerator {
239 fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
240 Ok(self.generate())
241 }
242}
243
244impl<'a> TagGrant for &'a RandomGenerator {
245 fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
246 Ok(self.generate())
247 }
248}
249
250impl TagGrant for Rc<RandomGenerator> {
251 fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
252 Ok(self.generate())
253 }
254}
255
256impl TagGrant for Arc<RandomGenerator> {
257 fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
258 Ok(self.generate())
259 }
260}
261
262impl TagGrant for Assertion {
263 fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
264 self.counted_signature(counter, grant)
265 }
266}
267
268impl<'a> TagGrant for &'a Assertion {
269 fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
270 self.counted_signature(counter, grant)
271 }
272}
273
274impl TagGrant for Rc<Assertion> {
275 fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
276 self.counted_signature(counter, grant)
277 }
278}
279
280impl TagGrant for Arc<Assertion> {
281 fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
282 self.counted_signature(counter, grant)
283 }
284}
285
286mod scope_serde {
287 use crate::primitives::scope::Scope;
288
289 use serde::ser::{Serializer};
290 use serde::de::{Deserialize, Deserializer, Error};
291
292 pub fn serialize<S: Serializer>(scope: &Scope, serializer: S) -> Result<S::Ok, S::Error> {
293 serializer.serialize_str(&scope.to_string())
294 }
295
296 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Scope, D::Error> {
297 let as_string: &str = <&str>::deserialize(deserializer)?;
298 as_string.parse().map_err(Error::custom)
299 }
300}
301
302mod url_serde {
303 use super::Url;
304
305 use serde::ser::{Serializer};
306 use serde::de::{Deserialize, Deserializer, Error};
307
308 pub fn serialize<S: Serializer>(url: &Url, serializer: S) -> Result<S::Ok, S::Error> {
309 serializer.serialize_str(&url.to_string())
310 }
311
312 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Url, D::Error> {
313 let as_string: &str = <&str>::deserialize(deserializer)?;
314 as_string.parse().map_err(Error::custom)
315 }
316}
317
318mod time_serde {
319 use super::Time;
320 use chrono::{TimeZone, Utc};
321
322 use serde::ser::{Serializer};
323 use serde::de::{Deserialize, Deserializer};
324
325 pub fn serialize<S: Serializer>(time: &Time, serializer: S) -> Result<S::Ok, S::Error> {
326 serializer.serialize_i64(time.timestamp())
327 }
328
329 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Time, D::Error> {
330 let as_timestamp: i64 = <i64>::deserialize(deserializer)?;
331 Ok(Utc.timestamp_opt(as_timestamp, 0).unwrap())
332 }
333}
334
335impl SerdeAssertionGrant {
336 fn try_from(grant: &Grant) -> Result<Self, ()> {
337 let mut public_extensions: HashMap<String, Option<String>> = HashMap::new();
338
339 if grant.extensions.private().any(|_| true) {
340 return Err(());
341 }
342
343 for (name, content) in grant.extensions.public() {
344 public_extensions.insert(name.to_string(), content.map(str::to_string));
345 }
346
347 Ok(SerdeAssertionGrant {
348 owner_id: grant.owner_id.clone(),
349 client_id: grant.client_id.clone(),
350 scope: grant.scope.clone(),
351 redirect_uri: grant.redirect_uri.clone(),
352 until: grant.until,
353 public_extensions,
354 })
355 }
356
357 fn grant(self) -> Grant {
358 let mut extensions = Extensions::new();
359 for (name, content) in self.public_extensions.into_iter() {
360 extensions.set_raw(name, Value::public(content))
361 }
362 Grant {
363 owner_id: self.owner_id,
364 client_id: self.client_id,
365 scope: self.scope,
366 redirect_uri: self.redirect_uri,
367 until: self.until,
368 extensions,
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 #[allow(dead_code, unused)]
379 fn assert_send_sync_static() {
380 fn uses<T: Send + Sync + 'static>(arg: T) {}
381 let _ = uses(RandomGenerator::new(16));
382 let fake_key = [0u8; 16];
383 let _ = uses(Assertion::new(AssertionKind::HmacSha256, &fake_key));
384 }
385}