use std::sync::Arc;
use tokio::sync::RwLock;
use viewpoint_cdp::CdpConnection;
use viewpoint_cdp::protocol::fetch::{
AuthChallenge, AuthChallengeResponse, AuthRequiredEvent, ContinueWithAuthParams,
};
use crate::error::NetworkError;
#[derive(Debug, Clone)]
pub struct HttpCredentials {
pub username: String,
pub password: String,
pub origin: Option<String>,
}
impl HttpCredentials {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
origin: None,
}
}
pub fn for_origin(
username: impl Into<String>,
password: impl Into<String>,
origin: impl Into<String>,
) -> Self {
Self {
username: username.into(),
password: password.into(),
origin: Some(origin.into()),
}
}
pub fn matches_origin(&self, challenge_origin: &str) -> bool {
match &self.origin {
Some(origin) => {
challenge_origin == origin || challenge_origin.ends_with(&format!(".{origin}"))
}
None => true, }
}
}
#[derive(Debug)]
pub struct AuthHandler {
connection: Arc<CdpConnection>,
session_id: String,
credentials: RwLock<Option<HttpCredentials>>,
max_retries: u32,
retry_counts: RwLock<std::collections::HashMap<String, u32>>,
}
impl AuthHandler {
pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
Self {
connection,
session_id,
credentials: RwLock::new(None),
max_retries: 3,
retry_counts: RwLock::new(std::collections::HashMap::new()),
}
}
pub fn with_credentials(
connection: Arc<CdpConnection>,
session_id: String,
credentials: HttpCredentials,
) -> Self {
Self {
connection,
session_id,
credentials: RwLock::new(Some(credentials)),
max_retries: 3,
retry_counts: RwLock::new(std::collections::HashMap::new()),
}
}
pub async fn set_credentials(&self, credentials: HttpCredentials) {
let mut creds = self.credentials.write().await;
*creds = Some(credentials);
}
pub fn set_credentials_sync(&self, credentials: HttpCredentials) {
if let Ok(mut creds) = self.credentials.try_write() {
*creds = Some(credentials);
}
}
pub async fn clear_credentials(&self) {
let mut creds = self.credentials.write().await;
*creds = None;
}
pub async fn handle_auth_challenge(
&self,
event: &AuthRequiredEvent,
) -> Result<bool, NetworkError> {
let creds = self.credentials.read().await;
if let Some(credentials) = &*creds {
if !credentials.matches_origin(&event.auth_challenge.origin) {
tracing::debug!(
origin = %event.auth_challenge.origin,
"No matching credentials for origin"
);
return self.cancel_auth(&event.request_id).await.map(|()| false);
}
{
let mut counts = self.retry_counts.write().await;
let count = counts
.entry(event.auth_challenge.origin.clone())
.or_insert(0);
if *count >= self.max_retries {
tracing::warn!(
origin = %event.auth_challenge.origin,
retries = self.max_retries,
"Max auth retries exceeded, canceling"
);
return self.cancel_auth(&event.request_id).await.map(|()| false);
}
*count += 1;
}
self.provide_credentials(
&event.request_id,
&event.auth_challenge,
&credentials.username,
&credentials.password,
)
.await?;
Ok(true)
} else {
tracing::debug!(
origin = %event.auth_challenge.origin,
scheme = %event.auth_challenge.scheme,
"No credentials available, deferring to default"
);
self.default_auth(&event.request_id).await?;
Ok(false)
}
}
async fn provide_credentials(
&self,
request_id: &str,
challenge: &AuthChallenge,
username: &str,
password: &str,
) -> Result<(), NetworkError> {
tracing::debug!(
origin = %challenge.origin,
scheme = %challenge.scheme,
realm = %challenge.realm,
"Providing credentials for auth challenge"
);
self.connection
.send_command::<_, serde_json::Value>(
"Fetch.continueWithAuth",
Some(ContinueWithAuthParams {
request_id: request_id.to_string(),
auth_challenge_response: AuthChallengeResponse::provide_credentials(
username, password,
),
}),
Some(&self.session_id),
)
.await?;
Ok(())
}
async fn cancel_auth(&self, request_id: &str) -> Result<(), NetworkError> {
tracing::debug!("Canceling auth challenge");
self.connection
.send_command::<_, serde_json::Value>(
"Fetch.continueWithAuth",
Some(ContinueWithAuthParams {
request_id: request_id.to_string(),
auth_challenge_response: AuthChallengeResponse::cancel(),
}),
Some(&self.session_id),
)
.await?;
Ok(())
}
async fn default_auth(&self, request_id: &str) -> Result<(), NetworkError> {
self.connection
.send_command::<_, serde_json::Value>(
"Fetch.continueWithAuth",
Some(ContinueWithAuthParams {
request_id: request_id.to_string(),
auth_challenge_response: AuthChallengeResponse::default_response(),
}),
Some(&self.session_id),
)
.await?;
Ok(())
}
pub async fn reset_retries(&self, origin: &str) {
let mut counts = self.retry_counts.write().await;
counts.remove(origin);
}
pub async fn reset_all_retries(&self) {
let mut counts = self.retry_counts.write().await;
counts.clear();
}
}
#[cfg(test)]
mod tests;