Skip to main content

amql_selector/
selector.rs

1//! Structural selector parser.
2//!
3//! Parses selectors like:
4//!   `controller[method="POST"]`
5//!   `function[async]`
6//!   `class > method[name="create"]`
7//!   `[owner="@backend"]`
8
9use crate::types::TagName;
10pub use amql_predicates::{AttrOp, AttrPredicate, Predicate, PredicateOp, PredicateValue};
11use serde::Serialize;
12
13/// Parsed selector AST — a chain of compound selectors.
14#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
15#[cfg_attr(feature = "flow", derive(flowjs_rs::Flow))]
16#[cfg_attr(feature = "ts", ts(export))]
17#[cfg_attr(feature = "flow", flow(export))]
18#[derive(Debug, Clone, Serialize)]
19pub struct SelectorAst {
20    /// Ordered chain of compound selectors (left-to-right).
21    pub compounds: Vec<CompoundSelector>,
22}
23
24/// A single compound selector: optional tag + zero or more attribute predicates.
25#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
26#[cfg_attr(feature = "flow", derive(flowjs_rs::Flow))]
27#[cfg_attr(feature = "ts", ts(export))]
28#[cfg_attr(feature = "flow", flow(export))]
29#[derive(Debug, Clone, Serialize)]
30pub struct CompoundSelector {
31    /// Tag filter, if any (e.g. "function").
32    pub tag: Option<TagName>,
33    /// Attribute predicates within brackets.
34    pub attrs: Vec<AttrPredicate>,
35    /// Combinator linking this compound to the previous one.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub combinator: Option<Combinator>,
38}
39
40/// Combinator between compound selectors.
41#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
42#[cfg_attr(feature = "flow", derive(flowjs_rs::Flow))]
43#[cfg_attr(feature = "ts", ts(export))]
44#[cfg_attr(feature = "flow", flow(export))]
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
46#[non_exhaustive]
47pub enum Combinator {
48    /// `>` — direct child
49    Child,
50    /// ` ` — descendant
51    Descendant,
52    /// `+` — adjacent sibling
53    AdjacentSibling,
54    /// `~` — general sibling
55    GeneralSibling,
56}
57
58/// Parse a selector string into an AST.
59#[must_use = "parsing a selector is useless without inspecting the result"]
60pub fn parse_selector(input: &str) -> Result<SelectorAst, String> {
61    let trimmed = input.trim();
62    if trimmed.is_empty() {
63        return Err("Empty selector".to_string());
64    }
65    let mut parser = SelectorParser::new(trimmed);
66    let ast = parser.parse()?;
67    // Reject selectors where all compounds are empty (no tag, no attrs)
68    if ast
69        .compounds
70        .iter()
71        .all(|c| c.tag.is_none() && c.attrs.is_empty())
72    {
73        return Err("Empty selector".to_string());
74    }
75    // Reject sibling combinators — they are not supported
76    for compound in &ast.compounds {
77        if matches!(
78            compound.combinator,
79            Some(Combinator::AdjacentSibling) | Some(Combinator::GeneralSibling)
80        ) {
81            return Err("Sibling combinators (+, ~) are not supported".to_string());
82        }
83    }
84    Ok(ast)
85}
86
87struct SelectorParser<'a> {
88    input: &'a str,
89    bytes: &'a [u8],
90    pos: usize,
91}
92
93impl<'a> SelectorParser<'a> {
94    fn new(input: &'a str) -> Self {
95        Self {
96            input,
97            bytes: input.as_bytes(),
98            pos: 0,
99        }
100    }
101
102    fn slice(&self, start: usize, end: usize) -> &'a str {
103        &self.input[start..end]
104    }
105
106    fn parse(&mut self) -> Result<SelectorAst, String> {
107        let mut compounds = vec![self.parse_compound()?];
108
109        while self.pos < self.bytes.len() {
110            let combinator = match self.parse_combinator() {
111                Some(c) => c,
112                None => break,
113            };
114            let mut compound = self.parse_compound()?;
115            compound.combinator = Some(combinator);
116            compounds.push(compound);
117        }
118
119        self.skip_whitespace();
120        if self.pos < self.bytes.len() {
121            return Err(format!(
122                "Unexpected character '{}' at position {}",
123                self.bytes[self.pos] as char, self.pos
124            ));
125        }
126
127        Ok(SelectorAst { compounds })
128    }
129
130    fn parse_compound(&mut self) -> Result<CompoundSelector, String> {
131        self.skip_whitespace();
132        let tag = self.parse_tag();
133        let attrs = self.parse_attr_list()?;
134        Ok(CompoundSelector {
135            tag,
136            attrs,
137            combinator: None,
138        })
139    }
140
141    fn parse_tag(&mut self) -> Option<TagName> {
142        let start = self.pos;
143        while self.pos < self.bytes.len() && self.is_ident_char(self.bytes[self.pos]) {
144            self.pos += 1;
145        }
146        if self.pos > start {
147            Some(TagName::from(self.slice(start, self.pos)))
148        } else {
149            None
150        }
151    }
152
153    fn parse_attr_list(&mut self) -> Result<Vec<AttrPredicate>, String> {
154        let mut attrs = Vec::new();
155        while self.pos < self.bytes.len() && self.bytes[self.pos] == b'[' {
156            self.pos += 1; // skip '['
157
158            // Extract bracket content and delegate to amql-predicates
159            let start = self.pos;
160            let mut depth = 1;
161            while self.pos < self.bytes.len() && depth > 0 {
162                match self.bytes[self.pos] {
163                    b'[' => depth += 1,
164                    b']' => depth -= 1,
165                    b'"' | b'\'' => {
166                        let quote = self.bytes[self.pos];
167                        self.pos += 1;
168                        while self.pos < self.bytes.len() && self.bytes[self.pos] != quote {
169                            if self.bytes[self.pos] == b'\\' {
170                                self.pos += 1;
171                            }
172                            self.pos += 1;
173                        }
174                    }
175                    _ => {}
176                }
177                if depth > 0 {
178                    self.pos += 1;
179                }
180            }
181
182            if depth != 0 {
183                return Err(format!("Unclosed '[' at position {}", start - 1));
184            }
185
186            let bracket_content = self.slice(start, self.pos);
187            self.pos += 1; // skip closing ']'
188
189            let parsed = amql_predicates::parse_predicate_list(bracket_content)?;
190            attrs.extend(parsed);
191        }
192        Ok(attrs)
193    }
194
195    fn parse_combinator(&mut self) -> Option<Combinator> {
196        let before_space = self.pos;
197        self.skip_whitespace();
198
199        if self.pos >= self.bytes.len() {
200            return None;
201        }
202
203        let ch = self.bytes[self.pos];
204        if ch == b'>' || ch == b'+' || ch == b'~' {
205            self.pos += 1;
206            self.skip_whitespace();
207            return Some(match ch {
208                b'>' => Combinator::Child,
209                b'+' => Combinator::AdjacentSibling,
210                b'~' => Combinator::GeneralSibling,
211                _ => unreachable!(),
212            });
213        }
214
215        // If we consumed whitespace and next char starts a compound, it's a descendant combinator
216        if self.pos > before_space && self.pos < self.bytes.len() {
217            let next = self.bytes[self.pos];
218            if self.is_ident_char(next) || next == b'[' {
219                return Some(Combinator::Descendant);
220            }
221        }
222
223        None
224    }
225
226    fn skip_whitespace(&mut self) {
227        while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_whitespace() {
228            self.pos += 1;
229        }
230    }
231
232    fn is_ident_char(&self, ch: u8) -> bool {
233        ch.is_ascii_alphanumeric() || ch == b'_' || ch == b'-'
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn parses_basic_selectors() {
243        // Arrange
244        let bare_tag = "controller";
245        let tag_attr_presence = "function[async]";
246        let tag_attr_value = r#"controller[method="POST"]"#;
247        let attr_only = r#"[owner="@backend"]"#;
248
249        // Act
250        let bare = parse_selector(bare_tag).unwrap();
251        let presence = parse_selector(tag_attr_presence).unwrap();
252        let value = parse_selector(tag_attr_value).unwrap();
253        let attr = parse_selector(attr_only).unwrap();
254
255        // Assert
256        assert_eq!(bare.compounds.len(), 1);
257        assert_eq!(bare.compounds[0].tag.as_deref(), Some("controller"));
258        assert!(bare.compounds[0].attrs.is_empty());
259
260        assert_eq!(presence.compounds.len(), 1);
261        assert_eq!(presence.compounds[0].tag.as_deref(), Some("function"));
262        assert_eq!(
263            presence.compounds[0].attrs,
264            vec![AttrPredicate {
265                name: "async".to_string(),
266                op: None,
267                value: None,
268            }]
269        );
270
271        assert_eq!(value.compounds.len(), 1);
272        assert_eq!(value.compounds[0].tag.as_deref(), Some("controller"));
273        assert_eq!(
274            value.compounds[0].attrs,
275            vec![AttrPredicate {
276                name: "method".to_string(),
277                op: Some(AttrOp::Eq),
278                value: Some(PredicateValue::String("POST".to_string())),
279            }]
280        );
281
282        assert_eq!(attr.compounds.len(), 1);
283        assert!(attr.compounds[0].tag.is_none());
284        assert_eq!(
285            attr.compounds[0].attrs,
286            vec![AttrPredicate {
287                name: "owner".to_string(),
288                op: Some(AttrOp::Eq),
289                value: Some(PredicateValue::String("@backend".to_string())),
290            }]
291        );
292    }
293
294    #[test]
295    fn parses_multiple_and_quoted() {
296        // Arrange
297        let multi_attrs = r#"function[name="create",async]"#;
298        let single_quoted = "controller[method='POST']";
299
300        // Act
301        let multi = parse_selector(multi_attrs).unwrap();
302        let quoted = parse_selector(single_quoted).unwrap();
303
304        // Assert
305        assert_eq!(multi.compounds[0].attrs.len(), 2);
306        assert_eq!(
307            multi.compounds[0].attrs[0],
308            AttrPredicate {
309                name: "name".to_string(),
310                op: Some(AttrOp::Eq),
311                value: Some(PredicateValue::String("create".to_string())),
312            }
313        );
314        assert_eq!(
315            multi.compounds[0].attrs[1],
316            AttrPredicate {
317                name: "async".to_string(),
318                op: None,
319                value: None,
320            }
321        );
322
323        assert_eq!(
324            quoted.compounds[0].attrs[0],
325            AttrPredicate {
326                name: "method".to_string(),
327                op: Some(AttrOp::Eq),
328                value: Some(PredicateValue::String("POST".to_string())),
329            }
330        );
331    }
332
333    #[test]
334    fn parses_operators() {
335        // Arrange
336        let starts = r#"[name^="handle"]"#;
337        let contains = r#"[name*="user"]"#;
338        let ends = r#"[name$="Controller"]"#;
339
340        // Act
341        let starts_ast = parse_selector(starts).unwrap();
342        let contains_ast = parse_selector(contains).unwrap();
343        let ends_ast = parse_selector(ends).unwrap();
344
345        // Assert
346        assert_eq!(
347            starts_ast.compounds[0].attrs[0],
348            AttrPredicate {
349                name: "name".to_string(),
350                op: Some(AttrOp::StartsWith),
351                value: Some(PredicateValue::String("handle".to_string())),
352            }
353        );
354
355        assert_eq!(
356            contains_ast.compounds[0].attrs[0],
357            AttrPredicate {
358                name: "name".to_string(),
359                op: Some(AttrOp::Contains),
360                value: Some(PredicateValue::String("user".to_string())),
361            }
362        );
363
364        assert_eq!(
365            ends_ast.compounds[0].attrs[0],
366            AttrPredicate {
367                name: "name".to_string(),
368                op: Some(AttrOp::EndsWith),
369                value: Some(PredicateValue::String("Controller".to_string())),
370            }
371        );
372    }
373
374    #[test]
375    fn parses_combinators() {
376        // Arrange
377        let child = "class > method";
378        let descendant = "class method";
379        let complex = r#"class[name="UserService"] > method[async]"#;
380
381        // Act
382        let child_ast = parse_selector(child).unwrap();
383        let desc_ast = parse_selector(descendant).unwrap();
384        let complex_ast = parse_selector(complex).unwrap();
385
386        // Assert
387        assert_eq!(child_ast.compounds.len(), 2);
388        assert_eq!(child_ast.compounds[0].tag.as_deref(), Some("class"));
389        assert_eq!(child_ast.compounds[1].tag.as_deref(), Some("method"));
390        assert_eq!(child_ast.compounds[1].combinator, Some(Combinator::Child));
391
392        assert_eq!(desc_ast.compounds.len(), 2);
393        assert_eq!(desc_ast.compounds[0].tag.as_deref(), Some("class"));
394        assert_eq!(desc_ast.compounds[1].tag.as_deref(), Some("method"));
395        assert_eq!(
396            desc_ast.compounds[1].combinator,
397            Some(Combinator::Descendant)
398        );
399
400        assert_eq!(complex_ast.compounds.len(), 2);
401        assert_eq!(complex_ast.compounds[0].tag.as_deref(), Some("class"));
402        assert_eq!(
403            complex_ast.compounds[0].attrs,
404            vec![AttrPredicate {
405                name: "name".to_string(),
406                op: Some(AttrOp::Eq),
407                value: Some(PredicateValue::String("UserService".to_string())),
408            }]
409        );
410        assert_eq!(complex_ast.compounds[1].tag.as_deref(), Some("method"));
411        assert_eq!(
412            complex_ast.compounds[1].attrs,
413            vec![AttrPredicate {
414                name: "async".to_string(),
415                op: None,
416                value: None,
417            }]
418        );
419        assert_eq!(complex_ast.compounds[1].combinator, Some(Combinator::Child));
420    }
421
422    #[test]
423    fn handles_escape_sequences() {
424        // Arrange & Act
425        let escaped_quote = parse_selector(r#"[name="foo\"bar"]"#).unwrap();
426        let escaped_backslash = parse_selector(r#"[name="foo\\bar"]"#).unwrap();
427
428        // Assert
429        assert_eq!(
430            escaped_quote.compounds[0].attrs[0].value,
431            Some(PredicateValue::String(r#"foo"bar"#.to_string()))
432        );
433        assert_eq!(
434            escaped_backslash.compounds[0].attrs[0].value,
435            Some(PredicateValue::String(r"foo\bar".to_string()))
436        );
437    }
438
439    #[test]
440    fn rejects_empty_selectors() {
441        // Arrange & Act & Assert
442        assert!(parse_selector("").is_err());
443        assert!(parse_selector("   ").is_err());
444    }
445
446    #[test]
447    fn rejects_sibling_combinators() {
448        // Arrange & Act & Assert
449        assert!(parse_selector("a + b").is_err());
450        assert!(parse_selector("a ~ b").is_err());
451    }
452
453    #[test]
454    fn parses_numeric_operators_in_selector() {
455        // Arrange & Act
456        let ast = parse_selector("[count>=5]").unwrap();
457
458        // Assert
459        assert_eq!(
460            ast.compounds[0].attrs[0].op,
461            Some(AttrOp::Gte),
462            "should parse >= operator"
463        );
464        assert_eq!(
465            ast.compounds[0].attrs[0].value,
466            Some(PredicateValue::Number(5.0)),
467            "should parse numeric value"
468        );
469    }
470}