gauth 0.10.2

HTTP Client for Google OAuth2
Documentation
//! OAuth2 for installed apps (three-legged consent flow).
//!
//! See [`Auth`] for the entry point. Construct via [`Auth::from_file`] with
//! a Google API Console credentials JSON and a list of scopes, then call
//! [`Auth::access_token`] to retrieve a bearer token. The first call walks
//! the consent flow (either via the supplied [`AuthHandlerFn`] or the
//! default `stdin` handler); subsequent calls reuse a cached refresh
//! token. The token is cached on disk under `$HOME/.{app_name}/` (or
//! `GAUTH_TOKEN_DIR` if set) and atomically replaced on refresh.

use std::path::PathBuf;
use std::result::Result as StdResult;
use std::{env, path};

use anyhow::Error as AnyError;
use chrono::Utc;
use reqwest::Client as HttpClient;
use serde_derive::{Deserialize, Serialize};

mod credentials;
mod errors;

use errors::{AuthError, Result};

const GRANT_TYPE: &str = "authorization_code";
const DEFAULT_APP_NAME: &str = "gauth_app";
const TOKEN_DIR: &str = "GAUTH_TOKEN_DIR";
const GOOGLE_VALIDATE_HOST: &str = "https://www.googleapis.com";

/// AuthHandler is a type that describes a function that takes a consent uri and returns an auth code
type AuthHandler = Box<dyn AuthHandlerFn>;

/// AuthHandlerFn is a function that takes a consent uri and returns an auth code
pub trait AuthHandlerFn: Fn(String) -> StdResult<String, AnyError> + 'static + Send + Sync {}

impl<H> AuthHandlerFn for H where
    H: Fn(String) -> StdResult<String, AnyError> + 'static + Send + Sync
{
}

/// Auth struct represents an auth instance
pub struct Auth {
    app_name: String,

    auth_handler: Option<AuthHandler>,
    oauth_creds: credentials::OauthCredentials,
    consent_uri: String,

    token_validate_host: String,
    http_client: HttpClient,
}

/// Access token
#[derive(Debug, Deserialize, Serialize)]
struct Token {
    access_token: String,
    expires_in: u64,
    refresh_token: Option<String>,
    scope: Option<String>,
    token_type: String,

    expires_at: Option<u64>,
}

impl Token {
    fn bearer_token(&self) -> String {
        format!("{} {}", self.token_type, self.access_token)
    }

    fn is_expired(&self) -> bool {
        match self.expires_at {
            Some(expires_at) => expires_at < Utc::now().timestamp() as u64,
            None => true,
        }
    }

    fn set_expires_at(&mut self) {
        self.expires_at = Some(Utc::now().timestamp() as u64 + self.expires_in);
    }
}

impl Auth {
    /// Creates a new auth instance from a key file and scopes
    pub fn from_file(key_path: &str, scopes: Vec<&str>) -> Result<Self> {
        let kp = path::Path::new(key_path);
        let oauth_creds = credentials::read_oauth_config(kp)?.installed;

        let scope = scopes.join(" ");
        let consent_uri = credentials::auth_code_uri_str(&oauth_creds, &scope)?;

        Ok(Self {
            app_name: DEFAULT_APP_NAME.to_owned(),
            auth_handler: None,
            oauth_creds,
            consent_uri,
            token_validate_host: GOOGLE_VALIDATE_HOST.to_owned(),
            http_client: HttpClient::new(),
        })
    }

    /// App_name can be used to override the default app name
    pub fn app_name(mut self, app_name: &str) -> Self {
        self.app_name = app_name.to_owned();
        self
    }

    /// Handler can be used to override the default auth handler
    pub fn handler<H: AuthHandlerFn>(mut self, handler: H) -> Self {
        self.auth_handler = Some(Box::new(handler));
        self
    }

    async fn generate_new_token(&self) -> Result<Token> {
        let auth_code = match self.auth_handler.as_ref() {
            Some(h) => (h)(self.consent_uri.clone()),
            None => default_auth_handler(self.consent_uri.clone()),
        }?;

        self.exchange_auth_code(auth_code)
            .await
            .and_then(|token| self.cache_token(token))
    }

