use std::sync::Arc;
use bitwarden_crypto::KeyStore;
use bitwarden_state::registry::StateRegistry;
#[cfg(feature = "secrets")]
use crate::client::login_method::ServiceAccountLoginMethod;
use crate::key_management::KeySlotIds;
#[async_trait::async_trait]
pub trait TokenHandler: 'static + Send + Sync {
fn initialize_middleware(
&self,
state_registry: &StateRegistry,
identity_config: bitwarden_api_base::Configuration,
key_store: KeyStore<KeySlotIds>,
) -> Arc<dyn reqwest_middleware::Middleware>;
async fn set_tokens(&self, token: String, refresh_token: Option<String>, expires_in: u64);
#[cfg(feature = "secrets")]
async fn set_sm_login_method(&self, _login_method: ServiceAccountLoginMethod) {}
}
#[cfg_attr(feature = "uniffi", uniffi::export(with_foreign))]
#[async_trait::async_trait]
pub trait ClientManagedTokens: std::fmt::Debug + Send + Sync {
async fn get_access_token(&self) -> Option<String>;
}
#[derive(Clone)]
pub struct ClientManagedTokenHandler {
tokens: Arc<dyn ClientManagedTokens>,
}
impl ClientManagedTokenHandler {
pub fn new(tokens: Arc<dyn ClientManagedTokens>) -> Arc<Self> {
Arc::new(Self { tokens })
}
}
#[async_trait::async_trait]
impl TokenHandler for ClientManagedTokenHandler {
fn initialize_middleware(
&self,
_state_registry: &StateRegistry,
_identity_config: bitwarden_api_base::Configuration,
_key_store: KeyStore<KeySlotIds>,
) -> Arc<dyn reqwest_middleware::Middleware> {
Arc::new(self.clone())
}
async fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
panic!("Client-managed tokens cannot be set by the SDK");
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl reqwest_middleware::Middleware for ClientManagedTokenHandler {
async fn handle(
&self,
mut req: reqwest::Request,
ext: &mut http::Extensions,
next: reqwest_middleware::Next<'_>,
) -> Result<reqwest::Response, reqwest_middleware::Error> {
if ext.get::<bitwarden_api_base::AuthRequired>().is_some()
&& let Some(token) = self.tokens.get_access_token().await
{
match format!("Bearer {}", token).parse() {
Ok(header_value) => {
req.headers_mut()
.insert(http::header::AUTHORIZATION, header_value);
}
Err(e) => {
tracing::warn!("Failed to parse auth token for header: {e}");
}
}
}
let resp = next.run(req, ext).await?;
Ok(resp)
}
}
#[derive(Clone, Copy)]
pub struct NoopTokenHandler;
#[async_trait::async_trait]
impl TokenHandler for NoopTokenHandler {
fn initialize_middleware(
&self,
_state_registry: &StateRegistry,
_identity_config: bitwarden_api_base::Configuration,
_key_store: KeyStore<KeySlotIds>,
) -> Arc<dyn reqwest_middleware::Middleware> {
Arc::new(*self)
}
async fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
panic!("Cannot set tokens on NoopTokenHandler");
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl reqwest_middleware::Middleware for NoopTokenHandler {
async fn handle(
&self,
req: reqwest::Request,
ext: &mut http::Extensions,
next: reqwest_middleware::Next<'_>,
) -> Result<reqwest::Response, reqwest_middleware::Error> {
next.run(req, ext).await
}
}
#[cfg(test)]
mod tests {
use wiremock::MockServer;
use super::*;
#[derive(Debug)]
struct MockTokenProvider {
token: Option<String>,
}
#[async_trait::async_trait]
impl ClientManagedTokens for MockTokenProvider {
async fn get_access_token(&self) -> Option<String> {
self.token.clone()
}
}
async fn test_setup(
token: Option<String>,
) -> (reqwest_middleware::ClientWithMiddleware, MockServer) {
let provider = Arc::new(MockTokenProvider { token });
let handler = ClientManagedTokenHandler::new(provider);
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with((*handler).clone())
.build();
let server = MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::any())
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
(client, server)
}
#[tokio::test]
async fn attaches_bearer_token_when_auth_required() {
let (client, server) = test_setup(Some("test-token".to_string())).await;
client
.get(format!("{}/test", server.uri()))
.with_extension(bitwarden_api_base::AuthRequired::Bearer)
.send()
.await
.unwrap();
let requests = server.received_requests().await.unwrap();
assert_eq!(requests.len(), 1);
assert_eq!(
requests[0]
.headers
.get("Authorization")
.map(|v| v.to_str().unwrap()),
Some("Bearer test-token")
);
}
#[tokio::test]
async fn does_not_attach_token_without_auth_required() {
let (client, server) = test_setup(Some("test-token".to_string())).await;
client
.get(format!("{}/test", server.uri()))
.send()
.await
.unwrap();
let requests = server.received_requests().await.unwrap();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].headers.get("Authorization"), None);
}
#[tokio::test]
async fn does_not_attach_token_when_provider_returns_none() {
let (client, server) = test_setup(None).await;
client
.get(format!("{}/test", server.uri()))
.with_extension(bitwarden_api_base::AuthRequired::Bearer)
.send()
.await
.unwrap();
let requests = server.received_requests().await.unwrap();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].headers.get("Authorization"), None);
}
}