use std::borrow::Cow;
use std::cell::Cell;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use annotate_snippets::renderer::{AnsiColor, Color, DEFAULT_TERM_WIDTH};
use annotate_snippets::{AnnotationKind, Group, Snippet, renderer};
use serde::ser::SerializeStruct;
use serde::{Serialize, Serializer};
use yara_x_parser::Span;
use crate::SourceCode;
pub type Level = annotate_snippets::Level<'static>;
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)]
pub struct SourceId(u32);
#[derive(PartialEq, Debug, Clone, Eq, Default)]
pub struct CodeLoc {
source_id: Option<SourceId>,
span: Span,
}
impl CodeLoc {
pub(crate) fn new(source_id: Option<SourceId>, span: Span) -> Self {
Self { source_id, span }
}
}
pub struct Patch {
code_cache: Arc<CodeCache>,
code_loc: CodeLoc,
replacement: String,
}
impl Patch {
pub fn origin(&self) -> Option<String> {
self.code_cache
.read()
.get(&self.code_loc.source_id.unwrap())
.unwrap()
.origin
.clone()
}
pub fn span(&self) -> Span {
self.code_loc.span.clone()
}
pub fn replacement(&self) -> &str {
&self.replacement
}
}
#[derive(Clone)]
pub(crate) struct Report {
code_cache: Arc<CodeCache>,
with_colors: bool,
max_width: usize,
level: Level,
code: &'static str,
title: String,
labels: Vec<(Level, CodeLoc, String)>,
footers: Vec<(Level, String)>,
sections: Vec<Section>,
}
#[derive(Clone)]
pub(crate) struct Section {
level: Level,
title: String,
patches: Vec<(CodeLoc, String)>,
}
impl Report {
#[inline]
pub(crate) fn title(&self) -> &str {
self.title.as_str()
}
pub(crate) fn labels(&self) -> impl Iterator<Item = Label<'_>> {
self.labels.iter().map(|(level, code_loc, text)| {
let source_id =
code_loc.source_id.expect("CodeLoc without source ID");
let code_cache = self.code_cache.read();
let cache_entry = code_cache.get(&source_id).unwrap();
let span = code_loc.span.clone();
let (line, column) = match cache_entry
.byte_offset_to_line_col(span.start())
{
Some((line, column)) => (line, column),
None => panic!(
"can't find line and column for span {span} in code:\n{}",
&cache_entry.code
),
};
Label {
level: level_as_text(level),
code_origin: cache_entry.origin.clone(),
line,
column,
span,
text,
}
})
}
#[inline]
pub(crate) fn footers(&self) -> impl Iterator<Item = Footer<'_>> {
self.footers
.iter()
.map(|(level, text)| Footer { level: level_as_text(level), text })
}
pub(crate) fn patches(&self) -> impl Iterator<Item = Patch> + use<'_> {
self.sections.iter().flat_map(|section| {
section.patches.iter().map(|(code_loc, replacement)| Patch {
code_cache: self.code_cache.clone(),
code_loc: code_loc.clone(),
replacement: replacement.clone(),
})
})
}
pub(crate) fn new_section<T: Into<String>>(
&mut self,
level: Level,
title: T,
) -> &mut Self {
self.sections.push(Section {
level,
title: title.into(),
patches: vec![],
});
self
}
pub(crate) fn patch<R: Into<String>>(
&mut self,
code_loc: CodeLoc,
replacement: R,
) -> &mut Self {
if self.sections.is_empty() {
self.new_section(Level::HELP, "consider the following change");
};
self.sections
.last_mut()
.unwrap()
.patches
.push((code_loc, replacement.into()));
self
}
}
impl Serialize for Report {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let labels = self.labels().collect::<Vec<_>>();
let footers = &self.footers().collect::<Vec<_>>();
let mut s = serializer.serialize_struct("report", 4)?;
s.serialize_field("code", &self.code)?;
s.serialize_field("title", &self.title)?;
if let Some(label) = labels
.iter()
.find(|label| label.level == level_as_text(&self.level))
{
s.serialize_field("line", &label.line)?;
s.serialize_field("column", &label.column)?;
}
s.serialize_field("labels", &labels)?;
s.serialize_field("footers", &footers)?;
s.serialize_field("text", &self.to_string())?;
s.end()
}
}
impl PartialEq for Report {
fn eq(&self, other: &Self) -> bool {
self.level.eq(&other.level)
&& self.code.eq(other.code)
&& self.title.eq(&other.title)
&& self.labels.eq(&other.labels)
&& self.footers.eq(&other.footers)
}
}
impl Eq for Report {}
impl Debug for Report {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl Display for Report {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let code_cache = self.code_cache.read();
let mut group = Group::with_title(
self.level.clone().primary_title(&self.title).id(self.code),
);
let mut source_ids = Vec::new();
for (_, label_ref, _) in &self.labels {
let sid = label_ref.source_id.unwrap();
if !source_ids.contains(&sid) {
source_ids.push(sid);
}
}
for source_id in source_ids {
let cache_entry = code_cache.get(&source_id).unwrap();
let min_offset = self
.labels
.iter()
.filter(|l| l.1.source_id.unwrap() == source_id)
.map(|l| l.1.span.start())
.min()
.unwrap();
let max_offset = self
.labels
.iter()
.filter(|l| l.1.source_id.unwrap() == source_id)
.map(|l| l.1.span.end())
.max()
.unwrap();
let (sliced_src, line_start, slice_start) =
get_source_slice(cache_entry, min_offset, max_offset);
let mut snippet = Snippet::source(sliced_src)
.line_start(line_start)
.path(cache_entry.origin.as_deref().unwrap_or("line"));
for (level, label_ref, label) in &self.labels {
if label_ref.source_id.unwrap() == source_id {
let annotation_kind = if matches!(level, &Level::ERROR) {
AnnotationKind::Primary
} else {
AnnotationKind::Context
};
let span_start =
label_ref.span.start().saturating_sub(slice_start);
let span_end =
label_ref.span.end().saturating_sub(slice_start);
snippet = snippet.annotation(
annotation_kind
.span(span_start..span_end)
.label(label),
);
}
}
group = group.element(snippet);
}
for (level, text) in &self.footers {
group = group.element(level.clone().message(text.as_str()));
}
let renderer = if self.with_colors {
annotate_snippets::Renderer::styled()
} else {
annotate_snippets::Renderer::plain()
};
let renderer = renderer.term_width(self.max_width);
let mut groups = vec![group];
for section in &self.sections {
if section.patches.is_empty() {
continue;
}
let sid = section.patches[0].0.source_id.unwrap();
let cache_entry = code_cache.get(&sid).unwrap();
let min_offset = section
.patches
.iter()
.map(|(loc, _)| loc.span.start())
.min()
.unwrap();
let max_offset = section
.patches
.iter()
.map(|(loc, _)| loc.span.end())
.max()
.unwrap();
let (sliced_src, line_start, slice_start) =
get_source_slice(cache_entry, min_offset, max_offset);
let mut snippet = Snippet::source(sliced_src)
.line_start(line_start)
.path(cache_entry.origin.as_deref().unwrap_or("line"));
for (code_loc, replacement) in §ion.patches {
let span_start =
code_loc.span.start().saturating_sub(slice_start);
let span_end = code_loc.span.end().saturating_sub(slice_start);
snippet = snippet.patch(annotate_snippets::Patch::new(
span_start..span_end,
replacement,
))
}
groups.push(
section
.level
.clone()
.secondary_title(§ion.title)
.element(snippet),
);
}
let text = renderer.render(&groups);
write!(f, "{text}")
}
}
fn get_source_slice(
cache_entry: &CodeCacheEntry,
min_offset: usize,
max_offset: usize,
) -> (&str, usize, usize) {
let line_starts = &cache_entry.line_starts;
let start_line_idx =
line_starts.partition_point(|&x| x <= min_offset).saturating_sub(1);
let end_line_idx =
line_starts.partition_point(|&x| x <= max_offset).saturating_sub(1);
let slice_start = line_starts[start_line_idx];
let slice_end = if end_line_idx + 1 < line_starts.len() {
line_starts[end_line_idx + 1]
} else {
cache_entry.code.len()
};
(
&cache_entry.code[slice_start..slice_end],
start_line_idx + 1,
slice_start,
)
}
#[derive(Serialize)]
pub struct Label<'a> {
level: &'a str,
code_origin: Option<String>,
line: usize,
column: usize,
span: Span,
text: &'a str,
}
impl Label<'_> {
#[inline]
pub fn origin(&self) -> Option<&str> {
self.code_origin.as_deref()
}
#[inline]
pub fn span(&self) -> &Span {
&self.span
}
#[inline]
pub fn text(&self) -> &str {
self.text
}
}
#[derive(Serialize)]
pub struct Footer<'a> {
level: &'a str,
text: &'a str,
}
pub struct ReportBuilder {
with_colors: bool,
max_width: usize,
current_source_id: Cell<Option<SourceId>>,
next_source_id: Cell<SourceId>,
code_cache: Arc<CodeCache>,
}
struct CodeCache {
data: RwLock<HashMap<SourceId, CodeCacheEntry>>,
}
impl CodeCache {
fn new() -> Self {
Self { data: RwLock::new(HashMap::new()) }
}
pub fn read(
&self,
) -> RwLockReadGuard<'_, HashMap<SourceId, CodeCacheEntry>> {
self.data.read().unwrap()
}
pub fn write(
&self,
) -> RwLockWriteGuard<'_, HashMap<SourceId, CodeCacheEntry>> {
self.data.write().unwrap()
}
}
struct CodeCacheEntry {
code: String,
line_starts: Vec<usize>,
origin: Option<String>,
}
impl CodeCacheEntry {
fn byte_offset_to_line_col(
&self,
byte_offset: usize,
) -> Option<(usize, usize)> {
if byte_offset > self.code.len()
|| !self.code.is_char_boundary(byte_offset)
{
return None;
}
let line = self.line_starts.partition_point(|&x| x <= byte_offset);
let line_start = self.line_starts[line - 1];
let col = self.code[line_start..byte_offset].chars().count() + 1;
Some((line, col))
}
}
impl Default for ReportBuilder {
fn default() -> Self {
Self::new()
}
}
impl ReportBuilder {
pub fn new() -> Self {
Self {
with_colors: false,
max_width: DEFAULT_TERM_WIDTH,
current_source_id: Cell::new(None),
next_source_id: Cell::new(SourceId(0)),
code_cache: Arc::new(CodeCache::new()),
}
}
pub fn with_colors(&mut self, yes: bool) -> &mut Self {
self.with_colors = yes;
self
}
pub fn max_width(&mut self, width: usize) -> &mut Self {
self.max_width = width;
self
}
pub fn get_current_source_id(&self) -> Option<SourceId> {
self.current_source_id.get()
}
pub fn set_current_source_id(&mut self, source_id: SourceId) {
self.current_source_id.set(Some(source_id));
}
pub fn span_to_code_loc(&self, span: Span) -> CodeLoc {
CodeLoc::new(self.get_current_source_id(), span)
}
pub fn green_style(&self) -> renderer::Style {
if self.with_colors {
renderer::Style::new()
.fg_color(Some(Color::Ansi(AnsiColor::BrightGreen)))
} else {
renderer::Style::new()
}
}
pub fn register_source(&self, src: &SourceCode) -> SourceId {
let source_id = self.next_source_id.get();
self.next_source_id.set(SourceId(source_id.0 + 1));
self.current_source_id.set(Some(source_id));
self.code_cache.write().entry(source_id).or_insert_with(|| {
let s = if let Some(s) = src.valid {
Cow::Borrowed(s)
} else {
String::from_utf8_lossy(src.raw.as_ref())
};
let code = s.replace('\t', " ");
let line_starts = compute_line_starts(&code);
CodeCacheEntry {
code,
line_starts,
origin: src.origin.clone(),
}
});
source_id
}
pub fn get_snippet(&self, span: Span) -> String {
let source_id = self.get_current_source_id().unwrap();
let code_cache = self.code_cache.read();
let cache_entry = code_cache.get(&source_id).unwrap();
let src = cache_entry.code.as_str();
src[span.range()].to_string()
}
pub fn create_report(
&self,
level: Level,
code: &'static str,
title: String,
labels: Vec<(Level, CodeLoc, String)>,
footers: Vec<(Level, Option<String>)>,
) -> Report {
assert!(!labels.is_empty());
let footers = footers
.into_iter()
.filter_map(|(level, text)| text.map(|text| (level, text)))
.collect();
Report {
code_cache: self.code_cache.clone(),
with_colors: self.with_colors,
max_width: self.max_width,
level,
code,
title,
labels,
footers,
sections: Vec::new(),
}
}
}
fn level_as_text(level: &Level) -> &'static str {
match *level {
Level::ERROR => "error",
Level::WARNING => "warning",
Level::INFO => "info",
Level::NOTE => "note",
Level::HELP => "help",
_ => panic!("unsupported level {level:?}"),
}
}
fn compute_line_starts(text: &str) -> Vec<usize> {
let mut line_starts = vec![0];
for (i, c) in text.char_indices() {
if c == '\n' {
line_starts.push(i + 1);
}
}
line_starts
}
#[cfg(test)]
mod tests {
use crate::compiler::report::{CodeCacheEntry, compute_line_starts};
fn helper(text: &str, offset: usize) -> Option<(usize, usize)> {
let line_starts = compute_line_starts(text);
let entry = CodeCacheEntry {
code: text.to_string(),
line_starts,
origin: None,
};
entry.byte_offset_to_line_col(offset)
}
#[test]
fn byte_offset_to_line_col_single_line() {
let text = "Hello, World!";
assert_eq!(helper(text, 0), Some((1, 1))); assert_eq!(helper(text, 7), Some((1, 8))); assert_eq!(helper(text, 12), Some((1, 13))); }
#[test]
fn byte_offset_to_line_col_multiline() {
let text = "Hello\nRust\nWorld!";
assert_eq!(helper(text, 0), Some((1, 1))); assert_eq!(helper(text, 5), Some((1, 6))); assert_eq!(helper(text, 6), Some((2, 1))); assert_eq!(helper(text, 9), Some((2, 4))); assert_eq!(helper(text, 11), Some((3, 1))); }
#[test]
fn byte_offset_to_line_col_empty_string() {
let text = "";
assert_eq!(helper(text, 0), Some((1, 1)));
}
#[test]
fn byte_offset_to_line_col_out_of_bounds() {
let text = "Hello, World!";
assert_eq!(helper(text, text.len() + 1), None);
}
#[test]
fn byte_offset_to_line_col_end_of_string() {
let text = "Hello, World!";
assert_eq!(helper(text, text.len()), Some((1, 14))); }
#[test]
fn byte_offset_to_line_col_multibyte_characters() {
let text = "Hello, 你好!";
assert_eq!(helper(text, 7), Some((1, 8))); assert_eq!(helper(text, 8), None); assert_eq!(helper(text, 10), Some((1, 9))); assert_eq!(helper(text, 13), Some((1, 10))); }
}