use super::grant::{Value, Extensions, Grant};
use super::{Url, Time};
use super::scope::Scope;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
use hmac::{digest::CtOutput, Mac, Hmac};
use rand::{rngs::OsRng, RngCore, thread_rng};
use serde::{Deserialize, Serialize};
use rmp_serde;
pub trait TagGrant {
fn tag(&mut self, usage: u64, grant: &Grant) -> Result<String, ()>;
}
pub struct RandomGenerator {
random: OsRng,
len: usize,
}
impl RandomGenerator {
pub fn new(length: usize) -> RandomGenerator {
RandomGenerator {
random: OsRng {},
len: length,
}
}
fn generate(&self) -> String {
let mut result = vec![0; self.len];
let mut rnd = self.random;
rnd.try_fill_bytes(result.as_mut_slice())
.expect("Failed to generate random token");
STANDARD.encode(result)
}
}
pub struct Assertion {
hasher: Hmac<sha2::Sha256>,
}
#[non_exhaustive]
pub enum AssertionKind {
HmacSha256,
}
#[derive(Serialize, Deserialize)]
struct SerdeAssertionGrant {
owner_id: String,
client_id: String,
#[serde(with = "scope_serde")]
scope: Scope,
#[serde(with = "url_serde")]
redirect_uri: Url,
#[serde(with = "time_serde")]
until: Time,
public_extensions: HashMap<String, Option<String>>,
}
#[derive(Serialize, Deserialize)]
struct AssertGrant(Vec<u8>, Vec<u8>);
pub struct TaggedAssertion<'a>(&'a Assertion, &'a str);
impl Assertion {
pub fn new(kind: AssertionKind, key: &[u8]) -> Self {
match kind {
AssertionKind::HmacSha256 => Assertion {
hasher: Hmac::<sha2::Sha256>::new_from_slice(key).unwrap(),
},
}
}
pub fn ephemeral() -> Self {
let mut rand_bytes: [u8; 32] = [0; 32];
thread_rng().fill_bytes(&mut rand_bytes);
Assertion {
hasher: Hmac::<sha2::Sha256>::new_from_slice(&rand_bytes).unwrap(),
}
}
pub fn tag<'a>(&'a self, tag: &'a str) -> TaggedAssertion<'a> {
TaggedAssertion(self, tag)
}
fn extract<'a>(&self, token: &'a str) -> Result<(Grant, String), ()> {
let decoded = STANDARD.decode(token).map_err(|_| ())?;
let assertion: AssertGrant = rmp_serde::from_slice(&decoded).map_err(|_| ())?;
let mut hasher = self.hasher.clone();
hasher.update(&assertion.0);
hasher.verify_slice(assertion.1.as_slice()).map_err(|_| ())?;
let (_, serde_grant, tag): (u64, SerdeAssertionGrant, String) =
rmp_serde::from_slice(&assertion.0).map_err(|_| ())?;
Ok((serde_grant.grant(), tag))
}
fn signature(&self, data: &[u8]) -> CtOutput<hmac::Hmac<sha2::Sha256>> {
let mut hasher = self.hasher.clone();
hasher.update(data);
hasher.finalize()
}
fn counted_signature(&self, counter: u64, grant: &Grant) -> Result<String, ()> {
let serde_grant = SerdeAssertionGrant::try_from(grant)?;
let tosign = rmp_serde::to_vec(&(serde_grant, counter)).unwrap();
let signature = self.signature(&tosign);
Ok(STANDARD.encode(signature.into_bytes()))
}
fn generate_tagged(&self, counter: u64, grant: &Grant, tag: &str) -> Result<String, ()> {
let serde_grant = SerdeAssertionGrant::try_from(grant)?;
let tosign = rmp_serde::to_vec(&(counter, serde_grant, tag)).unwrap();
let signature = self.signature(&tosign);
let assert = AssertGrant(tosign, signature.into_bytes().to_vec());
Ok(STANDARD.encode(rmp_serde::to_vec(&assert).unwrap()))
}
}
impl<'a> TaggedAssertion<'a> {
pub fn sign(&self, counter: u64, grant: &Grant) -> Result<String, ()> {
self.0.generate_tagged(counter, grant, self.1)
}
pub fn extract<'b>(&self, token: &'b str) -> Result<Grant, ()> {
self.0
.extract(token)
.and_then(|(token, tag)| if tag == self.1 { Ok(token) } else { Err(()) })
}
}
impl<'a, T: TagGrant + ?Sized + 'a> TagGrant for Box<T> {
fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
(&mut **self).tag(counter, grant)
}
}
impl<'a, T: TagGrant + ?Sized + 'a> TagGrant for &'a mut T {
fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
(&mut **self).tag(counter, grant)
}
}
impl TagGrant for RandomGenerator {
fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
Ok(self.generate())
}
}
impl<'a> TagGrant for &'a RandomGenerator {
fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
Ok(self.generate())
}
}
impl TagGrant for Rc<RandomGenerator> {
fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
Ok(self.generate())
}
}
impl TagGrant for Arc<RandomGenerator> {
fn tag(&mut self, _: u64, _: &Grant) -> Result<String, ()> {
Ok(self.generate())
}
}
impl TagGrant for Assertion {
fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
self.counted_signature(counter, grant)
}
}
impl<'a> TagGrant for &'a Assertion {
fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
self.counted_signature(counter, grant)
}
}
impl TagGrant for Rc<Assertion> {
fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
self.counted_signature(counter, grant)
}
}
impl TagGrant for Arc<Assertion> {
fn tag(&mut self, counter: u64, grant: &Grant) -> Result<String, ()> {
self.counted_signature(counter, grant)
}
}
mod scope_serde {
use crate::primitives::scope::Scope;
use serde::ser::{Serializer};
use serde::de::{Deserialize, Deserializer, Error};
pub fn serialize<S: Serializer>(scope: &Scope, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&scope.to_string())
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Scope, D::Error> {
let as_string: &str = <&str>::deserialize(deserializer)?;
as_string.parse().map_err(Error::custom)
}
}
mod url_serde {
use super::Url;
use serde::ser::{Serializer};
use serde::de::{Deserialize, Deserializer, Error};
pub fn serialize<S: Serializer>(url: &Url, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&url.to_string())
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Url, D::Error> {
let as_string: &str = <&str>::deserialize(deserializer)?;
as_string.parse().map_err(Error::custom)
}
}
mod time_serde {
use super::Time;
use chrono::{TimeZone, Utc};
use serde::ser::{Serializer};
use serde::de::{Deserialize, Deserializer};
pub fn serialize<S: Serializer>(time: &Time, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_i64(time.timestamp())
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Time, D::Error> {
let as_timestamp: i64 = <i64>::deserialize(deserializer)?;
Ok(Utc.timestamp_opt(as_timestamp, 0).unwrap())
}
}
impl SerdeAssertionGrant {
fn try_from(grant: &Grant) -> Result<Self, ()> {
let mut public_extensions: HashMap<String, Option<String>> = HashMap::new();
if grant.extensions.private().any(|_| true) {
return Err(());
}
for (name, content) in grant.extensions.public() {
public_extensions.insert(name.to_string(), content.map(str::to_string));
}
Ok(SerdeAssertionGrant {
owner_id: grant.owner_id.clone(),
client_id: grant.client_id.clone(),
scope: grant.scope.clone(),
redirect_uri: grant.redirect_uri.clone(),
until: grant.until,
public_extensions,
})
}
fn grant(self) -> Grant {
let mut extensions = Extensions::new();
for (name, content) in self.public_extensions.into_iter() {
extensions.set_raw(name, Value::public(content))
}
Grant {
owner_id: self.owner_id,
client_id: self.client_id,
scope: self.scope,
redirect_uri: self.redirect_uri,
until: self.until,
extensions,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(dead_code, unused)]
fn assert_send_sync_static() {
fn uses<T: Send + Sync + 'static>(arg: T) {}
let _ = uses(RandomGenerator::new(16));
let fake_key = [0u8; 16];
let _ = uses(Assertion::new(AssertionKind::HmacSha256, &fake_key));
}
}