use std::sync::{Arc, RwLock};
use bitwarden_core::{
NotAuthenticatedError, OrganizationId,
auth::{TokenHandler, login::LoginError},
client::login_method::ServiceAccountLoginMethod,
key_management::KeySlotIds,
};
use bitwarden_crypto::KeyStore;
use bitwarden_state::registry::StateRegistry;
use chrono::Utc;
use super::middleware::{MiddlewareExt, MiddlewareWrapper};
use crate::token_management::middleware::TOKEN_RENEW_MARGIN_SECONDS;
#[derive(Clone, Default)]
pub struct SecretsManagerTokenHandler {
inner: Arc<RwLock<SecretsManagerTokenHandlerInner>>,
}
#[derive(Clone, Default)]
struct SecretsManagerTokenHandlerInner {
access_token: Option<String>,
expires_on: Option<i64>,
login_method: Option<Arc<ServiceAccountLoginMethod>>,
identity_config: Option<bitwarden_api_api::Configuration>,
key_store: Option<KeyStore<KeySlotIds>>,
}
#[async_trait::async_trait]
impl TokenHandler for SecretsManagerTokenHandler {
fn initialize_middleware(
&self,
_state_registry: &StateRegistry,
identity_config: bitwarden_api_api::Configuration,
key_store: KeyStore<KeySlotIds>,
) -> Arc<dyn reqwest_middleware::Middleware> {
{
let mut inner = self.inner.write().expect("RwLock is not poisoned");
inner.identity_config = Some(identity_config);
inner.key_store = Some(key_store);
}
Arc::new(MiddlewareWrapper(self.clone()))
}
async fn set_tokens(
&self,
access_token: String,
_refresh_token: Option<String>,
expires_in: u64,
) {
let mut inner = self.inner.write().expect("RwLock is not poisoned");
inner.access_token = Some(access_token);
inner.expires_on = Some(Utc::now().timestamp() + expires_in as i64);
}
async fn set_sm_login_method(&self, login_method: ServiceAccountLoginMethod) {
let mut inner = self.inner.write().expect("RwLock is not poisoned");
inner.login_method = Some(Arc::new(login_method));
}
}
impl SecretsManagerTokenHandler {
pub fn get_access_token_organization(&self) -> Option<OrganizationId> {
let inner = self.inner.read().ok()?;
match inner.login_method.as_deref()? {
ServiceAccountLoginMethod::AccessToken {
organization_id, ..
} => Some(*organization_id),
}
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl MiddlewareExt for SecretsManagerTokenHandler {
async fn get_token(&self) -> Result<Option<String>, LoginError> {
let inner = self.inner.read().expect("RwLock is not poisoned").clone();
if let Some(expires) = inner.expires_on
&& Utc::now().timestamp() < expires - TOKEN_RENEW_MARGIN_SECONDS
{
return Ok(inner.access_token.clone());
}
let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
let key_store = inner.key_store.ok_or(NotAuthenticatedError)?;
let (access_token, refresh_token, expires_in) =
bitwarden_core::auth::renew::renew_sm_token_sdk_managed(
login_method.as_ref(),
identity_config,
key_store,
)
.await?;
self.set_tokens(access_token.clone(), refresh_token, expires_in)
.await;
Ok(Some(access_token))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use bitwarden_core::{
auth::{AccessToken, TokenHandler},
client::login_method::ServiceAccountLoginMethod,
key_management::KeySlotIds,
};
use bitwarden_crypto::KeyStore;
use bitwarden_state::registry::StateRegistry;
use wiremock::MockServer;
use super::*;
use crate::token_management::test_utils::*;
fn service_account_login_method() -> ServiceAccountLoginMethod {
let access_token = AccessToken::from_str(
"0.ec2c1d46-6a4b-4751-a310-af9601317f2d.C2IgxjjLF7qSshsbwe8JGcbM075YXw:X8vbvA0bduihIDe/qrzIQQ==",
)
.unwrap();
ServiceAccountLoginMethod::AccessToken {
access_token,
organization_id: "00000000-0000-0000-0000-000000000001".parse().unwrap(),
state_file: None,
}
}
#[tokio::test]
async fn attaches_existing_token_when_not_expired() {
let app_server = start_app_server().await;
let identity_server = MockServer::start().await;
let handler = SecretsManagerTokenHandler::default();
handler
.set_sm_login_method(service_account_login_method())
.await;
handler
.set_tokens("original-token".to_string(), None, 3600)
.await;
let registry = StateRegistry::new_with_memory_db();
let client = build_client(handler.initialize_middleware(
®istry,
identity_config(&identity_server.uri()),
KeyStore::<KeySlotIds>::default(),
));
let auth = send_auth_request(&client, &app_server).await;
assert_eq!(auth.as_deref(), Some("Bearer original-token"));
assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
}
#[tokio::test]
async fn renews_expired_token() {
let app_server = start_app_server().await;
let identity_server = start_renewal_server("renewed-token").await;
let handler = SecretsManagerTokenHandler::default();
handler
.set_sm_login_method(service_account_login_method())
.await;
handler
.set_tokens("expired-token".to_string(), None, 0)
.await;
let registry = StateRegistry::new_with_memory_db();
let client = build_client(handler.initialize_middleware(
®istry,
identity_config(&identity_server.uri()),
KeyStore::<KeySlotIds>::default(),
));
let auth = send_auth_request(&client, &app_server).await;
assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
}
}