1use crate::types::TagName;
10pub use amql_predicates::{AttrOp, AttrPredicate, Predicate, PredicateOp, PredicateValue};
11use serde::Serialize;
12
13#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
15#[cfg_attr(feature = "flow", derive(flowjs_rs::Flow))]
16#[cfg_attr(feature = "ts", ts(export))]
17#[cfg_attr(feature = "flow", flow(export))]
18#[derive(Debug, Clone, Serialize)]
19pub struct SelectorAst {
20 pub compounds: Vec<CompoundSelector>,
22}
23
24#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
26#[cfg_attr(feature = "flow", derive(flowjs_rs::Flow))]
27#[cfg_attr(feature = "ts", ts(export))]
28#[cfg_attr(feature = "flow", flow(export))]
29#[derive(Debug, Clone, Serialize)]
30pub struct CompoundSelector {
31 pub tag: Option<TagName>,
33 pub attrs: Vec<AttrPredicate>,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub combinator: Option<Combinator>,
38}
39
40#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
42#[cfg_attr(feature = "flow", derive(flowjs_rs::Flow))]
43#[cfg_attr(feature = "ts", ts(export))]
44#[cfg_attr(feature = "flow", flow(export))]
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
46#[non_exhaustive]
47pub enum Combinator {
48 Child,
50 Descendant,
52 AdjacentSibling,
54 GeneralSibling,
56}
57
58#[must_use = "parsing a selector is useless without inspecting the result"]
60pub fn parse_selector(input: &str) -> Result<SelectorAst, String> {
61 let trimmed = input.trim();
62 if trimmed.is_empty() {
63 return Err("Empty selector".to_string());
64 }
65 let mut parser = SelectorParser::new(trimmed);
66 let ast = parser.parse()?;
67 if ast
69 .compounds
70 .iter()
71 .all(|c| c.tag.is_none() && c.attrs.is_empty())
72 {
73 return Err("Empty selector".to_string());
74 }
75 for compound in &ast.compounds {
77 if matches!(
78 compound.combinator,
79 Some(Combinator::AdjacentSibling) | Some(Combinator::GeneralSibling)
80 ) {
81 return Err("Sibling combinators (+, ~) are not supported".to_string());
82 }
83 }
84 Ok(ast)
85}
86
87struct SelectorParser<'a> {
88 input: &'a str,
89 bytes: &'a [u8],
90 pos: usize,
91}
92
93impl<'a> SelectorParser<'a> {
94 fn new(input: &'a str) -> Self {
95 Self {
96 input,
97 bytes: input.as_bytes(),
98 pos: 0,
99 }
100 }
101
102 fn slice(&self, start: usize, end: usize) -> &'a str {
103 &self.input[start..end]
104 }
105
106 fn parse(&mut self) -> Result<SelectorAst, String> {
107 let mut compounds = vec![self.parse_compound()?];
108
109 while self.pos < self.bytes.len() {
110 let combinator = match self.parse_combinator() {
111 Some(c) => c,
112 None => break,
113 };
114 let mut compound = self.parse_compound()?;
115 compound.combinator = Some(combinator);
116 compounds.push(compound);
117 }
118
119 self.skip_whitespace();
120 if self.pos < self.bytes.len() {
121 return Err(format!(
122 "Unexpected character '{}' at position {}",
123 self.bytes[self.pos] as char, self.pos
124 ));
125 }
126
127 Ok(SelectorAst { compounds })
128 }
129
130 fn parse_compound(&mut self) -> Result<CompoundSelector, String> {
131 self.skip_whitespace();
132 let tag = self.parse_tag();
133 let attrs = self.parse_attr_list()?;
134 Ok(CompoundSelector {
135 tag,
136 attrs,
137 combinator: None,
138 })
139 }
140
141 fn parse_tag(&mut self) -> Option<TagName> {
142 let start = self.pos;
143 while self.pos < self.bytes.len() && self.is_ident_char(self.bytes[self.pos]) {
144 self.pos += 1;
145 }
146 if self.pos > start {
147 Some(TagName::from(self.slice(start, self.pos)))
148 } else {
149 None
150 }
151 }
152
153 fn parse_attr_list(&mut self) -> Result<Vec<AttrPredicate>, String> {
154 let mut attrs = Vec::new();
155 while self.pos < self.bytes.len() && self.bytes[self.pos] == b'[' {
156 self.pos += 1; let start = self.pos;
160 let mut depth = 1;
161 while self.pos < self.bytes.len() && depth > 0 {
162 match self.bytes[self.pos] {
163 b'[' => depth += 1,
164 b']' => depth -= 1,
165 b'"' | b'\'' => {
166 let quote = self.bytes[self.pos];
167 self.pos += 1;
168 while self.pos < self.bytes.len() && self.bytes[self.pos] != quote {
169 if self.bytes[self.pos] == b'\\' {
170 self.pos += 1;
171 }
172 self.pos += 1;
173 }
174 }
175 _ => {}
176 }
177 if depth > 0 {
178 self.pos += 1;
179 }
180 }
181
182 if depth != 0 {
183 return Err(format!("Unclosed '[' at position {}", start - 1));
184 }
185
186 let bracket_content = self.slice(start, self.pos);
187 self.pos += 1; let parsed = amql_predicates::parse_predicate_list(bracket_content)?;
190 attrs.extend(parsed);
191 }
192 Ok(attrs)
193 }
194
195 fn parse_combinator(&mut self) -> Option<Combinator> {
196 let before_space = self.pos;
197 self.skip_whitespace();
198
199 if self.pos >= self.bytes.len() {
200 return None;
201 }
202
203 let ch = self.bytes[self.pos];
204 if ch == b'>' || ch == b'+' || ch == b'~' {
205 self.pos += 1;
206 self.skip_whitespace();
207 return Some(match ch {
208 b'>' => Combinator::Child,
209 b'+' => Combinator::AdjacentSibling,
210 b'~' => Combinator::GeneralSibling,
211 _ => unreachable!(),
212 });
213 }
214
215 if self.pos > before_space && self.pos < self.bytes.len() {
217 let next = self.bytes[self.pos];
218 if self.is_ident_char(next) || next == b'[' {
219 return Some(Combinator::Descendant);
220 }
221 }
222
223 None
224 }
225
226 fn skip_whitespace(&mut self) {
227 while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_whitespace() {
228 self.pos += 1;
229 }
230 }
231
232 fn is_ident_char(&self, ch: u8) -> bool {
233 ch.is_ascii_alphanumeric() || ch == b'_' || ch == b'-'
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn parses_basic_selectors() {
243 let bare_tag = "controller";
245 let tag_attr_presence = "function[async]";
246 let tag_attr_value = r#"controller[method="POST"]"#;
247 let attr_only = r#"[owner="@backend"]"#;
248
249 let bare = parse_selector(bare_tag).unwrap();
251 let presence = parse_selector(tag_attr_presence).unwrap();
252 let value = parse_selector(tag_attr_value).unwrap();
253 let attr = parse_selector(attr_only).unwrap();
254
255 assert_eq!(bare.compounds.len(), 1);
257 assert_eq!(bare.compounds[0].tag.as_deref(), Some("controller"));
258 assert!(bare.compounds[0].attrs.is_empty());
259
260 assert_eq!(presence.compounds.len(), 1);
261 assert_eq!(presence.compounds[0].tag.as_deref(), Some("function"));
262 assert_eq!(
263 presence.compounds[0].attrs,
264 vec![AttrPredicate {
265 name: "async".to_string(),
266 op: None,
267 value: None,
268 }]
269 );
270
271 assert_eq!(value.compounds.len(), 1);
272 assert_eq!(value.compounds[0].tag.as_deref(), Some("controller"));
273 assert_eq!(
274 value.compounds[0].attrs,
275 vec![AttrPredicate {
276 name: "method".to_string(),
277 op: Some(AttrOp::Eq),
278 value: Some(PredicateValue::String("POST".to_string())),
279 }]
280 );
281
282 assert_eq!(attr.compounds.len(), 1);
283 assert!(attr.compounds[0].tag.is_none());
284 assert_eq!(
285 attr.compounds[0].attrs,
286 vec![AttrPredicate {
287 name: "owner".to_string(),
288 op: Some(AttrOp::Eq),
289 value: Some(PredicateValue::String("@backend".to_string())),
290 }]
291 );
292 }
293
294 #[test]
295 fn parses_multiple_and_quoted() {
296 let multi_attrs = r#"function[name="create",async]"#;
298 let single_quoted = "controller[method='POST']";
299
300 let multi = parse_selector(multi_attrs).unwrap();
302 let quoted = parse_selector(single_quoted).unwrap();
303
304 assert_eq!(multi.compounds[0].attrs.len(), 2);
306 assert_eq!(
307 multi.compounds[0].attrs[0],
308 AttrPredicate {
309 name: "name".to_string(),
310 op: Some(AttrOp::Eq),
311 value: Some(PredicateValue::String("create".to_string())),
312 }
313 );
314 assert_eq!(
315 multi.compounds[0].attrs[1],
316 AttrPredicate {
317 name: "async".to_string(),
318 op: None,
319 value: None,
320 }
321 );
322
323 assert_eq!(
324 quoted.compounds[0].attrs[0],
325 AttrPredicate {
326 name: "method".to_string(),
327 op: Some(AttrOp::Eq),
328 value: Some(PredicateValue::String("POST".to_string())),
329 }
330 );
331 }
332
333 #[test]
334 fn parses_operators() {
335 let starts = r#"[name^="handle"]"#;
337 let contains = r#"[name*="user"]"#;
338 let ends = r#"[name$="Controller"]"#;
339
340 let starts_ast = parse_selector(starts).unwrap();
342 let contains_ast = parse_selector(contains).unwrap();
343 let ends_ast = parse_selector(ends).unwrap();
344
345 assert_eq!(
347 starts_ast.compounds[0].attrs[0],
348 AttrPredicate {
349 name: "name".to_string(),
350 op: Some(AttrOp::StartsWith),
351 value: Some(PredicateValue::String("handle".to_string())),
352 }
353 );
354
355 assert_eq!(
356 contains_ast.compounds[0].attrs[0],
357 AttrPredicate {
358 name: "name".to_string(),
359 op: Some(AttrOp::Contains),
360 value: Some(PredicateValue::String("user".to_string())),
361 }
362 );
363
364 assert_eq!(
365 ends_ast.compounds[0].attrs[0],
366 AttrPredicate {
367 name: "name".to_string(),
368 op: Some(AttrOp::EndsWith),
369 value: Some(PredicateValue::String("Controller".to_string())),
370 }
371 );
372 }
373
374 #[test]
375 fn parses_combinators() {
376 let child = "class > method";
378 let descendant = "class method";
379 let complex = r#"class[name="UserService"] > method[async]"#;
380
381 let child_ast = parse_selector(child).unwrap();
383 let desc_ast = parse_selector(descendant).unwrap();
384 let complex_ast = parse_selector(complex).unwrap();
385
386 assert_eq!(child_ast.compounds.len(), 2);
388 assert_eq!(child_ast.compounds[0].tag.as_deref(), Some("class"));
389 assert_eq!(child_ast.compounds[1].tag.as_deref(), Some("method"));
390 assert_eq!(child_ast.compounds[1].combinator, Some(Combinator::Child));
391
392 assert_eq!(desc_ast.compounds.len(), 2);
393 assert_eq!(desc_ast.compounds[0].tag.as_deref(), Some("class"));
394 assert_eq!(desc_ast.compounds[1].tag.as_deref(), Some("method"));
395 assert_eq!(
396 desc_ast.compounds[1].combinator,
397 Some(Combinator::Descendant)
398 );
399
400 assert_eq!(complex_ast.compounds.len(), 2);
401 assert_eq!(complex_ast.compounds[0].tag.as_deref(), Some("class"));
402 assert_eq!(
403 complex_ast.compounds[0].attrs,
404 vec![AttrPredicate {
405 name: "name".to_string(),
406 op: Some(AttrOp::Eq),
407 value: Some(PredicateValue::String("UserService".to_string())),
408 }]
409 );
410 assert_eq!(complex_ast.compounds[1].tag.as_deref(), Some("method"));
411 assert_eq!(
412 complex_ast.compounds[1].attrs,
413 vec![AttrPredicate {
414 name: "async".to_string(),
415 op: None,
416 value: None,
417 }]
418 );
419 assert_eq!(complex_ast.compounds[1].combinator, Some(Combinator::Child));
420 }
421
422 #[test]
423 fn handles_escape_sequences() {
424 let escaped_quote = parse_selector(r#"[name="foo\"bar"]"#).unwrap();
426 let escaped_backslash = parse_selector(r#"[name="foo\\bar"]"#).unwrap();
427
428 assert_eq!(
430 escaped_quote.compounds[0].attrs[0].value,
431 Some(PredicateValue::String(r#"foo"bar"#.to_string()))
432 );
433 assert_eq!(
434 escaped_backslash.compounds[0].attrs[0].value,
435 Some(PredicateValue::String(r"foo\bar".to_string()))
436 );
437 }
438
439 #[test]
440 fn rejects_empty_selectors() {
441 assert!(parse_selector("").is_err());
443 assert!(parse_selector(" ").is_err());
444 }
445
446 #[test]
447 fn rejects_sibling_combinators() {
448 assert!(parse_selector("a + b").is_err());
450 assert!(parse_selector("a ~ b").is_err());
451 }
452
453 #[test]
454 fn parses_numeric_operators_in_selector() {
455 let ast = parse_selector("[count>=5]").unwrap();
457
458 assert_eq!(
460 ast.compounds[0].attrs[0].op,
461 Some(AttrOp::Gte),
462 "should parse >= operator"
463 );
464 assert_eq!(
465 ast.compounds[0].attrs[0].value,
466 Some(PredicateValue::Number(5.0)),
467 "should parse numeric value"
468 );
469 }
470}