use std::fmt;
use serde::{Deserialize, Serialize};
use crate::error::LocatorError;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
#[derive(Default)]
pub struct AriaSnapshot {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub disabled: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expanded: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub selected: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub checked: Option<AriaCheckedState>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pressed: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub level: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub value_now: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub value_min: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub value_max: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub value_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_frame: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frame_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frame_name: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub iframe_refs: Vec<String>,
#[serde(rename = "ref", skip_serializing_if = "Option::is_none")]
pub node_ref: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) element_index: Option<usize>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub children: Vec<AriaSnapshot>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AriaCheckedState {
False,
True,
Mixed,
}
impl AriaSnapshot {
pub fn new() -> Self {
Self::default()
}
pub fn with_role(role: impl Into<String>) -> Self {
Self {
role: Some(role.into()),
..Self::default()
}
}
#[must_use]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
#[must_use]
pub fn child(mut self, child: AriaSnapshot) -> Self {
self.children.push(child);
self
}
pub fn to_yaml(&self) -> String {
let mut output = String::new();
self.write_yaml(&mut output, 0);
output
}
fn write_yaml(&self, output: &mut String, indent: usize) {
let prefix = " ".repeat(indent);
if let Some(ref role) = self.role {
output.push_str(&prefix);
output.push_str("- ");
output.push_str(role);
if let Some(ref name) = self.name {
output.push_str(" \"");
output.push_str(&name.replace('"', "\\\""));
output.push('"');
}
if let Some(disabled) = self.disabled {
if disabled {
output.push_str(" [disabled]");
}
}
if let Some(ref checked) = self.checked {
match checked {
AriaCheckedState::True => output.push_str(" [checked]"),
AriaCheckedState::Mixed => output.push_str(" [mixed]"),
AriaCheckedState::False => {}
}
}
if let Some(selected) = self.selected {
if selected {
output.push_str(" [selected]");
}
}
if let Some(expanded) = self.expanded {
if expanded {
output.push_str(" [expanded]");
}
}
if let Some(level) = self.level {
output.push_str(&format!(" [level={level}]"));
}
if self.is_frame == Some(true) {
output.push_str(" [frame-boundary]");
if let Some(ref url) = self.frame_url {
output.push_str(&format!(" [frame-url=\"{}\"]", url.replace('"', "\\\"")));
}
if let Some(ref name) = self.frame_name {
if !name.is_empty() {
output.push_str(&format!(" [frame-name=\"{}\"]", name.replace('"', "\\\"")));
}
}
}
if let Some(ref node_ref) = self.node_ref {
output.push_str(&format!(" [ref={}]", node_ref));
}
output.push('\n');
for child in &self.children {
child.write_yaml(output, indent + 1);
}
}
}
pub fn from_yaml(yaml: &str) -> Result<Self, LocatorError> {
let mut root = AriaSnapshot::new();
root.role = Some("root".to_string());
let mut stack: Vec<(usize, AriaSnapshot)> = vec![(0, root)];
for line in yaml.lines() {
if line.trim().is_empty() {
continue;
}
let indent = line.chars().take_while(|c| *c == ' ').count() / 2;
let trimmed = line.trim();
if !trimmed.starts_with('-') {
continue;
}
let content = trimmed[1..].trim();
let (role, name, attrs) = parse_aria_line(content)?;
let mut node = AriaSnapshot::with_role(role);
if let Some(n) = name {
node.name = Some(n);
}
for attr in attrs {
match attr.as_str() {
"disabled" => node.disabled = Some(true),
"checked" => node.checked = Some(AriaCheckedState::True),
"mixed" => node.checked = Some(AriaCheckedState::Mixed),
"selected" => node.selected = Some(true),
"expanded" => node.expanded = Some(true),
"frame-boundary" => node.is_frame = Some(true),
s if s.starts_with("level=") => {
if let Ok(level) = s[6..].parse() {
node.level = Some(level);
}
}
s if s.starts_with("frame-url=\"") && s.ends_with('"') => {
let url = &s[11..s.len() - 1];
node.frame_url = Some(url.replace("\\\"", "\""));
}
s if s.starts_with("frame-name=\"") && s.ends_with('"') => {
let name = &s[12..s.len() - 1];
node.frame_name = Some(name.replace("\\\"", "\""));
}
s if s.starts_with("ref=") => {
node.node_ref = Some(s[4..].to_string());
}
_ => {}
}
}
while stack.len() > 1 && stack.last().is_some_and(|(i, _)| *i >= indent) {
let (_, child) = stack.pop().unwrap();
if let Some((_, parent)) = stack.last_mut() {
parent.children.push(child);
}
}
stack.push((indent, node));
}
while stack.len() > 1 {
let (_, child) = stack.pop().unwrap();
if let Some((_, parent)) = stack.last_mut() {
parent.children.push(child);
}
}
Ok(stack.pop().map(|(_, s)| s).unwrap_or_default())
}
pub fn matches(&self, expected: &AriaSnapshot) -> bool {
if expected.role.is_some() && self.role != expected.role {
return false;
}
if let Some(ref expected_name) = expected.name {
match &self.name {
Some(actual_name) => {
if !matches_name(expected_name, actual_name) {
return false;
}
}
None => return false,
}
}
if expected.disabled.is_some() && self.disabled != expected.disabled {
return false;
}
if expected.checked.is_some() && self.checked != expected.checked {
return false;
}
if expected.selected.is_some() && self.selected != expected.selected {
return false;
}
if expected.expanded.is_some() && self.expanded != expected.expanded {
return false;
}
if expected.level.is_some() && self.level != expected.level {
return false;
}
if expected.children.len() > self.children.len() {
return false;
}
for (i, expected_child) in expected.children.iter().enumerate() {
if !self
.children
.get(i)
.is_some_and(|c| c.matches(expected_child))
{
return false;
}
}
true
}
pub fn diff(&self, expected: &AriaSnapshot) -> String {
let actual_yaml = self.to_yaml();
let expected_yaml = expected.to_yaml();
if actual_yaml == expected_yaml {
return String::new();
}
let mut diff = String::new();
diff.push_str("Expected:\n");
for line in expected_yaml.lines() {
diff.push_str(" ");
diff.push_str(line);
diff.push('\n');
}
diff.push_str("\nActual:\n");
for line in actual_yaml.lines() {
diff.push_str(" ");
diff.push_str(line);
diff.push('\n');
}
diff
}
}
impl fmt::Display for AriaSnapshot {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_yaml())
}
}
fn parse_aria_line(content: &str) -> Result<(String, Option<String>, Vec<String>), LocatorError> {
let mut parts = content.splitn(2, ' ');
let role = parts.next().unwrap_or("").to_string();
if role.is_empty() {
return Err(LocatorError::EvaluationError(
"Empty role in aria snapshot".to_string(),
));
}
let rest = parts.next().unwrap_or("");
let mut name = None;
let mut attrs = Vec::new();
if let Some(start) = rest.find('"') {
if let Some(end) = rest[start + 1..].find('"') {
name = Some(rest[start + 1..start + 1 + end].replace("\\\"", "\""));
}
}
for part in rest.split('[') {
if let Some(end) = part.find(']') {
attrs.push(part[..end].to_string());
}
}
Ok((role, name, attrs))
}
fn matches_name(pattern: &str, actual: &str) -> bool {
if pattern.starts_with('/') {
let flags_end = pattern.rfind('/');
if let Some(end) = flags_end {
if end > 0 {
let regex_str = &pattern[1..end];
let flags = &pattern[end + 1..];
let case_insensitive = flags.contains('i');
let regex_result = if case_insensitive {
regex::RegexBuilder::new(regex_str)
.case_insensitive(true)
.build()
} else {
regex::Regex::new(regex_str)
};
if let Ok(re) = regex_result {
return re.is_match(actual);
}
}
}
}
pattern == actual
}
pub use super::aria_js::{aria_snapshot_js, aria_snapshot_with_refs_js};
#[cfg(test)]
mod tests;