use std::borrow::Cow;
use std::cell::Cell;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use annotate_snippets::renderer::{AnsiColor, Color, DEFAULT_TERM_WIDTH};
use annotate_snippets::{
Annotation, AnnotationKind, Group, Snippet, renderer,
};
use serde::ser::SerializeStruct;
use serde::{Serialize, Serializer};
use yara_x_parser::Span;
use crate::SourceCode;
pub type Level = annotate_snippets::Level<'static>;
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)]
pub struct SourceId(u32);
#[derive(PartialEq, Debug, Clone, Eq, Default)]
pub struct CodeLoc {
source_id: Option<SourceId>,
span: Span,
}
impl CodeLoc {
pub(crate) fn new(source_id: Option<SourceId>, span: Span) -> Self {
Self { source_id, span }
}
}
pub struct Patch {
code_cache: Arc<CodeCache>,
code_loc: CodeLoc,
replacement: String,
}
impl Patch {
pub fn origin(&self) -> Option<String> {
self.code_cache
.read()
.get(&self.code_loc.source_id.unwrap())
.unwrap()
.origin
.clone()
}
pub fn span(&self) -> Span {
self.code_loc.span.clone()
}
pub fn replacement(&self) -> &str {
&self.replacement
}
}
#[derive(Clone)]
pub(crate) struct Report {
code_cache: Arc<CodeCache>,
with_colors: bool,
max_width: usize,
level: Level,
code: &'static str,
title: String,
labels: Vec<(Level, CodeLoc, String)>,
footers: Vec<(Level, String)>,
sections: Vec<Section>,
}
#[derive(Clone)]
pub(crate) struct Section {
level: Level,
title: String,
patches: Vec<(CodeLoc, String)>,
}
impl Report {
#[inline]
pub(crate) fn title(&self) -> &str {
self.title.as_str()
}
pub(crate) fn labels(&self) -> impl Iterator<Item = Label<'_>> {
self.labels.iter().map(|(level, code_loc, text)| {
let source_id =
code_loc.source_id.expect("CodeLoc without source ID");
let code_cache = self.code_cache.read();
let cache_entry = code_cache.get(&source_id).unwrap();
let code = &cache_entry.code;
let code_origin = cache_entry.origin.clone();
let span = code_loc.span.clone();
let (line, column) =
match byte_offset_to_line_col(code, span.start()) {
Some((line, column)) => (line, column),
None => panic!(
"can't find line and column for span {span} in code:\n{code}",
),
};
Label {
level: level_as_text(level),
code_origin,
line,
column,
span,
text,
}
})
}
#[inline]
pub(crate) fn footers(&self) -> impl Iterator<Item = Footer<'_>> {
self.footers
.iter()
.map(|(level, text)| Footer { level: level_as_text(level), text })
}
pub(crate) fn patches(&self) -> impl Iterator<Item = Patch> + use<'_> {
self.sections.iter().flat_map(|section| {
section.patches.iter().map(|(code_loc, replacement)| Patch {
code_cache: self.code_cache.clone(),
code_loc: code_loc.clone(),
replacement: replacement.clone(),
})
})
}
pub(crate) fn new_section<T: Into<String>>(
&mut self,
level: Level,
title: T,
) -> &mut Self {
self.sections.push(Section {
level,
title: title.into(),
patches: vec![],
});
self
}
pub(crate) fn patch<R: Into<String>>(
&mut self,
code_loc: CodeLoc,
replacement: R,
) -> &mut Self {
if self.sections.is_empty() {
self.new_section(Level::HELP, "consider the following change");
};
self.sections
.last_mut()
.unwrap()
.patches
.push((code_loc, replacement.into()));
self
}
}
impl Serialize for Report {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let labels = self.labels().collect::<Vec<_>>();
let footers = &self.footers().collect::<Vec<_>>();
let mut s = serializer.serialize_struct("report", 4)?;
s.serialize_field("code", &self.code)?;
s.serialize_field("title", &self.title)?;
if let Some(label) = labels
.iter()
.find(|label| label.level == level_as_text(&self.level))
{
s.serialize_field("line", &label.line)?;
s.serialize_field("column", &label.column)?;
}
s.serialize_field("labels", &labels)?;
s.serialize_field("footers", &footers)?;
s.serialize_field("text", &self.to_string())?;
s.end()
}
}
impl PartialEq for Report {
fn eq(&self, other: &Self) -> bool {
self.level.eq(&other.level)
&& self.code.eq(other.code)
&& self.title.eq(&other.title)
&& self.labels.eq(&other.labels)
&& self.footers.eq(&other.footers)
}
}
impl Eq for Report {}
impl Debug for Report {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl Display for Report {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let code_cache = self.code_cache.read();
let mut group = Group::with_title(
self.level.clone().primary_title(&self.title).id(self.code),
);
let source_id =
self.labels.first().and_then(|label| label.1.source_id).unwrap();
let mut cache_entry = code_cache.get(&source_id).unwrap();
let src = cache_entry.code.as_str();
let mut snippet: Snippet<'_, Annotation> = Snippet::source(src)
.path(cache_entry.origin.as_deref().unwrap_or("line"));
for (level, label_ref, label) in &self.labels {
let label_source_id = label_ref.source_id.unwrap();
if label_source_id != source_id {
cache_entry = code_cache.get(&label_source_id).unwrap();
group = group.element(snippet);
snippet = Snippet::source(cache_entry.code.as_str())
.path(cache_entry.origin.as_deref().unwrap_or("line"));
}
let annotation_kind = if matches!(level, &Level::ERROR) {
AnnotationKind::Primary
} else {
AnnotationKind::Context
};
let span_start = label_ref.span.start();
let span_end = label_ref.span.end();
snippet = snippet.annotation(
annotation_kind.span(span_start..span_end).label(label),
);
}
group = group.element(snippet);
for (level, text) in &self.footers {
group = group.element(level.clone().message(text.as_str()));
}
let renderer = if self.with_colors {
annotate_snippets::Renderer::styled()
} else {
annotate_snippets::Renderer::plain()
};
let renderer = renderer.term_width(self.max_width);
let mut groups = vec![group];
for section in &self.sections {
let mut snippet = Snippet::source(src);
for (code_loc, replacement) in §ion.patches {
snippet = snippet.patch(annotate_snippets::Patch::new(
code_loc.span.range(),
replacement,
))
}
groups.push(
section
.level
.clone()
.secondary_title(§ion.title)
.element(snippet),
);
}
let text = renderer.render(&groups);
write!(f, "{text}")
}
}
#[derive(Serialize)]
pub struct Label<'a> {
level: &'a str,
code_origin: Option<String>,
line: usize,
column: usize,
span: Span,
text: &'a str,
}
impl Label<'_> {
#[inline]
pub fn origin(&self) -> Option<&str> {
self.code_origin.as_deref()
}
#[inline]
pub fn span(&self) -> &Span {
&self.span
}
#[inline]
pub fn text(&self) -> &str {
self.text
}
}
#[derive(Serialize)]
pub struct Footer<'a> {
level: &'a str,
text: &'a str,
}
pub struct ReportBuilder {
with_colors: bool,
max_width: usize,
current_source_id: Cell<Option<SourceId>>,
next_source_id: Cell<SourceId>,
code_cache: Arc<CodeCache>,
}
struct CodeCache {
data: RwLock<HashMap<SourceId, CodeCacheEntry>>,
}
impl CodeCache {
fn new() -> Self {
Self { data: RwLock::new(HashMap::new()) }
}
pub fn read(
&self,
) -> RwLockReadGuard<'_, HashMap<SourceId, CodeCacheEntry>> {
self.data.read().unwrap()
}
pub fn write(
&self,
) -> RwLockWriteGuard<'_, HashMap<SourceId, CodeCacheEntry>> {
self.data.write().unwrap()
}
}
struct CodeCacheEntry {
code: String,
origin: Option<String>,
}
impl Default for ReportBuilder {
fn default() -> Self {
Self::new()
}
}
impl ReportBuilder {
pub fn new() -> Self {
Self {
with_colors: false,
max_width: DEFAULT_TERM_WIDTH,
current_source_id: Cell::new(None),
next_source_id: Cell::new(SourceId(0)),
code_cache: Arc::new(CodeCache::new()),
}
}
pub fn with_colors(&mut self, yes: bool) -> &mut Self {
self.with_colors = yes;
self
}
pub fn max_width(&mut self, width: usize) -> &mut Self {
self.max_width = width;
self
}
pub fn get_current_source_id(&self) -> Option<SourceId> {
self.current_source_id.get()
}
pub fn set_current_source_id(&mut self, source_id: SourceId) {
self.current_source_id.set(Some(source_id));
}
pub fn span_to_code_loc(&self, span: Span) -> CodeLoc {
CodeLoc::new(self.get_current_source_id(), span)
}
pub fn green_style(&self) -> renderer::Style {
if self.with_colors {
renderer::Style::new()
.fg_color(Some(Color::Ansi(AnsiColor::BrightGreen)))
} else {
renderer::Style::new()
}
}
pub fn register_source(&self, src: &SourceCode) -> SourceId {
let source_id = self.next_source_id.get();
self.next_source_id.set(SourceId(source_id.0 + 1));
self.current_source_id.set(Some(source_id));
self.code_cache.write().entry(source_id).or_insert_with(|| {
let s = if let Some(s) = src.valid {
Cow::Borrowed(s)
} else {
String::from_utf8_lossy(src.raw.as_ref())
};
CodeCacheEntry {
code: s.replace('\t', " "),
origin: src.origin.clone(),
}
});
source_id
}
pub fn get_snippet(&self, span: Span) -> String {
let source_id = self.get_current_source_id().unwrap();
let code_cache = self.code_cache.read();
let cache_entry = code_cache.get(&source_id).unwrap();
let src = cache_entry.code.as_str();
src[span.range()].to_string()
}
pub fn create_report(
&self,
level: Level,
code: &'static str,
title: String,
labels: Vec<(Level, CodeLoc, String)>,
footers: Vec<(Level, Option<String>)>,
) -> Report {
assert!(!labels.is_empty());
let footers = footers
.into_iter()
.filter_map(|(level, text)| text.map(|text| (level, text)))
.collect();
Report {
code_cache: self.code_cache.clone(),
with_colors: self.with_colors,
max_width: self.max_width,
level,
code,
title,
labels,
footers,
sections: Vec::new(),
}
}
}
fn level_as_text(level: &Level) -> &'static str {
match *level {
Level::ERROR => "error",
Level::WARNING => "warning",
Level::INFO => "info",
Level::NOTE => "note",
Level::HELP => "help",
_ => panic!("unsupported level {level:?}"),
}
}
fn byte_offset_to_line_col(
text: &str,
byte_offset: usize,
) -> Option<(usize, usize)> {
if byte_offset > text.len() {
return None; }
let mut line = 1;
let mut col = 1;
for (i, c) in text.char_indices() {
if i == byte_offset {
return Some((line, col));
}
if c == '\n' {
line += 1;
col = 1; } else {
col += 1;
}
}
if byte_offset == text.len() {
return Some((line, col));
}
None
}
#[cfg(test)]
mod tests {
use crate::compiler::report::byte_offset_to_line_col;
#[test]
fn byte_offset_to_line_col_single_line() {
let text = "Hello, World!";
assert_eq!(byte_offset_to_line_col(text, 0), Some((1, 1))); assert_eq!(byte_offset_to_line_col(text, 7), Some((1, 8))); assert_eq!(byte_offset_to_line_col(text, 12), Some((1, 13))); }
#[test]
fn byte_offset_to_line_col_multiline() {
let text = "Hello\nRust\nWorld!";
assert_eq!(byte_offset_to_line_col(text, 0), Some((1, 1))); assert_eq!(byte_offset_to_line_col(text, 5), Some((1, 6))); assert_eq!(byte_offset_to_line_col(text, 6), Some((2, 1))); assert_eq!(byte_offset_to_line_col(text, 9), Some((2, 4))); assert_eq!(byte_offset_to_line_col(text, 11), Some((3, 1))); }
#[test]
fn byte_offset_to_line_col_empty_string() {
let text = "";
assert_eq!(byte_offset_to_line_col(text, 0), Some((1, 1)));
}
#[test]
fn byte_offset_to_line_col_out_of_bounds() {
let text = "Hello, World!";
assert_eq!(byte_offset_to_line_col(text, text.len() + 1), None);
}
#[test]
fn byte_offset_to_line_col_end_of_string() {
let text = "Hello, World!";
assert_eq!(byte_offset_to_line_col(text, text.len()), Some((1, 14))); }
#[test]
fn byte_offset_to_line_col_multibyte_characters() {
let text = "Hello, 你好!";
assert_eq!(byte_offset_to_line_col(text, 7), Some((1, 8))); assert_eq!(byte_offset_to_line_col(text, 8), None); assert_eq!(byte_offset_to_line_col(text, 10), Some((1, 9))); assert_eq!(byte_offset_to_line_col(text, 13), Some((1, 10))); }
}