use std::sync::Arc;
use tokio::sync::RwLock;
use viewpoint_cdp::CdpConnection;
use viewpoint_cdp::protocol::fetch::{
AuthChallenge, AuthChallengeResponse, AuthChallengeSource, AuthRequiredEvent,
ContinueWithAuthParams,
};
use crate::error::NetworkError;
#[derive(Debug, Clone)]
pub struct ProxyCredentials {
pub username: String,
pub password: String,
}
impl ProxyCredentials {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct HttpCredentials {
pub username: String,
pub password: String,
pub origin: Option<String>,
}
impl HttpCredentials {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
origin: None,
}
}
pub fn for_origin(
username: impl Into<String>,
password: impl Into<String>,
origin: impl Into<String>,
) -> Self {
Self {
username: username.into(),
password: password.into(),
origin: Some(origin.into()),
}
}
pub fn matches_origin(&self, challenge_origin: &str) -> bool {
match &self.origin {
Some(origin) => {
challenge_origin == origin || challenge_origin.ends_with(&format!(".{origin}"))
}
None => true, }
}
}
#[derive(Debug)]
pub struct AuthHandler {
connection: Arc<CdpConnection>,
session_id: String,
credentials: RwLock<Option<HttpCredentials>>,
proxy_credentials: RwLock<Option<ProxyCredentials>>,
max_retries: u32,
retry_counts: RwLock<std::collections::HashMap<String, u32>>,
}
impl AuthHandler {
pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
Self {
connection,
session_id,
credentials: RwLock::new(None),
proxy_credentials: RwLock::new(None),
max_retries: 3,
retry_counts: RwLock::new(std::collections::HashMap::new()),
}
}
pub fn with_credentials(
connection: Arc<CdpConnection>,
session_id: String,
credentials: HttpCredentials,
) -> Self {
Self {
connection,
session_id,
credentials: RwLock::new(Some(credentials)),
proxy_credentials: RwLock::new(None),
max_retries: 3,
retry_counts: RwLock::new(std::collections::HashMap::new()),
}
}
pub fn with_proxy_credentials(
connection: Arc<CdpConnection>,
session_id: String,
proxy_credentials: ProxyCredentials,
) -> Self {
Self {
connection,
session_id,
credentials: RwLock::new(None),
proxy_credentials: RwLock::new(Some(proxy_credentials)),
max_retries: 3,
retry_counts: RwLock::new(std::collections::HashMap::new()),
}
}
pub fn with_all_credentials(
connection: Arc<CdpConnection>,
session_id: String,
http_credentials: Option<HttpCredentials>,
proxy_credentials: Option<ProxyCredentials>,
) -> Self {
Self {
connection,
session_id,
credentials: RwLock::new(http_credentials),
proxy_credentials: RwLock::new(proxy_credentials),
max_retries: 3,
retry_counts: RwLock::new(std::collections::HashMap::new()),
}
}
pub async fn set_credentials(&self, credentials: HttpCredentials) {
let mut creds = self.credentials.write().await;
*creds = Some(credentials);
}
pub fn set_credentials_sync(&self, credentials: HttpCredentials) {
if let Ok(mut creds) = self.credentials.try_write() {
*creds = Some(credentials);
}
}
pub async fn clear_credentials(&self) {
let mut creds = self.credentials.write().await;
*creds = None;
}
pub async fn set_proxy_credentials(&self, credentials: ProxyCredentials) {
let mut creds = self.proxy_credentials.write().await;
*creds = Some(credentials);
}
pub fn set_proxy_credentials_sync(&self, credentials: ProxyCredentials) {
if let Ok(mut creds) = self.proxy_credentials.try_write() {
*creds = Some(credentials);
}
}
pub async fn clear_proxy_credentials(&self) {
let mut creds = self.proxy_credentials.write().await;
*creds = None;
}
pub async fn handle_auth_challenge(
&self,
event: &AuthRequiredEvent,
) -> Result<bool, NetworkError> {
if event.auth_challenge.source == AuthChallengeSource::Proxy {
return self.handle_proxy_auth(event).await;
}
let creds = self.credentials.read().await;
if let Some(credentials) = &*creds {
if !credentials.matches_origin(&event.auth_challenge.origin) {
tracing::debug!(
origin = %event.auth_challenge.origin,
"No matching credentials for origin"
);
return self.cancel_auth(&event.request_id).await.map(|()| false);
}
{
let mut counts = self.retry_counts.write().await;
let count = counts
.entry(event.auth_challenge.origin.clone())
.or_insert(0);
if *count >= self.max_retries {
tracing::warn!(
origin = %event.auth_challenge.origin,
retries = self.max_retries,
"Max auth retries exceeded, canceling"
);
return self.cancel_auth(&event.request_id).await.map(|()| false);
}
*count += 1;
}
self.provide_credentials(
&event.request_id,
&event.auth_challenge,
&credentials.username,
&credentials.password,
)
.await?;
Ok(true)
} else {
tracing::debug!(
origin = %event.auth_challenge.origin,
scheme = %event.auth_challenge.scheme,
"No credentials available, deferring to default"
);
self.default_auth(&event.request_id).await?;
Ok(false)
}
}
async fn handle_proxy_auth(&self, event: &AuthRequiredEvent) -> Result<bool, NetworkError> {
let proxy_creds = self.proxy_credentials.read().await;
if let Some(credentials) = &*proxy_creds {
let retry_key = format!("proxy:{}", event.auth_challenge.origin);
{
let mut counts = self.retry_counts.write().await;
let count = counts.entry(retry_key.clone()).or_insert(0);
if *count >= self.max_retries {
tracing::warn!(
origin = %event.auth_challenge.origin,
retries = self.max_retries,
"Max proxy auth retries exceeded, canceling"
);
return self.cancel_auth(&event.request_id).await.map(|()| false);
}
*count += 1;
}
tracing::debug!(
origin = %event.auth_challenge.origin,
scheme = %event.auth_challenge.scheme,
"Providing proxy credentials"
);
self.provide_credentials(
&event.request_id,
&event.auth_challenge,
&credentials.username,
&credentials.password,
)
.await?;
Ok(true)
} else {
tracing::debug!(
origin = %event.auth_challenge.origin,
scheme = %event.auth_challenge.scheme,
"No proxy credentials available, deferring to default"
);
self.default_auth(&event.request_id).await?;
Ok(false)
}
}
async fn provide_credentials(
&self,
request_id: &str,
challenge: &AuthChallenge,
username: &str,
password: &str,
) -> Result<(), NetworkError> {
tracing::debug!(
origin = %challenge.origin,
scheme = %challenge.scheme,
realm = %challenge.realm,
"Providing credentials for auth challenge"
);
self.connection
.send_command::<_, serde_json::Value>(
"Fetch.continueWithAuth",
Some(ContinueWithAuthParams {
request_id: request_id.to_string(),
auth_challenge_response: AuthChallengeResponse::provide_credentials(
username, password,
),
}),
Some(&self.session_id),
)
.await?;
Ok(())
}
async fn cancel_auth(&self, request_id: &str) -> Result<(), NetworkError> {
tracing::debug!("Canceling auth challenge");
self.connection
.send_command::<_, serde_json::Value>(
"Fetch.continueWithAuth",
Some(ContinueWithAuthParams {
request_id: request_id.to_string(),
auth_challenge_response: AuthChallengeResponse::cancel(),
}),
Some(&self.session_id),
)
.await?;
Ok(())
}
async fn default_auth(&self, request_id: &str) -> Result<(), NetworkError> {
self.connection
.send_command::<_, serde_json::Value>(
"Fetch.continueWithAuth",
Some(ContinueWithAuthParams {
request_id: request_id.to_string(),
auth_challenge_response: AuthChallengeResponse::default_response(),
}),
Some(&self.session_id),
)
.await?;
Ok(())
}
pub async fn reset_retries(&self, origin: &str) {
let mut counts = self.retry_counts.write().await;
counts.remove(origin);
}
pub async fn reset_all_retries(&self) {
let mut counts = self.retry_counts.write().await;
counts.clear();
}
}
#[cfg(test)]
mod tests;