use std::collections::BTreeMap;
use crate::core::{
Document, ElementData, ErrorKind, NamespaceUri, NodeId, NodeKind, QName, Span, XmlError,
XmlResult,
};
use crate::security::QuerySecurityConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Query {
steps: Vec<QueryStep>,
source: String,
}
impl Query {
pub fn parse(source: &str) -> XmlResult<Self> {
Parser::new(source).parse()
}
pub fn evaluate(&self, document: &Document) -> XmlResult<QueryResult> {
self.evaluate_with_context(document, &NamespaceContext::default())
}
pub fn evaluate_with_context(
&self,
document: &Document,
namespaces: &NamespaceContext,
) -> XmlResult<QueryResult> {
self.evaluate_with_options(document, namespaces, &QuerySecurityConfig::default())
}
pub fn evaluate_with_options(
&self,
document: &Document,
namespaces: &NamespaceContext,
security: &QuerySecurityConfig,
) -> XmlResult<QueryResult> {
let Some(root) = document.root() else {
return Ok(QueryResult::default());
};
let mut evaluator = Evaluator::new(document, namespaces, security);
evaluator.evaluate(root, &self.steps)
}
pub fn source(&self) -> &str {
&self.source
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct QueryResult {
values: Vec<QueryValue>,
}
impl QueryResult {
pub fn values(&self) -> &[QueryValue] {
&self.values
}
pub fn nodes(&self) -> Vec<NodeId> {
self.values
.iter()
.filter_map(|value| match value {
QueryValue::Node(id) => Some(*id),
_ => None,
})
.collect()
}
pub fn strings(&self) -> Vec<&str> {
self.values
.iter()
.filter_map(|value| match value {
QueryValue::Text(value) | QueryValue::Attribute { value, .. } => {
Some(value.as_str())
}
QueryValue::Node(_) => None,
})
.collect()
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QueryValue {
Node(NodeId),
Text(String),
Attribute { name: QName, value: String },
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct NamespaceContext {
aliases: BTreeMap<String, NamespaceUri>,
}
impl NamespaceContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_alias(
mut self,
alias: impl Into<String>,
uri: impl Into<String>,
) -> XmlResult<Self> {
self.aliases.insert(alias.into(), NamespaceUri::new(uri)?);
Ok(self)
}
pub fn resolve(&self, alias: &str) -> Option<&NamespaceUri> {
self.aliases.get(alias)
}
}
pub trait DocumentQueryExt {
fn query(&self, source: &str) -> XmlResult<QueryResult>;
fn query_with_context(
&self,
source: &str,
namespaces: &NamespaceContext,
) -> XmlResult<QueryResult>;
}
impl DocumentQueryExt for Document {
fn query(&self, source: &str) -> XmlResult<QueryResult> {
Query::parse(source)?.evaluate(self)
}
fn query_with_context(
&self,
source: &str,
namespaces: &NamespaceContext,
) -> XmlResult<QueryResult> {
Query::parse(source)?.evaluate_with_context(self, namespaces)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum QueryStep {
Root,
Child(NodeTest),
Descendant(NodeTest),
Attribute(NameTest),
Text,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct NodeTest {
name: NameTest,
predicate: Option<Predicate>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Predicate {
attribute: NameTest,
value: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct NameTest {
prefix: Option<String>,
local: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum TokenKind {
Slash,
DoubleSlash,
At,
LBracket,
RBracket,
Eq,
LParen,
RParen,
Name(String),
String(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Token {
kind: TokenKind,
position: usize,
}
fn lex(source: &str) -> XmlResult<Vec<Token>> {
let bytes = source.as_bytes();
let mut position = 0;
let mut tokens = Vec::new();
while position < bytes.len() {
match bytes[position] {
b'/' if bytes.get(position + 1) == Some(&b'/') => {
tokens.push(Token {
kind: TokenKind::DoubleSlash,
position,
});
position += 2;
}
b'/' => {
tokens.push(Token {
kind: TokenKind::Slash,
position,
});
position += 1;
}
b'@' => {
tokens.push(Token {
kind: TokenKind::At,
position,
});
position += 1;
}
b'[' => {
tokens.push(Token {
kind: TokenKind::LBracket,
position,
});
position += 1;
}
b']' => {
tokens.push(Token {
kind: TokenKind::RBracket,
position,
});
position += 1;
}
b'=' => {
tokens.push(Token {
kind: TokenKind::Eq,
position,
});
position += 1;
}
b'(' => {
tokens.push(Token {
kind: TokenKind::LParen,
position,
});
position += 1;
}
b')' => {
tokens.push(Token {
kind: TokenKind::RParen,
position,
});
position += 1;
}
b'\'' | b'"' => {
let quote = bytes[position];
let start = position;
position += 1;
let value_start = position;
while position < bytes.len() && bytes[position] != quote {
position += 1;
}
if position >= bytes.len() {
return Err(query_error(source, start, "unterminated string literal"));
}
let value = source[value_start..position].to_owned();
tokens.push(Token {
kind: TokenKind::String(value),
position: start,
});
position += 1;
}
ch if ch.is_ascii_whitespace() => position += 1,
_ => {
let start = position;
while position < bytes.len() && is_name_byte(bytes[position]) {
position += 1;
}
if start == position {
return Err(query_error(
source,
position,
format!("unexpected character `{}`", bytes[position] as char),
));
}
tokens.push(Token {
kind: TokenKind::Name(source[start..position].to_owned()),
position: start,
});
}
}
}
Ok(tokens)
}
fn is_name_byte(byte: u8) -> bool {
byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'-' | b'.' | b':')
}
struct Parser<'a> {
source: &'a str,
tokens: Vec<Token>,
position: usize,
}
impl<'a> Parser<'a> {
fn new(source: &'a str) -> Self {
Self {
source,
tokens: lex(source).unwrap_or_default(),
position: 0,
}
}
fn parse(mut self) -> XmlResult<Query> {
self.tokens = lex(self.source)?;
if self.tokens.is_empty() {
return Err(query_error(self.source, 0, "query cannot be empty"));
}
self.expect_slash_like_start()?;
let mut steps = vec![QueryStep::Root];
while !self.is_eof() {
let axis = self.consume_axis()?;
let step = self.parse_step(axis)?;
steps.push(step);
}
Ok(Query {
steps,
source: self.source.to_owned(),
})
}
fn expect_slash_like_start(&mut self) -> XmlResult<()> {
match self.peek_kind() {
Some(TokenKind::Slash | TokenKind::DoubleSlash) => Ok(()),
_ => Err(query_error(
self.source,
self.peek_position(),
"query must start with `/` or `//`",
)),
}
}
fn consume_axis(&mut self) -> XmlResult<Axis> {
match self.next_kind() {
Some(TokenKind::Slash) => Ok(Axis::Child),
Some(TokenKind::DoubleSlash) => Ok(Axis::Descendant),
_ => Err(query_error(
self.source,
self.peek_position(),
"expected `/` or `//`",
)),
}
}
fn parse_step(&mut self, axis: Axis) -> XmlResult<QueryStep> {
if self.consume_at() {
let name = self.parse_name()?;
return Ok(QueryStep::Attribute(name));
}
let name = self.parse_name()?;
if name.prefix.is_none() && name.local == "text" && self.consume_lparen() {
self.expect_rparen()?;
return Ok(QueryStep::Text);
}
let predicate = if self.consume_lbracket() {
Some(self.parse_predicate()?)
} else {
None
};
let test = NodeTest { name, predicate };
Ok(match axis {
Axis::Child => QueryStep::Child(test),
Axis::Descendant => QueryStep::Descendant(test),
})
}
fn parse_predicate(&mut self) -> XmlResult<Predicate> {
if !self.consume_at() {
return Err(query_error(
self.source,
self.peek_position(),
"predicate must select an attribute with `@`",
));
}
let attribute = self.parse_name()?;
self.expect_eq()?;
let value = self.parse_string()?;
self.expect_rbracket()?;
Ok(Predicate { attribute, value })
}
fn parse_name(&mut self) -> XmlResult<NameTest> {
match self.next() {
Some(Token {
kind: TokenKind::Name(name),
position,
}) => name_test(self.source, position, &name),
Some(token) => Err(query_error(
self.source,
token.position,
"expected XML name in query step",
)),
None => Err(query_error(
self.source,
self.source.len(),
"expected XML name in query step",
)),
}
}
fn parse_string(&mut self) -> XmlResult<String> {
match self.next() {
Some(Token {
kind: TokenKind::String(value),
..
}) => Ok(value),
Some(token) => Err(query_error(
self.source,
token.position,
"expected string literal",
)),
None => Err(query_error(
self.source,
self.source.len(),
"expected string literal",
)),
}
}
fn consume_at(&mut self) -> bool {
self.consume(|kind| matches!(kind, TokenKind::At))
}
fn consume_lparen(&mut self) -> bool {
self.consume(|kind| matches!(kind, TokenKind::LParen))
}
fn consume_lbracket(&mut self) -> bool {
self.consume(|kind| matches!(kind, TokenKind::LBracket))
}
fn expect_rparen(&mut self) -> XmlResult<()> {
self.expect(|kind| matches!(kind, TokenKind::RParen), "expected `)`")
}
fn expect_rbracket(&mut self) -> XmlResult<()> {
self.expect(|kind| matches!(kind, TokenKind::RBracket), "expected `]`")
}
fn expect_eq(&mut self) -> XmlResult<()> {
self.expect(|kind| matches!(kind, TokenKind::Eq), "expected `=`")
}
fn expect(&mut self, matches: impl FnOnce(&TokenKind) -> bool, message: &str) -> XmlResult<()> {
match self.next() {
Some(token) if matches(&token.kind) => Ok(()),
Some(token) => Err(query_error(self.source, token.position, message)),
None => Err(query_error(self.source, self.source.len(), message)),
}
}
fn consume(&mut self, matches: impl FnOnce(&TokenKind) -> bool) -> bool {
if self
.tokens
.get(self.position)
.is_some_and(|token| matches(&token.kind))
{
self.position += 1;
true
} else {
false
}
}
fn next_kind(&mut self) -> Option<TokenKind> {
self.next().map(|token| token.kind)
}
fn next(&mut self) -> Option<Token> {
let token = self.tokens.get(self.position).cloned();
if token.is_some() {
self.position += 1;
}
token
}
fn peek_kind(&self) -> Option<&TokenKind> {
self.tokens.get(self.position).map(|token| &token.kind)
}
fn peek_position(&self) -> usize {
self.tokens
.get(self.position)
.map(|token| token.position)
.unwrap_or(self.source.len())
}
fn is_eof(&self) -> bool {
self.position >= self.tokens.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Axis {
Child,
Descendant,
}
fn name_test(source: &str, position: usize, raw: &str) -> XmlResult<NameTest> {
let mut parts = raw.split(':');
let first = parts.next().expect("split always yields one part");
match (parts.next(), parts.next()) {
(Some(local), None) if !first.is_empty() && !local.is_empty() => Ok(NameTest {
prefix: Some(first.to_owned()),
local: local.to_owned(),
}),
(None, None) if !first.is_empty() => Ok(NameTest {
prefix: None,
local: first.to_owned(),
}),
_ => Err(query_error(source, position, "invalid qualified name")),
}
}
struct Evaluator<'a> {
document: &'a Document,
namespaces: &'a NamespaceContext,
security: &'a QuerySecurityConfig,
steps: usize,
}
impl<'a> Evaluator<'a> {
fn new(
document: &'a Document,
namespaces: &'a NamespaceContext,
security: &'a QuerySecurityConfig,
) -> Self {
Self {
document,
namespaces,
security,
steps: 0,
}
}
fn evaluate(&mut self, root: NodeId, steps: &[QueryStep]) -> XmlResult<QueryResult> {
let mut values = match steps.get(1) {
Some(step) => self.apply_first_step(root, step)?,
None => vec![QueryValue::Node(root)],
};
for step in steps.iter().skip(2) {
values = self.apply_step(values, step)?;
}
Ok(QueryResult { values })
}
fn apply_step(
&mut self,
values: Vec<QueryValue>,
step: &QueryStep,
) -> XmlResult<Vec<QueryValue>> {
let mut next = Vec::new();
for value in values {
let QueryValue::Node(node_id) = value else {
continue;
};
match step {
QueryStep::Root => next.push(QueryValue::Node(node_id)),
QueryStep::Child(test) => {
for child in self.element_children(node_id)? {
if self.node_matches(child, test)? {
next.push(QueryValue::Node(child));
}
}
}
QueryStep::Descendant(test) => {
for descendant in self.descendants(node_id)? {
if self.node_matches(descendant, test)? {
next.push(QueryValue::Node(descendant));
}
}
}
QueryStep::Attribute(name) => {
if let Some(element) = self.element(node_id)? {
for attribute in element.attributes() {
if self.name_matches(attribute.name(), name)? {
next.push(QueryValue::Attribute {
name: attribute.name().clone(),
value: attribute.value().to_owned(),
});
}
}
}
}
QueryStep::Text => {
for child in self.element_children(node_id)? {
if let NodeKind::Text(value) = self.document.node(child)?.kind() {
next.push(QueryValue::Text(value.clone()));
}
}
}
}
}
Ok(next)
}
fn apply_first_step(&mut self, root: NodeId, step: &QueryStep) -> XmlResult<Vec<QueryValue>> {
Ok(match step {
QueryStep::Root => vec![QueryValue::Node(root)],
QueryStep::Child(test) if self.node_matches(root, test)? => {
vec![QueryValue::Node(root)]
}
QueryStep::Descendant(test) => {
let mut matches = Vec::new();
if self.node_matches(root, test)? {
matches.push(QueryValue::Node(root));
}
matches.extend(
self.descendants(root)?
.into_iter()
.filter_map(|node| match self.node_matches(node, test) {
Ok(true) => Some(Ok(QueryValue::Node(node))),
Ok(false) => None,
Err(error) => Some(Err(error)),
})
.collect::<XmlResult<Vec<_>>>()?,
);
matches
}
QueryStep::Attribute(_) | QueryStep::Text => Vec::new(),
QueryStep::Child(_) => Vec::new(),
})
}
fn element_children(&mut self, node_id: NodeId) -> XmlResult<Vec<NodeId>> {
self.bump()?;
match self.document.node(node_id)?.kind() {
NodeKind::Element(element) => Ok(element.children().to_vec()),
_ => Ok(Vec::new()),
}
}
fn descendants(&mut self, node_id: NodeId) -> XmlResult<Vec<NodeId>> {
let mut descendants = Vec::new();
let mut stack = self.element_children(node_id)?;
stack.reverse();
while let Some(current) = stack.pop() {
self.bump()?;
descendants.push(current);
let mut children = self.element_children(current)?;
children.reverse();
stack.extend(children);
}
Ok(descendants)
}
fn node_matches(&mut self, node_id: NodeId, test: &NodeTest) -> XmlResult<bool> {
self.bump()?;
let Some(element) = self.element(node_id)? else {
return Ok(false);
};
if !self.name_matches(element.name(), &test.name)? {
return Ok(false);
}
match &test.predicate {
Some(predicate) => self.predicate_matches(element, predicate),
None => Ok(true),
}
}
fn predicate_matches(&self, element: &ElementData, predicate: &Predicate) -> XmlResult<bool> {
for attribute in element.attributes() {
if self.name_matches(attribute.name(), &predicate.attribute)?
&& attribute.value() == predicate.value
{
return Ok(true);
}
}
Ok(false)
}
fn name_matches(&self, name: &QName, test: &NameTest) -> XmlResult<bool> {
if name.local() != test.local {
return Ok(false);
}
match &test.prefix {
Some(prefix) => {
let uri = self.namespaces.resolve(prefix).ok_or_else(|| {
XmlError::new(
ErrorKind::UnknownNamespacePrefix,
format!("namespace alias `{prefix}` is not declared"),
)
})?;
Ok(name.namespace_uri().is_some_and(|name_uri| name_uri == uri))
}
None => Ok(name.prefix().is_none() && name.namespace_uri().is_none()),
}
}
fn element(&self, node_id: NodeId) -> XmlResult<Option<&ElementData>> {
Ok(match self.document.node(node_id)?.kind() {
NodeKind::Element(element) => Some(element),
_ => None,
})
}
fn bump(&mut self) -> XmlResult<()> {
self.steps += 1;
self.security.check_steps(self.steps)
}
}
fn query_error(source: &str, position: usize, message: impl Into<String>) -> XmlError {
XmlError::new(ErrorKind::Query, message).with_span(span_for_byte(source, position))
}
fn span_for_byte(source: &str, byte_position: usize) -> Span {
let mut line = 1;
let mut column = 1;
for (index, ch) in source.char_indices() {
if index >= byte_position {
break;
}
if ch == '\n' {
line += 1;
column = 1;
} else {
column += 1;
}
}
Span::new(line, column)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Attribute, NamespaceDeclaration};
use crate::parser;
fn sample_document() -> XmlResult<Document> {
let mut document = Document::new();
let root = document.add_root_element(QName::qualified("doc", "Root", "urn:doc")?)?;
document
.add_namespace_declaration(root, NamespaceDeclaration::prefixed("doc", "urn:doc")?)?;
let first = document.add_element(root, QName::qualified("doc", "Item", "urn:doc")?)?;
document.add_attribute(first, Attribute::new(QName::new("code")?, "A1"))?;
document.add_attribute(
first,
Attribute::new(QName::qualified("doc", "kind", "urn:doc")?, "primary"),
)?;
let first_name = document.add_element(first, QName::new("Name")?)?;
document.add_text(first_name, "Alpha")?;
let second = document.add_element(root, QName::qualified("doc", "Item", "urn:doc")?)?;
document.add_attribute(second, Attribute::new(QName::new("code")?, "B2"))?;
let second_name = document.add_element(second, QName::new("Name")?)?;
document.add_text(second_name, "Beta")?;
let note = document.add_element(root, QName::new("Note")?)?;
document.add_text(note, "Loose")?;
Ok(document)
}
fn ns() -> NamespaceContext {
NamespaceContext::new()
.with_alias("d", "urn:doc")
.expect("namespace alias")
}
#[test]
fn query_lexer_tokenizes_path() -> XmlResult<()> {
let tokens = lex("/d:Root//Name[@code='A1']/text()")?;
assert!(matches!(tokens[0].kind, TokenKind::Slash));
assert!(tokens
.iter()
.any(|token| matches!(token.kind, TokenKind::DoubleSlash)));
assert!(tokens
.iter()
.any(|token| matches!(token.kind, TokenKind::String(_))));
Ok(())
}
#[test]
fn query_parser_builds_absolute_path() -> XmlResult<()> {
let query = Query::parse("/Root/Child")?;
assert_eq!(query.steps.len(), 3);
assert_eq!(query.source(), "/Root/Child");
Ok(())
}
#[test]
fn query_evaluator_selects_absolute_path() -> XmlResult<()> {
let document = sample_document()?;
let result = document.query_with_context("/d:Root/d:Item", &ns())?;
assert_eq!(result.len(), 2);
Ok(())
}
#[test]
fn query_evaluator_selects_descendants() -> XmlResult<()> {
let document = sample_document()?;
let result = document.query("//Name")?;
assert_eq!(result.len(), 2);
Ok(())
}
#[test]
fn query_evaluator_selects_attribute() -> XmlResult<()> {
let document = sample_document()?;
let result = document.query_with_context("/d:Root/d:Item/@code", &ns())?;
assert_eq!(result.strings(), vec!["A1", "B2"]);
Ok(())
}
#[test]
fn query_evaluator_selects_text() -> XmlResult<()> {
let document = sample_document()?;
let result = document.query_with_context("/d:Root/d:Item/Name/text()", &ns())?;
assert_eq!(result.strings(), vec!["Alpha", "Beta"]);
Ok(())
}
#[test]
fn query_evaluator_filters_by_attribute_predicate() -> XmlResult<()> {
let document = sample_document()?;
let result =
document.query_with_context("/d:Root/d:Item[@code='A1']/Name/text()", &ns())?;
assert_eq!(result.strings(), vec!["Alpha"]);
Ok(())
}
#[test]
fn query_namespaces_alias_filters_by_namespaced_attribute_predicate() -> XmlResult<()> {
let document = sample_document()?;
let result =
document.query_with_context("/d:Root/d:Item[@d:kind='primary']/Name/text()", &ns())?;
assert_eq!(result.strings(), vec!["Alpha"]);
Ok(())
}
#[test]
fn query_namespaces_default_namespace_requires_alias() -> XmlResult<()> {
let document =
parser::parse_str(r#"<Root xmlns="urn:default"><Child>value</Child></Root>"#)?;
let namespaces = NamespaceContext::new().with_alias("d", "urn:default")?;
assert!(document.query("/Root")?.is_empty());
assert_eq!(
document
.query_with_context("/d:Root/d:Child/text()", &namespaces)?
.strings(),
vec!["value"]
);
Ok(())
}
#[test]
fn query_compiled_query_can_be_reused() -> XmlResult<()> {
let document = sample_document()?;
let query = Query::parse("//Name/text()")?;
assert_eq!(query.evaluate(&document)?.strings(), vec!["Alpha", "Beta"]);
assert_eq!(query.evaluate(&document)?.strings(), vec!["Alpha", "Beta"]);
Ok(())
}
#[test]
fn query_valid_cases_cover_mvp_surface() -> XmlResult<()> {
let document = sample_document()?;
let namespace = ns();
let cases = [
("/d:Root", 1),
("/d:Root/d:Item", 2),
("/d:Root/Note", 1),
("//d:Item", 2),
("//Name", 2),
("//Name/text()", 2),
("/d:Root/d:Item/@code", 2),
("/d:Root/d:Item[@code='B2']", 1),
("/d:Root/d:Item[@code='B2']/Name/text()", 1),
("/d:Root/d:Item/@d:kind", 1),
];
for (source, expected_len) in cases {
assert_eq!(
document.query_with_context(source, &namespace)?.len(),
expected_len,
"query {source}"
);
}
Ok(())
}
#[test]
fn query_invalid_cases_have_structured_errors_with_span() {
let invalid = ["", "Root", "/", "/Root[", "/Root[@id]", "/Root/text("];
for source in invalid {
let error = Query::parse(source).expect_err("query must fail");
assert_eq!(error.kind(), &ErrorKind::Query);
assert!(error.span().is_some());
}
}
#[test]
fn query_namespace_alias_must_be_declared() {
let document = sample_document().expect("document");
let error = document
.query("/d:Root")
.expect_err("missing namespace alias must fail");
assert_eq!(error.kind(), &ErrorKind::UnknownNamespacePrefix);
}
#[test]
fn query_security_limits_steps() {
let document = sample_document().expect("document");
let query = Query::parse("//Name").expect("query");
let security = QuerySecurityConfig::default()
.with_limits(crate::security::SecurityLimits::default().with_max_query_steps(1));
let error = query
.evaluate_with_options(&document, &NamespaceContext::default(), &security)
.expect_err("query step limit must fail");
assert_eq!(error.kind(), &ErrorKind::Parse);
}
}