use std::fmt::Display;
use std::io::ErrorKind;
use std::path::PathBuf;
use image::ImageError;
use crate::{Harness, config::config};
pub type SnapshotResult = Result<(), SnapshotError>;
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct SnapshotOptions {
pub threshold: f32,
pub failed_pixel_count_threshold: usize,
pub output_path: PathBuf,
}
#[derive(Debug, Clone, Copy)]
pub struct OsThreshold<T> {
pub windows: T,
pub macos: T,
pub linux: T,
pub fallback: T,
}
impl Default for OsThreshold<usize> {
fn default() -> Self {
config().os_failed_pixel_count_threshold()
}
}
impl Default for OsThreshold<f32> {
fn default() -> Self {
config().os_threshold()
}
}
impl From<usize> for OsThreshold<usize> {
fn from(value: usize) -> Self {
Self::new(value)
}
}
impl From<f32> for OsThreshold<f32> {
fn from(value: f32) -> Self {
Self::new(value)
}
}
impl<T> OsThreshold<T>
where
T: Copy,
{
pub fn new(same: T) -> Self {
Self {
windows: same,
macos: same,
linux: same,
fallback: same,
}
}
#[inline]
pub fn windows(mut self, threshold: T) -> Self {
self.windows = threshold;
self
}
#[inline]
pub fn macos(mut self, threshold: T) -> Self {
self.macos = threshold;
self
}
#[inline]
pub fn linux(mut self, threshold: T) -> Self {
self.linux = threshold;
self
}
pub fn threshold(&self) -> T {
if cfg!(target_os = "windows") {
self.windows
} else if cfg!(target_os = "macos") {
self.macos
} else if cfg!(target_os = "linux") {
self.linux
} else {
self.fallback
}
}
}
impl From<OsThreshold<Self>> for usize {
fn from(threshold: OsThreshold<Self>) -> Self {
threshold.threshold()
}
}
impl From<OsThreshold<Self>> for f32 {
fn from(threshold: OsThreshold<Self>) -> Self {
threshold.threshold()
}
}
impl Default for SnapshotOptions {
fn default() -> Self {
Self {
threshold: config().threshold(),
output_path: config().output_path(),
failed_pixel_count_threshold: config().failed_pixel_count_threshold(),
}
}
}
impl SnapshotOptions {
pub fn new() -> Self {
Default::default()
}
#[inline]
pub fn threshold(mut self, threshold: impl Into<f32>) -> Self {
self.threshold = threshold.into();
self
}
#[inline]
pub fn output_path(mut self, output_path: impl Into<PathBuf>) -> Self {
self.output_path = output_path.into();
self
}
#[inline]
pub fn failed_pixel_count_threshold(
mut self,
failed_pixel_count_threshold: impl Into<OsThreshold<usize>>,
) -> Self {
let failed_pixel_count_threshold = failed_pixel_count_threshold.into().threshold();
self.failed_pixel_count_threshold = failed_pixel_count_threshold;
self
}
}
#[derive(Debug)]
pub enum SnapshotError {
Diff {
name: String,
diff: i32,
diff_path: PathBuf,
},
OpenSnapshot {
path: PathBuf,
err: ImageError,
},
SizeMismatch {
name: String,
expected: (u32, u32),
actual: (u32, u32),
},
WriteSnapshot {
path: PathBuf,
err: ImageError,
},
RenderError {
err: String,
},
}
const HOW_TO_UPDATE_SCREENSHOTS: &str =
"Run `UPDATE_SNAPSHOTS=1 cargo test --all-features` to update the snapshots.";
impl Display for SnapshotError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Diff {
name,
diff,
diff_path,
} => {
let diff_path =
std::path::absolute(diff_path).unwrap_or_else(|_| diff_path.clone());
write!(
f,
"'{name}' Image did not match snapshot. Diff: {diff}, {}. {HOW_TO_UPDATE_SCREENSHOTS}",
diff_path.display()
)
}
Self::OpenSnapshot { path, err } => {
let path = std::path::absolute(path).unwrap_or_else(|_| path.clone());
match err {
ImageError::IoError(io) => match io.kind() {
ErrorKind::NotFound => {
write!(
f,
"Missing snapshot: {}. {HOW_TO_UPDATE_SCREENSHOTS}",
path.display()
)
}
err => {
write!(
f,
"Error reading snapshot: {err}\nAt: {}. {HOW_TO_UPDATE_SCREENSHOTS}",
path.display()
)
}
},
err => {
write!(
f,
"Error decoding snapshot: {err}\nAt: {}. Make sure git-lfs is setup correctly. Read the instructions here: https://github.com/emilk/egui/blob/main/CONTRIBUTING.md#making-a-pr",
path.display()
)
}
}
}
Self::SizeMismatch {
name,
expected,
actual,
} => {
write!(
f,
"'{name}' Image size did not match snapshot. Expected: {expected:?}, Actual: {actual:?}. {HOW_TO_UPDATE_SCREENSHOTS}"
)
}
Self::WriteSnapshot { path, err } => {
let path = std::path::absolute(path).unwrap_or_else(|_| path.clone());
write!(f, "Error writing snapshot: {err}\nAt: {}", path.display())
}
Self::RenderError { err } => {
write!(f, "Error rendering image: {err}")
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Mode {
Test,
UpdateFailing,
UpdateAll,
}
impl Mode {
fn from_env() -> Self {
let Ok(value) = std::env::var("UPDATE_SNAPSHOTS") else {
return Self::Test;
};
match value.as_str() {
"false" | "0" | "no" | "off" => Self::Test,
"true" | "1" | "yes" | "on" => Self::UpdateFailing,
"force" => Self::UpdateAll,
unknown => {
panic!("Unsupported value for UPDATE_SNAPSHOTS: {unknown:?}");
}
}
}
fn is_update(&self) -> bool {
match self {
Self::Test => false,
Self::UpdateFailing | Self::UpdateAll => true,
}
}
}
pub fn try_image_snapshot_options(
new: &image::RgbaImage,
name: impl Into<String>,
options: &SnapshotOptions,
) -> SnapshotResult {
try_image_snapshot_options_impl(new, name.into(), options)
}
fn try_image_snapshot_options_impl(
new: &image::RgbaImage,
name: String,
options: &SnapshotOptions,
) -> SnapshotResult {
#![expect(clippy::print_stdout)]
let mode = Mode::from_env();
let SnapshotOptions {
threshold,
output_path,
failed_pixel_count_threshold,
} = options;
let parent_path = if let Some(parent) = PathBuf::from(&name).parent() {
output_path.join(parent)
} else {
output_path.clone()
};
std::fs::create_dir_all(parent_path).ok();
let snapshot_path = output_path.join(format!("{name}.png"));
let diff_path = output_path.join(format!("{name}.diff.png"));
let old_backup_path = output_path.join(format!("{name}.old.png"));
let new_path = output_path.join(format!("{name}.new.png"));
std::fs::remove_file(&diff_path).ok();
std::fs::remove_file(&old_backup_path).ok();
std::fs::remove_file(&new_path).ok();
let update_snapshot = || {
std::fs::rename(&snapshot_path, &old_backup_path).ok();
new.save(&snapshot_path)
.map_err(|err| SnapshotError::WriteSnapshot {
err,
path: snapshot_path.clone(),
})?;
std::fs::remove_file(&new_path).ok();
println!("Updated snapshot: {}", snapshot_path.display());
Ok(())
};
let write_new_png = || {
new.save(&new_path)
.map_err(|err| SnapshotError::WriteSnapshot {
err,
path: new_path.clone(),
})?;
Ok(())
};
let previous = match image::open(&snapshot_path) {
Ok(image) => image.to_rgba8(),
Err(err) => {
if mode.is_update() {
return update_snapshot();
} else {
write_new_png()?;
return Err(SnapshotError::OpenSnapshot {
path: snapshot_path.clone(),
err,
});
}
}
};
if previous.dimensions() != new.dimensions() {
if mode.is_update() {
return update_snapshot();
} else {
write_new_png()?;
return Err(SnapshotError::SizeMismatch {
name,
expected: previous.dimensions(),
actual: new.dimensions(),
});
}
}
let threshold = if mode == Mode::UpdateAll {
0.0 } else {
*threshold
};
let result =
dify::diff::get_results(previous, new.clone(), threshold, true, None, &None, &None);
let Some((num_wrong_pixels, diff_image)) = result else {
return Ok(()); };
let below_threshold = num_wrong_pixels as i64 <= *failed_pixel_count_threshold as i64;
if !below_threshold {
diff_image
.save(diff_path.clone())
.map_err(|err| SnapshotError::WriteSnapshot {
path: diff_path.clone(),
err,
})?;
}
match mode {
Mode::Test => {
if below_threshold {
Ok(())
} else {
write_new_png()?;
Err(SnapshotError::Diff {
name,
diff: num_wrong_pixels,
diff_path,
})
}
}
Mode::UpdateFailing => {
if below_threshold {
Ok(())
} else {
update_snapshot()
}
}
Mode::UpdateAll => update_snapshot(),
}
}
pub fn try_image_snapshot(current: &image::RgbaImage, name: impl Into<String>) -> SnapshotResult {
try_image_snapshot_options(current, name, &SnapshotOptions::default())
}
#[track_caller]
pub fn image_snapshot_options(
current: &image::RgbaImage,
name: impl Into<String>,
options: &SnapshotOptions,
) {
match try_image_snapshot_options(current, name, options) {
Ok(_) => {}
Err(err) => {
panic!("{err}");
}
}
}
#[track_caller]
pub fn image_snapshot(current: &image::RgbaImage, name: impl Into<String>) {
match try_image_snapshot(current, name) {
Ok(_) => {}
Err(err) => {
panic!("{err}");
}
}
}
#[cfg(any(feature = "wgpu", feature = "snapshot"))]
impl<State> Harness<'_, State> {
pub fn options(&self) -> &SnapshotOptions {
&self.default_snapshot_options
}
pub fn try_snapshot_options(
&mut self,
name: impl Into<String>,
options: &SnapshotOptions,
) -> SnapshotResult {
let image = self
.render()
.map_err(|err| SnapshotError::RenderError { err })?;
try_image_snapshot_options(&image, name.into(), options)
}
pub fn try_snapshot(&mut self, name: impl Into<String>) -> SnapshotResult {
let image = self
.render()
.map_err(|err| SnapshotError::RenderError { err })?;
try_image_snapshot_options(&image, name.into(), &self.default_snapshot_options)
}
#[track_caller]
pub fn snapshot_options(&mut self, name: impl Into<String>, options: &SnapshotOptions) {
let result = self.try_snapshot_options(name, options);
self.snapshot_results.add(result);
}
#[track_caller]
pub fn snapshot(&mut self, name: impl Into<String>) {
let result = self.try_snapshot(name);
self.snapshot_results.add(result);
}
#[deprecated = "Only for debugging, don't commit this."]
#[cfg(not(target_arch = "wasm32"))]
pub fn debug_open_snapshot(&mut self) {
let image = self
.render()
.map_err(|err| SnapshotError::RenderError { err })
.unwrap();
let temp_file = tempfile::Builder::new()
.disable_cleanup(true) .prefix("kittest-snapshot")
.suffix(".png")
.tempfile()
.expect("Failed to create temp file");
let path = temp_file.path();
image
.save(temp_file.path())
.map_err(|err| SnapshotError::WriteSnapshot {
err,
path: path.to_path_buf(),
})
.unwrap();
let path = temp_file.into_temp_path();
#[expect(clippy::print_stdout)]
{
println!("Wrote debug snapshot to: {}", path.display());
}
let result = open::that(&path);
if let Err(err) = result {
#[expect(clippy::print_stderr)]
{
eprintln!(
"Failed to open image {} in default image viewer: {err}",
path.display()
);
}
}
}
pub fn take_snapshot_results(&mut self) -> SnapshotResults {
std::mem::take(&mut self.snapshot_results)
}
}
#[derive(Debug)]
pub struct SnapshotResults {
errors: Vec<SnapshotError>,
handled: bool,
location: std::panic::Location<'static>,
}
impl Default for SnapshotResults {
#[track_caller]
fn default() -> Self {
Self {
errors: Vec::new(),
handled: true, location: *std::panic::Location::caller(),
}
}
}
impl Display for SnapshotResults {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.errors.is_empty() {
write!(f, "All snapshots passed")
} else {
writeln!(f, "Snapshot errors:")?;
for error in &self.errors {
writeln!(f, " {error}")?;
}
Ok(())
}
}
}
impl SnapshotResults {
#[track_caller]
pub fn new() -> Self {
Default::default()
}
pub fn add(&mut self, result: SnapshotResult) {
self.handled = false;
if let Err(err) = result {
self.errors.push(err);
}
}
pub fn extend(&mut self, other: Self) {
self.handled = false;
self.errors.extend(other.into_inner());
}
pub fn extend_harness<T>(&mut self, harness: &mut Harness<'_, T>) {
self.extend(harness.take_snapshot_results());
}
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
#[expect(clippy::missing_errors_doc)]
pub fn into_result(self) -> Result<(), Self> {
if self.has_errors() { Err(self) } else { Ok(()) }
}
pub fn into_inner(mut self) -> Vec<SnapshotError> {
self.handled = true;
std::mem::take(&mut self.errors)
}
#[expect(clippy::unused_self)]
pub fn unwrap(self) {
}
}
impl From<SnapshotResults> for Vec<SnapshotError> {
fn from(results: SnapshotResults) -> Self {
results.into_inner()
}
}
impl Drop for SnapshotResults {
#[track_caller]
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
#[expect(clippy::manual_assert)]
if self.has_errors() {
panic!("{}", self);
}
thread_local! {
static UNHANDLED_SNAPSHOT_RESULTS_COUNTER: std::cell::RefCell<usize> = const { std::cell::RefCell::new(0) };
}
if !self.handled {
let count = UNHANDLED_SNAPSHOT_RESULTS_COUNTER.with(|counter| {
let mut count = counter.borrow_mut();
*count += 1;
*count
});
#[expect(clippy::manual_assert)]
if count >= 2 {
panic!(
r#"
Multiple SnapshotResults were dropped without being handled.
In order to allow consistent snapshot updates, all snapshot results within a test should be merged in a single SnapshotResults instance.
Usually this is handled internally in a harness. If you have multiple harnesses, you can merge the results using `Harness::take_snapshot_results` and `SnapshotResults::extend`.
The SnapshotResult was constructed at {}
"#,
self.location
);
}
}
}
}