use std::path::PathBuf;
use std::sync::Arc;
use ahash::HashMap;
use parking_lot::RwLock;
use serde::Deserialize;
use serde::Serialize;
use mago_interner::StringIdentifier;
use mago_interner::ThreadedInterner;
use crate::error::SourceError;
pub mod error;
#[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub enum SourceCategory {
BuiltIn,
External,
#[default]
UserDefined,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[repr(C)]
pub struct SourceIdentifier(pub StringIdentifier, pub SourceCategory);
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct Source {
pub identifier: SourceIdentifier,
pub path: Option<PathBuf>,
pub content: StringIdentifier,
pub size: usize,
pub lines: Vec<usize>,
}
pub trait HasSource {
fn source(&self) -> SourceIdentifier;
}
#[derive(Debug)]
struct SourceEntry {
path: Option<PathBuf>,
content: Option<(StringIdentifier, usize, Vec<usize>)>,
}
#[derive(Debug)]
struct SourceManagerInner {
sources: HashMap<SourceIdentifier, SourceEntry>,
sources_by_name: HashMap<StringIdentifier, SourceIdentifier>,
}
#[derive(Clone, Debug)]
pub struct SourceManager {
interner: ThreadedInterner,
inner: Arc<RwLock<SourceManagerInner>>,
}
impl SourceCategory {
#[inline(always)]
pub const fn is_built_in(&self) -> bool {
matches!(self, Self::BuiltIn)
}
#[inline(always)]
pub const fn is_external(&self) -> bool {
matches!(self, Self::External)
}
#[inline(always)]
pub const fn is_user_defined(&self) -> bool {
matches!(self, Self::UserDefined)
}
}
impl SourceIdentifier {
#[inline(always)]
pub fn dummy() -> Self {
Self(StringIdentifier::empty(), SourceCategory::UserDefined)
}
#[inline(always)]
pub const fn value(&self) -> StringIdentifier {
self.0
}
#[inline(always)]
pub const fn category(&self) -> SourceCategory {
self.1
}
}
impl Source {
#[inline(always)]
pub fn standalone(interner: &ThreadedInterner, name: &str, content: &str) -> Self {
let lines: Vec<_> = line_starts(content).collect();
let size = content.len();
let content_id = interner.intern(content);
Self {
identifier: SourceIdentifier(interner.intern(name), SourceCategory::UserDefined),
path: None,
content: content_id,
size,
lines,
}
}
#[inline(always)]
pub fn line_number(&self, offset: usize) -> usize {
self.lines.binary_search(&offset).unwrap_or_else(|next_line| next_line - 1)
}
pub fn get_line_start_offset(&self, line: usize) -> Option<usize> {
self.lines.get(line).copied()
}
pub fn get_line_end_offset(&self, line: usize) -> Option<usize> {
match self.lines.get(line + 1) {
Some(&end) => Some(end - 1),
None if line == self.lines.len() - 1 => Some(self.size),
_ => None,
}
}
#[inline(always)]
pub fn column_number(&self, offset: usize) -> usize {
let line_start = self.lines.binary_search(&offset).unwrap_or_else(|next_line| self.lines[next_line - 1]);
offset - line_start
}
}
impl SourceManager {
#[inline(always)]
pub fn new(interner: ThreadedInterner) -> Self {
Self {
interner,
inner: Arc::new(RwLock::new(SourceManagerInner {
sources: HashMap::default(),
sources_by_name: HashMap::default(),
})),
}
}
#[inline(always)]
pub fn insert_path(&self, name: impl AsRef<str>, path: PathBuf, category: SourceCategory) -> SourceIdentifier {
let name_str = name.as_ref();
let name_id = self.interner.intern(name_str);
let source_id = SourceIdentifier(name_id, category);
{
let inner = self.inner.read();
if inner.sources.contains_key(&source_id) {
return source_id;
}
}
let mut inner = self.inner.write();
if inner.sources.contains_key(&source_id) {
return source_id;
}
inner.sources.insert(source_id, SourceEntry { path: Some(path), content: None });
inner.sources_by_name.insert(name_id, source_id);
source_id
}
#[inline(always)]
pub fn insert_content(
&self,
name: impl AsRef<str>,
content: impl AsRef<str>,
category: SourceCategory,
) -> SourceIdentifier {
let name_str = name.as_ref();
let content_str = content.as_ref();
let name_id = self.interner.intern(name_str);
{
let inner = self.inner.read();
if let Some(&source_id) = inner.sources_by_name.get(&name_id) {
return source_id;
}
}
let lines: Vec<_> = line_starts(content_str).collect();
let size = content_str.len();
let content_id = self.interner.intern(content_str);
let source_id = SourceIdentifier(name_id, category);
let mut inner = self.inner.write();
if let Some(&existing) = inner.sources_by_name.get(&name_id) {
return existing;
}
inner.sources.insert(source_id, SourceEntry { path: None, content: Some((content_id, size, lines)) });
inner.sources_by_name.insert(name_id, source_id);
source_id
}
#[inline(always)]
pub fn contains(&self, source_id: &SourceIdentifier) -> bool {
let inner = self.inner.read();
inner.sources.contains_key(source_id)
}
#[inline(always)]
pub fn source_ids(&self) -> Vec<SourceIdentifier> {
let inner = self.inner.read();
inner.sources.keys().cloned().collect()
}
#[inline(always)]
pub fn source_ids_for_category(&self, category: SourceCategory) -> Vec<SourceIdentifier> {
let inner = self.inner.read();
inner.sources.keys().filter(|id| id.category() == category).cloned().collect()
}
#[inline(always)]
pub fn source_ids_except_category(&self, category: SourceCategory) -> Vec<SourceIdentifier> {
let inner = self.inner.read();
inner.sources.keys().filter(|id| id.category() != category).cloned().collect()
}
#[inline(always)]
pub fn load(&self, source_id: &SourceIdentifier) -> Result<Source, SourceError> {
let path = {
let inner = self.inner.read();
let entry = inner.sources.get(source_id).ok_or(SourceError::UnavailableSource(*source_id))?;
if let Some((content, size, ref lines)) = entry.content {
return Ok(Source {
identifier: *source_id,
path: entry.path.clone(),
content,
size,
lines: lines.clone(),
});
}
entry.path.clone().ok_or(SourceError::UnavailableSource(*source_id))?
};
let bytes = std::fs::read(&path).map_err(SourceError::IOError)?;
let content_str = match String::from_utf8(bytes) {
Ok(s) => s,
Err(err) => {
let s = err.into_bytes();
let s = String::from_utf8_lossy(&s).into_owned();
if source_id.category().is_user_defined() {
tracing::debug!(
"Source '{}' contains invalid UTF-8 sequence; behavior is undefined.",
path.display()
);
} else {
tracing::info!(
"Source '{}' contains invalid UTF-8 sequence; behavior is undefined.",
path.display()
);
}
s
}
};
let lines: Vec<_> = line_starts(&content_str).collect();
let size = content_str.len();
let content_id = self.interner.intern(&content_str);
{
let mut inner = self.inner.write();
if let Some(entry) = inner.sources.get_mut(source_id) {
if entry.content.is_none() {
entry.content = Some((content_id, size, lines.clone()));
}
Ok(Source { identifier: *source_id, path: entry.path.clone(), content: content_id, size, lines })
} else {
Err(SourceError::UnavailableSource(*source_id))
}
}
}
#[inline(always)]
pub fn write(&self, source_id: SourceIdentifier, new_content: impl AsRef<str>) -> Result<(), SourceError> {
let new_content_str = new_content.as_ref();
let new_content_id = self.interner.intern(new_content_str);
let new_lines: Vec<_> = line_starts(new_content_str).collect();
let new_size = new_content_str.len();
let path_opt = {
let mut inner = self.inner.write();
let entry = inner.sources.get_mut(&source_id).ok_or(SourceError::UnavailableSource(source_id))?;
if let Some((old_content, _, _)) = entry.content
&& old_content == new_content_id
{
return Ok(());
}
entry.content = Some((new_content_id, new_size, new_lines));
entry.path.clone()
};
if let Some(ref path) = path_opt {
std::fs::write(path, self.interner.lookup(&new_content_id)).map_err(SourceError::IOError)?;
}
Ok(())
}
#[inline(always)]
pub fn len(&self) -> usize {
let inner = self.inner.read();
inner.sources.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
let inner = self.inner.read();
inner.sources.is_empty()
}
}
impl<T: HasSource> HasSource for Box<T> {
#[inline(always)]
fn source(&self) -> SourceIdentifier {
self.as_ref().source()
}
}
#[inline(always)]
fn line_starts(source: &str) -> impl Iterator<Item = usize> + '_ {
let bytes = source.as_bytes();
std::iter::once(0)
.chain(memchr::memchr_iter(b'\n', bytes).map(|i| if i > 0 && bytes[i - 1] == b'\r' { i } else { i + 1 }))
}