1use std::future::Future;
8use std::time::Duration;
9
10use ferridriver::Locator;
11use ferridriver_expect::{Expect, ExpectContext, MatchError, poll_until as expect_poll_until};
12
13use super::ScreenshotMatcherOptions;
14use crate::model::TestFailure;
15
16fn locator_ctx(locator: &Locator, method: &'static str, is_not: bool) -> ExpectContext {
17 ExpectContext {
18 method,
19 subject: format!("locator('{}')", locator.selector()),
20 is_not,
21 }
22}
23
24async fn poll_until_test<F, Fut>(timeout: Duration, ctx: ExpectContext, check: F) -> Result<(), TestFailure>
25where
26 F: FnMut() -> Fut,
27 Fut: Future<Output = Result<(), MatchError>>,
28{
29 expect_poll_until(timeout, ctx, check).await.map_err(Into::into)
30}
31
32#[allow(async_fn_in_trait)]
37pub trait LocatorSnapshotMatchers {
38 async fn to_match_snapshot(&self, name: &str) -> Result<(), TestFailure>;
40
41 async fn to_have_screenshot(&self, name: &str) -> Result<(), TestFailure>;
44
45 async fn to_have_screenshot_with(&self, name: &str, options: ScreenshotMatcherOptions) -> Result<(), TestFailure>;
48
49 async fn to_match_aria_snapshot(&self, expected_yaml: &str) -> Result<(), TestFailure>;
52}
53
54impl LocatorSnapshotMatchers for Expect<'_, Locator> {
55 async fn to_match_snapshot(&self, name: &str) -> Result<(), TestFailure> {
56 let locator = self.subject;
57 let actual = locator.text_content().await.unwrap_or(None).unwrap_or_default();
58 let snap_dir = std::path::PathBuf::from("__snapshots__");
59 let update = std::env::var("UPDATE_SNAPSHOTS").is_ok();
60 let info = crate::model::TestInfo {
61 test_id: crate::model::TestId {
62 file: String::new(),
63 suite: None,
64 name: name.to_string(),
65 line: None,
66 },
67 title_path: vec![name.to_string()],
68 retry: 0,
69 worker_index: 0,
70 parallel_index: 0,
71 repeat_each_index: 0,
72 output_dir: std::path::PathBuf::from("test-results"),
73 snapshot_dir: snap_dir,
74 snapshot_path_template: None,
75 update_snapshots: crate::config::UpdateSnapshotsMode::default(),
76 ignore_snapshots: false,
77 attachments: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
78 steps: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
79 soft_errors: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
80 errors: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
81 snapshot_suffix: std::sync::Arc::new(tokio::sync::Mutex::new(String::new())),
82 column: None,
83 project: None,
84 config_snapshot: None,
85 timeout: self.timeout,
86 tags: Vec::new(),
87 start_time: std::time::Instant::now(),
88 event_bus: None,
89 annotations: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
90 };
91 crate::snapshot::assert_snapshot(&info, &actual, name, update)
92 }
93
94 async fn to_have_screenshot(&self, name: &str) -> Result<(), TestFailure> {
95 self
96 .to_have_screenshot_with(name, ScreenshotMatcherOptions::default())
97 .await
98 }
99
100 async fn to_have_screenshot_with(&self, name: &str, options: ScreenshotMatcherOptions) -> Result<(), TestFailure> {
101 let locator = self.subject;
102 let actual_png = capture_with_options(locator, &options).await?;
103 crate::snapshot::compare_screenshot_png_with(&actual_png, name, &options)
104 }
105
106 async fn to_match_aria_snapshot(&self, expected_yaml: &str) -> Result<(), TestFailure> {
107 let locator = self.subject;
108 let is_not = self.is_not;
109 poll_until_test(
110 self.timeout,
111 locator_ctx(locator, "toMatchAriaSnapshot", is_not),
112 || {
113 let expected_yaml = expected_yaml.to_string();
114 async move {
115 let aria_tree = locator
116 .evaluate(
117 "el => { \
118 if (!el) return 'EMPTY'; \
119 const inj = window.__fd && window.__fd._injected; \
120 if (inj && typeof inj.ariaSnapshot === 'function') { \
121 try { return inj.ariaSnapshot(el, { mode: 'default' }); } catch (_) {} \
122 } \
123 function walk(node, indent) { \
124 let role = node.getAttribute('role') || node.tagName.toLowerCase(); \
125 let name = node.getAttribute('aria-label') || node.textContent?.trim()?.substring(0, 50) || ''; \
126 let line = indent + '- ' + role; \
127 if (name) line += ' \"' + name + '\"'; \
128 let lines = [line]; \
129 for (const child of node.children) { \
130 lines.push(...walk(child, indent + ' ')); \
131 } \
132 return lines; \
133 } \
134 return walk(el, '').join('\\n'); \
135 }",
136 ferridriver::protocol::SerializedArgument::default(),
137 None,
138 None,
139 )
140 .await
141 .ok()
142 .and_then(|v| v.as_str().map(String::from))
143 .unwrap_or_else(|| "EMPTY".into());
144
145 let expected_nodes = parse_aria_yaml(&expected_yaml);
146 let actual_nodes = parse_aria_yaml(&aria_tree);
147 let lines_match = match_aria_template(&expected_nodes, &actual_nodes);
148
149 if lines_match == is_not {
150 Err(MatchError::new(
151 format!("{}\n{expected_yaml}", if is_not { "not matching" } else { "matching" }),
152 aria_tree,
153 ))
154 } else {
155 Ok(())
156 }
157 }
158 },
159 )
160 .await
161 }
162}
163
164async fn capture_with_options(locator: &Locator, options: &ScreenshotMatcherOptions) -> Result<Vec<u8>, TestFailure> {
167 let page = locator.page();
168
169 let mut style_blocks: Vec<String> = Vec::new();
170
171 if options.animations.as_deref() == Some("disabled") {
172 style_blocks.push(
173 "*, *::before, *::after { \
174 animation-duration: 0s !important; \
175 animation-delay: 0s !important; \
176 animation-iteration-count: 1 !important; \
177 transition-duration: 0s !important; \
178 transition-delay: 0s !important; \
179 }"
180 .to_string(),
181 );
182 }
183
184 if options.caret.as_deref() == Some("hide") {
185 style_blocks.push("html, body, * { caret-color: transparent !important; }".to_string());
186 }
187
188 if let Some(ref style_path) = options.style_path {
189 match std::fs::read_to_string(style_path) {
190 Ok(content) => style_blocks.push(content),
191 Err(e) => {
192 return Err(TestFailure {
193 message: format!("toHaveScreenshot stylePath {} unreadable: {e}", style_path.display()),
194 stack: None,
195 diff: None,
196 screenshot: None,
197 });
198 },
199 }
200 }
201
202 let mask_color = options.mask_color.as_deref().unwrap_or("#FF00FF");
203 if !options.mask.is_empty() {
204 let mut mask_css = String::new();
205 for selector in &options.mask {
206 mask_css.push_str(selector);
207 mask_css.push_str(" { background: ");
208 mask_css.push_str(mask_color);
209 mask_css.push_str(" !important; color: ");
210 mask_css.push_str(mask_color);
211 mask_css.push_str(" !important; }\n");
212 }
213 style_blocks.push(mask_css);
214 }
215
216 let token = "ferridriver-screenshot-capture";
217
218 if !style_blocks.is_empty() {
219 let combined = style_blocks.join("\n");
220 let escaped = serde_json::to_string(&combined).unwrap_or_else(|_| "\"\"".to_string());
221 let inject_script = format!(
222 "(function() {{ \
223 const s = document.createElement('style'); \
224 s.setAttribute('data-{TOK}', '1'); \
225 s.textContent = {ESC}; \
226 document.head.appendChild(s); \
227 return true; \
228 }})()",
229 TOK = token,
230 ESC = escaped,
231 );
232 let _ = page
233 .evaluate(
234 &inject_script,
235 ferridriver::protocol::SerializedArgument::default(),
236 None,
237 )
238 .await
239 .map_err(|e| TestFailure {
240 message: format!("screenshot capture-options inject failed: {e}"),
241 stack: None,
242 diff: None,
243 screenshot: None,
244 })?;
245 }
246
247 let raw_png = locator.screenshot().await.map_err(|e| TestFailure {
248 message: format!("screenshot failed: {e}"),
249 stack: None,
250 diff: None,
251 screenshot: None,
252 });
253
254 if !style_blocks.is_empty() {
255 let cleanup = format!(
256 "(function() {{ \
257 document.querySelectorAll('style[data-{TOK}]').forEach(function(n) {{ n.remove(); }}); \
258 return true; \
259 }})()",
260 TOK = token,
261 );
262 let _ = page
263 .evaluate(&cleanup, ferridriver::protocol::SerializedArgument::default(), None)
264 .await;
265 }
266
267 let png = raw_png?;
268
269 if let Some(clip) = options.clip {
270 Ok(crop_png_to_clip(&png, &clip)?)
271 } else {
272 Ok(png)
273 }
274}
275
276fn crop_png_to_clip(png: &[u8], clip: &super::ScreenshotClip) -> Result<Vec<u8>, TestFailure> {
277 use image::GenericImageView;
278
279 let img = image::load_from_memory_with_format(png, image::ImageFormat::Png).map_err(|e| TestFailure {
280 message: format!("toHaveScreenshot clip: failed to decode capture: {e}"),
281 stack: None,
282 diff: None,
283 screenshot: None,
284 })?;
285 let (img_w, img_h) = img.dimensions();
286 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
287 let x = (clip.x.max(0.0).min(f64::from(img_w))) as u32;
288 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
289 let y = (clip.y.max(0.0).min(f64::from(img_h))) as u32;
290 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
291 let w = (clip.width.max(0.0).min(f64::from(img_w.saturating_sub(x)))) as u32;
292 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
293 let h = (clip.height.max(0.0).min(f64::from(img_h.saturating_sub(y)))) as u32;
294 if w == 0 || h == 0 {
295 return Err(TestFailure {
296 message: format!(
297 "toHaveScreenshot clip: empty rect after clamping (x={x} y={y} w={w} h={h}) against {img_w}x{img_h} capture"
298 ),
299 stack: None,
300 diff: None,
301 screenshot: None,
302 });
303 }
304 let cropped = img.crop_imm(x, y, w, h);
305 let mut out = Vec::new();
306 cropped
307 .write_to(&mut std::io::Cursor::new(&mut out), image::ImageFormat::Png)
308 .map_err(|e| TestFailure {
309 message: format!("toHaveScreenshot clip: re-encode failed: {e}"),
310 stack: None,
311 diff: None,
312 screenshot: None,
313 })?;
314 Ok(out)
315}
316
317#[derive(Debug, Clone, Default)]
320struct AriaNode {
321 role: String,
322 name: Option<AriaName>,
323 attrs: Vec<String>,
324 children: Vec<AriaNode>,
325}
326
327#[derive(Debug, Clone)]
328enum AriaName {
329 Literal(String),
330 Regex(regex::Regex),
331}
332
333impl AriaName {
334 fn matches(&self, s: &str) -> bool {
335 match self {
336 Self::Literal(expected) => s.contains(expected),
337 Self::Regex(re) => re.is_match(s),
338 }
339 }
340}
341
342fn parse_aria_yaml(input: &str) -> Vec<AriaNode> {
343 let mut roots: Vec<AriaNode> = Vec::new();
344 let mut path: Vec<(usize, Vec<usize>)> = Vec::new();
345
346 for raw in input.lines() {
347 let trimmed = raw.trim_end();
348 let indent = trimmed.chars().take_while(|c| *c == ' ').count();
349 let line = trimmed.trim_start();
350 if line.is_empty() || !line.starts_with('-') {
351 continue;
352 }
353 let body = line.trim_start_matches('-').trim_start();
354 if body.starts_with("text:") {
355 continue;
356 }
357 let node = parse_aria_line_body(body);
358 while path.last().is_some_and(|(prev_indent, _)| *prev_indent >= indent) {
359 path.pop();
360 }
361 let path_indices = if let Some((_, parent_path)) = path.last() {
362 parent_path.clone()
363 } else {
364 Vec::new()
365 };
366 let mut children_holder: &mut Vec<AriaNode> = &mut roots;
367 for &i in &path_indices {
368 children_holder = &mut children_holder[i].children;
369 }
370 let new_index = children_holder.len();
371 children_holder.push(node);
372 let mut new_path = path_indices.clone();
373 new_path.push(new_index);
374 path.push((indent, new_path));
375 }
376 roots
377}
378
379fn parse_aria_line_body(body: &str) -> AriaNode {
380 let mut body = body.trim_end_matches(':').trim_end();
381 while body.ends_with(':') {
382 body = body[..body.len() - 1].trim_end();
383 }
384 let mut node = AriaNode::default();
385 let mut role_end = 0;
386 for (i, c) in body.char_indices() {
387 if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
388 role_end = i + c.len_utf8();
389 } else {
390 break;
391 }
392 }
393 node.role = body[..role_end].to_string();
394 let rest = body[role_end..].trim_start();
395
396 let mut rest_owned: String = rest.to_string();
397 while let Some(open) = rest_owned.find('[') {
398 let Some(close_rel) = rest_owned[open..].find(']') else {
399 break;
400 };
401 let close = open + close_rel;
402 let attr = rest_owned[open + 1..close].trim().to_string();
403 if !attr.is_empty() {
404 node.attrs.push(attr);
405 }
406 rest_owned = format!("{}{}", &rest_owned[..open], &rest_owned[close + 1..]);
407 }
408
409 let rest = rest_owned.trim();
410 if let Some(stripped) = rest.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
411 node.name = Some(AriaName::Literal(stripped.to_string()));
412 } else if let Some(stripped) = rest.strip_prefix('/').and_then(|s| {
413 let last_slash = s.rfind('/')?;
414 Some((&s[..last_slash], &s[last_slash + 1..]))
415 }) {
416 let (pattern, _flags) = stripped;
417 if let Ok(re) = regex::Regex::new(pattern) {
418 node.name = Some(AriaName::Regex(re));
419 }
420 } else if !rest.is_empty() && rest != ":" {
421 node.name = Some(AriaName::Literal(rest.to_string()));
422 }
423 node
424}
425
426fn match_aria_template(expected: &[AriaNode], actual: &[AriaNode]) -> bool {
427 let flat_actual = flatten_dfs(actual);
428 let mut cursor = 0usize;
429 for exp in expected {
430 let mut matched = false;
431 while cursor < flat_actual.len() {
432 if matches_aria_node(exp, flat_actual[cursor]) {
433 cursor += 1;
434 matched = true;
435 break;
436 }
437 cursor += 1;
438 }
439 if !matched {
440 return false;
441 }
442 }
443 true
444}
445
446fn flatten_dfs(roots: &[AriaNode]) -> Vec<&AriaNode> {
447 let mut out: Vec<&AriaNode> = Vec::new();
448 fn walk<'b>(node: &'b AriaNode, out: &mut Vec<&'b AriaNode>) {
449 out.push(node);
450 for child in &node.children {
451 walk(child, out);
452 }
453 }
454 for r in roots {
455 walk(r, &mut out);
456 }
457 out
458}
459
460fn matches_aria_node(expected: &AriaNode, actual: &AriaNode) -> bool {
461 if !expected.role.is_empty() && expected.role != actual.role {
462 return false;
463 }
464 if let Some(ref name) = expected.name {
465 let actual_name = match &actual.name {
466 Some(AriaName::Literal(s)) => s.clone(),
467 Some(AriaName::Regex(_)) | None => String::new(),
468 };
469 if !name.matches(&actual_name) {
470 return false;
471 }
472 }
473 for attr in &expected.attrs {
474 if !actual.attrs.iter().any(|a| a == attr) {
475 return false;
476 }
477 }
478 if !expected.children.is_empty() && !match_aria_template(&expected.children, &actual.children) {
479 return false;
480 }
481 true
482}
483
484#[cfg(test)]
485mod aria_tests {
486 use super::*;
487
488 #[test]
489 fn parses_simple_role_name_pairs() {
490 let nodes = parse_aria_yaml(
491 "- main\n - heading \"Title\"\n - button \"Click\"\n - list:\n - listitem \"One\"\n - listitem \"Two\"\n",
492 );
493 assert_eq!(nodes.len(), 1);
494 let main = &nodes[0];
495 assert_eq!(main.role, "main");
496 assert_eq!(main.children.len(), 3);
497 assert_eq!(main.children[0].role, "heading");
498 assert!(matches!(main.children[0].name, Some(AriaName::Literal(ref s)) if s == "Title"));
499 let list = &main.children[2];
500 assert_eq!(list.role, "list");
501 assert_eq!(list.children.len(), 2);
502 assert_eq!(list.children[0].role, "listitem");
503 }
504
505 #[test]
506 fn parses_state_brackets() {
507 let nodes = parse_aria_yaml("- button [disabled] \"Save\"");
508 assert_eq!(nodes.len(), 1);
509 assert_eq!(nodes[0].role, "button");
510 assert_eq!(nodes[0].attrs, vec!["disabled".to_string()]);
511 assert!(matches!(nodes[0].name, Some(AriaName::Literal(ref s)) if s == "Save"));
512 }
513
514 #[test]
515 fn enforces_ancestor_relationships() {
516 let actual = parse_aria_yaml("- main\n - toolbar\n - button \"Cut\"\n - list\n - listitem \"Item\"\n");
517 let expected = parse_aria_yaml("- main\n - list\n - button \"Cut\"\n");
518 assert!(!match_aria_template(&expected, &actual));
519 }
520
521 #[test]
522 fn accepts_descendant_under_correct_parent() {
523 let actual = parse_aria_yaml("- main\n - list\n - listitem \"One\"\n - listitem \"Two\"\n");
524 let expected = parse_aria_yaml("- main\n - list\n - listitem \"Two\"\n");
525 assert!(match_aria_template(&expected, &actual));
526 }
527
528 #[test]
529 fn requires_state_to_be_present_on_actual() {
530 let actual = parse_aria_yaml("- button \"Save\"");
531 let expected = parse_aria_yaml("- button [disabled] \"Save\"");
532 assert!(!match_aria_template(&expected, &actual));
533 }
534
535 #[test]
536 fn matches_template_against_subtree_of_actual() {
537 let actual = parse_aria_yaml("- main\n - button \"Save\" [disabled]\n");
538 let expected = parse_aria_yaml("- button [disabled] \"Save\"");
539 assert!(match_aria_template(&expected, &actual));
540 }
541}