1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use crate::base64_decode;
use crate::error::Error;
use crate::key_provider::GoogleKeyProvider;
use crate::key_provider::KeyProvider;
use crate::token::IdPayload;
use crate::token::RequiredClaims;
use crate::token::Token;
use serde::Deserialize;
use serde_derive::{Deserialize, Serialize};
use serde_json;
use std::sync::{Arc, Mutex};
use std::time::SystemTime;
use std::time::UNIX_EPOCH;

pub struct ClientBuilder {
    client_id: String,
    key_provider: Arc<Mutex<dyn KeyProvider + Send>>,
    check_expiration: bool,
}

impl ClientBuilder {
    pub fn new(client_id: &str) -> ClientBuilder {
        ClientBuilder {
            client_id: client_id.to_owned(),
            key_provider: Arc::new(Mutex::new(GoogleKeyProvider::new())),
            check_expiration: true,
        }
    }
    pub fn custom_key_provider<T: KeyProvider + Send + 'static>(mut self, provider: T) -> Self {
        self.key_provider = Arc::new(Mutex::new(provider));
        self
    }

    pub fn unsafe_ignore_expiration(mut self) -> Self {
        self.check_expiration = false;
        self
    }

    pub fn build(self) -> Client {
        Client {
            client_id: self.client_id,
            key_provider: self.key_provider,
            check_expiration: self.check_expiration,
        }
    }
}

pub struct Client {
    client_id: String,
    key_provider: Arc<Mutex<dyn KeyProvider + Send>>,
    check_expiration: bool,
}

impl Client {
    pub fn builder(client_id: &str) -> ClientBuilder {
        ClientBuilder::new(client_id)
    }

    pub fn new(client_id: &str) -> Client {
        ClientBuilder::new(client_id).build()
    }

    pub fn verify_token_with_payload<P>(&self, token_string: &str) -> Result<Token<P>, Error>
    where
        for<'a> P: Deserialize<'a>,
    {
        let mut segments = token_string.split('.');
        let encoded_header = segments.next().ok_or(Error::InvalidToken)?;
        let encoded_payload = segments.next().ok_or(Error::InvalidToken)?;
        let encoded_signature = segments.next().ok_or(Error::InvalidToken)?;

        let header: Header = serde_json::from_slice(&base64_decode(&encoded_header)?)?;

        let key = match self.key_provider.lock().unwrap().get_key(&header.key_id) {
            Ok(Some(key)) => key,
            Ok(None) => return Err(Error::InvalidToken),
            Err(_) => return Err(Error::RetrieveKeyFailure),
        };
        let signed_body = format!("{}.{}", encoded_header, encoded_payload);
        let signature = base64_decode(&encoded_signature)?;
        key.verify(signed_body.as_bytes(), &signature)?;
        let payload = base64_decode(&encoded_payload)?;
        let claims: RequiredClaims = serde_json::from_slice(&payload)?;

        if claims.get_audience() != self.client_id {
            return Err(Error::InvalidToken);
        }
        let issuer = claims.get_issuer();
        if issuer != "https://accounts.google.com" && issuer != "accounts.google.com" {
            return Err(Error::InvalidToken);
        }
        let current_timestamp = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs();
        if self.check_expiration {
            if claims.get_expires_at() < current_timestamp {
                return Err(Error::Expired);
            }
        }
        if claims.get_issued_at() > claims.get_expires_at() {
            return Err(Error::InvalidToken);
        }
        let decoded_payload: P = serde_json::from_slice(&payload)?;
        Ok(Token::new(claims, decoded_payload))
    }

    pub fn verify_token(&self, token_string: &str) -> Result<Token<()>, Error> {
        self.verify_token_with_payload::<()>(token_string)
    }

    pub fn verify_id_token(&self, token_string: &str) -> Result<Token<IdPayload>, Error> {
        self.verify_token_with_payload(token_string)
    }
}

#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
pub struct Header {
    #[serde(rename = "kid")]
    key_id: String,
}