use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD};
use serde::{Deserialize, Serialize};
use tracing::error;
use crate::{
code::Code,
config::{ClientID, ClientSecret, Config, RedirectURI, TokenEndPoint},
error::Error,
nonce::Nonce,
refresh_token::RefreshToken,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IDToken {
pub iss: String,
pub aud: String,
pub sub: String,
pub azp: Option<String>,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub name: Option<String>,
pub picture: Option<String>,
pub at_hash: Option<String>,
pub iat: u32,
pub exp: u32,
pub nonce: Option<Nonce>,
}
impl IDToken {
pub fn from_id_token_raw(id_token: &IDTokenRaw) -> Result<Self, Error> {
let split: Vec<_> = id_token.0.split(".").collect();
if split.len() != 3 {
return Err(Error::Decode);
}
let bytes = BASE64_URL_SAFE_NO_PAD.decode(split[1]).map_err(|e| {
error!("Failed to decode IDToken: {}", e);
Error::Decode
})?;
let id_token = serde_json::from_slice::<IDToken>(&bytes).map_err(|e| {
error!("Failed to deserialize IDToken: {}", e);
Error::Deserialize
})?;
Ok(id_token)
}
}
#[derive(Debug, Clone)]
pub struct IDTokenRequest<'a> {
token_endpoint: &'a TokenEndPoint,
code: Code,
client_id: &'a ClientID,
client_secret: &'a ClientSecret,
redirect_uri: &'a RedirectURI,
grant_type: &'a str,
}
impl<'a> IDTokenRequest<'a> {
pub fn new(config: &'a Config, code: Code) -> Self {
Self {
token_endpoint: config.token_endpoint(),
code,
client_id: config.client_id(),
client_secret: config.client_secret(),
redirect_uri: config.redirect_uri(),
grant_type: "authorization_code",
}
}
pub fn token_endpoint(&self) -> &TokenEndPoint {
self.token_endpoint
}
pub fn code(&self) -> &Code {
&self.code
}
pub fn client_id(&self) -> &ClientID {
self.client_id
}
pub fn client_secret(&self) -> &ClientSecret {
self.client_secret
}
pub fn redirect_uri(&self) -> &RedirectURI {
self.redirect_uri
}
pub fn grant_type(&self) -> &str {
self.grant_type
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct IDTokenResponse {
access_token: AccessToken,
expires_in: u32,
id_token: IDTokenRaw,
scope: String,
token_type: String,
refresh_token: Option<RefreshToken>,
}
impl IDTokenResponse {
pub fn access_token(&self) -> &AccessToken {
&self.access_token
}
pub fn expires_in(&self) -> u32 {
self.expires_in
}
pub fn id_token(&self) -> &IDTokenRaw {
&self.id_token
}
pub fn scope(&self) -> &str {
&self.scope
}
pub fn token_type(&self) -> &str {
&self.token_type
}
pub fn refresh_token(&self) -> &Option<RefreshToken> {
&self.refresh_token
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AccessToken(pub(crate) String);
impl AccessToken {
pub fn value(&self) -> String {
self.0.clone()
}
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct IDTokenRaw(String);
pub async fn send_id_token_req(req: &IDTokenRequest<'_>) -> Result<IDTokenResponse, Error> {
use reqwest::Client;
use std::collections::HashMap;
use url::Url;
let url = Url::parse(req.token_endpoint().value()).map_err(|e| {
error!("Failed to parse url: {:?}", e);
Error::ParseURL
})?;
let mut params = HashMap::new();
params.insert("code", req.code().0.as_str());
params.insert("client_id", req.client_id().value());
params.insert("client_secret", req.client_secret().value());
params.insert("redirect_uri", req.redirect_uri().value());
params.insert("grant_type", req.grant_type());
let client = Client::new();
let res = client
.post(url)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(¶ms)
.send()
.await
.map_err(|e| {
error!("Failed to send request: {:?}", e);
Error::Send
})?;
if !res.status().is_success() {
return Err(Error::SendStatus(res.status()));
}
let res_json = res.json::<IDTokenResponse>().await.map_err(|e| {
error!("Failed to deserialize JSON: {:?}", e);
Error::DeserializeJson
})?;
Ok(res_json)
}
#[cfg(test)]
mod tests {
use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD};
use crate::{
code::Code,
config::ConfigBuilder,
error::Error,
id_token::{AccessToken, IDToken, IDTokenRaw, IDTokenRequest, IDTokenResponse},
refresh_token::RefreshToken,
};
#[test]
fn test_access_token_value() {
let token = AccessToken("test_token".to_string());
assert_eq!(token.value(), "test_token");
}
#[test]
fn test_id_token_decode_success() {
let id_token_json = r#"{
"iss": "https://accounts.google.com",
"aud": "my_aud",
"sub": "my_sub",
"azp": "my_azp",
"email": "email@gmail.com",
"email_verified": true,
"given_name": "my_given_name",
"family_name": "my_family_name",
"name": "my_name",
"picture": "https://picture.example.com",
"at_hash": "my_at_hash",
"iat": 1742189616,
"exp": 1742193216,
"nonce": "my_nonce"
}"#;
let encoded = BASE64_URL_SAFE_NO_PAD.encode(id_token_json);
let mut token_raw = "header.".to_string();
token_raw.push_str(&encoded);
token_raw.push_str(".signature");
let id_token_raw = IDTokenRaw(token_raw);
let decoded = IDToken::from_id_token_raw(&id_token_raw);
assert!(decoded.is_ok());
}
#[test]
fn test_id_token_decode_invalid_base64() {
let id_token_raw = IDTokenRaw("invalid_base64".to_string());
let decoded = IDToken::from_id_token_raw(&id_token_raw);
assert!(matches!(decoded, Err(Error::Decode)));
}
#[test]
fn test_id_token_decode_invalid_json() {
let invalid_json = BASE64_URL_SAFE_NO_PAD.encode("not a valid json");
let id_token_raw = IDTokenRaw(invalid_json);
let decoded = IDToken::from_id_token_raw(&id_token_raw);
assert!(matches!(decoded, Err(Error::Decode)));
}
#[test]
fn test_id_token_request_new() {
let config = ConfigBuilder::new()
.token_endpoint("https://token.example.com")
.client_id("client_id")
.client_secret("secret")
.redirect_uri("https://redirect.example.com")
.build();
let code = Code("auth_code".to_string());
let request = IDTokenRequest::new(&config, code.clone());
assert_eq!(request.token_endpoint.0, "https://token.example.com");
assert_eq!(request.client_id.0, "client_id");
assert_eq!(request.client_secret.0, "secret");
assert_eq!(request.redirect_uri.0, "https://redirect.example.com");
assert_eq!(request.code, code);
}
#[test]
fn test_id_token_response_getters() {
let access_token = AccessToken("access_token_value".to_string());
let id_token_raw = IDTokenRaw("id_token_value".to_string());
let refresh_token = Some(RefreshToken("refresh_token_value".to_string()));
let response = IDTokenResponse {
access_token: access_token.clone(),
expires_in: 3600,
id_token: id_token_raw.clone(),
scope: "openid email".to_string(),
token_type: "Bearer".to_string(),
refresh_token: refresh_token.clone(),
};
assert_eq!(response.access_token(), &access_token);
assert_eq!(response.expires_in(), 3600);
assert_eq!(response.id_token(), &id_token_raw);
assert_eq!(response.scope(), "openid email");
assert_eq!(response.token_type(), "Bearer");
assert_eq!(response.refresh_token(), &refresh_token);
}
}