use std::{
fs::File,
io::{Read, Seek},
path::Path,
time::Instant,
};
use super::{LoadStats, Loader, LoaderConfig, utils};
use crate::{Result, error};
use crate::{
StatementData,
ast::Statement,
syn::{Parser, Token},
};
use indexmap::IndexMap;
use logos::Logos;
use snafu::ensure;
pub struct StandardLoader {
modules: IndexMap<String, Statement>,
config: LoaderConfig,
stats: LoadStats,
file_cache: IndexMap<std::path::PathBuf, Statement>,
}
impl Default for StandardLoader {
fn default() -> Self {
Self::new(LoaderConfig::default())
}
}
impl StandardLoader {
pub fn new(config: LoaderConfig) -> Self {
Self {
modules: IndexMap::new(),
config,
stats: LoadStats::new(),
file_cache: IndexMap::new(),
}
}
pub fn builder() -> StandardLoaderBuilder {
StandardLoaderBuilder::new()
}
pub fn stats(&self) -> &LoadStats {
&self.stats
}
pub fn clear_cache(&mut self) {
self.file_cache.clear();
}
pub fn cache_size(&self) -> usize {
self.file_cache.len()
}
fn merge_statements(
left: &mut Statement,
right: &Statement,
allow_collisions: bool,
) -> Result<()> {
match &right.data {
StatementData::Group(right_stmts) | StatementData::Labeled(_, right_stmts) => {
match &mut left.data {
StatementData::Group(left_stmts) | StatementData::Labeled(_, left_stmts) => {
let additional_capacity =
right_stmts.len().saturating_sub(left_stmts.len());
if additional_capacity > 0 {
left_stmts.reserve(additional_capacity);
}
for (key, value) in right_stmts {
if let Some(target) = left_stmts.get_mut(key) {
Self::merge_statements(target, value, allow_collisions)?;
} else {
left_stmts.insert(key.clone(), value.clone());
}
}
}
_ => {
ensure!(
allow_collisions,
error::CollisionSnafu {
left_id: left.id.clone(),
left_location: left.meta.location.clone(),
right_id: right.id.clone(),
right_location: right.meta.location.clone()
}
);
*left = right.clone();
}
}
}
StatementData::Single(_) => {
ensure!(
allow_collisions,
error::CollisionSnafu {
left_id: left.id.clone(),
left_location: left.meta.location.clone(),
right_id: right.id.clone(),
right_location: right.meta.location.clone()
}
);
*left = right.clone();
}
}
Ok(())
}
fn parse_file<R>(
&mut self,
name: &str,
code: &mut R,
filename: Option<String>,
) -> Result<Statement>
where
R: Read + Seek,
{
let start_time = Instant::now();
let filename = filename.unwrap_or_else(|| name.to_string());
let mut module_code = String::new();
code.read_to_string(&mut module_code)
.map_err(|e| error::Error::Io {
reason: format!("Failed to read file '{}': {}", filename, e),
})?;
if module_code.trim().is_empty() {
return Err(error::Error::Io {
reason: format!("File '{}' is empty or contains only whitespace", filename),
});
}
let lexer = Token::lexer(&module_code);
let mut parser = Parser::new(&filename, lexer);
let module = parser.parse().map_err(|e| {
error::Error::Io {
reason: format!("Failed to parse file '{}': {}", filename, e),
}
})?;
self.stats.files_processed += 1;
self.stats.processing_time_ms += start_time.elapsed().as_millis() as u64;
if self.config.validate_on_load {
module.validate().map_err(|e| error::Error::Io {
reason: format!("Validation failed for file '{}': {}", filename, e),
})?;
}
Ok(module)
}
pub fn add_module<R>(
&mut self,
name: &str,
code: &mut R,
filename: Option<String>,
) -> Result<&mut Self>
where
R: Read + Seek,
{
let module = self.parse_file(name, code, filename)?;
if let Some(existing) = self.modules.get_mut(name) {
Self::merge_statements(existing, &module, self.config.allow_collisions)?;
} else {
self.modules.insert(name.to_string(), module);
self.stats.modules_created += 1;
}
Ok(self)
}
pub fn import<P>(&mut self, path: P) -> Result<&mut Self>
where
P: AsRef<Path>,
{
let path = path.as_ref();
utils::validate_path(path)?;
let name = utils::basename(path)?;
if let Some(cached_module) = self.file_cache.get(path) {
if let Some(existing) = self.modules.get_mut(&name) {
Self::merge_statements(existing, cached_module, self.config.allow_collisions)?;
} else {
self.modules.insert(name, cached_module.clone());
self.stats.modules_created += 1;
}
return Ok(self);
}
let mut file = File::open(path).map_err(|e| error::Error::Io {
reason: format!("Failed to open file '{}': {}", path.display(), e),
})?;
let module = self.parse_file(&name, &mut file, Some(name.clone()))?;
self.file_cache.insert(path.to_path_buf(), module.clone());
if let Some(existing) = self.modules.get_mut(&name) {
Self::merge_statements(existing, &module, self.config.allow_collisions)?;
} else {
self.modules.insert(name, module);
self.stats.modules_created += 1;
}
Ok(self)
}
pub fn add_file<P>(&mut self, path: P) -> Result<&mut Self>
where
P: AsRef<Path>,
{
let path = path.as_ref();
utils::validate_path(path)?;
let name = utils::basename(path)?;
if let Some(cached_module) = self.file_cache.get(path) {
if let Some(existing) = self.modules.get_mut("main") {
Self::merge_statements(existing, cached_module, self.config.allow_collisions)?;
} else {
self.modules
.insert("main".to_string(), cached_module.clone());
self.stats.modules_created += 1;
}
return Ok(self);
}
let mut file = File::open(path).map_err(|e| error::Error::Io {
reason: format!("Failed to open file '{}': {}", path.display(), e),
})?;
let module = self.parse_file("main", &mut file, Some(name))?;
self.file_cache.insert(path.to_path_buf(), module.clone());
if let Some(existing) = self.modules.get_mut("main") {
Self::merge_statements(existing, &module, self.config.allow_collisions)?;
} else {
self.modules.insert("main".to_string(), module);
self.stats.modules_created += 1;
}
Ok(self)
}
pub fn import_dir<P>(&mut self, path: P) -> Result<&mut Self>
where
P: AsRef<Path>,
{
let path = path.as_ref();
utils::validate_path(path)?;
let files = utils::discover_files(path)?;
for file in &files {
self.import(file)?;
}
Ok(self)
}
pub fn add_dir<P>(&mut self, path: P) -> Result<&mut Self>
where
P: AsRef<Path>,
{
let path = path.as_ref();
utils::validate_path(path)?;
let files = utils::discover_files(path)?;
for file in &files {
self.add_file(file)?;
}
Ok(self)
}
pub fn main<P>(&mut self, name: &str, search_paths: Vec<P>) -> Result<&mut Self>
where
P: AsRef<Path>,
{
for path in &search_paths {
let base_path = path.as_ref();
let file_path = base_path.join(name).with_extension("bml");
if file_path.exists() && file_path.is_file() {
return self.add_file(file_path);
}
let dir_path = base_path.join(name).with_extension("d");
if dir_path.exists() && dir_path.is_dir() {
return self.add_dir(dir_path);
}
}
Err(error::Error::Search {
name: name.to_string(),
search_paths: search_paths
.iter()
.map(|x| x.as_ref().to_path_buf())
.collect(),
})
}
pub fn module_names(&self) -> Vec<&String> {
self.modules.keys().collect()
}
pub fn get_module(&self, name: &str) -> Option<&Statement> {
self.modules.get(name)
}
pub fn has_module(&self, name: &str) -> bool {
self.modules.contains_key(name)
}
pub fn remove_module(&mut self, name: &str) -> Option<Statement> {
self.modules.shift_remove(name)
}
}
impl Loader for StandardLoader {
fn is_resolution_enabled(&self) -> bool {
self.config.resolve_macros
}
fn skip_macro_resolution(&mut self) -> Result<&mut Self> {
self.config.resolve_macros = false;
Ok(self)
}
fn read(&self) -> Result<Statement> {
self.modules
.get("main")
.cloned()
.ok_or(error::Error::NoMain)
}
}
pub struct StandardLoaderBuilder {
config: LoaderConfig,
}
impl StandardLoaderBuilder {
pub fn new() -> Self {
Self {
config: LoaderConfig::default(),
}
}
pub fn resolve_macros(mut self, resolve: bool) -> Self {
self.config.resolve_macros = resolve;
self
}
pub fn allow_collisions(mut self, allow: bool) -> Self {
self.config.allow_collisions = allow;
self
}
pub fn max_recursion_depth(mut self, depth: usize) -> Self {
self.config.max_recursion_depth = depth;
self
}
pub fn validate_on_load(mut self, validate: bool) -> Self {
self.config.validate_on_load = validate;
self
}
pub fn add_search_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.config.search_paths.push(path.as_ref().to_path_buf());
self
}
pub fn build(self) -> StandardLoader {
StandardLoader::new(self.config)
}
}
impl Default for StandardLoaderBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{Location, Metadata, Statement, Value};
use indexmap::IndexMap;
use semver::Version;
#[test]
fn test_builder_pattern() {
let loader = StandardLoader::builder()
.resolve_macros(false)
.allow_collisions(true)
.validate_on_load(true)
.build();
assert!(!loader.config.resolve_macros);
assert!(loader.config.allow_collisions);
assert!(loader.config.validate_on_load);
}
#[test]
fn test_module_management() {
let mut loader = StandardLoader::default();
assert_eq!(loader.module_names().len(), 0);
assert!(!loader.has_module("test"));
let module = Statement::new_module("test", IndexMap::new(), Metadata::default());
loader.modules.insert("test".to_string(), module);
assert_eq!(loader.module_names().len(), 1);
assert!(loader.has_module("test"));
assert!(loader.get_module("test").is_some());
let removed = loader.remove_module("test");
assert!(removed.is_some());
assert!(!loader.has_module("test"));
}
#[test]
fn test_statistics() {
let loader = StandardLoader::default();
let stats = loader.stats();
assert_eq!(stats.files_processed, 0);
assert_eq!(stats.modules_created, 0);
assert_eq!(stats.macros_resolved, 0);
}
#[test]
fn test_cache_management() {
let mut loader = StandardLoader::default();
assert_eq!(loader.cache_size(), 0);
let module = Statement::new_module("test", IndexMap::new(), Metadata::default());
loader.file_cache.insert("test.bml".into(), module);
assert_eq!(loader.cache_size(), 1);
loader.clear_cache();
assert_eq!(loader.cache_size(), 0);
}
#[test]
pub fn load_single() {
let expected = Statement::new_module(
".",
IndexMap::from([
(
"tire".into(),
Statement::new_control(
"tire",
None,
Value::new_version(
Version::new(1, 0, 0),
Metadata {
location: Location::default(),
comment: None,
label: Some("Test".into()),
},
),
Metadata::default(),
)
.unwrap(),
),
(
"section-1".into(),
Statement::new_section(
"section-1",
IndexMap::from([
(
"number".into(),
Statement::new_assign(
"number",
None,
Value::new_int(4, Metadata::default()),
Metadata {
location: Location::default(),
comment: Some("Documentation".into()),
label: None,
},
)
.unwrap(),
),
(
"floating".into(),
Statement::new_assign(
"floating",
Some(crate::ValueType::F32),
Value::new_f32(3.14, Metadata::default()),
Metadata::default(),
)
.unwrap(),
),
(
"versioning".into(),
Statement::new_assign(
"versioning",
None,
Value::new_version(
Version::parse("1.2.3-beta.6").unwrap(),
Metadata::default(),
),
Metadata::default(),
)
.unwrap(),
),
(
"requires".into(),
Statement::new_assign(
"requires",
None,
Value::new_require(
semver::VersionReq::parse("^1.3.3").unwrap(),
Metadata::default(),
),
Metadata::default(),
)
.unwrap(),
),
(
"strings".into(),
Statement::new_assign(
"strings",
None,
Value::new_string("hello world".into(), Metadata::default()),
Metadata::default(),
)
.unwrap(),
),
]),
Metadata::default(),
),
),
]),
Metadata::default(),
);
let mut loader = StandardLoader::default();
let result = loader
.add_module(
"main",
&mut std::io::Cursor::new(include_str!("../../examples/simple.bml")),
None,
)
.unwrap()
.load()
.unwrap();
assert_eq!(result, expected);
}
#[test]
pub fn load_multiple() {
let mut loader = StandardLoader::default();
loader
.add_module(
"main",
&mut std::io::Cursor::new(include_str!("../../examples/append.d/00-first.bml")),
None,
)
.unwrap()
.add_module(
"main",
&mut std::io::Cursor::new(include_str!("../../examples/append.d/01-second.bml")),
None,
)
.unwrap();
let result = loader.load().unwrap();
assert!(result.find_by_path("section-1.number").is_some());
assert!(result.find_by_path("section-2.number").is_some());
}
}