    /// Returns an access token.
    /// If the access token is not expired, it will return the cached access token.
    /// Otherwise, it will exchange the auth code for an access token.
    pub async fn access_token(&self) -> Result<String> {
        let token = match self.cached_token() {
            Ok(token) => token,
            Err(_) => self.generate_new_token().await?,
        };

        if self.is_token_valid(&token).await {
            return Ok(token.bearer_token());
        }

        self.refresh_token(token)
            .await
            .and_then(|token| self.cache_token(token))
            .map(|token| token.bearer_token())
    }

    /// Synchronous wrapper around [`Self::access_token`], available under the
    /// `app-blocking` feature. Drives the async call to completion via
    /// `futures::executor::block_on` — useful when wiring gauth into a
    /// synchronous integration point (e.g. a tonic interceptor or a sync
    /// trait impl). Must not be called from inside a tokio runtime; for that
    /// case, await [`Self::access_token`] directly.
    #[cfg(feature = "app-blocking")]
    pub fn access_token_blocking(&self) -> Result<String> {
        futures::executor::block_on(self.access_token())
    }

    async fn exchange_auth_code(&self, auth_code: String) -> Result<Token> {
        let req_builder = self
            .http_client
            .post(self.oauth_creds.token_uri.as_str())
            .form(&[
                ("code", auth_code.as_str()),
                ("client_id", self.oauth_creds.client_id.as_str()),
                ("client_secret", self.oauth_creds.client_secret.as_str()),
                ("redirect_uri", self.oauth_creds.redirect_uri()?.as_str()),
                ("grant_type", GRANT_TYPE),
            ]);

        let res = match req_builder.send().await {
            Ok(resp) => resp,
            Err(err) => return Err(AuthError::ReqwestError(err)),
        };

        let token = match res.json::<Token>().await {
            Ok(token) => token,
            Err(err) => return Err(AuthError::ReqwestError(err)),
        };

        Ok(token)
    }

    async fn refresh_token(&self, token: Token) -> Result<Token> {
        let refresh_token_str = token
            .refresh_token
            .as_ref()
            .ok_or(AuthError::RefreshTokenValue)?
            .as_str();

        let req_builder = self
            .http_client
            .post(self.oauth_creds.token_uri.as_str())
            .form(&[
                ("refresh_token", refresh_token_str),
                ("client_id", self.oauth_creds.client_id.as_str()),
                ("client_secret", self.oauth_creds.client_secret.as_str()),
                ("grant_type", "refresh_token"),
            ]);

        let res = match req_builder.send().await {
            Ok(resp) => resp,
            Err(err) => return Err(AuthError::ReqwestError(err)),
        };

        let mut token = match res.json::<Token>().await {
            Ok(token) => token,
            Err(err) => return Err(AuthError::ReqwestError(err)),
        };

        // refresh token is not returned on refresh
        token.refresh_token = Some(refresh_token_str.to_owned());
        Ok(token)
    }

    fn cached_token(&self) -> Result<Token> {
        let token_dir = self.token_dir()?;
        let b = std::fs::read(token_dir.join("access_token.json"))?;
        Ok(serde_json::from_slice::<Token>(&b)?)
    }

    fn cache_token(&self, token: Token) -> Result<Token> {
        let token_dir = self.token_dir()?;

        if !token_dir.exists() {
            std::fs::create_dir_all(&token_dir)?;
        }

        let mut token = token;
        token.set_expires_at();

        let token_path = token_dir.join("access_token.json");
        let b = serde_json::to_vec(&token)?;

        // Atomic write: serialise to a sibling tmp file, fsync, then rename.
        // `rename` on the same filesystem is atomic, so concurrent refreshers
        // can't observe a half-written `access_token.json`. PID + a process-
        // local counter keeps the tmp path unique across both processes and
        // threads.
        let suffix = TMP_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        let tmp_path = token_dir.join(format!(
            "access_token.json.{}.{}.tmp",
            std::process::id(),
            suffix,
        ));
        write_then_rename(&tmp_path, &token_path, &b)?;

        Ok(token)
    }

    fn token_dir(&self) -> Result<PathBuf> {
        if let Ok(token_dir) = env::var(TOKEN_DIR) {
            Ok(PathBuf::from(token_dir))
        } else {
            match dirs::home_dir() {
                Some(d) => Ok(d.join(format!(".{}", self.app_name))),
                None => Err(AuthError::HomeDirError),
            }
        }
    }

