use super::{Result, error};
use crate::ast::{Statement, Value, ValueType};
use indexmap::IndexSet;
use snafu::OptionExt;
use std::fmt;
#[derive(Clone)]
pub enum Walk<'source> {
Statement(&'source Statement),
Value(&'source Value),
}
impl<'source> Walk<'source> {
pub fn new(module: &'source Statement) -> Self {
Self::Statement(module)
}
pub fn from_value(value: &'source Value) -> Self {
Self::Value(value)
}
pub fn location(&self) -> &crate::ast::Location {
match self {
Self::Statement(stmt) => &stmt.meta.location,
Self::Value(value) => &value.meta.location,
}
}
pub fn node_type(&self) -> NodeType {
match self {
Self::Statement(stmt) => NodeType::Statement(stmt.type_.clone()),
Self::Value(value) => NodeType::Value(value.type_of()),
}
}
pub const fn is_statement(&self) -> bool {
matches!(self, Self::Statement(_))
}
pub const fn is_value(&self) -> bool {
matches!(self, Self::Value(_))
}
pub fn get_child(&self, field: &str) -> Result<&Statement> {
match self {
Self::Statement(stmt) => {
let children = stmt.get_grouped().context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
children.get(field).context(error::NoFieldSnafu {
field: field.to_string(),
location: stmt.meta.location.clone(),
})
}
Self::Value(value) => error::NotScopeSnafu {
location: value.meta.location.clone(),
}
.fail(),
}
}
pub fn get<T>(&self, field: &str) -> Result<T>
where
T: TryFrom<&'source Value, Error = error::Error>,
{
if field.contains('.') {
return self.get_by_path(field);
}
match self {
Self::Statement(stmt) => {
let children = stmt
.get_grouped()
.or_else(|| stmt.get_labeled().map(|x| x.1))
.context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
let target_stmt = children.get(field).context(error::NoFieldSnafu {
location: stmt.meta.location.clone(),
field: field.to_string(),
})?;
let value = target_stmt.get_value().context(error::NoValueSnafu {
location: stmt.meta.location.clone(),
field: field.to_string(),
})?;
value.try_into()
}
Self::Value(value) => {
let table = value.as_table().context(error::NotScopeSnafu {
location: value.meta.location.clone(),
})?;
let target_value = table.get(field).context(error::NoFieldSnafu {
location: value.meta.location.clone(),
field: field.to_string(),
})?;
target_value.try_into()
}
}
}
pub fn get_by_path<T>(&self, path: &str) -> Result<T>
where
T: TryFrom<&'source Value, Error = error::Error>,
{
let parts: Vec<&str> = path.split('.').collect();
let mut current = self.clone();
for (i, part) in parts.iter().enumerate() {
if i == parts.len() - 1 {
return current.get(part);
} else {
current = current.walk(part)?;
}
}
unreachable!("Path navigation should have returned or errored")
}
pub fn get_id(&self) -> Option<&str> {
match self {
Self::Statement(stmt) => Some(&stmt.id),
Self::Value(_) => None,
}
}
pub fn get_label<T>(&self, index: usize) -> Result<T>
where
T: TryFrom<&'source Value, Error = error::Error>,
{
match self {
Self::Statement(stmt) => {
let (labels, _) = stmt.get_labeled().context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
let label = labels.get(index).context(error::NoElementSnafu {
location: stmt.meta.location.clone(),
index,
})?;
label.try_into()
}
Self::Value(value) => error::NotScopeSnafu {
location: value.meta.location.clone(),
}
.fail(),
}
}
pub fn get_sections(&self) -> Result<IndexSet<String>> {
match self {
Self::Statement(stmt) => {
let children = stmt.get_grouped().context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
Ok(children
.iter()
.filter_map(|(k, s)| {
if matches!(s.type_, crate::StatementType::Section(..)) {
Some(k.clone())
} else {
None
}
})
.collect())
}
Self::Value(value) => error::NotScopeSnafu {
location: value.meta.location.clone(),
}
.fail(),
}
}
pub fn get_all_blocks(&self) -> Result<IndexSet<String>> {
match self {
Self::Statement(stmt) => {
let children = stmt.get_grouped().context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
Ok(children
.iter()
.filter_map(|(k, s)| {
if matches!(s.type_, crate::StatementType::Block { .. }) {
Some(k.clone())
} else {
None
}
})
.collect())
}
Self::Value(value) => error::NotScopeSnafu {
location: value.meta.location.clone(),
}
.fail(),
}
}
pub fn get_blocks(&self, field: &str) -> Result<IndexSet<String>> {
match self {
Self::Statement(stmt) => {
let children = stmt
.get_grouped()
.or_else(|| stmt.get_labeled().map(|x| x.1))
.context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
Ok(children
.iter()
.filter_map(|(k, s)| {
if s.get_labeled().is_some() && s.id == field {
Some(k.clone())
} else {
None
}
})
.collect())
}
Self::Value(value) => error::NotScopeSnafu {
location: value.meta.location.clone(),
}
.fail(),
}
}
pub fn walk(&self, field: &str) -> Result<Self> {
match self {
Self::Statement(stmt) => {
if let Some(children) = stmt
.get_grouped()
.or_else(|| stmt.get_labeled().map(|x| x.1))
{
let target = children.get(field).context(error::NoFieldSnafu {
location: stmt.meta.location.clone(),
field: field.to_string(),
})?;
Ok(Self::Statement(target))
} else if let Some(value) = stmt.get_value() {
Ok(Self::Value(value))
} else {
error::NotScopeSnafu {
location: stmt.meta.location.clone(),
}
.fail()
}
}
Self::Value(value) => {
let table = value.as_table().context(error::NotScopeSnafu {
location: value.meta.location.clone(),
})?;
let target = table.get(field).context(error::NoFieldSnafu {
location: value.meta.location.clone(),
field: field.to_string(),
})?;
Ok(Self::Value(target))
}
}
}
pub fn field_names(&self) -> Result<Vec<String>> {
match self {
Self::Statement(stmt) => {
let children = stmt
.get_grouped()
.or_else(|| stmt.get_labeled().map(|x| x.1))
.context(error::NotScopeSnafu {
location: stmt.meta.location.clone(),
})?;
Ok(children.keys().cloned().collect())
}
Self::Value(value) => {
let table = value.as_table().context(error::NotScopeSnafu {
location: value.meta.location.clone(),
})?;
Ok(table.keys().cloned().collect())
}
}
}
pub fn has_field(&self, field: &str) -> bool {
self.get_child(field).is_ok() || self.walk(field).is_ok()
}
pub fn len(&self) -> usize {
match self {
Self::Statement(stmt) => stmt
.get_grouped()
.or_else(|| stmt.get_labeled().map(|x| x.1))
.map(|children| children.len())
.unwrap_or(0),
Self::Value(value) => value.as_table().map(|table| table.len()).unwrap_or(0),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub enum NodeType {
Statement(crate::StatementType),
Value(ValueType),
}
impl fmt::Display for NodeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NodeType::Statement(stmt_type) => write!(
f,
"Statement({})",
match stmt_type {
crate::StatementType::Control(_) => "Control",
crate::StatementType::Assignment(_) => "Assignment",
crate::StatementType::Block { .. } => "Block",
crate::StatementType::Section(_) => "Section",
crate::StatementType::Module(_) => "Module",
}
),
NodeType::Value(value_type) => write!(f, "Value({})", value_type),
}
}
}
impl<'source> fmt::Debug for Walk<'source> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Statement(stmt) => f
.debug_struct("Walk::Statement")
.field("id", &stmt.id)
.field("type", &stmt.type_)
.field("location", &stmt.meta.location)
.finish(),
Self::Value(value) => f
.debug_struct("Walk::Value")
.field("type", &value.type_of())
.field("location", &value.meta.location)
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{Metadata, Statement, Value};
use indexmap::IndexMap;
fn create_test_statement() -> Statement {
let mut children = IndexMap::new();
let value = Value::new_string("test_value".to_string(), Metadata::default());
let stmt = Statement::new_assign("test_field", None, value, Metadata::default()).unwrap();
children.insert("test_field".to_string(), stmt);
Statement::new_module("test", children, Metadata::default())
}
#[test]
fn test_walker_creation() {
let stmt = create_test_statement();
let walker = Walk::new(&stmt);
assert!(walker.is_statement());
assert!(!walker.is_value());
assert_eq!(walker.get_id(), Some("test"));
}
#[test]
fn test_field_access() {
let stmt = create_test_statement();
let walker = Walk::new(&stmt);
assert!(walker.has_field("test_field"));
assert!(!walker.has_field("nonexistent"));
let value: String = walker.get("test_field").unwrap();
assert_eq!(value, "test_value");
}
#[test]
fn test_navigation() {
let stmt = create_test_statement();
let walker = Walk::new(&stmt);
let child_walker = walker.walk("test_field").unwrap();
assert!(child_walker.is_statement());
}
#[test]
fn test_field_enumeration() {
let stmt = create_test_statement();
let walker = Walk::new(&stmt);
let fields = walker.field_names().unwrap();
assert_eq!(fields.len(), 1);
assert!(fields.contains(&"test_field".to_string()));
assert_eq!(walker.len(), 1);
assert!(!walker.is_empty());
}
#[test]
fn test_node_type() {
let stmt = create_test_statement();
let walker = Walk::new(&stmt);
let node_type = walker.node_type();
assert!(matches!(node_type, NodeType::Statement(_)));
}
}