1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
use crate::base64_decode; use crate::error::Error; use crate::key_provider::GoogleKeyProvider; use crate::key_provider::KeyProvider; use crate::token::IdPayload; use crate::token::RequiredClaims; use crate::token::Token; use serde::Deserialize; use serde_derive::{Deserialize, Serialize}; use serde_json; use std::sync::{Arc, Mutex}; use std::time::SystemTime; use std::time::UNIX_EPOCH; pub struct ClientBuilder { client_id: String, key_provider: Arc<Mutex<dyn KeyProvider + Send>>, check_expiration: bool, } impl ClientBuilder { pub fn new(client_id: &str) -> ClientBuilder { ClientBuilder { client_id: client_id.to_owned(), key_provider: Arc::new(Mutex::new(GoogleKeyProvider::new())), check_expiration: true, } } pub fn custom_key_provider<T: KeyProvider + Send + 'static>(mut self, provider: T) -> Self { self.key_provider = Arc::new(Mutex::new(provider)); self } pub fn unsafe_ignore_expiration(mut self) -> Self { self.check_expiration = false; self } pub fn build(self) -> Client { Client { client_id: self.client_id, key_provider: self.key_provider, check_expiration: self.check_expiration, } } } pub struct Client { client_id: String, key_provider: Arc<Mutex<dyn KeyProvider + Send>>, check_expiration: bool, } impl Client { pub fn builder(client_id: &str) -> ClientBuilder { ClientBuilder::new(client_id) } pub fn new(client_id: &str) -> Client { ClientBuilder::new(client_id).build() } pub fn verify_token_with_payload<P>(&self, token_string: &str) -> Result<Token<P>, Error> where for<'a> P: Deserialize<'a>, { let mut segments = token_string.split('.'); let encoded_header = segments.next().ok_or(Error::InvalidToken)?; let encoded_payload = segments.next().ok_or(Error::InvalidToken)?; let encoded_signature = segments.next().ok_or(Error::InvalidToken)?; let header: Header = serde_json::from_slice(&base64_decode(&encoded_header)?)?; let key = match self.key_provider.lock().unwrap().get_key(&header.key_id) { Ok(Some(key)) => key, Ok(None) => return Err(Error::InvalidToken), Err(_) => return Err(Error::RetrieveKeyFailure), }; let signed_body = format!("{}.{}", encoded_header, encoded_payload); let signature = base64_decode(&encoded_signature)?; key.verify(signed_body.as_bytes(), &signature)?; let payload = base64_decode(&encoded_payload)?; let claims: RequiredClaims = serde_json::from_slice(&payload)?; if claims.get_audience() != self.client_id { return Err(Error::InvalidToken); } let issuer = claims.get_issuer(); if issuer != "https://accounts.google.com" && issuer != "accounts.google.com" { return Err(Error::InvalidToken); } let current_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); if self.check_expiration { if claims.get_expires_at() < current_timestamp { return Err(Error::Expired); } } if claims.get_issued_at() > claims.get_expires_at() { return Err(Error::InvalidToken); } let decoded_payload: P = serde_json::from_slice(&payload)?; Ok(Token::new(claims, decoded_payload)) } pub fn verify_token(&self, token_string: &str) -> Result<Token<()>, Error> { self.verify_token_with_payload::<()>(token_string) } pub fn verify_id_token(&self, token_string: &str) -> Result<Token<IdPayload>, Error> { self.verify_token_with_payload(token_string) } } #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] pub struct Header { #[serde(rename = "kid")] key_id: String, }