    async fn is_token_valid(&self, token: &Token) -> bool {
        if token.is_expired() {
            return false;
        }

        let url = format!(
            "{}/oauth2/v3/tokeninfo?access_token={}",
            self.token_validate_host, token.access_token
        );

        match self.http_client.get(url.as_str()).send().await {
            Ok(resp) => resp.status().is_success(),
            Err(_) => false,
        }
    }
}

static TMP_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);

/// Write `bytes` to `tmp_path`, fsync, then atomically rename it to
/// `final_path`.
///
/// `rename` on the same filesystem is atomic, so a concurrent reader either
/// sees the previous file contents or the new ones — never a half-written
/// blob. The `sync_all` call before the rename also defends against a
/// power-loss / crash between write and rename: on filesystems with weak
/// data-ordering (e.g. ext4 `data=writeback`) the rename's metadata can
/// otherwise hit disk before the new file's data blocks, leaving
/// `access_token.json` referencing a file with zero or partial data after
/// reboot. The tmp file is best-effort removed on any failure so we don't
/// leak `.tmp` files into the cache dir.
fn write_then_rename(tmp_path: &PathBuf, final_path: &PathBuf, bytes: &[u8]) -> Result<()> {
    use std::io::Write;

    let write_and_sync = || -> std::io::Result<()> {
        let mut file = std::fs::File::create(tmp_path)?;
        file.write_all(bytes)?;
        file.sync_all()?;
        Ok(())
    };

    if let Err(err) = write_and_sync() {
        let _ = std::fs::remove_file(tmp_path);
        return Err(err.into());
    }
    if let Err(err) = std::fs::rename(tmp_path, final_path) {
        let _ = std::fs::remove_file(tmp_path);
        return Err(err.into());
    }
    Ok(())
}

fn default_auth_handler(consent_uri: String) -> StdResult<String, AnyError> {
    println!("> open the link in browser\n\n{}\n", consent_uri);
    println!("> enter the auth. code\n");

    let mut auth_code = String::new();
    std::io::stdin().read_line(&mut auth_code)?;

    Ok(auth_code)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::env;

    #[tokio::test]
    async fn test_access_token_success() {
        let mut google = mockito::Server::new_async().await;
        let google_host = google.url();

        google
            .mock("POST", "/token")
            .with_status(200)
            .with_body(r#"{"access_token":"access_token","expires_in":3599,"refresh_token":"refresh_token","scope":"https://www.googleapis.com/auth/drive","token_type":"Bearer"}"#)
            .create_async()
            .await;

        let consent_uri = format!(
            "{}/o/oauth2/auth?client_id=client_id&response_type=code&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&include_granted_scopes=true&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&access_type=offline&state=pass-through+value",
            google_host
        );

        let expected_consent_uri = consent_uri.clone();
        let auth_handler = move |auth_consent_uri: String| -> StdResult<String, AnyError> {
            assert_eq!(auth_consent_uri, expected_consent_uri);
            Ok("auth_code".to_owned())
        };

        // SAFETY: single-threaded test, no other code depends on this var
        unsafe { env::set_var(TOKEN_DIR, "./tmp/gauth_app") };

        let auth = Auth {
            app_name: "gauth_app".to_owned(),
            auth_handler: None,
            consent_uri,
            oauth_creds: credentials::OauthCredentials {
                client_id: "client_id".to_owned(),
                project_id: "project_id".to_owned(),
                auth_uri: format!("{}/o/oauth2/auth", google_host),
                token_uri: format!("{}/token", google_host),
                auth_provider_x509_cert_url: "auth_provider_x509_cert_url".to_owned(),
                client_secret: "client_secret".to_owned(),
                redirect_uris: vec!["urn:ietf:wg:oauth:2.0:oob".to_owned()],
            },
            token_validate_host: google_host.to_owned(),
            http_client: HttpClient::new(),
        };

        let auth = auth.handler(auth_handler);

        let token = auth.access_token().await.unwrap();
        assert_eq!(token, "Bearer access_token");
        // SAFETY: single-threaded test, no other code depends on this var
        unsafe { env::remove_var(TOKEN_DIR) };
    }
}