use crate::error::{Error, Result};
use crate::server::channel::Channel;
use parking_lot::Mutex;
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Clone, Debug)]
struct SelectorEngine {
name: String,
script: String,
content_script: bool,
}
struct SelectorsInner {
engines: Vec<SelectorEngine>,
engine_names: HashSet<String>,
test_id_attribute: Option<String>,
contexts: Vec<Channel>,
}
#[derive(Clone)]
pub struct Selectors {
inner: Arc<Mutex<SelectorsInner>>,
}
impl Selectors {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(SelectorsInner {
engines: Vec::new(),
engine_names: HashSet::new(),
test_id_attribute: None,
contexts: Vec::new(),
})),
}
}
pub async fn add_context(&self, channel: Channel) -> Result<()> {
let (engines_snapshot, attr_snapshot) = {
let mut inner = self.inner.lock();
inner.contexts.push(channel.clone());
(inner.engines.clone(), inner.test_id_attribute.clone())
};
for engine in &engines_snapshot {
let params = serde_json::json!({
"selectorEngine": {
"name": engine.name,
"source": engine.script,
"contentScript": engine.content_script,
}
});
channel
.send_no_result("registerSelectorEngine", params)
.await?;
}
if let Some(attr) = attr_snapshot {
channel
.send_no_result(
"setTestIdAttributeName",
serde_json::json!({ "testIdAttributeName": attr }),
)
.await?;
}
Ok(())
}
pub fn remove_context(&self, channel: &Channel) {
let mut inner = self.inner.lock();
inner.contexts.retain(|c| c.guid() != channel.guid());
}
pub async fn register(
&self,
name: &str,
script: &str,
content_script: Option<bool>,
) -> Result<()> {
let content_script = content_script.unwrap_or(false);
let channels_snapshot = {
let mut inner = self.inner.lock();
if inner.engine_names.contains(name) {
return Err(Error::ProtocolError(format!(
"Selector engine '{name}' is already registered"
)));
}
inner.engine_names.insert(name.to_string());
inner.engines.push(SelectorEngine {
name: name.to_string(),
script: script.to_string(),
content_script,
});
inner.contexts.clone()
};
let params = serde_json::json!({
"selectorEngine": {
"name": name,
"source": script,
"contentScript": content_script,
}
});
for channel in &channels_snapshot {
channel
.send_no_result("registerSelectorEngine", params.clone())
.await?;
}
Ok(())
}
pub fn test_id_attribute(&self) -> String {
self.inner
.lock()
.test_id_attribute
.clone()
.unwrap_or_else(|| "data-testid".to_string())
}
pub async fn set_test_id_attribute(&self, attribute: &str) -> Result<()> {
let channels_snapshot = {
let mut inner = self.inner.lock();
inner.test_id_attribute = Some(attribute.to_string());
inner.contexts.clone()
};
let params = serde_json::json!({ "testIdAttributeName": attribute });
for channel in &channels_snapshot {
channel
.send_no_result("setTestIdAttributeName", params.clone())
.await?;
}
Ok(())
}
}
impl Default for Selectors {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for Selectors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.lock();
f.debug_struct("Selectors")
.field("engines", &inner.engines)
.field("test_id_attribute", &inner.test_id_attribute)
.finish()
}
}