use std::collections::HashMap;
use std::fmt;
use std::path::PathBuf;
use std::sync::{Arc, LazyLock, RwLock};
use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet};
thread_local! {
static PLAIN_OUTPUT: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
}
#[doc(hidden)]
pub struct PlainOutputGuard;
impl Default for PlainOutputGuard {
fn default() -> Self {
Self::new()
}
}
impl PlainOutputGuard {
pub fn new() -> Self {
PLAIN_OUTPUT.with(|c| c.set(true));
PlainOutputGuard
}
}
impl Drop for PlainOutputGuard {
fn drop(&mut self) {
PLAIN_OUTPUT.with(|c| c.set(false));
}
}
static SOURCE_CACHE: LazyLock<RwLock<HashMap<PathBuf, Arc<str>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
fn cached_source(path: &PathBuf) -> Option<Arc<str>> {
if let Ok(cache) = SOURCE_CACHE.read() {
if let Some(content) = cache.get(path) {
return Some(content.clone());
}
}
let content: Arc<str> = std::fs::read_to_string(path).ok()?.into();
if let Ok(mut cache) = SOURCE_CACHE.write() {
cache.entry(path.clone()).or_insert_with(|| content.clone());
}
Some(content)
}
fn absolute_source_path(manifest_dir: &str, file_path: &str) -> PathBuf {
use std::path::Path;
let manifest = Path::new(manifest_dir);
let file = Path::new(file_path);
let manifest_components: Vec<_> = manifest.components().collect();
let file_components: Vec<_> = file.components().collect();
let mut overlap = 0;
for len in 1..=manifest_components.len().min(file_components.len()) {
let suffix = &manifest_components[manifest_components.len() - len..];
let prefix = &file_components[..len];
if suffix == prefix {
overlap = len;
}
}
let workspace_root: PathBuf = manifest_components[..manifest_components.len() - overlap]
.iter()
.collect();
workspace_root.join(file)
}
fn byte_offset_of(source: &str, line: u32, col: u32) -> usize {
if line == 0 {
return 0;
}
let line_start: usize = source
.split('\n')
.take((line - 1) as usize)
.map(|l| l.len() + 1) .sum();
(line_start + col as usize).min(source.len())
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct ErrorContext {
actual_value: String,
expected_value: Option<String>,
error_node: &'static PatternNode,
}
#[derive(Debug)]
pub struct ErrorReport {
errors: Vec<ErrorContext>,
abs_path: PathBuf,
rel_path: String,
}
pub struct PatternNode {
pub kind: NodeKind,
pub parent: Option<&'static PatternNode>,
pub line_start: u32,
pub col_start: u32,
pub line_end: u32,
pub col_end: u32,
}
#[derive(Debug)]
pub enum NodeKind {
Slice {
items: &'static [&'static PatternNode],
rest: bool,
},
Set {
items: &'static [&'static PatternNode],
rest: bool,
},
Tuple {
items: &'static [&'static PatternNode],
},
Map {
entries: &'static [(&'static str, &'static PatternNode)],
rest: bool,
},
Struct {
name: &'static str,
fields: &'static [(&'static str, &'static PatternNode)],
rest: bool,
},
EnumVariant {
path: &'static str,
args: Option<&'static [&'static PatternNode]>,
},
Simple {
value: &'static str,
},
Comparison {
op: ComparisonOp,
value: &'static str,
},
Range {
pattern: &'static str,
},
Regex {
pattern: &'static str,
},
Like {
expr: &'static str,
},
Wildcard,
Closure {
closure: &'static str,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ComparisonOp {
Less,
LessEqual,
Greater,
GreaterEqual,
Equal,
NotEqual,
}
impl ComparisonOp {
pub fn as_str(&self) -> &'static str {
match self {
ComparisonOp::Less => "<",
ComparisonOp::LessEqual => "<=",
ComparisonOp::Greater => ">",
ComparisonOp::GreaterEqual => ">=",
ComparisonOp::Equal => "==",
ComparisonOp::NotEqual => "!=",
}
}
}
impl ErrorReport {
pub fn new(manifest_dir: &str, file_path: &str) -> Self {
ErrorReport {
errors: Vec::new(),
abs_path: absolute_source_path(manifest_dir, file_path),
rel_path: file_path.to_string(),
}
}
pub fn new_probe() -> Self {
ErrorReport {
errors: Vec::new(),
abs_path: PathBuf::new(),
rel_path: String::new(),
}
}
pub fn is_empty(&self) -> bool {
self.errors.is_empty()
}
pub fn push(
&mut self,
error_node: &'static PatternNode,
actual: String,
expected: Option<String>,
) {
self.errors.push(ErrorContext {
actual_value: actual,
expected_value: expected,
error_node,
});
}
}
fn error_label(error: &ErrorContext) -> String {
match &error.error_node.kind {
NodeKind::Comparison {
op: ComparisonOp::Equal,
..
} => format!(
"expected {}, got {}",
error.expected_value.as_deref().unwrap_or("?"),
error.actual_value,
),
NodeKind::EnumVariant { .. } => format!(
"expected variant {}, got {}",
error.error_node, error.actual_value,
),
NodeKind::Slice { items, rest } => {
if *rest {
format!("slice pattern mismatch, got {}", error.actual_value)
} else {
let n = items.len();
let suffix = if n == 1 { "element" } else { "elements" };
format!(
"expected slice with {} {}, got {}",
n, suffix, error.actual_value
)
}
}
NodeKind::Set { rest, .. } => {
if *rest {
format!("set pattern mismatch, got {}", error.actual_value)
} else {
format!("set pattern mismatch (exact), got {}", error.actual_value)
}
}
NodeKind::Closure { .. } => format!(
"closure condition not satisfied, got {}",
error.actual_value,
),
_ => format!("got {}", error.actual_value),
}
}
impl fmt::Display for ErrorReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.errors.is_empty() {
return Ok(());
}
let source_content = cached_source(&self.abs_path);
let labels: Vec<String> = self.errors.iter().map(error_label).collect();
let renderer = if PLAIN_OUTPUT.with(|c| c.get())
|| std::env::var_os("NO_COLOR").is_some()
|| !std::io::IsTerminal::is_terminal(&std::io::stderr())
{
Renderer::plain()
} else {
Renderer::styled()
};
if let Some(source) = &source_content {
let annotations: Vec<_> = self
.errors
.iter()
.zip(labels.iter())
.map(|(error, label)| {
let start = byte_offset_of(
source,
error.error_node.line_start,
error.error_node.col_start,
);
let end =
byte_offset_of(source, error.error_node.line_end, error.error_node.col_end)
.max(start + 1);
AnnotationKind::Primary
.span(start..end)
.label(label.as_str())
})
.collect();
let snippet = Snippet::source(&**source)
.line_start(1)
.path(&self.rel_path)
.annotations(annotations);
let report = Level::ERROR
.primary_title("assert_struct! failed")
.element(snippet);
write!(f, "{}", renderer.render(&[report]))?;
} else {
write!(f, "assert_struct! failed:")?;
for (error, label) in self.errors.iter().zip(labels.iter()) {
write!(
f,
"\n --> {}:{}\n {label}",
self.rel_path, error.error_node.line_start
)?;
}
}
Ok(())
}
}
impl fmt::Debug for PatternNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PatternNode")
.field("kind", &self.kind)
.field("parent", &self.parent.map(|_| "<parent>"))
.field("line_start", &self.line_start)
.field("col_start", &self.col_start)
.field("line_end", &self.line_end)
.field("col_end", &self.col_end)
.finish()
}
}
impl fmt::Display for PatternNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
NodeKind::Struct { name, .. } => write!(f, "{} {{ ... }}", name),
NodeKind::Slice { .. } => write!(f, "[...]"),
NodeKind::Set { .. } => write!(f, "#(...)"),
NodeKind::Tuple { items } => write!(f, "({})", ".., ".repeat(items.len())),
NodeKind::Map { entries, .. } => write!(f, "#{{ {} entries }}", entries.len()),
NodeKind::EnumVariant { path, args } => {
if args.is_some() {
write!(f, "{}(...)", path)
} else {
write!(f, "{}", path)
}
}
NodeKind::Simple { value } => write!(f, "{}", value),
NodeKind::Comparison { op, value } => write!(f, "{} {}", op.as_str(), value),
NodeKind::Range { pattern } => write!(f, "{}", pattern),
NodeKind::Regex { pattern } => write!(f, "=~ {}", pattern),
NodeKind::Like { expr } => write!(f, "=~ {}", expr),
NodeKind::Wildcard => write!(f, "_"),
NodeKind::Closure { closure } => write!(f, "{}", closure),
}
}
}