1use crate::error::Result;
8use crate::protocol::{Locator, Page};
9use std::path::Path;
10use std::time::Duration;
11
12const DEFAULT_ASSERTION_TIMEOUT: Duration = Duration::from_secs(5);
14
15const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
17
18pub fn expect(locator: Locator) -> Expectation {
88 Expectation::new(locator)
89}
90
91pub struct Expectation {
93 locator: Locator,
94 timeout: Duration,
95 poll_interval: Duration,
96 negate: bool,
97}
98
99#[allow(clippy::wrong_self_convention)]
102impl Expectation {
103 pub(crate) fn new(locator: Locator) -> Self {
105 Self {
106 locator,
107 timeout: DEFAULT_ASSERTION_TIMEOUT,
108 poll_interval: DEFAULT_POLL_INTERVAL,
109 negate: false,
110 }
111 }
112
113 pub fn with_timeout(mut self, timeout: Duration) -> Self {
116 self.timeout = timeout;
117 self
118 }
119
120 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
124 self.poll_interval = interval;
125 self
126 }
127
128 #[allow(clippy::should_implement_trait)]
133 pub fn not(mut self) -> Self {
134 self.negate = true;
135 self
136 }
137
138 pub async fn to_be_visible(self) -> Result<()> {
144 let start = std::time::Instant::now();
145 let selector = self.locator.selector().to_string();
146
147 loop {
148 let is_visible = self.locator.is_visible().await?;
149
150 let matches = if self.negate { !is_visible } else { is_visible };
152
153 if matches {
154 return Ok(());
155 }
156
157 if start.elapsed() >= self.timeout {
159 let message = if self.negate {
160 format!(
161 "Expected element '{}' NOT to be visible, but it was visible after {:?}",
162 selector, self.timeout
163 )
164 } else {
165 format!(
166 "Expected element '{}' to be visible, but it was not visible after {:?}",
167 selector, self.timeout
168 )
169 };
170 return Err(crate::error::Error::AssertionTimeout(message));
171 }
172
173 tokio::time::sleep(self.poll_interval).await;
175 }
176 }
177
178 pub async fn to_be_hidden(self) -> Result<()> {
184 let negated = Expectation {
187 negate: !self.negate, ..self
189 };
190 negated.to_be_visible().await
191 }
192
193 pub async fn to_have_text(self, expected: &str) -> Result<()> {
200 let start = std::time::Instant::now();
201 let selector = self.locator.selector().to_string();
202 let expected = expected.trim();
203
204 loop {
205 let actual_text = self.locator.inner_text().await?;
207 let actual = actual_text.trim();
208
209 let matches = if self.negate {
211 actual != expected
212 } else {
213 actual == expected
214 };
215
216 if matches {
217 return Ok(());
218 }
219
220 if start.elapsed() >= self.timeout {
222 let message = if self.negate {
223 format!(
224 "Expected element '{}' NOT to have text '{}', but it did after {:?}",
225 selector, expected, self.timeout
226 )
227 } else {
228 format!(
229 "Expected element '{}' to have text '{}', but had '{}' after {:?}",
230 selector, expected, actual, self.timeout
231 )
232 };
233 return Err(crate::error::Error::AssertionTimeout(message));
234 }
235
236 tokio::time::sleep(self.poll_interval).await;
238 }
239 }
240
241 pub async fn to_have_text_regex(self, pattern: &str) -> Result<()> {
245 let start = std::time::Instant::now();
246 let selector = self.locator.selector().to_string();
247 let re = regex::Regex::new(pattern)
248 .map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
249
250 loop {
251 let actual_text = self.locator.inner_text().await?;
252 let actual = actual_text.trim();
253
254 let matches = if self.negate {
256 !re.is_match(actual)
257 } else {
258 re.is_match(actual)
259 };
260
261 if matches {
262 return Ok(());
263 }
264
265 if start.elapsed() >= self.timeout {
267 let message = if self.negate {
268 format!(
269 "Expected element '{}' NOT to match pattern '{}', but it did after {:?}",
270 selector, pattern, self.timeout
271 )
272 } else {
273 format!(
274 "Expected element '{}' to match pattern '{}', but had '{}' after {:?}",
275 selector, pattern, actual, self.timeout
276 )
277 };
278 return Err(crate::error::Error::AssertionTimeout(message));
279 }
280
281 tokio::time::sleep(self.poll_interval).await;
283 }
284 }
285
286 pub async fn to_contain_text(self, expected: &str) -> Result<()> {
292 let start = std::time::Instant::now();
293 let selector = self.locator.selector().to_string();
294
295 loop {
296 let actual_text = self.locator.inner_text().await?;
297 let actual = actual_text.trim();
298
299 let matches = if self.negate {
301 !actual.contains(expected)
302 } else {
303 actual.contains(expected)
304 };
305
306 if matches {
307 return Ok(());
308 }
309
310 if start.elapsed() >= self.timeout {
312 let message = if self.negate {
313 format!(
314 "Expected element '{}' NOT to contain text '{}', but it did after {:?}",
315 selector, expected, self.timeout
316 )
317 } else {
318 format!(
319 "Expected element '{}' to contain text '{}', but had '{}' after {:?}",
320 selector, expected, actual, self.timeout
321 )
322 };
323 return Err(crate::error::Error::AssertionTimeout(message));
324 }
325
326 tokio::time::sleep(self.poll_interval).await;
328 }
329 }
330
331 pub async fn to_contain_text_regex(self, pattern: &str) -> Result<()> {
335 let start = std::time::Instant::now();
336 let selector = self.locator.selector().to_string();
337 let re = regex::Regex::new(pattern)
338 .map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
339
340 loop {
341 let actual_text = self.locator.inner_text().await?;
342 let actual = actual_text.trim();
343
344 let matches = if self.negate {
346 !re.is_match(actual)
347 } else {
348 re.is_match(actual)
349 };
350
351 if matches {
352 return Ok(());
353 }
354
355 if start.elapsed() >= self.timeout {
357 let message = if self.negate {
358 format!(
359 "Expected element '{}' NOT to contain pattern '{}', but it did after {:?}",
360 selector, pattern, self.timeout
361 )
362 } else {
363 format!(
364 "Expected element '{}' to contain pattern '{}', but had '{}' after {:?}",
365 selector, pattern, actual, self.timeout
366 )
367 };
368 return Err(crate::error::Error::AssertionTimeout(message));
369 }
370
371 tokio::time::sleep(self.poll_interval).await;
373 }
374 }
375
376 pub async fn to_have_value(self, expected: &str) -> Result<()> {
382 let start = std::time::Instant::now();
383 let selector = self.locator.selector().to_string();
384
385 loop {
386 let actual = self.locator.input_value(None).await?;
387
388 let matches = if self.negate {
390 actual != expected
391 } else {
392 actual == expected
393 };
394
395 if matches {
396 return Ok(());
397 }
398
399 if start.elapsed() >= self.timeout {
401 let message = if self.negate {
402 format!(
403 "Expected input '{}' NOT to have value '{}', but it did after {:?}",
404 selector, expected, self.timeout
405 )
406 } else {
407 format!(
408 "Expected input '{}' to have value '{}', but had '{}' after {:?}",
409 selector, expected, actual, self.timeout
410 )
411 };
412 return Err(crate::error::Error::AssertionTimeout(message));
413 }
414
415 tokio::time::sleep(self.poll_interval).await;
417 }
418 }
419
420 pub async fn to_have_value_regex(self, pattern: &str) -> Result<()> {
424 let start = std::time::Instant::now();
425 let selector = self.locator.selector().to_string();
426 let re = regex::Regex::new(pattern)
427 .map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
428
429 loop {
430 let actual = self.locator.input_value(None).await?;
431
432 let matches = if self.negate {
434 !re.is_match(&actual)
435 } else {
436 re.is_match(&actual)
437 };
438
439 if matches {
440 return Ok(());
441 }
442
443 if start.elapsed() >= self.timeout {
445 let message = if self.negate {
446 format!(
447 "Expected input '{}' NOT to match pattern '{}', but it did after {:?}",
448 selector, pattern, self.timeout
449 )
450 } else {
451 format!(
452 "Expected input '{}' to match pattern '{}', but had '{}' after {:?}",
453 selector, pattern, actual, self.timeout
454 )
455 };
456 return Err(crate::error::Error::AssertionTimeout(message));
457 }
458
459 tokio::time::sleep(self.poll_interval).await;
461 }
462 }
463
464 pub async fn to_be_enabled(self) -> Result<()> {
471 let start = std::time::Instant::now();
472 let selector = self.locator.selector().to_string();
473
474 loop {
475 let is_enabled = self.locator.is_enabled().await?;
476
477 let matches = if self.negate { !is_enabled } else { is_enabled };
479
480 if matches {
481 return Ok(());
482 }
483
484 if start.elapsed() >= self.timeout {
486 let message = if self.negate {
487 format!(
488 "Expected element '{}' NOT to be enabled, but it was enabled after {:?}",
489 selector, self.timeout
490 )
491 } else {
492 format!(
493 "Expected element '{}' to be enabled, but it was not enabled after {:?}",
494 selector, self.timeout
495 )
496 };
497 return Err(crate::error::Error::AssertionTimeout(message));
498 }
499
500 tokio::time::sleep(self.poll_interval).await;
502 }
503 }
504
505 pub async fn to_be_disabled(self) -> Result<()> {
512 let negated = Expectation {
515 negate: !self.negate, ..self
517 };
518 negated.to_be_enabled().await
519 }
520
521 pub async fn to_be_checked(self) -> Result<()> {
527 let start = std::time::Instant::now();
528 let selector = self.locator.selector().to_string();
529
530 loop {
531 let is_checked = self.locator.is_checked().await?;
532
533 let matches = if self.negate { !is_checked } else { is_checked };
535
536 if matches {
537 return Ok(());
538 }
539
540 if start.elapsed() >= self.timeout {
542 let message = if self.negate {
543 format!(
544 "Expected element '{}' NOT to be checked, but it was checked after {:?}",
545 selector, self.timeout
546 )
547 } else {
548 format!(
549 "Expected element '{}' to be checked, but it was not checked after {:?}",
550 selector, self.timeout
551 )
552 };
553 return Err(crate::error::Error::AssertionTimeout(message));
554 }
555
556 tokio::time::sleep(self.poll_interval).await;
558 }
559 }
560
561 pub async fn to_be_unchecked(self) -> Result<()> {
567 let negated = Expectation {
570 negate: !self.negate, ..self
572 };
573 negated.to_be_checked().await
574 }
575
576 pub async fn to_be_editable(self) -> Result<()> {
583 let start = std::time::Instant::now();
584 let selector = self.locator.selector().to_string();
585
586 loop {
587 let is_editable = self.locator.is_editable().await?;
588
589 let matches = if self.negate {
591 !is_editable
592 } else {
593 is_editable
594 };
595
596 if matches {
597 return Ok(());
598 }
599
600 if start.elapsed() >= self.timeout {
602 let message = if self.negate {
603 format!(
604 "Expected element '{}' NOT to be editable, but it was editable after {:?}",
605 selector, self.timeout
606 )
607 } else {
608 format!(
609 "Expected element '{}' to be editable, but it was not editable after {:?}",
610 selector, self.timeout
611 )
612 };
613 return Err(crate::error::Error::AssertionTimeout(message));
614 }
615
616 tokio::time::sleep(self.poll_interval).await;
618 }
619 }
620
621 pub async fn to_be_focused(self) -> Result<()> {
627 let start = std::time::Instant::now();
628 let selector = self.locator.selector().to_string();
629
630 loop {
631 let is_focused = self.locator.is_focused().await?;
632
633 let matches = if self.negate { !is_focused } else { is_focused };
635
636 if matches {
637 return Ok(());
638 }
639
640 if start.elapsed() >= self.timeout {
642 let message = if self.negate {
643 format!(
644 "Expected element '{}' NOT to be focused, but it was focused after {:?}",
645 selector, self.timeout
646 )
647 } else {
648 format!(
649 "Expected element '{}' to be focused, but it was not focused after {:?}",
650 selector, self.timeout
651 )
652 };
653 return Err(crate::error::Error::AssertionTimeout(message));
654 }
655
656 tokio::time::sleep(self.poll_interval).await;
658 }
659 }
660
661 pub async fn to_match_aria_snapshot(self, expected: &str) -> Result<()> {
676 use crate::protocol::serialize_argument;
677
678 let selector = self.locator.selector().to_string();
679 let timeout_ms = self.timeout.as_millis() as f64;
680 let expected_value = serialize_argument(&serde_json::Value::String(expected.to_string()));
681
682 self.locator
683 .frame()
684 .frame_expect(
685 &selector,
686 "to.match.aria",
687 expected_value,
688 self.negate,
689 timeout_ms,
690 )
691 .await
692 }
693
694 pub async fn to_have_screenshot(
701 self,
702 baseline_path: impl AsRef<Path>,
703 options: Option<ScreenshotAssertionOptions>,
704 ) -> Result<()> {
705 let opts = options.unwrap_or_default();
706 let baseline_path = baseline_path.as_ref();
707
708 if opts.animations == Some(Animations::Disabled) {
710 let _ = self
711 .locator
712 .evaluate_js(DISABLE_ANIMATIONS_JS, None::<&()>)
713 .await;
714 }
715
716 let screenshot_opts = if let Some(ref mask_locators) = opts.mask {
718 let mask_js = build_mask_js(mask_locators);
720 let _ = self.locator.evaluate_js(&mask_js, None::<&()>).await;
721 None
722 } else {
723 None
724 };
725
726 compare_screenshot(
727 &opts,
728 baseline_path,
729 self.timeout,
730 self.poll_interval,
731 self.negate,
732 || async { self.locator.screenshot(screenshot_opts.clone()).await },
733 )
734 .await
735 }
736}
737
738const DISABLE_ANIMATIONS_JS: &str = r#"
740(() => {
741 const style = document.createElement('style');
742 style.textContent = '*, *::before, *::after { animation-duration: 0s !important; animation-delay: 0s !important; transition-duration: 0s !important; transition-delay: 0s !important; }';
743 style.setAttribute('data-playwright-no-animations', '');
744 document.head.appendChild(style);
745})()
746"#;
747
748fn build_mask_js(locators: &[Locator]) -> String {
750 let selectors: Vec<String> = locators
751 .iter()
752 .map(|l| {
753 let sel = l.selector().replace('\'', "\\'");
754 format!(
755 r#"
756 (function() {{
757 var els = document.querySelectorAll('{}');
758 els.forEach(function(el) {{
759 var rect = el.getBoundingClientRect();
760 var overlay = document.createElement('div');
761 overlay.setAttribute('data-playwright-mask', '');
762 overlay.style.cssText = 'position:fixed;z-index:2147483647;background:#FF00FF;pointer-events:none;'
763 + 'left:' + rect.left + 'px;top:' + rect.top + 'px;width:' + rect.width + 'px;height:' + rect.height + 'px;';
764 document.body.appendChild(overlay);
765 }});
766 }})();
767 "#,
768 sel
769 )
770 })
771 .collect();
772 selectors.join("\n")
773}
774
775#[derive(Debug, Clone, Copy, PartialEq, Eq)]
779pub enum Animations {
780 Allow,
782 Disabled,
784}
785
786#[derive(Debug, Clone, Default)]
790pub struct ScreenshotAssertionOptions {
791 pub max_diff_pixels: Option<u32>,
793 pub max_diff_pixel_ratio: Option<f64>,
795 pub threshold: Option<f64>,
797 pub animations: Option<Animations>,
799 pub mask: Option<Vec<Locator>>,
801 pub update_snapshots: Option<bool>,
803}
804
805impl ScreenshotAssertionOptions {
806 pub fn builder() -> ScreenshotAssertionOptionsBuilder {
808 ScreenshotAssertionOptionsBuilder::default()
809 }
810}
811
812#[derive(Debug, Clone, Default)]
814pub struct ScreenshotAssertionOptionsBuilder {
815 max_diff_pixels: Option<u32>,
816 max_diff_pixel_ratio: Option<f64>,
817 threshold: Option<f64>,
818 animations: Option<Animations>,
819 mask: Option<Vec<Locator>>,
820 update_snapshots: Option<bool>,
821}
822
823impl ScreenshotAssertionOptionsBuilder {
824 pub fn max_diff_pixels(mut self, pixels: u32) -> Self {
826 self.max_diff_pixels = Some(pixels);
827 self
828 }
829
830 pub fn max_diff_pixel_ratio(mut self, ratio: f64) -> Self {
832 self.max_diff_pixel_ratio = Some(ratio);
833 self
834 }
835
836 pub fn threshold(mut self, threshold: f64) -> Self {
838 self.threshold = Some(threshold);
839 self
840 }
841
842 pub fn animations(mut self, animations: Animations) -> Self {
844 self.animations = Some(animations);
845 self
846 }
847
848 pub fn mask(mut self, locators: Vec<Locator>) -> Self {
850 self.mask = Some(locators);
851 self
852 }
853
854 pub fn update_snapshots(mut self, update: bool) -> Self {
856 self.update_snapshots = Some(update);
857 self
858 }
859
860 pub fn build(self) -> ScreenshotAssertionOptions {
862 ScreenshotAssertionOptions {
863 max_diff_pixels: self.max_diff_pixels,
864 max_diff_pixel_ratio: self.max_diff_pixel_ratio,
865 threshold: self.threshold,
866 animations: self.animations,
867 mask: self.mask,
868 update_snapshots: self.update_snapshots,
869 }
870 }
871}
872
873pub fn expect_page(page: &Page) -> PageExpectation {
877 PageExpectation::new(page.clone())
878}
879
880#[allow(clippy::wrong_self_convention)]
882pub struct PageExpectation {
883 page: Page,
884 timeout: Duration,
885 poll_interval: Duration,
886 negate: bool,
887}
888
889impl PageExpectation {
890 fn new(page: Page) -> Self {
891 Self {
892 page,
893 timeout: DEFAULT_ASSERTION_TIMEOUT,
894 poll_interval: DEFAULT_POLL_INTERVAL,
895 negate: false,
896 }
897 }
898
899 pub fn with_timeout(mut self, timeout: Duration) -> Self {
901 self.timeout = timeout;
902 self
903 }
904
905 #[allow(clippy::should_implement_trait)]
907 pub fn not(mut self) -> Self {
908 self.negate = true;
909 self
910 }
911
912 pub async fn to_have_title(self, expected: &str) -> Result<()> {
918 let start = std::time::Instant::now();
919 let expected = expected.trim();
920
921 loop {
922 let actual = self.page.title().await?;
923 let actual = actual.trim();
924
925 let matches = if self.negate {
926 actual != expected
927 } else {
928 actual == expected
929 };
930
931 if matches {
932 return Ok(());
933 }
934
935 if start.elapsed() >= self.timeout {
936 let message = if self.negate {
937 format!(
938 "Expected page NOT to have title '{}', but it did after {:?}",
939 expected, self.timeout,
940 )
941 } else {
942 format!(
943 "Expected page to have title '{}', but got '{}' after {:?}",
944 expected, actual, self.timeout,
945 )
946 };
947 return Err(crate::error::Error::AssertionTimeout(message));
948 }
949
950 tokio::time::sleep(self.poll_interval).await;
951 }
952 }
953
954 pub async fn to_have_title_regex(self, pattern: &str) -> Result<()> {
960 let start = std::time::Instant::now();
961 let re = regex::Regex::new(pattern)
962 .map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
963
964 loop {
965 let actual = self.page.title().await?;
966
967 let matches = if self.negate {
968 !re.is_match(&actual)
969 } else {
970 re.is_match(&actual)
971 };
972
973 if matches {
974 return Ok(());
975 }
976
977 if start.elapsed() >= self.timeout {
978 let message = if self.negate {
979 format!(
980 "Expected page title NOT to match '{}', but '{}' matched after {:?}",
981 pattern, actual, self.timeout,
982 )
983 } else {
984 format!(
985 "Expected page title to match '{}', but got '{}' after {:?}",
986 pattern, actual, self.timeout,
987 )
988 };
989 return Err(crate::error::Error::AssertionTimeout(message));
990 }
991
992 tokio::time::sleep(self.poll_interval).await;
993 }
994 }
995
996 pub async fn to_have_url(self, expected: &str) -> Result<()> {
1002 let start = std::time::Instant::now();
1003
1004 loop {
1005 let actual = self.page.url();
1006
1007 let matches = if self.negate {
1008 actual != expected
1009 } else {
1010 actual == expected
1011 };
1012
1013 if matches {
1014 return Ok(());
1015 }
1016
1017 if start.elapsed() >= self.timeout {
1018 let message = if self.negate {
1019 format!(
1020 "Expected page NOT to have URL '{}', but it did after {:?}",
1021 expected, self.timeout,
1022 )
1023 } else {
1024 format!(
1025 "Expected page to have URL '{}', but got '{}' after {:?}",
1026 expected, actual, self.timeout,
1027 )
1028 };
1029 return Err(crate::error::Error::AssertionTimeout(message));
1030 }
1031
1032 tokio::time::sleep(self.poll_interval).await;
1033 }
1034 }
1035
1036 pub async fn to_have_url_regex(self, pattern: &str) -> Result<()> {
1042 let start = std::time::Instant::now();
1043 let re = regex::Regex::new(pattern)
1044 .map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
1045
1046 loop {
1047 let actual = self.page.url();
1048
1049 let matches = if self.negate {
1050 !re.is_match(&actual)
1051 } else {
1052 re.is_match(&actual)
1053 };
1054
1055 if matches {
1056 return Ok(());
1057 }
1058
1059 if start.elapsed() >= self.timeout {
1060 let message = if self.negate {
1061 format!(
1062 "Expected page URL NOT to match '{}', but '{}' matched after {:?}",
1063 pattern, actual, self.timeout,
1064 )
1065 } else {
1066 format!(
1067 "Expected page URL to match '{}', but got '{}' after {:?}",
1068 pattern, actual, self.timeout,
1069 )
1070 };
1071 return Err(crate::error::Error::AssertionTimeout(message));
1072 }
1073
1074 tokio::time::sleep(self.poll_interval).await;
1075 }
1076 }
1077
1078 pub async fn to_have_screenshot(
1082 self,
1083 baseline_path: impl AsRef<Path>,
1084 options: Option<ScreenshotAssertionOptions>,
1085 ) -> Result<()> {
1086 let opts = options.unwrap_or_default();
1087 let baseline_path = baseline_path.as_ref();
1088
1089 if opts.animations == Some(Animations::Disabled) {
1091 let _ = self.page.evaluate_expression(DISABLE_ANIMATIONS_JS).await;
1092 }
1093
1094 if let Some(ref mask_locators) = opts.mask {
1096 let mask_js = build_mask_js(mask_locators);
1097 let _ = self.page.evaluate_expression(&mask_js).await;
1098 }
1099
1100 compare_screenshot(
1101 &opts,
1102 baseline_path,
1103 self.timeout,
1104 self.poll_interval,
1105 self.negate,
1106 || async { self.page.screenshot(None).await },
1107 )
1108 .await
1109 }
1110}
1111
1112async fn compare_screenshot<F, Fut>(
1114 opts: &ScreenshotAssertionOptions,
1115 baseline_path: &Path,
1116 timeout: Duration,
1117 poll_interval: Duration,
1118 negate: bool,
1119 take_screenshot: F,
1120) -> Result<()>
1121where
1122 F: Fn() -> Fut,
1123 Fut: std::future::Future<Output = Result<Vec<u8>>>,
1124{
1125 let threshold = opts.threshold.unwrap_or(0.2);
1126 let max_diff_pixels = opts.max_diff_pixels;
1127 let max_diff_pixel_ratio = opts.max_diff_pixel_ratio;
1128 let update_snapshots = opts.update_snapshots.unwrap_or(false);
1129
1130 let actual_bytes = take_screenshot().await?;
1132
1133 if !baseline_path.exists() || update_snapshots {
1135 if let Some(parent) = baseline_path.parent() {
1136 tokio::fs::create_dir_all(parent).await.map_err(|e| {
1137 crate::error::Error::ProtocolError(format!(
1138 "Failed to create baseline directory: {}",
1139 e
1140 ))
1141 })?;
1142 }
1143 tokio::fs::write(baseline_path, &actual_bytes)
1144 .await
1145 .map_err(|e| {
1146 crate::error::Error::ProtocolError(format!(
1147 "Failed to write baseline screenshot: {}",
1148 e
1149 ))
1150 })?;
1151 return Ok(());
1152 }
1153
1154 let baseline_bytes = tokio::fs::read(baseline_path).await.map_err(|e| {
1156 crate::error::Error::ProtocolError(format!("Failed to read baseline screenshot: {}", e))
1157 })?;
1158
1159 let start = std::time::Instant::now();
1160
1161 loop {
1162 let screenshot_bytes = if start.elapsed().is_zero() {
1163 actual_bytes.clone()
1164 } else {
1165 take_screenshot().await?
1166 };
1167
1168 let comparison = compare_images(&baseline_bytes, &screenshot_bytes, threshold)?;
1169
1170 let within_tolerance =
1171 is_within_tolerance(&comparison, max_diff_pixels, max_diff_pixel_ratio);
1172
1173 let matches = if negate {
1174 !within_tolerance
1175 } else {
1176 within_tolerance
1177 };
1178
1179 if matches {
1180 return Ok(());
1181 }
1182
1183 if start.elapsed() >= timeout {
1184 if negate {
1185 return Err(crate::error::Error::AssertionTimeout(format!(
1186 "Expected screenshots NOT to match, but they matched after {:?}",
1187 timeout
1188 )));
1189 }
1190
1191 let baseline_stem = baseline_path
1193 .file_stem()
1194 .and_then(|s| s.to_str())
1195 .unwrap_or("screenshot");
1196 let baseline_ext = baseline_path
1197 .extension()
1198 .and_then(|s| s.to_str())
1199 .unwrap_or("png");
1200 let baseline_dir = baseline_path.parent().unwrap_or(Path::new("."));
1201
1202 let actual_path =
1203 baseline_dir.join(format!("{}-actual.{}", baseline_stem, baseline_ext));
1204 let diff_path = baseline_dir.join(format!("{}-diff.{}", baseline_stem, baseline_ext));
1205
1206 let _ = tokio::fs::write(&actual_path, &screenshot_bytes).await;
1207
1208 if let Ok(diff_bytes) =
1209 generate_diff_image(&baseline_bytes, &screenshot_bytes, threshold)
1210 {
1211 let _ = tokio::fs::write(&diff_path, diff_bytes).await;
1212 }
1213
1214 return Err(crate::error::Error::AssertionTimeout(format!(
1215 "Screenshot mismatch: {} pixels differ ({:.2}% of total). \
1216 Max allowed: {}. Threshold: {:.2}. \
1217 Actual saved to: {}. Diff saved to: {}. \
1218 Timed out after {:?}",
1219 comparison.diff_count,
1220 comparison.diff_ratio * 100.0,
1221 max_diff_pixels
1222 .map(|p| p.to_string())
1223 .or_else(|| max_diff_pixel_ratio.map(|r| format!("{:.2}%", r * 100.0)))
1224 .unwrap_or_else(|| "0".to_string()),
1225 threshold,
1226 actual_path.display(),
1227 diff_path.display(),
1228 timeout,
1229 )));
1230 }
1231
1232 tokio::time::sleep(poll_interval).await;
1233 }
1234}
1235
1236struct ImageComparison {
1238 diff_count: u32,
1239 diff_ratio: f64,
1240}
1241
1242fn is_within_tolerance(
1243 comparison: &ImageComparison,
1244 max_diff_pixels: Option<u32>,
1245 max_diff_pixel_ratio: Option<f64>,
1246) -> bool {
1247 if let Some(max_pixels) = max_diff_pixels {
1248 if comparison.diff_count > max_pixels {
1249 return false;
1250 }
1251 } else if let Some(max_ratio) = max_diff_pixel_ratio {
1252 if comparison.diff_ratio > max_ratio {
1253 return false;
1254 }
1255 } else {
1256 if comparison.diff_count > 0 {
1258 return false;
1259 }
1260 }
1261 true
1262}
1263
1264fn compare_images(
1266 baseline_bytes: &[u8],
1267 actual_bytes: &[u8],
1268 threshold: f64,
1269) -> Result<ImageComparison> {
1270 use image::GenericImageView;
1271
1272 let baseline_img = image::load_from_memory(baseline_bytes).map_err(|e| {
1273 crate::error::Error::ProtocolError(format!("Failed to decode baseline image: {}", e))
1274 })?;
1275 let actual_img = image::load_from_memory(actual_bytes).map_err(|e| {
1276 crate::error::Error::ProtocolError(format!("Failed to decode actual image: {}", e))
1277 })?;
1278
1279 let (bw, bh) = baseline_img.dimensions();
1280 let (aw, ah) = actual_img.dimensions();
1281
1282 if bw != aw || bh != ah {
1284 let total = bw.max(aw) * bh.max(ah);
1285 return Ok(ImageComparison {
1286 diff_count: total,
1287 diff_ratio: 1.0,
1288 });
1289 }
1290
1291 let total_pixels = bw * bh;
1292 if total_pixels == 0 {
1293 return Ok(ImageComparison {
1294 diff_count: 0,
1295 diff_ratio: 0.0,
1296 });
1297 }
1298
1299 let threshold_sq = threshold * threshold;
1300 let mut diff_count: u32 = 0;
1301
1302 for y in 0..bh {
1303 for x in 0..bw {
1304 let bp = baseline_img.get_pixel(x, y);
1305 let ap = actual_img.get_pixel(x, y);
1306
1307 let dr = (bp[0] as f64 - ap[0] as f64) / 255.0;
1309 let dg = (bp[1] as f64 - ap[1] as f64) / 255.0;
1310 let db = (bp[2] as f64 - ap[2] as f64) / 255.0;
1311 let da = (bp[3] as f64 - ap[3] as f64) / 255.0;
1312
1313 let dist_sq = (dr * dr + dg * dg + db * db + da * da) / 4.0;
1314
1315 if dist_sq > threshold_sq {
1316 diff_count += 1;
1317 }
1318 }
1319 }
1320
1321 Ok(ImageComparison {
1322 diff_count,
1323 diff_ratio: diff_count as f64 / total_pixels as f64,
1324 })
1325}
1326
1327fn generate_diff_image(
1329 baseline_bytes: &[u8],
1330 actual_bytes: &[u8],
1331 threshold: f64,
1332) -> Result<Vec<u8>> {
1333 use image::{GenericImageView, ImageBuffer, Rgba};
1334
1335 let baseline_img = image::load_from_memory(baseline_bytes).map_err(|e| {
1336 crate::error::Error::ProtocolError(format!("Failed to decode baseline image: {}", e))
1337 })?;
1338 let actual_img = image::load_from_memory(actual_bytes).map_err(|e| {
1339 crate::error::Error::ProtocolError(format!("Failed to decode actual image: {}", e))
1340 })?;
1341
1342 let (bw, bh) = baseline_img.dimensions();
1343 let (aw, ah) = actual_img.dimensions();
1344 let width = bw.max(aw);
1345 let height = bh.max(ah);
1346
1347 let threshold_sq = threshold * threshold;
1348
1349 let mut diff_img: ImageBuffer<Rgba<u8>, Vec<u8>> = ImageBuffer::new(width, height);
1350
1351 for y in 0..height {
1352 for x in 0..width {
1353 if x >= bw || y >= bh || x >= aw || y >= ah {
1354 diff_img.put_pixel(x, y, Rgba([255, 0, 0, 255]));
1356 continue;
1357 }
1358
1359 let bp = baseline_img.get_pixel(x, y);
1360 let ap = actual_img.get_pixel(x, y);
1361
1362 let dr = (bp[0] as f64 - ap[0] as f64) / 255.0;
1363 let dg = (bp[1] as f64 - ap[1] as f64) / 255.0;
1364 let db = (bp[2] as f64 - ap[2] as f64) / 255.0;
1365 let da = (bp[3] as f64 - ap[3] as f64) / 255.0;
1366
1367 let dist_sq = (dr * dr + dg * dg + db * db + da * da) / 4.0;
1368
1369 if dist_sq > threshold_sq {
1370 diff_img.put_pixel(x, y, Rgba([255, 0, 0, 255]));
1372 } else {
1373 let gray = ((ap[0] as u16 + ap[1] as u16 + ap[2] as u16) / 3) as u8;
1375 diff_img.put_pixel(x, y, Rgba([gray, gray, gray, 100]));
1376 }
1377 }
1378 }
1379
1380 let mut output = std::io::Cursor::new(Vec::new());
1381 diff_img
1382 .write_to(&mut output, image::ImageFormat::Png)
1383 .map_err(|e| {
1384 crate::error::Error::ProtocolError(format!("Failed to encode diff image: {}", e))
1385 })?;
1386
1387 Ok(output.into_inner())
1388}
1389
1390#[cfg(test)]
1391mod tests {
1392 use super::*;
1393
1394 #[test]
1395 fn test_expectation_defaults() {
1396 assert_eq!(DEFAULT_ASSERTION_TIMEOUT, Duration::from_secs(5));
1398 assert_eq!(DEFAULT_POLL_INTERVAL, Duration::from_millis(100));
1399 }
1400}