1use crate::error::Result;
38use cdp_protocol::{dom, fetch, network, page as page_cdp, performance, runtime as runtime_cdp};
39use serde::Deserialize;
40use std::sync::Arc;
41use tokio::sync::Mutex;
42
43use crate::session::Session;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum DomainType {
48 Page,
50 Runtime,
52 Dom,
54 Network,
56 Fetch,
58 Performance,
60 Css,
62 Debugger,
64 Profiler,
66 Keyboard,
68 Mouse,
70}
71
72impl DomainType {
73 pub fn is_required(&self) -> bool {
75 matches!(
76 self,
77 DomainType::Page | DomainType::Runtime | DomainType::Dom | DomainType::Network
78 )
79 }
80
81 pub fn name(&self) -> &'static str {
83 match self {
84 DomainType::Page => "Page",
85 DomainType::Runtime => "Runtime",
86 DomainType::Dom => "DOM",
87 DomainType::Network => "Network",
88 DomainType::Fetch => "Fetch",
89 DomainType::Performance => "Performance",
90 DomainType::Css => "CSS",
91 DomainType::Debugger => "Debugger",
92 DomainType::Profiler => "Profiler",
93 DomainType::Keyboard => "Keyboard",
94 DomainType::Mouse => "Mouse",
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum DomainState {
102 Disabled,
104 Enabling,
106 Enabled,
108 Disabling,
110}
111
112#[derive(Debug, Clone)]
114pub struct DomainConfig {
115 pub page_enable_file_chooser: bool,
117 pub dom_include_whitespace: Option<dom::EnableIncludeWhitespaceOption>,
119 pub network_max_total_buffer_size: Option<u32>,
121 pub network_max_resource_buffer_size: Option<u32>,
122 pub network_max_post_data_size: Option<u32>,
123 pub fetch_handle_auth_requests: bool,
125}
126
127impl Default for DomainConfig {
128 fn default() -> Self {
129 Self {
130 page_enable_file_chooser: true,
131 dom_include_whitespace: Some(dom::EnableIncludeWhitespaceOption::All),
132 network_max_total_buffer_size: None,
133 network_max_resource_buffer_size: None,
134 network_max_post_data_size: None,
135 fetch_handle_auth_requests: false,
136 }
137 }
138}
139
140struct DomainManagerInner {
142 states: std::collections::HashMap<DomainType, DomainState>,
144 config: DomainConfig,
146 session: Arc<Session>,
148}
149
150pub struct DomainManager {
152 inner: Arc<Mutex<DomainManagerInner>>,
153}
154
155impl DomainManager {
156 pub(crate) fn new(session: Arc<Session>) -> Self {
159 Self {
160 inner: Arc::new(Mutex::new(DomainManagerInner {
161 states: std::collections::HashMap::new(),
162 config: DomainConfig::default(),
163 session,
164 })),
165 }
166 }
167
168 pub(crate) fn with_config(session: Arc<Session>, config: DomainConfig) -> Self {
170 Self {
171 inner: Arc::new(Mutex::new(DomainManagerInner {
172 states: std::collections::HashMap::new(),
173 config,
174 session,
175 })),
176 }
177 }
178
179 pub async fn enable_required_domains(&self) -> Result<()> {
183 tracing::debug!("Enabling required CDP domains...");
184
185 self.enable_page_domain().await?;
186 self.enable_runtime_domain().await?;
187 self.enable_dom_domain().await?;
188 self.enable_network_domain().await?;
189
190 tracing::info!("All required CDP domains are enabled");
191 Ok(())
192 }
193
194 pub async fn disable_all_domains(&self) -> Result<()> {
199 tracing::debug!("Disabling all enabled CDP domains...");
200
201 let inner = self.inner.lock().await;
202 let enabled_domains: Vec<DomainType> = inner
203 .states
204 .iter()
205 .filter(|(_, state)| **state == DomainState::Enabled)
206 .map(|(domain, _)| *domain)
207 .collect();
208
209 drop(inner); for domain in enabled_domains {
212 if let Err(e) = self.disable_domain_internal(domain).await {
213 tracing::warn!("Failed to disable {} domain: {:?}", domain.name(), e);
214 }
215 }
216
217 tracing::info!("All CDP domains disabled");
218 Ok(())
219 }
220
221 pub async fn is_enabled(&self, domain: DomainType) -> bool {
223 let inner = self.inner.lock().await;
224 matches!(
225 inner.states.get(&domain),
226 Some(DomainState::Enabled) | Some(DomainState::Enabling)
227 )
228 }
229
230 pub async fn get_state(&self, domain: DomainType) -> DomainState {
232 let inner = self.inner.lock().await;
233 inner
234 .states
235 .get(&domain)
236 .copied()
237 .unwrap_or(DomainState::Disabled)
238 }
239
240 pub async fn enable_page_domain(&self) -> Result<()> {
244 self.enable_domain_generic::<_, _, page_cdp::EnableReturnObject>(
245 DomainType::Page,
246 |config| page_cdp::Enable {
247 enable_file_chooser_opened_event: Some(config.page_enable_file_chooser),
248 },
249 )
250 .await
251 }
252
253 pub async fn enable_runtime_domain(&self) -> Result<()> {
255 self.enable_domain_generic::<_, _, runtime_cdp::EnableReturnObject>(
256 DomainType::Runtime,
257 |_| runtime_cdp::Enable(None),
258 )
259 .await
260 }
261
262 pub async fn enable_dom_domain(&self) -> Result<()> {
264 self.enable_domain_generic::<_, _, dom::EnableReturnObject>(DomainType::Dom, |config| {
265 dom::Enable {
266 include_whitespace: config.dom_include_whitespace.clone(),
267 }
268 })
269 .await
270 }
271
272 pub async fn enable_network_domain(&self) -> Result<()> {
274 self.enable_domain_generic::<_, _, network::EnableReturnObject>(
275 DomainType::Network,
276 |config| network::Enable {
277 max_total_buffer_size: config.network_max_total_buffer_size,
278 max_resource_buffer_size: config.network_max_resource_buffer_size,
279 max_post_data_size: config.network_max_post_data_size,
280 report_direct_socket_traffic: None,
281 },
282 )
283 .await
284 }
285
286 pub async fn enable_fetch_domain(&self) -> Result<()> {
290 self.enable_fetch_domain_with_patterns(None).await
291 }
292
293 pub async fn enable_fetch_domain_with_patterns(
295 &self,
296 patterns: Option<Vec<fetch::RequestPattern>>,
297 ) -> Result<()> {
298 self.enable_domain_generic::<_, _, fetch::EnableReturnObject>(DomainType::Fetch, |config| {
299 fetch::Enable {
300 patterns,
301 handle_auth_requests: Some(config.fetch_handle_auth_requests),
302 }
303 })
304 .await
305 }
306
307 pub async fn disable_fetch_domain(&self) -> Result<()> {
309 self.disable_domain_internal(DomainType::Fetch).await
310 }
311
312 pub async fn enable_performance_domain(&self) -> Result<()> {
314 self.enable_domain_generic::<_, _, performance::EnableReturnObject>(
315 DomainType::Performance,
316 |_| performance::Enable { time_domain: None },
317 )
318 .await
319 }
320
321 pub async fn disable_performance_domain(&self) -> Result<()> {
323 self.disable_domain_internal(DomainType::Performance).await
324 }
325
326 async fn set_state(&self, domain: DomainType, state: DomainState) {
330 let mut inner = self.inner.lock().await;
331 inner.states.insert(domain, state);
332 }
333
334 async fn enable_domain_generic<F, M, R>(
335 &self,
336 domain: DomainType,
337 command_factory: F,
338 ) -> Result<()>
339 where
340 F: FnOnce(&DomainConfig) -> M,
341 M: serde::Serialize + std::fmt::Debug + cdp_protocol::types::Method,
342 R: for<'de> Deserialize<'de>,
343 {
344 if self.is_enabled(domain).await {
345 return Ok(());
346 }
347
348 tracing::debug!("Enabling {} domain...", domain.name());
349 self.set_state(domain, DomainState::Enabling).await;
350
351 let inner = self.inner.lock().await;
352 let command = command_factory(&inner.config);
353 let result = inner.session.send_command::<M, R>(command, None).await;
354 drop(inner);
355
356 match result {
357 Ok(_) => {
358 self.set_state(domain, DomainState::Enabled).await;
359 tracing::debug!("{} domain enabled", domain.name());
360 Ok(())
361 }
362 Err(e) => {
363 self.set_state(domain, DomainState::Disabled).await;
364 Err(e)
365 }
366 }
367 }
368
369 async fn disable_domain_internal(&self, domain: DomainType) -> Result<()> {
371 if !self.is_enabled(domain).await {
372 return Ok(());
373 }
374
375 tracing::debug!("Disabling {} domain...", domain.name());
376 self.set_state(domain, DomainState::Disabling).await;
377
378 let inner = self.inner.lock().await;
379 let result: Result<()> = match domain {
380 DomainType::Fetch => {
381 let disable = fetch::Disable(None);
382 inner
383 .session
384 .send_command::<_, fetch::DisableReturnObject>(disable, None)
385 .await
386 .map(|_| ())
387 }
388 DomainType::Performance => {
389 let disable = performance::Disable(None);
390 inner
391 .session
392 .send_command::<_, performance::DisableReturnObject>(disable, None)
393 .await
394 .map(|_| ())
395 }
396 DomainType::Mouse | DomainType::Keyboard => Ok(()),
397 DomainType::Page | DomainType::Runtime | DomainType::Dom | DomainType::Network => {
398 tracing::warn!(
400 "{} domain is required and should normally stay enabled",
401 domain.name()
402 );
403 Ok(())
404 }
405 DomainType::Css | DomainType::Debugger | DomainType::Profiler => {
406 tracing::warn!("Disabling the {} domain is not supported", domain.name());
407 Ok(())
408 }
409 };
410 drop(inner);
411
412 if result.is_ok() {
413 self.set_state(domain, DomainState::Disabled).await;
414 tracing::debug!("{} domain disabled", domain.name());
415 } else {
416 self.set_state(domain, DomainState::Enabled).await;
418 }
419
420 result
421 }
422}
423
424impl Drop for DomainManager {
425 fn drop(&mut self) {
426 tracing::debug!("Dropping DomainManager");
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_domain_type_is_required() {
438 assert!(DomainType::Page.is_required());
439 assert!(DomainType::Runtime.is_required());
440 assert!(DomainType::Dom.is_required());
441 assert!(DomainType::Network.is_required());
442 assert!(!DomainType::Fetch.is_required());
443 assert!(!DomainType::Performance.is_required());
444 }
445
446 #[test]
447 fn test_domain_config_default() {
448 let config = DomainConfig::default();
449 assert!(config.page_enable_file_chooser);
450 assert!(!config.fetch_handle_auth_requests);
451 }
452}