use std::collections::HashMap;
use std::fmt::Formatter;
use std::future::{Future, IntoFuture};
use std::net::{IpAddr, Ipv4Addr};
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::Duration;
use serde_json::Value;
use tokio::sync::Mutex;
use url::Url;
use crate::Capabilities;
use crate::common::capabilities::desiredcapabilities::CapabilitiesHelper;
use crate::common::config::WebDriverConfig;
use crate::error::{WebDriverError, WebDriverResult};
use crate::session::DriverGuard;
use crate::session::create::start_session;
use crate::session::handle::SessionHandle;
use crate::session::http::create_reqwest_client;
use crate::web_driver::WebDriver;
use super::browser::{BrowserKind, detect_local_version};
use super::download::{DownloadConfig, Mirror, ensure_driver, resolve_version};
use super::error::ManagerError;
use super::process::{ManagedDriverProcess, SpawnConfig, SpawnContext, StdioMode};
use super::status::{
DriverId, DriverLogCallback, DriverLogLine, DriverLogSubscription, Emitter, LogSubscribers,
Status, StatusCallback, Subscription,
};
use super::version::DriverVersion;
const DEFAULT_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(60);
const DEFAULT_READY_TIMEOUT: Duration = Duration::from_secs(30);
fn default_cache_dir() -> PathBuf {
dirs::cache_dir().unwrap_or_else(std::env::temp_dir).join("thirtyfour").join("drivers")
}
pub struct WebDriverManager {
pub(crate) cfg: ResolvedConfig,
download_client: reqwest::Client,
drivers: Mutex<HashMap<DriverKey, Weak<ManagedDriverProcess>>>,
pub(crate) emitter: Emitter,
pub(crate) log_subscribers: LogSubscribers,
next_driver_id: Arc<AtomicU64>,
}
impl std::fmt::Debug for WebDriverManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebDriverManager").field("cfg", &self.cfg).finish_non_exhaustive()
}
}
#[derive(Clone)]
pub(crate) struct ResolvedConfig {
pub version: DriverVersion,
pub cache_dir: PathBuf,
pub host: IpAddr,
pub download_timeout: Duration,
pub ready_timeout: Duration,
pub offline: bool,
pub mirror: Mirror,
pub stdio: StdioMode,
pub driver_paths: HashMap<BrowserKind, PathBuf>,
}
impl std::fmt::Debug for ResolvedConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResolvedConfig")
.field("version", &self.version)
.field("cache_dir", &self.cache_dir)
.field("host", &self.host)
.field("download_timeout", &self.download_timeout)
.field("ready_timeout", &self.ready_timeout)
.field("offline", &self.offline)
.field("stdio", &self.stdio)
.field("driver_paths", &self.driver_paths)
.finish_non_exhaustive()
}
}
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
pub(super) struct DriverKey {
pub(super) browser: BrowserKind,
pub(super) version: String,
pub(super) host: IpAddr,
}
impl DriverGuard for ManagedDriverProcess {}
pub(crate) struct SessionGuard {
pub(crate) driver: Arc<ManagedDriverProcess>,
emitter: Emitter,
browser: BrowserKind,
session_id: String,
}
impl std::fmt::Debug for SessionGuard {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionGuard")
.field("browser", &self.browser)
.field("session_id", &self.session_id)
.field("driver", &self.driver)
.finish_non_exhaustive()
}
}
impl Drop for SessionGuard {
fn drop(&mut self) {
self.emitter.emit(Status::SessionEnded {
browser: self.browser,
session_id: self.session_id.clone(),
});
}
}
impl DriverGuard for SessionGuard {
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Default)]
pub struct WebDriverManagerBuilder {
pub(crate) version: DriverVersion,
pub(crate) cache_dir: Option<PathBuf>,
pub(crate) host: Option<IpAddr>,
pub(crate) download_timeout: Option<Duration>,
pub(crate) ready_timeout: Option<Duration>,
pub(crate) offline: Option<bool>,
pub(crate) mirror: Option<Mirror>,
pub(crate) stdio: Option<StdioMode>,
pub(crate) driver_paths: HashMap<BrowserKind, PathBuf>,
pub(crate) status_subscribers: Vec<StatusCallback>,
pub(crate) log_subscribers: Vec<DriverLogCallback>,
pub(crate) preloaded_caps: Option<Capabilities>,
}
impl std::fmt::Debug for WebDriverManagerBuilder {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebDriverManagerBuilder")
.field("version", &self.version)
.field("cache_dir", &self.cache_dir)
.field("host", &self.host)
.field("download_timeout", &self.download_timeout)
.field("ready_timeout", &self.ready_timeout)
.field("offline", &self.offline)
.field("stdio", &self.stdio)
.field("driver_paths", &self.driver_paths)
.field("status_subscribers", &self.status_subscribers.len())
.field("log_subscribers", &self.log_subscribers.len())
.finish_non_exhaustive()
}
}
impl Clone for WebDriverManagerBuilder {
fn clone(&self) -> Self {
Self {
version: self.version.clone(),
cache_dir: self.cache_dir.clone(),
host: self.host,
download_timeout: self.download_timeout,
ready_timeout: self.ready_timeout,
offline: self.offline,
mirror: self.mirror.clone(),
stdio: self.stdio,
driver_paths: self.driver_paths.clone(),
status_subscribers: self.status_subscribers.iter().map(Arc::clone).collect(),
log_subscribers: self.log_subscribers.iter().map(Arc::clone).collect(),
preloaded_caps: self.preloaded_caps.clone(),
}
}
}
impl WebDriverManagerBuilder {
pub fn version(mut self, v: DriverVersion) -> Self {
self.version = v;
self
}
pub fn latest(self) -> Self {
self.version(DriverVersion::Latest)
}
pub fn match_local(self) -> Self {
self.version(DriverVersion::MatchLocalBrowser)
}
pub fn from_caps(self) -> Self {
self.version(DriverVersion::FromCapabilities)
}
pub fn exact(self, v: impl Into<String>) -> Self {
self.version(DriverVersion::Exact(v.into()))
}
pub fn cache_dir(mut self, p: PathBuf) -> Self {
self.cache_dir = Some(p);
self
}
pub fn host(mut self, h: IpAddr) -> Self {
self.host = Some(h);
self
}
pub fn download_timeout(mut self, d: Duration) -> Self {
self.download_timeout = Some(d);
self
}
pub fn ready_timeout(mut self, d: Duration) -> Self {
self.ready_timeout = Some(d);
self
}
pub fn offline(self) -> Self {
self.offline_mode(true)
}
pub fn online(self) -> Self {
self.offline_mode(false)
}
pub fn offline_mode(mut self, yes: bool) -> Self {
self.offline = Some(yes);
self
}
pub fn chrome_metadata_mirror(mut self, base: Url) -> Self {
let m = self.mirror.get_or_insert_with(Mirror::default);
m.chrome_metadata = base;
self
}
pub fn geckodriver_downloads_mirror(mut self, base: Url) -> Self {
let m = self.mirror.get_or_insert_with(Mirror::default);
m.geckodriver_downloads = base;
self
}
pub fn edge_downloads_mirror(mut self, base: Url) -> Self {
let m = self.mirror.get_or_insert_with(Mirror::default);
m.edge_downloads = base;
self
}
pub fn stdio(mut self, mode: StdioMode) -> Self {
self.stdio = Some(mode);
self
}
pub fn driver_binary(mut self, browser: BrowserKind, path: impl Into<PathBuf>) -> Self {
self.driver_paths.insert(browser, path.into());
self
}
pub fn on_status<F>(mut self, f: F) -> Self
where
F: Fn(&Status) + Send + Sync + 'static,
{
self.status_subscribers.push(Arc::new(f));
self
}
pub fn on_driver_log<F>(mut self, f: F) -> Self
where
F: Fn(&DriverLogLine) + Send + Sync + 'static,
{
self.log_subscribers.push(Arc::new(f));
self
}
pub fn build(self) -> Arc<WebDriverManager> {
let cfg = ResolvedConfig {
version: self.version,
cache_dir: self.cache_dir.unwrap_or_else(default_cache_dir),
host: self.host.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST)),
download_timeout: self.download_timeout.unwrap_or(DEFAULT_DOWNLOAD_TIMEOUT),
ready_timeout: self.ready_timeout.unwrap_or(DEFAULT_READY_TIMEOUT),
offline: self.offline.unwrap_or(false),
mirror: self.mirror.unwrap_or_default(),
stdio: self.stdio.unwrap_or_default(),
driver_paths: self.driver_paths,
};
let emitter = Emitter::new();
for cb in self.status_subscribers {
std::mem::forget(emitter.add_arc(cb));
}
let log_subscribers = LogSubscribers::new();
for cb in self.log_subscribers {
std::mem::forget(log_subscribers.add_arc(cb));
}
Arc::new(WebDriverManager {
download_client: reqwest::Client::builder()
.build()
.expect("default reqwest client should always build"),
cfg,
drivers: Mutex::new(HashMap::new()),
emitter,
log_subscribers,
next_driver_id: Arc::new(AtomicU64::new(0)),
})
}
}
impl IntoFuture for WebDriverManagerBuilder {
type Output = WebDriverResult<WebDriver>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let caps = self
.preloaded_caps
.clone()
.ok_or_else(|| WebDriverError::from(ManagerError::NoCapabilities))?;
self.build().launch(caps).await
})
}
}
impl WebDriverManager {
pub fn builder() -> WebDriverManagerBuilder {
WebDriverManagerBuilder::default()
}
pub fn subscribe<F>(&self, f: F) -> Subscription
where
F: Fn(&Status) + Send + Sync + 'static,
{
self.emitter.add(f)
}
pub fn on_driver_log<F>(&self, f: F) -> DriverLogSubscription
where
F: Fn(&DriverLogLine) + Send + Sync + 'static,
{
self.log_subscribers.add(f)
}
pub(crate) fn mint_driver_id(&self) -> DriverId {
DriverId::from_raw(self.next_driver_id.fetch_add(1, Ordering::Relaxed))
}
pub async fn launch(
self: &Arc<Self>,
capabilities: impl Into<Capabilities>,
) -> WebDriverResult<WebDriver> {
let caps: Capabilities = capabilities.into();
let driver = self.ensure_driver(&caps).await.map_err(WebDriverError::from)?;
let browser = driver.browser;
let server_url: Url = driver
.url()
.parse()
.map_err(|e| WebDriverError::ParseError(format!("invalid driver url: {e}")))?;
let config = WebDriverConfig::default();
let client = create_reqwest_client(config.reqwest_timeout);
let client_arc: Arc<dyn crate::session::http::HttpClient> = Arc::new(client);
self.emitter.emit(Status::SessionStarting {
browser,
url: server_url.to_string(),
});
let session_id = start_session(client_arc.as_ref(), &server_url, &config, caps).await?;
self.emitter.emit(Status::SessionStarted {
browser,
session_id: session_id.to_string(),
url: server_url.to_string(),
});
let guard: Arc<dyn DriverGuard> = Arc::new(SessionGuard {
driver,
emitter: self.emitter.clone(),
browser,
session_id: session_id.to_string(),
});
let handle = SessionHandle::new_with_config_and_guard(
client_arc,
server_url,
session_id,
config,
Some(guard),
)?;
Ok(WebDriver {
handle: Arc::new(handle),
})
}
async fn ensure_driver(
&self,
caps: &Capabilities,
) -> Result<Arc<ManagedDriverProcess>, ManagerError> {
let browser = BrowserKind::from_capabilities(caps)?;
self.emitter.emit(Status::BrowserKindResolved {
browser,
});
if let Some(binary) = self.cfg.driver_paths.get(&browser).cloned() {
let version = format!("manual:{}", binary.display());
return self.spawn_or_reuse(browser, version, &binary).await;
}
let local = match self.cfg.version {
DriverVersion::MatchLocalBrowser => {
let custom = browser.binary_from_caps(caps);
Some(detect_local_version(browser, custom.as_deref(), &self.emitter)?)
}
_ => None,
};
let caps_version = caps._get("browserVersion").and_then(Value::as_str);
let download_cfg = DownloadConfig {
cache_dir: self.cfg.cache_dir.clone(),
mirror: self.cfg.mirror.clone(),
download_timeout: self.cfg.download_timeout,
offline: self.cfg.offline,
};
let resolved = resolve_version(
&self.download_client,
&download_cfg,
browser,
&self.cfg.version,
local.as_deref(),
caps_version,
&self.emitter,
)
.await?;
{
let key = DriverKey {
browser,
version: resolved.clone(),
host: self.cfg.host,
};
let map = self.drivers.lock().await;
if let Some(existing) = map.get(&key).and_then(Weak::upgrade) {
self.emitter.emit(Status::DriverReused {
browser,
version: resolved,
url: existing.url(),
});
return Ok(existing);
}
}
let driver_path =
ensure_driver(&self.download_client, &download_cfg, browser, &resolved, &self.emitter)
.await?;
self.spawn_or_reuse(browser, resolved, &driver_path.binary).await
}
async fn spawn_or_reuse(
&self,
browser: BrowserKind,
version: String,
binary: &Path,
) -> Result<Arc<ManagedDriverProcess>, ManagerError> {
let key = DriverKey {
browser,
version: version.clone(),
host: self.cfg.host,
};
{
let map = self.drivers.lock().await;
if let Some(existing) = map.get(&key).and_then(Weak::upgrade) {
self.emitter.emit(Status::DriverReused {
browser,
version: version.clone(),
url: existing.url(),
});
return Ok(existing);
}
}
let process = ManagedDriverProcess::spawn(
binary,
browser,
&SpawnConfig {
host: self.cfg.host,
ready_timeout: self.cfg.ready_timeout,
stdio: self.cfg.stdio,
},
SpawnContext {
driver_id: self.mint_driver_id(),
version: &version,
emitter: &self.emitter,
manager_log_subscribers: self.log_subscribers.clone(),
},
)
.await?;
let arc = Arc::new(process);
let mut map = self.drivers.lock().await;
if let Some(existing) = map.get(&key).and_then(Weak::upgrade) {
self.emitter.emit(Status::DriverReused {
browser,
version,
url: existing.url(),
});
return Ok(existing);
}
map.insert(key, Arc::downgrade(&arc));
Ok(arc)
}
}