use crate::error::Result;
use crate::protocol::{Locator, Page};
use std::path::Path;
use std::time::Duration;
const DEFAULT_ASSERTION_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
pub fn expect(locator: Locator) -> Expectation {
Expectation::new(locator)
}
pub struct Expectation {
locator: Locator,
timeout: Duration,
poll_interval: Duration,
negate: bool,
}
#[allow(clippy::wrong_self_convention)]
impl Expectation {
pub(crate) fn new(locator: Locator) -> Self {
Self {
locator,
timeout: DEFAULT_ASSERTION_TIMEOUT,
poll_interval: DEFAULT_POLL_INTERVAL,
negate: false,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
#[allow(clippy::should_implement_trait)]
pub fn not(mut self) -> Self {
self.negate = true;
self
}
pub async fn to_be_visible(self) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let is_visible = self.locator.is_visible().await?;
let matches = if self.negate { !is_visible } else { is_visible };
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to be visible, but it was visible after {:?}",
selector, self.timeout
)
} else {
format!(
"Expected element '{}' to be visible, but it was not visible after {:?}",
selector, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_be_hidden(self) -> Result<()> {
let negated = Expectation {
negate: !self.negate, ..self
};
negated.to_be_visible().await
}
pub async fn to_have_text(self, expected: &str) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
let expected = expected.trim();
loop {
let actual_text = self.locator.inner_text().await?;
let actual = actual_text.trim();
let matches = if self.negate {
actual != expected
} else {
actual == expected
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to have text '{}', but it did after {:?}",
selector, expected, self.timeout
)
} else {
format!(
"Expected element '{}' to have text '{}', but had '{}' after {:?}",
selector, expected, actual, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_text_regex(self, pattern: &str) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
let re = regex::Regex::new(pattern)
.map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
loop {
let actual_text = self.locator.inner_text().await?;
let actual = actual_text.trim();
let matches = if self.negate {
!re.is_match(actual)
} else {
re.is_match(actual)
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to match pattern '{}', but it did after {:?}",
selector, pattern, self.timeout
)
} else {
format!(
"Expected element '{}' to match pattern '{}', but had '{}' after {:?}",
selector, pattern, actual, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_contain_text(self, expected: &str) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let actual_text = self.locator.inner_text().await?;
let actual = actual_text.trim();
let matches = if self.negate {
!actual.contains(expected)
} else {
actual.contains(expected)
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to contain text '{}', but it did after {:?}",
selector, expected, self.timeout
)
} else {
format!(
"Expected element '{}' to contain text '{}', but had '{}' after {:?}",
selector, expected, actual, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_contain_text_regex(self, pattern: &str) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
let re = regex::Regex::new(pattern)
.map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
loop {
let actual_text = self.locator.inner_text().await?;
let actual = actual_text.trim();
let matches = if self.negate {
!re.is_match(actual)
} else {
re.is_match(actual)
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to contain pattern '{}', but it did after {:?}",
selector, pattern, self.timeout
)
} else {
format!(
"Expected element '{}' to contain pattern '{}', but had '{}' after {:?}",
selector, pattern, actual, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_value(self, expected: &str) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let actual = self.locator.input_value(None).await?;
let matches = if self.negate {
actual != expected
} else {
actual == expected
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected input '{}' NOT to have value '{}', but it did after {:?}",
selector, expected, self.timeout
)
} else {
format!(
"Expected input '{}' to have value '{}', but had '{}' after {:?}",
selector, expected, actual, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_value_regex(self, pattern: &str) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
let re = regex::Regex::new(pattern)
.map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
loop {
let actual = self.locator.input_value(None).await?;
let matches = if self.negate {
!re.is_match(&actual)
} else {
re.is_match(&actual)
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected input '{}' NOT to match pattern '{}', but it did after {:?}",
selector, pattern, self.timeout
)
} else {
format!(
"Expected input '{}' to match pattern '{}', but had '{}' after {:?}",
selector, pattern, actual, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_be_enabled(self) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let is_enabled = self.locator.is_enabled().await?;
let matches = if self.negate { !is_enabled } else { is_enabled };
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to be enabled, but it was enabled after {:?}",
selector, self.timeout
)
} else {
format!(
"Expected element '{}' to be enabled, but it was not enabled after {:?}",
selector, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_be_disabled(self) -> Result<()> {
let negated = Expectation {
negate: !self.negate, ..self
};
negated.to_be_enabled().await
}
pub async fn to_be_checked(self) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let is_checked = self.locator.is_checked().await?;
let matches = if self.negate { !is_checked } else { is_checked };
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to be checked, but it was checked after {:?}",
selector, self.timeout
)
} else {
format!(
"Expected element '{}' to be checked, but it was not checked after {:?}",
selector, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_be_unchecked(self) -> Result<()> {
let negated = Expectation {
negate: !self.negate, ..self
};
negated.to_be_checked().await
}
pub async fn to_be_editable(self) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let is_editable = self.locator.is_editable().await?;
let matches = if self.negate {
!is_editable
} else {
is_editable
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to be editable, but it was editable after {:?}",
selector, self.timeout
)
} else {
format!(
"Expected element '{}' to be editable, but it was not editable after {:?}",
selector, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_be_focused(self) -> Result<()> {
let start = std::time::Instant::now();
let selector = self.locator.selector().to_string();
loop {
let is_focused = self.locator.is_focused().await?;
let matches = if self.negate { !is_focused } else { is_focused };
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected element '{}' NOT to be focused, but it was focused after {:?}",
selector, self.timeout
)
} else {
format!(
"Expected element '{}' to be focused, but it was not focused after {:?}",
selector, self.timeout
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_match_aria_snapshot(self, expected: &str) -> Result<()> {
use crate::protocol::serialize_argument;
let selector = self.locator.selector().to_string();
let timeout_ms = self.timeout.as_millis() as f64;
let expected_value = serialize_argument(&serde_json::Value::String(expected.to_string()));
self.locator
.frame()
.frame_expect(
&selector,
"to.match.aria",
expected_value,
self.negate,
timeout_ms,
)
.await
}
pub async fn to_have_screenshot(
self,
baseline_path: impl AsRef<Path>,
options: Option<ScreenshotAssertionOptions>,
) -> Result<()> {
let opts = options.unwrap_or_default();
let baseline_path = baseline_path.as_ref();
if opts.animations == Some(Animations::Disabled) {
let _ = self
.locator
.evaluate_js(DISABLE_ANIMATIONS_JS, None::<&()>)
.await;
}
let screenshot_opts = if let Some(ref mask_locators) = opts.mask {
let mask_js = build_mask_js(mask_locators);
let _ = self.locator.evaluate_js(&mask_js, None::<&()>).await;
None
} else {
None
};
compare_screenshot(
&opts,
baseline_path,
self.timeout,
self.poll_interval,
self.negate,
|| async { self.locator.screenshot(screenshot_opts.clone()).await },
)
.await
}
}
const DISABLE_ANIMATIONS_JS: &str = r#"
(() => {
const style = document.createElement('style');
style.textContent = '*, *::before, *::after { animation-duration: 0s !important; animation-delay: 0s !important; transition-duration: 0s !important; transition-delay: 0s !important; }';
style.setAttribute('data-playwright-no-animations', '');
document.head.appendChild(style);
})()
"#;
fn build_mask_js(locators: &[Locator]) -> String {
let selectors: Vec<String> = locators
.iter()
.map(|l| {
let sel = l.selector().replace('\'', "\\'");
format!(
r#"
(function() {{
var els = document.querySelectorAll('{}');
els.forEach(function(el) {{
var rect = el.getBoundingClientRect();
var overlay = document.createElement('div');
overlay.setAttribute('data-playwright-mask', '');
overlay.style.cssText = 'position:fixed;z-index:2147483647;background:#FF00FF;pointer-events:none;'
+ 'left:' + rect.left + 'px;top:' + rect.top + 'px;width:' + rect.width + 'px;height:' + rect.height + 'px;';
document.body.appendChild(overlay);
}});
}})();
"#,
sel
)
})
.collect();
selectors.join("\n")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Animations {
Allow,
Disabled,
}
#[derive(Debug, Clone, Default)]
pub struct ScreenshotAssertionOptions {
pub max_diff_pixels: Option<u32>,
pub max_diff_pixel_ratio: Option<f64>,
pub threshold: Option<f64>,
pub animations: Option<Animations>,
pub mask: Option<Vec<Locator>>,
pub update_snapshots: Option<bool>,
}
impl ScreenshotAssertionOptions {
pub fn builder() -> ScreenshotAssertionOptionsBuilder {
ScreenshotAssertionOptionsBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct ScreenshotAssertionOptionsBuilder {
max_diff_pixels: Option<u32>,
max_diff_pixel_ratio: Option<f64>,
threshold: Option<f64>,
animations: Option<Animations>,
mask: Option<Vec<Locator>>,
update_snapshots: Option<bool>,
}
impl ScreenshotAssertionOptionsBuilder {
pub fn max_diff_pixels(mut self, pixels: u32) -> Self {
self.max_diff_pixels = Some(pixels);
self
}
pub fn max_diff_pixel_ratio(mut self, ratio: f64) -> Self {
self.max_diff_pixel_ratio = Some(ratio);
self
}
pub fn threshold(mut self, threshold: f64) -> Self {
self.threshold = Some(threshold);
self
}
pub fn animations(mut self, animations: Animations) -> Self {
self.animations = Some(animations);
self
}
pub fn mask(mut self, locators: Vec<Locator>) -> Self {
self.mask = Some(locators);
self
}
pub fn update_snapshots(mut self, update: bool) -> Self {
self.update_snapshots = Some(update);
self
}
pub fn build(self) -> ScreenshotAssertionOptions {
ScreenshotAssertionOptions {
max_diff_pixels: self.max_diff_pixels,
max_diff_pixel_ratio: self.max_diff_pixel_ratio,
threshold: self.threshold,
animations: self.animations,
mask: self.mask,
update_snapshots: self.update_snapshots,
}
}
}
pub fn expect_page(page: &Page) -> PageExpectation {
PageExpectation::new(page.clone())
}
#[allow(clippy::wrong_self_convention)]
pub struct PageExpectation {
page: Page,
timeout: Duration,
poll_interval: Duration,
negate: bool,
}
impl PageExpectation {
fn new(page: Page) -> Self {
Self {
page,
timeout: DEFAULT_ASSERTION_TIMEOUT,
poll_interval: DEFAULT_POLL_INTERVAL,
negate: false,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[allow(clippy::should_implement_trait)]
pub fn not(mut self) -> Self {
self.negate = true;
self
}
pub async fn to_have_title(self, expected: &str) -> Result<()> {
let start = std::time::Instant::now();
let expected = expected.trim();
loop {
let actual = self.page.title().await?;
let actual = actual.trim();
let matches = if self.negate {
actual != expected
} else {
actual == expected
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected page NOT to have title '{}', but it did after {:?}",
expected, self.timeout,
)
} else {
format!(
"Expected page to have title '{}', but got '{}' after {:?}",
expected, actual, self.timeout,
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_title_regex(self, pattern: &str) -> Result<()> {
let start = std::time::Instant::now();
let re = regex::Regex::new(pattern)
.map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
loop {
let actual = self.page.title().await?;
let matches = if self.negate {
!re.is_match(&actual)
} else {
re.is_match(&actual)
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected page title NOT to match '{}', but '{}' matched after {:?}",
pattern, actual, self.timeout,
)
} else {
format!(
"Expected page title to match '{}', but got '{}' after {:?}",
pattern, actual, self.timeout,
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_url(self, expected: &str) -> Result<()> {
let start = std::time::Instant::now();
loop {
let actual = self.page.url();
let matches = if self.negate {
actual != expected
} else {
actual == expected
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected page NOT to have URL '{}', but it did after {:?}",
expected, self.timeout,
)
} else {
format!(
"Expected page to have URL '{}', but got '{}' after {:?}",
expected, actual, self.timeout,
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_url_regex(self, pattern: &str) -> Result<()> {
let start = std::time::Instant::now();
let re = regex::Regex::new(pattern)
.map_err(|e| crate::error::Error::InvalidArgument(format!("Invalid regex: {}", e)))?;
loop {
let actual = self.page.url();
let matches = if self.negate {
!re.is_match(&actual)
} else {
re.is_match(&actual)
};
if matches {
return Ok(());
}
if start.elapsed() >= self.timeout {
let message = if self.negate {
format!(
"Expected page URL NOT to match '{}', but '{}' matched after {:?}",
pattern, actual, self.timeout,
)
} else {
format!(
"Expected page URL to match '{}', but got '{}' after {:?}",
pattern, actual, self.timeout,
)
};
return Err(crate::error::Error::AssertionTimeout(message));
}
tokio::time::sleep(self.poll_interval).await;
}
}
pub async fn to_have_screenshot(
self,
baseline_path: impl AsRef<Path>,
options: Option<ScreenshotAssertionOptions>,
) -> Result<()> {
let opts = options.unwrap_or_default();
let baseline_path = baseline_path.as_ref();
if opts.animations == Some(Animations::Disabled) {
let _ = self.page.evaluate_expression(DISABLE_ANIMATIONS_JS).await;
}
if let Some(ref mask_locators) = opts.mask {
let mask_js = build_mask_js(mask_locators);
let _ = self.page.evaluate_expression(&mask_js).await;
}
compare_screenshot(
&opts,
baseline_path,
self.timeout,
self.poll_interval,
self.negate,
|| async { self.page.screenshot(None).await },
)
.await
}
}
async fn compare_screenshot<F, Fut>(
opts: &ScreenshotAssertionOptions,
baseline_path: &Path,
timeout: Duration,
poll_interval: Duration,
negate: bool,
take_screenshot: F,
) -> Result<()>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<Vec<u8>>>,
{
let threshold = opts.threshold.unwrap_or(0.2);
let max_diff_pixels = opts.max_diff_pixels;
let max_diff_pixel_ratio = opts.max_diff_pixel_ratio;
let update_snapshots = opts.update_snapshots.unwrap_or(false);
let actual_bytes = take_screenshot().await?;
if !baseline_path.exists() || update_snapshots {
if let Some(parent) = baseline_path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
crate::error::Error::ProtocolError(format!(
"Failed to create baseline directory: {}",
e
))
})?;
}
tokio::fs::write(baseline_path, &actual_bytes)
.await
.map_err(|e| {
crate::error::Error::ProtocolError(format!(
"Failed to write baseline screenshot: {}",
e
))
})?;
return Ok(());
}
let baseline_bytes = tokio::fs::read(baseline_path).await.map_err(|e| {
crate::error::Error::ProtocolError(format!("Failed to read baseline screenshot: {}", e))
})?;
let start = std::time::Instant::now();
loop {
let screenshot_bytes = if start.elapsed().is_zero() {
actual_bytes.clone()
} else {
take_screenshot().await?
};
let comparison = compare_images(&baseline_bytes, &screenshot_bytes, threshold)?;
let within_tolerance =
is_within_tolerance(&comparison, max_diff_pixels, max_diff_pixel_ratio);
let matches = if negate {
!within_tolerance
} else {
within_tolerance
};
if matches {
return Ok(());
}
if start.elapsed() >= timeout {
if negate {
return Err(crate::error::Error::AssertionTimeout(format!(
"Expected screenshots NOT to match, but they matched after {:?}",
timeout
)));
}
let baseline_stem = baseline_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("screenshot");
let baseline_ext = baseline_path
.extension()
.and_then(|s| s.to_str())
.unwrap_or("png");
let baseline_dir = baseline_path.parent().unwrap_or(Path::new("."));
let actual_path =
baseline_dir.join(format!("{}-actual.{}", baseline_stem, baseline_ext));
let diff_path = baseline_dir.join(format!("{}-diff.{}", baseline_stem, baseline_ext));
let _ = tokio::fs::write(&actual_path, &screenshot_bytes).await;
if let Ok(diff_bytes) =
generate_diff_image(&baseline_bytes, &screenshot_bytes, threshold)
{
let _ = tokio::fs::write(&diff_path, diff_bytes).await;
}
return Err(crate::error::Error::AssertionTimeout(format!(
"Screenshot mismatch: {} pixels differ ({:.2}% of total). \
Max allowed: {}. Threshold: {:.2}. \
Actual saved to: {}. Diff saved to: {}. \
Timed out after {:?}",
comparison.diff_count,
comparison.diff_ratio * 100.0,
max_diff_pixels
.map(|p| p.to_string())
.or_else(|| max_diff_pixel_ratio.map(|r| format!("{:.2}%", r * 100.0)))
.unwrap_or_else(|| "0".to_string()),
threshold,
actual_path.display(),
diff_path.display(),
timeout,
)));
}
tokio::time::sleep(poll_interval).await;
}
}
struct ImageComparison {
diff_count: u32,
diff_ratio: f64,
}
fn is_within_tolerance(
comparison: &ImageComparison,
max_diff_pixels: Option<u32>,
max_diff_pixel_ratio: Option<f64>,
) -> bool {
if let Some(max_pixels) = max_diff_pixels {
if comparison.diff_count > max_pixels {
return false;
}
} else if let Some(max_ratio) = max_diff_pixel_ratio {
if comparison.diff_ratio > max_ratio {
return false;
}
} else {
if comparison.diff_count > 0 {
return false;
}
}
true
}
fn compare_images(
baseline_bytes: &[u8],
actual_bytes: &[u8],
threshold: f64,
) -> Result<ImageComparison> {
use image::GenericImageView;
let baseline_img = image::load_from_memory(baseline_bytes).map_err(|e| {
crate::error::Error::ProtocolError(format!("Failed to decode baseline image: {}", e))
})?;
let actual_img = image::load_from_memory(actual_bytes).map_err(|e| {
crate::error::Error::ProtocolError(format!("Failed to decode actual image: {}", e))
})?;
let (bw, bh) = baseline_img.dimensions();
let (aw, ah) = actual_img.dimensions();
if bw != aw || bh != ah {
let total = bw.max(aw) * bh.max(ah);
return Ok(ImageComparison {
diff_count: total,
diff_ratio: 1.0,
});
}
let total_pixels = bw * bh;
if total_pixels == 0 {
return Ok(ImageComparison {
diff_count: 0,
diff_ratio: 0.0,
});
}
let threshold_sq = threshold * threshold;
let mut diff_count: u32 = 0;
for y in 0..bh {
for x in 0..bw {
let bp = baseline_img.get_pixel(x, y);
let ap = actual_img.get_pixel(x, y);
let dr = (bp[0] as f64 - ap[0] as f64) / 255.0;
let dg = (bp[1] as f64 - ap[1] as f64) / 255.0;
let db = (bp[2] as f64 - ap[2] as f64) / 255.0;
let da = (bp[3] as f64 - ap[3] as f64) / 255.0;
let dist_sq = (dr * dr + dg * dg + db * db + da * da) / 4.0;
if dist_sq > threshold_sq {
diff_count += 1;
}
}
}
Ok(ImageComparison {
diff_count,
diff_ratio: diff_count as f64 / total_pixels as f64,
})
}
fn generate_diff_image(
baseline_bytes: &[u8],
actual_bytes: &[u8],
threshold: f64,
) -> Result<Vec<u8>> {
use image::{GenericImageView, ImageBuffer, Rgba};
let baseline_img = image::load_from_memory(baseline_bytes).map_err(|e| {
crate::error::Error::ProtocolError(format!("Failed to decode baseline image: {}", e))
})?;
let actual_img = image::load_from_memory(actual_bytes).map_err(|e| {
crate::error::Error::ProtocolError(format!("Failed to decode actual image: {}", e))
})?;
let (bw, bh) = baseline_img.dimensions();
let (aw, ah) = actual_img.dimensions();
let width = bw.max(aw);
let height = bh.max(ah);
let threshold_sq = threshold * threshold;
let mut diff_img: ImageBuffer<Rgba<u8>, Vec<u8>> = ImageBuffer::new(width, height);
for y in 0..height {
for x in 0..width {
if x >= bw || y >= bh || x >= aw || y >= ah {
diff_img.put_pixel(x, y, Rgba([255, 0, 0, 255]));
continue;
}
let bp = baseline_img.get_pixel(x, y);
let ap = actual_img.get_pixel(x, y);
let dr = (bp[0] as f64 - ap[0] as f64) / 255.0;
let dg = (bp[1] as f64 - ap[1] as f64) / 255.0;
let db = (bp[2] as f64 - ap[2] as f64) / 255.0;
let da = (bp[3] as f64 - ap[3] as f64) / 255.0;
let dist_sq = (dr * dr + dg * dg + db * db + da * da) / 4.0;
if dist_sq > threshold_sq {
diff_img.put_pixel(x, y, Rgba([255, 0, 0, 255]));
} else {
let gray = ((ap[0] as u16 + ap[1] as u16 + ap[2] as u16) / 3) as u8;
diff_img.put_pixel(x, y, Rgba([gray, gray, gray, 100]));
}
}
}
let mut output = std::io::Cursor::new(Vec::new());
diff_img
.write_to(&mut output, image::ImageFormat::Png)
.map_err(|e| {
crate::error::Error::ProtocolError(format!("Failed to encode diff image: {}", e))
})?;
Ok(output.into_inner())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expectation_defaults() {
assert_eq!(DEFAULT_ASSERTION_TIMEOUT, Duration::from_secs(5));
assert_eq!(DEFAULT_POLL_INTERVAL, Duration::from_millis(100));
}
}