use serde::{Deserialize, Serialize};
pub mod kinds {
pub const IF: &str = "If";
pub const MATCH: &str = "Match";
pub const LOOP: &str = "Loop";
pub const WHILE: &str = "While";
pub const FOR: &str = "For";
pub const BREAK: &str = "Break";
pub const CONTINUE: &str = "Continue";
pub const RETURN: &str = "Return";
pub const FUNCTION: &str = "Function";
pub const STRUCT: &str = "Struct";
pub const ENUM: &str = "Enum";
pub const TRAIT: &str = "Trait";
pub const IMPL: &str = "Impl";
pub const MODULE: &str = "Module";
pub const CLASS: &str = "Class";
pub const INTERFACE: &str = "Interface";
pub const BLOCK: &str = "Block";
pub const CALL: &str = "Call";
pub const ASSIGN: &str = "Assign";
pub const LET: &str = "Let";
pub const CONST: &str = "Const";
pub const STATIC: &str = "Static";
pub const ATTRIBUTE: &str = "Attribute";
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AstNode {
pub id: Option<i64>,
pub parent_id: Option<i64>,
pub kind: String,
pub byte_start: usize,
pub byte_end: usize,
}
impl AstNode {
pub fn new(
parent_id: Option<i64>,
kind: impl Into<String>,
byte_start: usize,
byte_end: usize,
) -> Self {
Self {
id: None,
parent_id,
kind: kind.into(),
byte_start,
byte_end,
}
}
pub fn span(&self) -> (usize, usize) {
(self.byte_start, self.byte_end)
}
pub fn contains(&self, position: usize) -> bool {
self.byte_start <= position && position < self.byte_end
}
pub fn len(&self) -> usize {
self.byte_end - self.byte_start
}
pub fn is_empty(&self) -> bool {
self.byte_end <= self.byte_start
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AstNodeWithText {
#[serde(flatten)]
pub node: AstNode,
pub text: Option<String>,
}
impl AstNodeWithText {
pub fn from_node(node: AstNode, text: Option<String>) -> Self {
Self { node, text }
}
pub fn text_or<'a>(&'a self, placeholder: &'a str) -> &'a str {
self.text.as_deref().unwrap_or(placeholder)
}
}
impl From<AstNode> for AstNodeWithText {
fn from(node: AstNode) -> Self {
Self { node, text: None }
}
}
pub fn is_structural_kind(kind: &str) -> bool {
matches!(
kind,
"if_expression" | "match_expression" | "while_expression"
| "for_expression" | "loop_expression" | "return_expression"
| "break_expression" | "continue_expression"
| "if_statement" | "match_statement" | "while_statement"
| "for_statement" | "break_statement" | "continue_statement"
| "return_statement"
| "function_item" | "method_definition" | "struct_item"
| "enum_item" | "trait_item" | "impl_item" | "mod_item"
| "class_definition" | "interface_definition"
| "block" | "block_expression" | "statement_block"
| "let_statement" | "let_declaration" | "expression_statement"
| "assignment_expression" | "augmented_assignment_expression"
| "call_expression"
| "attribute_item" | "decorated_definition"
| "const_item" | "static_item"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ast_node_new() {
let node = AstNode::new(Some(1), "IfExpression", 100, 250);
assert_eq!(node.id, None);
assert_eq!(node.parent_id, Some(1));
assert_eq!(node.kind, "IfExpression");
assert_eq!(node.byte_start, 100);
assert_eq!(node.byte_end, 250);
}
#[test]
fn test_ast_node_span() {
let node = AstNode::new(None, "Block", 50, 150);
assert_eq!(node.span(), (50, 150));
}
#[test]
fn test_ast_node_contains() {
let node = AstNode::new(None, "Block", 50, 150);
assert!(node.contains(50));
assert!(node.contains(100));
assert!(!node.contains(150)); assert!(!node.contains(200));
}
#[test]
fn test_ast_node_len() {
let node = AstNode::new(None, "Block", 50, 150);
assert_eq!(node.len(), 100);
}
#[test]
fn test_ast_node_is_empty() {
let empty = AstNode::new(None, "Empty", 100, 100);
assert!(empty.is_empty());
let non_empty = AstNode::new(None, "Block", 100, 200);
assert!(!non_empty.is_empty());
}
#[test]
fn test_is_structural_kind() {
assert!(is_structural_kind("if_expression"));
assert!(is_structural_kind("function_item"));
assert!(is_structural_kind("block"));
assert!(is_structural_kind("let_statement"));
assert!(!is_structural_kind("identifier"));
assert!(!is_structural_kind("string_literal"));
}
#[test]
fn test_ast_node_with_text_from_node() {
let node = AstNode::new(None, "IfExpression", 10, 50);
let with_text = AstNodeWithText::from(node.clone());
assert_eq!(with_text.node.kind, "IfExpression");
assert_eq!(with_text.text, None);
}
#[test]
fn test_ast_node_with_text_or() {
let node = AstNode::new(None, "IfExpression", 10, 50);
let with_text = AstNodeWithText {
node,
text: Some("if x { y }".to_string()),
};
assert_eq!(with_text.text_or("<missing>"), "if x { y }");
let no_text = AstNodeWithText {
node: AstNode::new(None, "Block", 0, 0),
text: None,
};
assert_eq!(no_text.text_or("<missing>"), "<missing>");
}
}