opentalk-client 0.0.1

A client library to interact with OpenTalk
Documentation
// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
//
// SPDX-License-Identifier: EUPL-1.2

use std::time::Duration;

use anyhow::Result;
use chrono::Utc;
use oauth2::{
    AuthUrl, ClientId, RefreshToken, ResourceOwnerPassword, ResourceOwnerUsername, Scope,
    TokenResponse as _, TokenUrl, basic::BasicClient,
};
use opentalk_client_data_persistence::{AccountTokens, DataManager};
use secrecy::{ExposeSecret, SecretString};

use crate::{Authorization, oidc::OidcEndpoints, oidc_authorization::REFRESH_BEFORE_EXPIRY};

/// TODO
#[derive(Debug)]
pub struct OidcDirectAccessGrant {
    data_manager: Box<dyn DataManager>,
    oidc_endpoints: OidcEndpoints,
    oidc_client_id: String,
}

#[async_trait::async_trait(?Send)]
impl Authorization for OidcDirectAccessGrant {
    async fn get_access_token(&self) -> Result<String> {
        self.get_token_and_refresh_if_needed(REFRESH_BEFORE_EXPIRY)
            .await
    }
}

#[async_trait::async_trait(?Send)]
impl Authorization for &OidcDirectAccessGrant {
    async fn get_access_token(&self) -> Result<String> {
        Authorization::get_access_token(*self).await
    }
}

impl OidcDirectAccessGrant {
    /// Loads accesss token and calls refresh if needed
    pub async fn get_token_and_refresh_if_needed(
        &self,
        refresh_before_expiry: Duration,
    ) -> Result<String> {
        let AccountTokens {
            access_token_expiry,
            access_token,
            ..
        } = self.data_manager.load_account_tokens()?;

        let now = Utc::now();
        if now + refresh_before_expiry > access_token_expiry {
            Ok(self.refresh_token().await?)
        } else {
            Ok(access_token)
        }
    }

    /// Performs token refresh
    pub async fn refresh_token(&self) -> Result<String> {
        let AccountTokens { refresh_token, .. } = self.data_manager.load_account_tokens()?;

        let client = BasicClient::new(ClientId::new(self.oidc_client_id.clone()))
            .set_auth_uri(
                AuthUrl::new(self.oidc_endpoints.authorization_endpoint.to_string()).unwrap(),
            )
            .set_token_uri(TokenUrl::new(self.oidc_endpoints.token_endpoint.to_string()).unwrap());

        let builder = reqwest::ClientBuilder::new();

        #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
        let builder = {
            // Following redirects opens the client up to SSRF vulnerabilities.
            builder.redirect(reqwest::redirect::Policy::none())
        };

        let http_client = builder.build().expect("Client should build");
        let http_client = super::ClientWrapper(http_client);

        let response = client
            .exchange_refresh_token(&RefreshToken::new(refresh_token))
            .request_async(&http_client)
            .await
            .unwrap();

        let now = Utc::now();
        let account_tokens = AccountTokens {
            access_token_expiry: now + response.expires_in().unwrap_or_default(),
            access_token: response.access_token().secret().clone(),
            refresh_token: response.refresh_token().unwrap().secret().clone(),
        };
        let _ = self
            .data_manager
            .store_account_tokens(account_tokens.clone());

        Ok(account_tokens.access_token)
    }

    /// perform oidc direct access grand authorization
    pub async fn create_with_direct_access_grant(
        data_manager: Box<dyn DataManager>,
        oidc_endpoints: OidcEndpoints,
        oidc_client_id: String,
        oidc_user: String,
        oidc_password: SecretString,
    ) -> Result<Self> {
        let oidc_client = BasicClient::new(ClientId::new(oidc_client_id.clone()))
            .set_auth_uri(AuthUrl::new(oidc_endpoints.authorization_endpoint.to_string()).unwrap())
            .set_token_uri(TokenUrl::new(oidc_endpoints.token_endpoint.to_string()).unwrap());

        let builder = reqwest::ClientBuilder::new();

        #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
        let builder = {
            // Following redirects opens the client up to SSRF vulnerabilities.
            builder.redirect(reqwest::redirect::Policy::none())
        };

        let http_client = builder.build().expect("Client should build");
        let http_client = super::ClientWrapper(http_client);

        let token_result = oidc_client
            .exchange_password(
                &ResourceOwnerUsername::new(oidc_user.clone()),
                &ResourceOwnerPassword::new(oidc_password.expose_secret().to_string()),
            )
            .add_scope(Scope::new("openid".to_string()))
            .request_async(&http_client)
            .await
            .unwrap();

        let now = Utc::now();

        let account_tokens = AccountTokens {
            access_token_expiry: now + token_result.expires_in().unwrap_or_default(),
            access_token: token_result.access_token().clone().into_secret(),
            refresh_token: token_result
                .refresh_token()
                .expect("Refresh token should be exist")
                .clone()
                .into_secret(),
        };

        data_manager.store_account_tokens(account_tokens.clone())?;

        println!("{:?}", token_result);

        Ok(Self {
            data_manager,
            oidc_endpoints,
            oidc_client_id,
        })
    }
}