use std::collections::HashMap;
use std::time::Duration;
use serde::Deserialize;
use tracing::{debug, info, instrument};
use url::Url;
use crate::error::CdpError;
const DEFAULT_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BrowserVersion {
pub browser: Option<String>,
pub protocol_version: Option<String>,
pub user_agent: Option<String>,
#[serde(rename = "V8-Version")]
pub v8_version: Option<String>,
pub webkit_version: Option<String>,
#[serde(rename = "webSocketDebuggerUrl")]
pub web_socket_debugger_url: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CdpConnectionOptions {
pub timeout: Option<Duration>,
pub headers: HashMap<String, String>,
}
impl CdpConnectionOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
#[must_use]
pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers.extend(headers);
self
}
}
#[instrument(level = "info", skip(options))]
pub async fn discover_websocket_url(
endpoint_url: &str,
options: &CdpConnectionOptions,
) -> Result<String, CdpError> {
let base_url = Url::parse(endpoint_url)
.map_err(|e| CdpError::InvalidEndpointUrl(format!("{endpoint_url}: {e}")))?;
if base_url.scheme() == "ws" || base_url.scheme() == "wss" {
debug!("URL is already a WebSocket URL, returning as-is");
return Ok(endpoint_url.to_string());
}
if base_url.scheme() != "http" && base_url.scheme() != "https" {
return Err(CdpError::InvalidEndpointUrl(format!(
"expected http, https, ws, or wss scheme, got: {}",
base_url.scheme()
)));
}
let version_url = base_url
.join("/json/version")
.map_err(|e| CdpError::InvalidEndpointUrl(format!("failed to build version URL: {e}")))?;
info!(url = %version_url, "Discovering WebSocket URL from HTTP endpoint");
let timeout = options.timeout.unwrap_or(DEFAULT_DISCOVERY_TIMEOUT);
let client = reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| CdpError::HttpRequestFailed(e.to_string()))?;
let mut request = client.get(version_url.as_str());
for (name, value) in &options.headers {
request = request.header(name, value);
}
let response = request.send().await.map_err(|e| {
if e.is_timeout() {
CdpError::ConnectionTimeout(timeout)
} else if e.is_connect() {
CdpError::ConnectionFailed(format!("failed to connect to {endpoint_url}: {e}"))
} else {
CdpError::HttpRequestFailed(e.to_string())
}
})?;
if !response.status().is_success() {
return Err(CdpError::EndpointDiscoveryFailed {
url: endpoint_url.to_string(),
reason: format!("HTTP status {}", response.status()),
});
}
let version: BrowserVersion =
response
.json()
.await
.map_err(|e| CdpError::EndpointDiscoveryFailed {
url: endpoint_url.to_string(),
reason: format!("failed to parse response: {e}"),
})?;
let ws_url =
version
.web_socket_debugger_url
.ok_or_else(|| CdpError::EndpointDiscoveryFailed {
url: endpoint_url.to_string(),
reason: "response missing webSocketDebuggerUrl field".to_string(),
})?;
info!(ws_url = %ws_url, browser = ?version.browser, "Discovered WebSocket URL");
Ok(ws_url)
}
#[cfg(test)]
mod tests;