use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, NumberOrString, Position, Range};
use tree_sitter::{Node, Parser, Tree};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DslCompletionContext {
TopLevel,
QueryObject,
AggsObject,
BoolQuery,
SortObject,
Default,
}
#[derive(Debug, Clone)]
pub struct DslParseResult {
pub tree: Option<Tree>,
pub diagnostics: Vec<Diagnostic>,
pub success: bool,
pub source: String,
}
pub struct DslParser {
parser: Parser,
source: String, }
impl DslParser {
pub fn new() -> Self {
let language = tree_sitter::Language::from(tree_sitter_json::LANGUAGE);
let mut parser = Parser::new();
parser
.set_language(&language)
.expect("Failed to set JSON language");
Self {
parser,
source: String::new(),
}
}
pub fn parse(&mut self, dsl: &str) -> Vec<Diagnostic> {
self.source = dsl.to_string();
let (_, diagnostics) = self.parse_with_tree(dsl);
diagnostics
}
pub fn parse_with_tree(&mut self, dsl: &str) -> (Option<Tree>, Vec<Diagnostic>) {
let tree = self.parser.parse(dsl, None);
let mut diagnostics = Vec::new();
if let Some(tree) = &tree {
self.collect_errors(tree.root_node(), dsl, &mut diagnostics);
} else {
diagnostics.push(Diagnostic {
range: Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 0,
character: dsl.len() as u32,
},
},
severity: Some(DiagnosticSeverity::ERROR),
code: Some(NumberOrString::String("DSL_PARSE_ERROR".to_string())),
code_description: None,
source: Some("tree-sitter-json".to_string()),
message: "Failed to parse JSON".to_string(),
related_information: None,
tags: None,
data: None,
});
}
if diagnostics
.iter()
.all(|d| d.severity != Some(DiagnosticSeverity::ERROR))
{
self.validate_dsl_structure(tree.as_ref(), dsl, &mut diagnostics);
}
(tree, diagnostics)
}
#[allow(clippy::only_used_in_recursion)]
fn collect_errors(&self, node: Node, source: &str, diagnostics: &mut Vec<Diagnostic>) {
if node.is_error() || node.is_missing() {
let start_byte = node.start_byte();
let end_byte = node.end_byte();
let start_point = node.start_position();
let end_point = node.end_position();
let node_text = if start_byte < source.len() && end_byte <= source.len() {
&source[start_byte..end_byte]
} else {
""
};
if node_text.trim().is_empty() && !node.is_missing() {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_errors(child, source, diagnostics);
}
return;
}
diagnostics.push(Diagnostic {
range: Range {
start: Position {
line: start_point.row as u32,
character: start_point.column as u32,
},
end: Position {
line: end_point.row as u32,
character: end_point.column as u32,
},
},
severity: Some(if node.is_error() {
DiagnosticSeverity::ERROR
} else {
DiagnosticSeverity::WARNING
}),
code: Some(NumberOrString::String("DSL_SYNTAX_ERROR".to_string())),
code_description: None,
source: Some("tree-sitter-json".to_string()),
message: if node.is_error() {
format!("JSON syntax error: {}", node_text)
} else {
"Missing JSON element".to_string()
},
related_information: None,
tags: None,
data: None,
});
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_errors(child, source, diagnostics);
}
}
fn validate_dsl_structure(
&self,
tree: Option<&Tree>,
json: &str,
diagnostics: &mut Vec<Diagnostic>,
) {
if let Some(tree) = tree {
let has_query = json.contains("\"query\"") || json.contains("'query'");
let has_aggs = json.contains("\"aggs\"") || json.contains("\"aggregations\"");
let has_sort = json.contains("\"sort\"");
if !has_query && !has_aggs && !has_sort {
diagnostics.push(Diagnostic {
range: Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 0,
character: json.len() as u32,
},
},
severity: Some(DiagnosticSeverity::HINT),
code: Some(NumberOrString::String("DSL_HINT".to_string())),
code_description: None,
source: Some("elasticsearch-dsl".to_string()),
message:
"Elasticsearch DSL typically includes 'query', 'aggs', or 'sort' fields"
.to_string(),
related_information: None,
tags: None,
data: None,
});
}
self.validate_query_structure(tree, json, diagnostics);
}
}
fn validate_query_structure(&self, tree: &Tree, json: &str, diagnostics: &mut Vec<Diagnostic>) {
let root = tree.root_node();
let valid_query_types = vec![
"match",
"match_all",
"match_none",
"match_phrase",
"match_phrase_prefix",
"multi_match",
"common",
"query_string",
"simple_query_string",
"term",
"terms",
"range",
"exists",
"prefix",
"wildcard",
"regexp",
"fuzzy",
"type",
"ids",
"constant_score",
"bool",
"boosting",
"dis_max",
"function_score",
"script_score",
"percolate",
];
if let Some(query_node) = self.find_field_in_object(root, json, "query") {
let query_value = self.get_node_text(query_node, json);
if query_node.kind() == "object" {
let mut found_valid_query = false;
self.check_query_types_recursive(
query_node,
json,
&valid_query_types,
&mut found_valid_query,
);
if !found_valid_query {
let range = self.node_range(query_node);
diagnostics.push(Diagnostic {
range,
severity: Some(DiagnosticSeverity::WARNING),
code: Some(NumberOrString::String("DSL_QUERY_TYPE".to_string())),
code_description: None,
source: Some("elasticsearch-dsl".to_string()),
message: "Query object should contain a valid query type (match, term, bool, etc.)".to_string(),
related_information: None,
tags: None,
data: None,
});
}
} else if query_value.trim().is_empty() {
let range = self.node_range(query_node);
diagnostics.push(Diagnostic {
range,
severity: Some(DiagnosticSeverity::WARNING),
code: Some(NumberOrString::String("DSL_EMPTY_QUERY".to_string())),
code_description: None,
source: Some("elasticsearch-dsl".to_string()),
message: "Query field should not be empty".to_string(),
related_information: None,
tags: None,
data: None,
});
}
}
}
fn find_field_in_object<'a>(
&self,
object_node: Node<'a>,
source: &str,
field_name: &str,
) -> Option<Node<'a>> {
if object_node.kind() != "object" {
return None;
}
let mut cursor = object_node.walk();
for child in object_node.children(&mut cursor) {
if child.kind() == "pair" {
if let Some(key_node) = child.child(0) {
if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
let key = key_text.trim_matches('"').trim_matches('\'');
if key == field_name {
return child.child(1);
}
}
}
}
}
None
}
#[allow(clippy::only_used_in_recursion)]
fn check_query_types_recursive<'a>(
&self,
node: Node<'a>,
source: &str,
valid_types: &[&str],
found: &mut bool,
) {
if *found {
return;
}
let node_kind = node.kind();
if node_kind == "pair" {
if let Some(key_node) = node.child(0) {
if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
let key = key_text.trim_matches('"').trim_matches('\'');
if valid_types.contains(&key) {
*found = true;
return;
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.check_query_types_recursive(child, source, valid_types, found);
}
}
fn get_node_text<'a>(&self, node: Node<'a>, source: &str) -> String {
node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
}
pub fn extract_fields(&self, tree: &Tree, source: &str) -> Vec<String> {
let mut fields = Vec::new();
self.extract_fields_recursive(tree.root_node(), source, &mut fields);
fields
}
#[allow(clippy::only_used_in_recursion)]
fn extract_fields_recursive<'a>(&self, node: Node<'a>, source: &str, fields: &mut Vec<String>) {
let node_kind = node.kind();
if node_kind == "pair" {
if let Some(key_node) = node.child(0) {
if key_node.kind() == "string" {
if let Ok(text) = key_node.utf8_text(source.as_bytes()) {
let field_name = text.trim_matches('"').trim_matches('\'');
if !field_name.is_empty() && !fields.contains(&field_name.to_string()) {
fields.push(field_name.to_string());
}
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.extract_fields_recursive(child, source, fields);
}
}
pub fn get_node_at_position<'a>(&self, tree: &'a Tree, position: Position) -> Option<Node<'a>> {
let root = tree.root_node();
let point = tree_sitter::Point {
row: position.line as usize,
column: position.character as usize,
};
root.descendant_for_point_range(point, point)
}
pub fn node_text(&self, node: Node, source: &str) -> String {
node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
}
pub fn node_range(&self, node: Node) -> Range {
let start = node.start_position();
let end = node.end_position();
Range {
start: Position {
line: start.row as u32,
character: start.column as u32,
},
end: Position {
line: end.row as u32,
character: end.column as u32,
},
}
}
pub fn analyze_completion_context(&self, node: Node, source: &str) -> DslCompletionContext {
let mut current = Some(node);
while let Some(n) = current {
let kind = n.kind();
if kind == "pair" {
if let Some(key_node) = n.child(0) {
if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
let key = key_text.trim_matches('"').trim_matches('\'');
if let Some(value_node) = n.child(1) {
if value_node.kind() == "object" {
match key {
"query" => return DslCompletionContext::QueryObject,
"aggs" | "aggregations" => {
return DslCompletionContext::AggsObject
}
"bool" => return DslCompletionContext::BoolQuery,
"sort" => return DslCompletionContext::SortObject,
_ => {}
}
}
}
}
}
}
if kind == "object" {
if let Some(parent) = n.parent() {
if parent.kind() == "pair" {
if let Some(key_node) = parent.child(0) {
if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
let key = key_text.trim_matches('"').trim_matches('\'');
match key {
"query" => return DslCompletionContext::QueryObject,
"aggs" | "aggregations" => {
return DslCompletionContext::AggsObject
}
"bool" => return DslCompletionContext::BoolQuery,
"sort" => return DslCompletionContext::SortObject,
_ => {}
}
}
}
}
}
if n.parent().is_none()
|| (n.parent().is_some() && n.parent().unwrap().kind() == "document")
{
return DslCompletionContext::TopLevel;
}
}
current = n.parent();
}
DslCompletionContext::Default
}
pub fn is_in_field_object(&self, node: Node, source: &str, field_name: &str) -> bool {
let mut current = Some(node);
while let Some(n) = current {
if n.kind() == "object" {
if let Some(parent) = n.parent() {
if parent.kind() == "pair" {
if let Some(key_node) = parent.child(0) {
if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
let key = key_text.trim_matches('"').trim_matches('\'');
if key == field_name {
return true;
}
}
}
}
}
}
current = n.parent();
}
false
}
pub fn extract_field_name(&self, node: Node, source: &str) -> Option<String> {
if node.kind() == "pair" {
if let Some(key_node) = node.child(0) {
if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
let key = key_text.trim_matches('"').trim_matches('\'');
return Some(key.to_string());
}
}
}
if node.kind() == "string" {
if let Ok(text) = node.utf8_text(source.as_bytes()) {
let key = text.trim_matches('"').trim_matches('\'');
if let Some(parent) = node.parent() {
if parent.kind() == "pair" && parent.child(0) == Some(node) {
return Some(key.to_string());
}
}
}
}
None
}
}
impl Default for DslParser {
fn default() -> Self {
Self::new()
}
}