use crate::error::{NomlError, Result};
use crate::value::Value;
use std::collections::BTreeMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub struct Document {
pub root: AstNode,
pub source_path: Option<String>,
pub source_text: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstNode {
pub value: AstValue,
pub span: Span,
pub comments: Comments,
pub format: FormatMetadata,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Span {
pub start: usize,
pub end: usize,
pub start_line: usize,
pub start_column: usize,
pub end_line: usize,
pub end_column: usize,
}
impl Default for Span {
fn default() -> Self {
Span {
start: 0,
end: 0,
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Comments {
pub before: Vec<Comment>,
pub inline: Option<Comment>,
pub after: Vec<Comment>,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct FormatMetadata {
pub leading_whitespace: String,
pub trailing_whitespace: String,
pub indentation: Indentation,
pub line_ending: LineEnding,
pub format_style: FormatStyle,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Indentation {
pub use_tabs: bool,
pub size: usize,
pub level: usize,
}
impl Default for Indentation {
fn default() -> Self {
Self {
use_tabs: false,
size: 2,
level: 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum LineEnding {
#[default]
Unix,
Windows,
Mac,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum FormatStyle {
#[default]
Default,
Array {
multiline: bool,
trailing_comma: bool,
bracket_spacing: BracketSpacing,
},
Table {
inline: bool,
equals_spacing: EqualsSpacing,
quoted_keys: bool,
},
KeyValue {
equals_spacing: EqualsSpacing,
quoted_key: bool,
},
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct BracketSpacing {
pub after_open: String,
pub before_close: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EqualsSpacing {
pub before: String,
pub after: String,
}
impl Default for EqualsSpacing {
fn default() -> Self {
Self {
before: " ".to_string(),
after: " ".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Comment {
pub text: String,
pub span: Span,
pub style: CommentStyle,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CommentStyle {
Line,
Block,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstValue {
Null,
Bool(bool),
Integer {
value: i64,
raw: String,
},
Float {
value: f64,
raw: String,
},
String {
value: String,
style: StringStyle,
has_escapes: bool,
},
Array {
elements: Vec<AstNode>,
multiline: bool,
trailing_comma: bool,
},
Table {
entries: Vec<TableEntry>,
inline: bool,
},
FunctionCall {
name: String,
args: Vec<AstNode>,
},
Interpolation {
path: String,
},
Include {
path: String,
},
Native {
type_name: String,
args: Vec<AstNode>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub struct TableEntry {
pub key: Key,
pub value: AstNode,
pub comments: Comments,
}
#[derive(Debug, Clone)]
pub struct Key {
pub segments: Vec<KeySegment>,
pub span: Span,
}
#[derive(Debug, Clone, PartialEq)]
pub struct KeySegment {
pub name: String,
pub quoted: bool,
pub quote_style: Option<StringStyle>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum StringStyle {
Double,
Single,
TripleDouble,
TripleSingle,
Raw {
hashes: usize,
},
}
impl Document {
pub fn new(root: AstNode) -> Self {
Self {
root,
source_path: None,
source_text: None,
}
}
pub fn with_source(root: AstNode, path: Option<String>, text: Option<String>) -> Self {
Self {
root,
source_path: path,
source_text: text,
}
}
pub fn to_value(&self) -> Result<Value> {
self.root.to_value()
}
pub fn source_text_for_span(&self, span: &Span) -> Option<&str> {
self.source_text
.as_ref()
.and_then(|text| text.get(span.start..span.end))
}
pub fn node_at_offset(&self, offset: usize) -> Option<&AstNode> {
self.root.find_node_at_offset(offset)
}
pub fn all_comments(&self) -> Vec<&Comment> {
let mut comments = Vec::new();
self.root.collect_comments(&mut comments);
comments
}
}
impl AstNode {
pub fn new(value: AstValue, span: Span) -> Self {
Self {
value,
span,
comments: Comments::default(),
format: FormatMetadata::default(),
}
}
pub fn with_comments(value: AstValue, span: Span, comments: Comments) -> Self {
Self {
value,
span,
comments,
format: FormatMetadata::default(),
}
}
pub fn with_metadata(
value: AstValue,
span: Span,
comments: Comments,
format: FormatMetadata,
) -> Self {
Self {
value,
span,
comments,
format,
}
}
pub fn to_value(&self) -> Result<Value> {
match &self.value {
AstValue::Null => Ok(Value::Null),
AstValue::Bool(b) => Ok(Value::Bool(*b)),
AstValue::Integer { value, .. } => Ok(Value::Integer(*value)),
AstValue::Float { value, .. } => Ok(Value::Float(*value)),
AstValue::String { value, .. } => Ok(Value::String(value.clone())),
AstValue::Array { elements, .. } => {
let values = elements
.iter()
.map(|elem| elem.to_value())
.collect::<Result<Vec<_>>>()?;
Ok(Value::Array(values))
}
AstValue::Table { entries, .. } => {
let mut value = Value::Table(BTreeMap::new());
for entry in entries {
let key = entry.key.to_string();
let entry_value = entry.value.to_value()?;
value.set(&key, entry_value)?;
}
Ok(value)
}
AstValue::FunctionCall { name, args } => {
match name.as_str() {
"env" => self.handle_env_function(args),
"size" => self.handle_size_function(args),
"duration" => self.handle_duration_function(args),
_ => Err(NomlError::validation(format!("Unknown function: {name}"))),
}
}
AstValue::Interpolation { path } => {
Err(NomlError::interpolation(
"Unresolved interpolation",
path.clone(),
))
}
AstValue::Include { path } => {
Err(NomlError::import(
path.clone(),
"Unresolved include directive",
))
}
AstValue::Native { type_name, args } => self.handle_native_type(type_name, args),
}
}
pub fn find_node_at_offset(&self, offset: usize) -> Option<&AstNode> {
if offset < self.span.start || offset >= self.span.end {
return None;
}
match &self.value {
AstValue::Array { elements, .. } => {
for element in elements {
if let Some(node) = element.find_node_at_offset(offset) {
return Some(node);
}
}
}
AstValue::Table { entries, .. } => {
for entry in entries {
if let Some(node) = entry.value.find_node_at_offset(offset) {
return Some(node);
}
}
}
AstValue::FunctionCall { args, .. } => {
for arg in args {
if let Some(node) = arg.find_node_at_offset(offset) {
return Some(node);
}
}
}
AstValue::Native { args, .. } => {
for arg in args {
if let Some(node) = arg.find_node_at_offset(offset) {
return Some(node);
}
}
}
_ => {}
}
Some(self)
}
pub fn collect_comments<'a>(&'a self, comments: &mut Vec<&'a Comment>) {
comments.extend(self.comments.before.iter());
if let Some(ref inline) = self.comments.inline {
comments.push(inline);
}
comments.extend(self.comments.after.iter());
match &self.value {
AstValue::Array { elements, .. } => {
for element in elements {
element.collect_comments(comments);
}
}
AstValue::Table { entries, .. } => {
for entry in entries {
comments.extend(entry.comments.before.iter());
if let Some(ref inline) = entry.comments.inline {
comments.push(inline);
}
comments.extend(entry.comments.after.iter());
entry.value.collect_comments(comments);
}
}
AstValue::FunctionCall { args, .. } => {
for arg in args {
arg.collect_comments(comments);
}
}
AstValue::Native { args, .. } => {
for arg in args {
arg.collect_comments(comments);
}
}
_ => {}
}
}
fn handle_env_function(&self, args: &[AstNode]) -> Result<Value> {
if args.is_empty() || args.len() > 2 {
return Err(NomlError::validation(
"env() function requires 1 or 2 arguments",
));
}
let var_name = match args[0].to_value()? {
Value::String(name) => name,
_ => {
return Err(NomlError::validation(
"env() first argument must be a string",
));
}
};
match std::env::var(&var_name) {
Ok(value) => Ok(Value::String(value)),
Err(_) => {
if args.len() == 2 {
args[1].to_value()
} else {
Err(NomlError::env_var(var_name, false))
}
}
}
}
fn handle_size_function(&self, args: &[AstNode]) -> Result<Value> {
if args.len() != 1 {
return Err(NomlError::validation(
"size() function requires exactly 1 argument",
));
}
let size_str = match args[0].to_value()? {
Value::String(s) => s,
_ => return Err(NomlError::validation("size() argument must be a string")),
};
parse_size(&size_str)
.map(Value::Size)
.ok_or_else(|| NomlError::validation(format!("Invalid size format: {size_str}")))
}
fn handle_duration_function(&self, args: &[AstNode]) -> Result<Value> {
if args.len() != 1 {
return Err(NomlError::validation(
"duration() function requires exactly 1 argument",
));
}
let duration_str = match args[0].to_value()? {
Value::String(s) => s,
_ => {
return Err(NomlError::validation(
"duration() argument must be a string",
));
}
};
parse_duration(&duration_str)
.map(Value::Duration)
.ok_or_else(|| {
NomlError::validation(format!("Invalid duration format: {duration_str}"))
})
}
fn handle_native_type(&self, type_name: &str, args: &[AstNode]) -> Result<Value> {
match type_name {
"size" => self.handle_size_function(args),
"duration" => self.handle_duration_function(args),
_ => Err(NomlError::validation(format!(
"Unknown native type: @{type_name}"
))),
}
}
}
impl Span {
pub fn new(
start: usize,
end: usize,
start_line: usize,
start_column: usize,
end_line: usize,
end_column: usize,
) -> Self {
Self {
start,
end,
start_line,
start_column,
end_line,
end_column,
}
}
pub fn merge(&self, other: &Span) -> Span {
Span {
start: self.start.min(other.start),
end: self.end.max(other.end),
start_line: self.start_line.min(other.start_line),
start_column: if self.start_line < other.start_line {
self.start_column
} else if self.start_line > other.start_line {
other.start_column
} else {
self.start_column.min(other.start_column)
},
end_line: self.end_line.max(other.end_line),
end_column: if self.end_line > other.end_line {
self.end_column
} else if self.end_line < other.end_line {
other.end_column
} else {
self.end_column.max(other.end_column)
},
}
}
pub fn contains(&self, offset: usize) -> bool {
offset >= self.start && offset < self.end
}
}
impl Key {
pub fn simple(name: String, span: Span) -> Self {
Self {
segments: vec![KeySegment {
name,
quoted: false,
quote_style: None,
}],
span,
}
}
pub fn dotted(segments: Vec<KeySegment>, span: Span) -> Self {
Self { segments, span }
}
}
impl fmt::Display for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (i, segment) in self.segments.iter().enumerate() {
if i > 0 {
write!(f, ".")?;
}
if segment.quoted {
write!(f, "\"{}\"", segment.name)?;
} else {
write!(f, "{}", segment.name)?;
}
}
Ok(())
}
}
impl PartialEq for Key {
fn eq(&self, other: &Self) -> bool {
self.segments == other.segments
}
}
impl Eq for Key {}
impl Comments {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.before.is_empty() && self.inline.is_none() && self.after.is_empty()
}
pub fn add_before(&mut self, comment: Comment) {
self.before.push(comment);
}
pub fn set_inline(&mut self, comment: Comment) {
self.inline = Some(comment);
}
pub fn add_after(&mut self, comment: Comment) {
self.after.push(comment);
}
}
fn parse_size(s: &str) -> Option<u64> {
let s = s.trim().to_lowercase();
if s.is_empty() {
return None;
}
let (number_part, unit_part) = if let Some(pos) = s.find(|c: char| c.is_alphabetic()) {
s.split_at(pos)
} else {
(s.as_str(), "")
};
let number: f64 = number_part.parse().ok()?;
if number < 0.0 {
return None;
}
let multiplier = match unit_part {
"" | "b" | "byte" | "bytes" => 1,
"k" | "kb" | "kib" => 1024,
"m" | "mb" | "mib" => 1024 * 1024,
"g" | "gb" | "gib" => 1024 * 1024 * 1024,
"t" | "tb" | "tib" => 1024_u64.pow(4),
"p" | "pb" | "pib" => 1024_u64.pow(5),
_ => return None,
};
Some((number * multiplier as f64) as u64)
}
fn parse_duration(s: &str) -> Option<f64> {
let s = s.trim().to_lowercase();
if s.is_empty() {
return None;
}
let (number_part, unit_part) = if let Some(pos) = s.find(|c: char| c.is_alphabetic()) {
s.split_at(pos)
} else {
(s.as_str(), "s") };
let number: f64 = number_part.parse().ok()?;
if number < 0.0 {
return None;
}
let multiplier = match unit_part {
"ms" | "millisecond" | "milliseconds" => 0.001,
"" | "s" | "sec" | "second" | "seconds" => 1.0,
"m" | "min" | "minute" | "minutes" => 60.0,
"h" | "hr" | "hour" | "hours" => 3600.0,
"d" | "day" | "days" => 86400.0,
"w" | "week" | "weeks" => 604800.0,
_ => return None,
};
Some(number * multiplier)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn span_operations() {
let span1 = Span::new(10, 20, 1, 10, 1, 20);
let span2 = Span::new(15, 25, 1, 15, 1, 25);
let merged = span1.merge(&span2);
assert_eq!(merged.start, 10);
assert_eq!(merged.end, 25);
assert!(merged.contains(12));
assert!(!merged.contains(5));
}
#[test]
fn key_display() {
let key = Key::simple("server".to_string(), Span::new(0, 6, 1, 1, 1, 6));
assert_eq!(key.to_string(), "server");
let dotted_key = Key::dotted(
vec![
KeySegment {
name: "server".to_string(),
quoted: false,
quote_style: None,
},
KeySegment {
name: "host".to_string(),
quoted: false,
quote_style: None,
},
],
Span::new(0, 11, 1, 1, 1, 11),
);
assert_eq!(dotted_key.to_string(), "server.host");
}
#[test]
fn parse_sizes() {
assert_eq!(parse_size("1024"), Some(1024));
assert_eq!(parse_size("1KB"), Some(1024));
assert_eq!(parse_size("1.5MB"), Some(1572864));
assert_eq!(parse_size("1GB"), Some(1073741824));
assert_eq!(parse_size("invalid"), None);
}
#[test]
fn parse_durations() {
assert_eq!(parse_duration("30"), Some(30.0));
assert_eq!(parse_duration("30s"), Some(30.0));
assert_eq!(parse_duration("1.5m"), Some(90.0));
assert_eq!(parse_duration("2h"), Some(7200.0));
assert_eq!(parse_duration("invalid"), None);
}
#[test]
fn ast_to_value_conversion() {
let span = Span::new(0, 4, 1, 1, 1, 4);
let bool_node = AstNode::new(AstValue::Bool(true), span);
assert_eq!(bool_node.to_value().unwrap(), Value::Bool(true));
let int_node = AstNode::new(
AstValue::Integer {
value: 42,
raw: "42".to_string(),
},
span,
);
assert_eq!(int_node.to_value().unwrap(), Value::Integer(42));
let str_node = AstNode::new(
AstValue::String {
value: "hello".to_string(),
style: StringStyle::Double,
has_escapes: false,
},
span,
);
assert_eq!(
str_node.to_value().unwrap(),
Value::String("hello".to_string())
);
}
}