1use crate::error::Result;
38use cdp_protocol::{dom, fetch, network, page as page_cdp, performance, runtime as runtime_cdp};
39use serde::Deserialize;
40use std::sync::Arc;
41use tokio::sync::Mutex;
42
43use crate::session::Session;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum DomainType {
48 Page,
50 Runtime,
52 Dom,
54 Network,
56 Fetch,
58 Performance,
60 Css,
62 Debugger,
64 Profiler,
66 Keyboard,
68 Mouse,
70}
71
72impl DomainType {
73 pub fn is_required(&self) -> bool {
75 matches!(
76 self,
77 DomainType::Page | DomainType::Runtime | DomainType::Dom | DomainType::Network
78 )
79 }
80
81 pub fn name(&self) -> &'static str {
83 match self {
84 DomainType::Page => "Page",
85 DomainType::Runtime => "Runtime",
86 DomainType::Dom => "DOM",
87 DomainType::Network => "Network",
88 DomainType::Fetch => "Fetch",
89 DomainType::Performance => "Performance",
90 DomainType::Css => "CSS",
91 DomainType::Debugger => "Debugger",
92 DomainType::Profiler => "Profiler",
93 DomainType::Keyboard => "Keyboard",
94 DomainType::Mouse => "Mouse",
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum DomainState {
102 Disabled,
104 Enabling,
106 Enabled,
108 Disabling,
110}
111
112#[derive(Debug, Clone)]
114pub struct DomainConfig {
115 pub page_enable_file_chooser: bool,
117 pub dom_include_whitespace: Option<dom::EnableIncludeWhitespaceOption>,
119 pub network_max_total_buffer_size: Option<u32>,
121 pub network_max_resource_buffer_size: Option<u32>,
122 pub network_max_post_data_size: Option<u32>,
123 pub fetch_handle_auth_requests: bool,
125}
126
127impl Default for DomainConfig {
128 fn default() -> Self {
129 Self {
130 page_enable_file_chooser: true,
131 dom_include_whitespace: Some(dom::EnableIncludeWhitespaceOption::All),
132 network_max_total_buffer_size: None,
133 network_max_resource_buffer_size: None,
134 network_max_post_data_size: None,
135 fetch_handle_auth_requests: false,
136 }
137 }
138}
139
140struct DomainManagerInner {
142 states: std::collections::HashMap<DomainType, DomainState>,
144 config: DomainConfig,
146 session: Arc<Session>,
148}
149
150pub struct DomainManager {
152 inner: Arc<Mutex<DomainManagerInner>>,
153}
154
155impl DomainManager {
156 pub(crate) fn new(session: Arc<Session>) -> Self {
159 Self {
160 inner: Arc::new(Mutex::new(DomainManagerInner {
161 states: std::collections::HashMap::new(),
162 config: DomainConfig::default(),
163 session,
164 })),
165 }
166 }
167
168 pub(crate) fn with_config(session: Arc<Session>, config: DomainConfig) -> Self {
170 Self {
171 inner: Arc::new(Mutex::new(DomainManagerInner {
172 states: std::collections::HashMap::new(),
173 config,
174 session,
175 })),
176 }
177 }
178
179 pub async fn enable_required_domains(&self) -> Result<()> {
183 tracing::debug!("Enabling required CDP domains...");
184
185 self.enable_page_domain().await?;
186 self.enable_runtime_domain().await?;
187 self.enable_dom_domain().await?;
188 self.enable_network_domain().await?;
189
190 tracing::info!("All required CDP domains are enabled");
191 Ok(())
192 }
193
194 pub async fn disable_all_domains(&self) -> Result<()> {
199 tracing::debug!("Disabling all enabled CDP domains...");
200
201 let inner = self.inner.lock().await;
202 let enabled_domains: Vec<DomainType> = inner
203 .states
204 .iter()
205 .filter(|(_, state)| **state == DomainState::Enabled)
206 .map(|(domain, _)| *domain)
207 .collect();
208
209 drop(inner); for domain in enabled_domains {
212 if let Err(e) = self.disable_domain_internal(domain).await {
213 tracing::warn!("Failed to disable {} domain: {:?}", domain.name(), e);
214 }
215 }
216
217 tracing::info!("All CDP domains disabled");
218 Ok(())
219 }
220
221 pub async fn is_enabled(&self, domain: DomainType) -> bool {
223 let inner = self.inner.lock().await;
224 matches!(
225 inner.states.get(&domain),
226 Some(DomainState::Enabled) | Some(DomainState::Enabling)
227 )
228 }
229
230 pub async fn get_state(&self, domain: DomainType) -> DomainState {
232 let inner = self.inner.lock().await;
233 inner
234 .states
235 .get(&domain)
236 .copied()
237 .unwrap_or(DomainState::Disabled)
238 }
239
240 pub async fn enable_page_domain(&self) -> Result<()> {
244 self.enable_domain_generic::<_, _, page_cdp::EnableReturnObject>(
245 DomainType::Page,
246 |config| page_cdp::Enable {
247 enable_file_chooser_opened_event: Some(config.page_enable_file_chooser),
248 },
249 )
250 .await
251 }
252
253 pub async fn enable_runtime_domain(&self) -> Result<()> {
255 self.enable_domain_generic::<_, _, runtime_cdp::EnableReturnObject>(
256 DomainType::Runtime,
257 |_| runtime_cdp::Enable(None),
258 )
259 .await
260 }
261
262 pub async fn enable_dom_domain(&self) -> Result<()> {
264 self.enable_domain_generic::<_, _, dom::EnableReturnObject>(DomainType::Dom, |config| {
265 dom::Enable {
266 include_whitespace: config.dom_include_whitespace.clone(),
267 }
268 })
269 .await
270 }
271
272 pub async fn enable_network_domain(&self) -> Result<()> {
274 self.enable_domain_generic::<_, _, network::EnableReturnObject>(
275 DomainType::Network,
276 |config| network::Enable {
277 max_total_buffer_size: config.network_max_total_buffer_size,
278 max_resource_buffer_size: config.network_max_resource_buffer_size,
279 max_post_data_size: config.network_max_post_data_size,
280 report_direct_socket_traffic: None,
281 enable_durable_messages: None,
282 },
283 )
284 .await
285 }
286
287 pub async fn enable_fetch_domain(&self) -> Result<()> {
291 self.enable_fetch_domain_with_patterns(None).await
292 }
293
294 pub async fn enable_fetch_domain_with_patterns(
296 &self,
297 patterns: Option<Vec<fetch::RequestPattern>>,
298 ) -> Result<()> {
299 self.enable_domain_generic::<_, _, fetch::EnableReturnObject>(DomainType::Fetch, |config| {
300 fetch::Enable {
301 patterns,
302 handle_auth_requests: Some(config.fetch_handle_auth_requests),
303 }
304 })
305 .await
306 }
307
308 pub async fn disable_fetch_domain(&self) -> Result<()> {
310 self.disable_domain_internal(DomainType::Fetch).await
311 }
312
313 pub async fn enable_performance_domain(&self) -> Result<()> {
315 self.enable_domain_generic::<_, _, performance::EnableReturnObject>(
316 DomainType::Performance,
317 |_| performance::Enable { time_domain: None },
318 )
319 .await
320 }
321
322 pub async fn disable_performance_domain(&self) -> Result<()> {
324 self.disable_domain_internal(DomainType::Performance).await
325 }
326
327 async fn set_state(&self, domain: DomainType, state: DomainState) {
331 let mut inner = self.inner.lock().await;
332 inner.states.insert(domain, state);
333 }
334
335 async fn enable_domain_generic<F, M, R>(
336 &self,
337 domain: DomainType,
338 command_factory: F,
339 ) -> Result<()>
340 where
341 F: FnOnce(&DomainConfig) -> M,
342 M: serde::Serialize + std::fmt::Debug + cdp_protocol::types::Method,
343 R: for<'de> Deserialize<'de>,
344 {
345 if self.is_enabled(domain).await {
346 return Ok(());
347 }
348
349 tracing::debug!("Enabling {} domain...", domain.name());
350 self.set_state(domain, DomainState::Enabling).await;
351
352 let inner = self.inner.lock().await;
353 let command = command_factory(&inner.config);
354 let result = inner.session.send_command::<M, R>(command, None).await;
355 drop(inner);
356
357 match result {
358 Ok(_) => {
359 self.set_state(domain, DomainState::Enabled).await;
360 tracing::debug!("{} domain enabled", domain.name());
361 Ok(())
362 }
363 Err(e) => {
364 self.set_state(domain, DomainState::Disabled).await;
365 Err(e)
366 }
367 }
368 }
369
370 async fn disable_domain_internal(&self, domain: DomainType) -> Result<()> {
372 if !self.is_enabled(domain).await {
373 return Ok(());
374 }
375
376 tracing::debug!("Disabling {} domain...", domain.name());
377 self.set_state(domain, DomainState::Disabling).await;
378
379 let inner = self.inner.lock().await;
380 let result: Result<()> = match domain {
381 DomainType::Fetch => {
382 let disable = fetch::Disable(None);
383 inner
384 .session
385 .send_command::<_, fetch::DisableReturnObject>(disable, None)
386 .await
387 .map(|_| ())
388 }
389 DomainType::Performance => {
390 let disable = performance::Disable(None);
391 inner
392 .session
393 .send_command::<_, performance::DisableReturnObject>(disable, None)
394 .await
395 .map(|_| ())
396 }
397 DomainType::Mouse | DomainType::Keyboard => Ok(()),
398 DomainType::Page | DomainType::Runtime | DomainType::Dom | DomainType::Network => {
399 tracing::warn!(
401 "{} domain is required and should normally stay enabled",
402 domain.name()
403 );
404 Ok(())
405 }
406 DomainType::Css | DomainType::Debugger | DomainType::Profiler => {
407 tracing::warn!("Disabling the {} domain is not supported", domain.name());
408 Ok(())
409 }
410 };
411 drop(inner);
412
413 if result.is_ok() {
414 self.set_state(domain, DomainState::Disabled).await;
415 tracing::debug!("{} domain disabled", domain.name());
416 } else {
417 self.set_state(domain, DomainState::Enabled).await;
419 }
420
421 result
422 }
423}
424
425impl Drop for DomainManager {
426 fn drop(&mut self) {
427 tracing::debug!("Dropping DomainManager");
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_domain_type_is_required() {
439 assert!(DomainType::Page.is_required());
440 assert!(DomainType::Runtime.is_required());
441 assert!(DomainType::Dom.is_required());
442 assert!(DomainType::Network.is_required());
443 assert!(!DomainType::Fetch.is_required());
444 assert!(!DomainType::Performance.is_required());
445 }
446
447 #[test]
448 fn test_domain_config_default() {
449 let config = DomainConfig::default();
450 assert!(config.page_enable_file_chooser);
451 assert!(!config.fetch_handle_auth_requests);
452 }
453}