use crate::{error::Location, ParseError, Value};
use pattern_core::{Pattern, Subject, Symbol};
use std::collections::{HashMap, HashSet};
pub(crate) fn transform_tree(
tree: &tree_sitter::Tree,
input: &str,
) -> Result<Vec<Pattern<Subject>>, ParseError> {
let root_node = tree.root_node();
if root_node.kind() != "gram_pattern" {
return Err(ParseError::new(
Location::from_node(&root_node),
format!("Expected gram_pattern root, found {}", root_node.kind()),
));
}
transform_gram_pattern(&root_node, input)
}
fn transform_gram_pattern(
node: &tree_sitter::Node,
input: &str,
) -> Result<Vec<Pattern<Subject>>, ParseError> {
let mut patterns = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if !child.is_named() {
continue; }
match child.kind() {
"comment" | "record" => {
continue;
}
"node_pattern" => {
patterns.push(transform_node_pattern(&child, input)?);
}
"relationship_pattern" => {
patterns.push(transform_relationship_pattern(&child, input)?);
}
"subject_pattern" => {
patterns.push(transform_subject_pattern(&child, input)?);
}
"annotated_pattern" => {
patterns.push(transform_annotated_pattern(&child, input)?);
}
_ => {
continue;
}
}
}
Ok(patterns)
}
fn transform_node_pattern(
node: &tree_sitter::Node,
input: &str,
) -> Result<Pattern<Subject>, ParseError> {
let subject = transform_subject(node, input)?;
Ok(Pattern::point(subject))
}
fn transform_relationship_pattern(
node: &tree_sitter::Node,
input: &str,
) -> Result<Pattern<Subject>, ParseError> {
let left_node = node
.child_by_field_name("left")
.ok_or_else(|| ParseError::missing_field(node, "left"))?;
let right_node = node
.child_by_field_name("right")
.ok_or_else(|| ParseError::missing_field(node, "right"))?;
let kind_node = node
.child_by_field_name("kind")
.ok_or_else(|| ParseError::missing_field(node, "kind"))?;
let left_pattern = transform_pattern_node(&left_node, input)?;
let right_pattern = transform_pattern_node(&right_node, input)?;
let (first, second) = handle_arrow_type(&kind_node, left_pattern, right_pattern);
let edge_subject = extract_edge_subject(node, input)?;
Ok(Pattern {
value: edge_subject,
elements: vec![first, second],
})
}
fn transform_subject_pattern(
node: &tree_sitter::Node,
input: &str,
) -> Result<Pattern<Subject>, ParseError> {
let subject = transform_subject(node, input)?;
let mut elements = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "subject_pattern_elements" {
let mut elem_cursor = child.walk();
for elem_child in child.children(&mut elem_cursor) {
if !elem_child.is_named() {
continue; }
match elem_child.kind() {
"pattern_reference" => {
elements.push(transform_pattern_reference(&elem_child, input)?);
}
"node_pattern"
| "relationship_pattern"
| "subject_pattern"
| "annotated_pattern" => {
elements.push(transform_pattern_node(&elem_child, input)?);
}
_ => {
continue;
}
}
}
break; }
}
Ok(Pattern {
value: subject,
elements,
})
}
fn transform_annotated_pattern(
node: &tree_sitter::Node,
input: &str,
) -> Result<Pattern<Subject>, ParseError> {
let annotation_subject = if let Some(anno_node) = node.child_by_field_name("annotations") {
extract_annotation_subject(&anno_node, input)?
} else {
Subject {
identity: Symbol(String::new()),
labels: HashSet::new(),
properties: HashMap::new(),
}
};
let pattern_node = node
.child_by_field_name("elements")
.ok_or_else(|| ParseError::missing_field(node, "elements"))?;
let element = transform_pattern_node(&pattern_node, input)?;
Ok(Pattern {
value: annotation_subject,
elements: vec![element],
})
}
fn transform_pattern_node(
node: &tree_sitter::Node,
input: &str,
) -> Result<Pattern<Subject>, ParseError> {
match node.kind() {
"node_pattern" => transform_node_pattern(node, input),
"relationship_pattern" => transform_relationship_pattern(node, input),
"subject_pattern" => transform_subject_pattern(node, input),
"annotated_pattern" => transform_annotated_pattern(node, input),
_ => Err(ParseError::from_node(
node,
format!("Unknown pattern type: {}", node.kind()),
)),
}
}
fn transform_subject(node: &tree_sitter::Node, input: &str) -> Result<Subject, ParseError> {
let mut identity = Symbol(String::new());
let mut labels = HashSet::new();
let mut properties = HashMap::new();
if let Some(id_node) = node.child_by_field_name("identifier") {
let id_text = extract_identifier(&id_node, input)?;
identity = Symbol(id_text);
}
if let Some(labels_node) = node.child_by_field_name("labels") {
labels = extract_labels(&labels_node, input)?;
}
if let Some(record_node) = node.child_by_field_name("record") {
properties = transform_record(&record_node, input)?;
}
Ok(Subject {
identity,
labels,
properties,
})
}
fn transform_pattern_reference(
node: &tree_sitter::Node,
input: &str,
) -> Result<Pattern<Subject>, ParseError> {
let mut identity = Symbol(String::new());
if let Some(id_node) = node.child_by_field_name("identifier") {
let id_text = extract_identifier(&id_node, input)?;
identity = Symbol(id_text);
}
Ok(Pattern::point(Subject {
identity,
labels: HashSet::new(),
properties: HashMap::new(),
}))
}
fn transform_record(
node: &tree_sitter::Node,
input: &str,
) -> Result<HashMap<String, pattern_core::Value>, ParseError> {
let mut properties = HashMap::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "record_property" {
let key_node = child
.child_by_field_name("key")
.ok_or_else(|| ParseError::missing_field(&child, "key"))?;
let key_text = extract_identifier(&key_node, input)?;
let value_node = child
.child_by_field_name("value")
.ok_or_else(|| ParseError::missing_field(&child, "value"))?;
let value = transform_value_to_pattern_value(&value_node, input)?;
properties.insert(key_text, value);
}
}
Ok(properties)
}
fn transform_value_to_pattern_value(
node: &tree_sitter::Node,
input: &str,
) -> Result<pattern_core::Value, ParseError> {
let codec_value = Value::from_tree_sitter_node(node, input)?;
match codec_value {
Value::String(s) => Ok(pattern_core::Value::VString(s)),
Value::Integer(i) => Ok(pattern_core::Value::VInteger(i)),
Value::Decimal(f) => Ok(pattern_core::Value::VDecimal(f)),
Value::Boolean(b) => Ok(pattern_core::Value::VBoolean(b)),
Value::Array(arr) => {
let converted: Result<Vec<_>, _> =
arr.into_iter().map(value_to_pattern_value).collect();
Ok(pattern_core::Value::VArray(converted?))
}
Value::Range { lower, upper } => {
Ok(pattern_core::Value::VRange(pattern_core::RangeValue {
lower: Some(lower as f64),
upper: Some(upper as f64),
}))
}
Value::TaggedString { tag, content } => {
Ok(pattern_core::Value::VTaggedString { tag, content })
}
}
}
fn value_to_pattern_value(v: Value) -> Result<pattern_core::Value, ParseError> {
match v {
Value::String(s) => Ok(pattern_core::Value::VString(s)),
Value::Integer(i) => Ok(pattern_core::Value::VInteger(i)),
Value::Decimal(f) => Ok(pattern_core::Value::VDecimal(f)),
Value::Boolean(b) => Ok(pattern_core::Value::VBoolean(b)),
Value::Array(arr) => {
let converted: Result<Vec<_>, _> =
arr.into_iter().map(value_to_pattern_value).collect();
Ok(pattern_core::Value::VArray(converted?))
}
Value::Range { lower, upper } => {
Ok(pattern_core::Value::VRange(pattern_core::RangeValue {
lower: Some(lower as f64),
upper: Some(upper as f64),
}))
}
Value::TaggedString { tag, content } => {
Ok(pattern_core::Value::VTaggedString { tag, content })
}
}
}
fn handle_arrow_type(
kind_node: &tree_sitter::Node,
left: Pattern<Subject>,
right: Pattern<Subject>,
) -> (Pattern<Subject>, Pattern<Subject>) {
match kind_node.kind() {
"left_arrow" => (right, left), "right_arrow" | "bidirectional_arrow" | "undirected_arrow" => (left, right), _ => (left, right), }
}
fn extract_edge_subject(node: &tree_sitter::Node, input: &str) -> Result<Subject, ParseError> {
let kind_node = node
.child_by_field_name("kind")
.ok_or_else(|| ParseError::missing_field(node, "kind"))?;
let identifier = if let Some(id_node) = kind_node.child_by_field_name("identifier") {
extract_identifier(&id_node, input)?
} else {
String::new()
};
let labels = if let Some(labels_node) = kind_node.child_by_field_name("labels") {
extract_labels(&labels_node, input)?
} else {
HashSet::new()
};
let properties = if let Some(record_node) = kind_node.child_by_field_name("record") {
transform_record(&record_node, input)?
} else {
HashMap::new()
};
Ok(Subject {
identity: Symbol(identifier),
labels,
properties,
})
}
fn extract_annotation_subject(
node: &tree_sitter::Node,
input: &str,
) -> Result<Subject, ParseError> {
let mut properties = HashMap::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "annotation" {
let key_node = child.child_by_field_name("key").ok_or_else(|| {
ParseError::from_node(&child, "Annotation missing key field".to_string())
})?;
let key = key_node
.utf8_text(input.as_bytes())
.map_err(|e| ParseError::from_node(&key_node, format!("UTF-8 error: {}", e)))?
.to_string();
if let Some(value_node) = child.child_by_field_name("value") {
let value = transform_value_to_pattern_value(&value_node, input)?;
properties.insert(key, value);
}
}
}
Ok(Subject {
identity: Symbol(String::new()), labels: HashSet::new(), properties, })
}
fn extract_identifier(node: &tree_sitter::Node, input: &str) -> Result<String, ParseError> {
let text = node
.utf8_text(input.as_bytes())
.map_err(|e| ParseError::from_node(node, format!("UTF-8 error: {}", e)))?;
if text.starts_with('"') && text.ends_with('"') && text.len() >= 2 {
Ok(text[1..text.len() - 1].to_string())
} else {
Ok(text.to_string())
}
}
fn extract_labels(node: &tree_sitter::Node, input: &str) -> Result<HashSet<String>, ParseError> {
let mut labels = HashSet::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "symbol" {
let label = child
.utf8_text(input.as_bytes())
.map_err(|e| ParseError::from_node(&child, format!("UTF-8 error: {}", e)))?
.to_string();
labels.insert(label);
}
}
Ok(labels)
}