use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Mutex, OnceLock};
fn snapshot_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
static UPDATE_ALL: AtomicBool = AtomicBool::new(false);
static REVIEW_MODE: AtomicBool = AtomicBool::new(false);
static SNAPSHOT_DIR: Mutex<Option<PathBuf>> = Mutex::new(None);
pub fn set_update_all(enabled: bool) {
let _lock = snapshot_lock().lock().unwrap();
UPDATE_ALL.store(enabled, Ordering::SeqCst);
}
pub fn is_update_all() -> bool {
UPDATE_ALL.load(Ordering::SeqCst)
}
pub fn set_review_mode(enabled: bool) {
let _lock = snapshot_lock().lock().unwrap();
REVIEW_MODE.store(enabled, Ordering::SeqCst);
}
pub fn is_review_mode() -> bool {
REVIEW_MODE.load(Ordering::SeqCst)
}
fn snapshot_dir() -> PathBuf {
let mut guard = SNAPSHOT_DIR.lock().unwrap();
if let Some(dir) = guard.clone() {
return dir;
}
let candidates = [
PathBuf::from(".snapshots"),
PathBuf::from("tests/.snapshots"),
];
for c in &candidates {
if c.exists() {
*guard = Some(c.clone());
return c.to_path_buf();
}
}
*guard = Some(PathBuf::from(".snapshots"));
PathBuf::from(".snapshots")
}
pub fn set_snapshot_dir(path: impl Into<PathBuf>) {
let _lock = snapshot_lock().lock().unwrap();
let dir = path.into();
let _ = std::fs::create_dir_all(&dir);
let mut guard = SNAPSHOT_DIR.lock().unwrap();
*guard = Some(dir);
}
pub fn assert_snapshot(name: &str, value: &dyn fmt::Display) {
let result = assert_snapshot_impl(name, value, &snapshot_dir());
if let Err(msg) = result {
panic!("{}", msg);
}
}
pub fn assert_snapshot_in(name: &str, value: &dyn fmt::Display, dir: &Path) {
let result = assert_snapshot_impl(name, value, dir);
if let Err(msg) = result {
panic!("{}", msg);
}
}
fn assert_snapshot_impl(name: &str, value: &dyn fmt::Display, dir: &Path) -> Result<(), String> {
let _lock = snapshot_lock().lock().unwrap();
let safe_name: String = name
.chars()
.map(|c| if c.is_alphanumeric() || c == '_' || c == '-' { c } else { '_' })
.collect();
let snap_path = dir.join(format!("{}.snap", safe_name));
let rendered = value.to_string();
if !snap_path.exists() {
std::fs::create_dir_all(dir)
.map_err(|e| format!("mkdir {:?}: {e}", dir))?;
std::fs::write(&snap_path, &rendered)
.map_err(|e| format!("write {:?}: {e}", snap_path))?;
if is_update_all() {
return Ok(());
}
return Err(format!(
"snapshot `{}` created at {:?}.\n\
Review the content and commit the snapshot file.\n\
Use `--update-all` to auto-accept new snapshots.",
name, snap_path
));
}
let existing = std::fs::read_to_string(&snap_path)
.map_err(|e| format!("read {:?}: {e}", snap_path))?;
if existing == rendered {
return Ok(());
}
if is_update_all() {
std::fs::write(&snap_path, &rendered)
.map_err(|e| format!("write {:?}: {e}", snap_path))?;
return Ok(());
}
let diff = simple_diff(&existing, &rendered, &snap_path);
if is_review_mode() {
eprintln!("\n Snapshot `{}` mismatch:", name);
eprint!("{}", diff);
eprint!(" Accept new snapshot? [y/N] ");
let _ = std::io::Write::flush(&mut std::io::stderr());
let mut input = String::new();
if std::io::stdin().read_line(&mut input).is_ok() {
if input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes") {
std::fs::write(&snap_path, &rendered)
.map_err(|e| format!("write {:?}: {e}", snap_path))?;
eprintln!(" ✓ Snapshot `{}` updated.", name);
return Ok(());
}
}
eprintln!(" ✗ Snapshot `{}` kept.", name);
return Err(format!(
"snapshot `{}` mismatch (rejected in review)\n{}",
name, diff
));
}
Err(format!(
"snapshot `{}` mismatch!\n\
expected (snapshot)\n\
actual (new)\n\
{}\n\
Rerun with `--update-all` to accept the new snapshot.",
name, diff
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_diff_identical() {
let d = simple_diff("hello\nworld", "hello\nworld", Path::new("test.snap"));
assert!(d.is_empty(), "identical content should have no diff");
}
#[test]
fn simple_diff_changed_line() {
let d = simple_diff("hello", "world", Path::new("test.snap"));
assert!(d.contains("hello"));
assert!(d.contains("world"));
}
#[test]
fn simple_diff_extra_line() {
let d = simple_diff("line1", "line1\nline2", Path::new("test.snap"));
assert!(d.contains("line2"), "should show the added line, got: {d}");
}
#[test]
fn simple_diff_same_line_count_different_content() {
let d = simple_diff("a\nb", "a\nc", Path::new("test.snap"));
assert!(d.contains("a"));
assert!(d.contains("c"));
}
#[test]
fn simple_diff_all_different() {
let d = simple_diff("a\nb", "c\nd", Path::new("test.snap"));
assert!(d.contains("a"));
assert!(d.contains("c"));
assert!(d.contains("b"));
assert!(d.contains("d"));
}
#[test]
fn snapshot_dir_default() {
let dir = snapshot_dir();
assert_eq!(dir, Path::new(".snapshots"));
}
#[test]
fn is_update_all_default_false() {
set_update_all(false);
assert!(!is_update_all());
set_update_all(true);
assert!(is_update_all());
set_update_all(false);
}
#[test]
fn simple_diff_same_content_same_length() {
let d = simple_diff("a\nb\nc", "a\nb\nc", Path::new("test.snap"));
assert!(d.is_empty(), "identical content should produce no diff");
}
#[test]
fn simple_diff_same_prefix_different_length() {
let d = simple_diff("a\nb", "a\nb\nc", Path::new("test.snap"));
assert!(d.contains("c"), "should show the new line 'c' as a diff, got: {d}");
}
#[test]
fn simple_diff_first_line_differs() {
let d = simple_diff("hello", "world", Path::new("test.snap"));
assert!(d.contains("hello"));
assert!(d.contains("world"));
}
#[test]
fn is_review_mode_default_false() {
set_review_mode(false);
assert!(!is_review_mode());
set_review_mode(true);
assert!(is_review_mode());
set_review_mode(false);
}
#[test]
fn simple_diff_line_count_diff() {
let d = simple_diff("a\nb", "a\nb\n\n", Path::new("test.snap"));
assert!(d.contains("lines"), "should show line count difference: {d}");
}
}
fn simple_diff(old: &str, new: &str, path: &Path) -> String {
let old_lines: Vec<&str> = old.lines().collect();
let new_lines: Vec<&str> = new.lines().collect();
let mut out = String::new();
let max = old_lines.len().max(new_lines.len());
for i in 0..max {
let old_line = old_lines.get(i).copied().unwrap_or("");
let new_line = new_lines.get(i).copied().unwrap_or("");
if old_line != new_line {
out.push_str(&format!(
" {} | {}\n {} | {}\n",
path.display(),
old_line,
path.display(),
new_line,
));
}
}
if out.is_empty() && old_lines.len() != new_lines.len() {
out.push_str(&format!(
" {}: snapshot has {} lines, actual has {} lines\n",
path.display(),
old_lines.len(),
new_lines.len(),
));
}
out
}