Skip to main content

pdf_ast/api/
mod.rs

1use crate::ast::{NodeId, NodeType, PdfAstGraph};
2use crate::types::PdfValue;
3use regex::Regex;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Debug, Clone)]
7pub enum QuerySelector {
8    NodeType(NodeType),
9    NodeTypeName(String),
10    ObjectId(u32, u16),
11    Path(Vec<QuerySelector>),
12    Child(Box<QuerySelector>, Box<QuerySelector>),
13    Descendant(Box<QuerySelector>, Box<QuerySelector>),
14    Parent(Box<QuerySelector>),
15    Ancestor(Box<QuerySelector>),
16    Sibling(Box<QuerySelector>),
17    HasProperty(String),
18    PropertyEquals(String, String),
19    PropertyMatches(String, Regex),
20    And(Vec<QuerySelector>),
21    Or(Vec<QuerySelector>),
22    Not(Box<QuerySelector>),
23    First,
24    Last,
25    Index(usize),
26    Range(usize, usize),
27}
28
29#[allow(dead_code)]
30pub struct QueryEngine<'a> {
31    graph: &'a PdfAstGraph,
32    cache: HashMap<String, Vec<NodeId>>,
33}
34
35impl<'a> QueryEngine<'a> {
36    pub fn new(graph: &'a PdfAstGraph) -> Self {
37        Self {
38            graph,
39            cache: HashMap::new(),
40        }
41    }
42
43    pub fn query(&mut self, selector: &QuerySelector) -> Vec<NodeId> {
44        self.evaluate_selector(selector, None)
45    }
46
47    pub fn query_from(&mut self, selector: &QuerySelector, context: NodeId) -> Vec<NodeId> {
48        self.evaluate_selector(selector, Some(context))
49    }
50
51    fn evaluate_selector(
52        &mut self,
53        selector: &QuerySelector,
54        context: Option<NodeId>,
55    ) -> Vec<NodeId> {
56        match selector {
57            QuerySelector::NodeType(node_type) => self.find_by_type(node_type.clone()),
58            QuerySelector::NodeTypeName(name) => self.find_by_type_name(name),
59            QuerySelector::ObjectId(num, gen) => self.find_by_object_id(*num, *gen),
60            QuerySelector::Path(selectors) => self.evaluate_path(selectors, context),
61            QuerySelector::Child(parent_sel, child_sel) => {
62                let parents = self.evaluate_selector(parent_sel, context);
63                let mut results = Vec::new();
64                for parent in parents {
65                    let children = self.graph.get_children(parent);
66                    for child in children {
67                        let matches = self.evaluate_selector(child_sel, Some(child));
68                        if matches.contains(&child) {
69                            results.push(child);
70                        }
71                    }
72                }
73                results
74            }
75            QuerySelector::Descendant(ancestor_sel, descendant_sel) => {
76                let ancestors = self.evaluate_selector(ancestor_sel, context);
77                let mut results = Vec::new();
78                for ancestor in ancestors {
79                    let descendants = self.get_all_descendants(ancestor);
80                    for desc in descendants {
81                        let matches = self.evaluate_selector(descendant_sel, Some(desc));
82                        if matches.contains(&desc) {
83                            results.push(desc);
84                        }
85                    }
86                }
87                results
88            }
89            QuerySelector::Parent(child_sel) => {
90                let children = self.evaluate_selector(child_sel, context);
91                let mut results = Vec::new();
92                for child in children {
93                    if let Some(parent) = self.graph.get_parent(child) {
94                        results.push(parent);
95                    }
96                }
97                results
98                    .into_iter()
99                    .collect::<HashSet<_>>()
100                    .into_iter()
101                    .collect()
102            }
103            QuerySelector::Ancestor(descendant_sel) => {
104                let descendants = self.evaluate_selector(descendant_sel, context);
105                let mut results = HashSet::new();
106                for desc in descendants {
107                    let mut current = desc;
108                    while let Some(parent) = self.graph.get_parent(current) {
109                        results.insert(parent);
110                        current = parent;
111                    }
112                }
113                results.into_iter().collect()
114            }
115            QuerySelector::Sibling(sel) => {
116                let nodes = self.evaluate_selector(sel, context);
117                let mut results = HashSet::new();
118                for node in nodes {
119                    if let Some(parent) = self.graph.get_parent(node) {
120                        for sibling in self.graph.get_children(parent) {
121                            if sibling != node {
122                                results.insert(sibling);
123                            }
124                        }
125                    }
126                }
127                results.into_iter().collect()
128            }
129            QuerySelector::HasProperty(prop) => self.find_with_property(prop),
130            QuerySelector::PropertyEquals(prop, value) => {
131                self.find_with_property_value(prop, value)
132            }
133            QuerySelector::PropertyMatches(prop, regex) => {
134                self.find_with_property_regex(prop, regex)
135            }
136            QuerySelector::And(selectors) => {
137                if selectors.is_empty() {
138                    return Vec::new();
139                }
140                let mut result_set: Option<HashSet<NodeId>> = None;
141                for sel in selectors {
142                    let matches: HashSet<NodeId> =
143                        self.evaluate_selector(sel, context).into_iter().collect();
144                    result_set = Some(match result_set {
145                        None => matches,
146                        Some(set) => set.intersection(&matches).cloned().collect(),
147                    });
148                }
149                result_set.unwrap_or_default().into_iter().collect()
150            }
151            QuerySelector::Or(selectors) => {
152                let mut result_set = HashSet::new();
153                for sel in selectors {
154                    result_set.extend(self.evaluate_selector(sel, context));
155                }
156                result_set.into_iter().collect()
157            }
158            QuerySelector::Not(sel) => {
159                let excluded: HashSet<NodeId> =
160                    self.evaluate_selector(sel, context).into_iter().collect();
161                let all_nodes: HashSet<NodeId> = self.graph.node_indices().into_iter().collect();
162                all_nodes.difference(&excluded).cloned().collect()
163            }
164            QuerySelector::First => {
165                if let Some(ctx) = context {
166                    self.graph.get_children(ctx).into_iter().take(1).collect()
167                } else {
168                    Vec::new()
169                }
170            }
171            QuerySelector::Last => {
172                if let Some(ctx) = context {
173                    let children = self.graph.get_children(ctx);
174                    if let Some(last) = children.last() {
175                        vec![*last]
176                    } else {
177                        Vec::new()
178                    }
179                } else {
180                    Vec::new()
181                }
182            }
183            QuerySelector::Index(idx) => {
184                if let Some(ctx) = context {
185                    self.graph
186                        .get_children(ctx)
187                        .into_iter()
188                        .nth(*idx)
189                        .map(|n| vec![n])
190                        .unwrap_or_default()
191                } else {
192                    Vec::new()
193                }
194            }
195            QuerySelector::Range(start, end) => {
196                if let Some(ctx) = context {
197                    self.graph
198                        .get_children(ctx)
199                        .into_iter()
200                        .skip(*start)
201                        .take(end - start)
202                        .collect()
203                } else {
204                    Vec::new()
205                }
206            }
207        }
208    }
209
210    fn evaluate_path(
211        &mut self,
212        selectors: &[QuerySelector],
213        context: Option<NodeId>,
214    ) -> Vec<NodeId> {
215        let mut current = if let Some(ctx) = context {
216            vec![ctx]
217        } else if let Some(root) = self.graph.get_root() {
218            vec![root]
219        } else {
220            return Vec::new();
221        };
222
223        for selector in selectors {
224            let mut next = Vec::new();
225            for node in current {
226                next.extend(self.evaluate_selector(selector, Some(node)));
227            }
228            current = next;
229        }
230
231        current
232    }
233
234    fn find_by_type(&self, node_type: NodeType) -> Vec<NodeId> {
235        self.graph
236            .node_indices()
237            .into_iter()
238            .filter(|&id| {
239                self.graph
240                    .get_node(id)
241                    .map(|n| n.node_type == node_type)
242                    .unwrap_or(false)
243            })
244            .collect()
245    }
246
247    fn find_by_type_name(&self, name: &str) -> Vec<NodeId> {
248        let node_type = match name {
249            "root" => NodeType::Root,
250            "catalog" => NodeType::Catalog,
251            "pages" => NodeType::Pages,
252            "page" => NodeType::Page,
253            "font" => NodeType::Font,
254            "image" => NodeType::Image,
255            "annotation" => NodeType::Annotation,
256            "form" => NodeType::Form,
257            "outline" => NodeType::Outline,
258            "struct" => NodeType::StructElem,
259            _ => return Vec::new(),
260        };
261        self.find_by_type(node_type)
262    }
263
264    fn find_by_object_id(&self, num: u32, gen: u16) -> Vec<NodeId> {
265        self.graph
266            .node_indices()
267            .into_iter()
268            .filter(|&id| {
269                self.graph
270                    .get_object_id(id)
271                    .map(|obj_id| obj_id.number == num && obj_id.generation == gen)
272                    .unwrap_or(false)
273            })
274            .collect()
275    }
276
277    fn find_with_property(&self, prop: &str) -> Vec<NodeId> {
278        self.graph
279            .node_indices()
280            .into_iter()
281            .filter(|&id| {
282                if let Some(node) = self.graph.get_node(id) {
283                    if let PdfValue::Dictionary(dict) = &node.value {
284                        return dict.contains_key(prop);
285                    }
286                }
287                false
288            })
289            .collect()
290    }
291
292    fn find_with_property_value(&self, prop: &str, value: &str) -> Vec<NodeId> {
293        self.graph
294            .node_indices()
295            .into_iter()
296            .filter(|&id| {
297                if let Some(node) = self.graph.get_node(id) {
298                    if let PdfValue::Dictionary(dict) = &node.value {
299                        if let Some(val) = dict.get(prop) {
300                            return self.value_matches(val, value);
301                        }
302                    }
303                }
304                false
305            })
306            .collect()
307    }
308
309    fn find_with_property_regex(&self, prop: &str, regex: &Regex) -> Vec<NodeId> {
310        self.graph
311            .node_indices()
312            .into_iter()
313            .filter(|&id| {
314                if let Some(node) = self.graph.get_node(id) {
315                    if let PdfValue::Dictionary(dict) = &node.value {
316                        if let Some(val) = dict.get(prop) {
317                            let val_str = self.value_to_string(val);
318                            return regex.is_match(&val_str);
319                        }
320                    }
321                }
322                false
323            })
324            .collect()
325    }
326
327    fn value_matches(&self, value: &PdfValue, target: &str) -> bool {
328        match value {
329            PdfValue::Name(n) => n.without_slash() == target,
330            PdfValue::String(s) => s.to_string_lossy() == target,
331            PdfValue::Integer(i) => i.to_string() == target,
332            PdfValue::Real(r) => r.to_string() == target,
333            PdfValue::Boolean(b) => b.to_string() == target,
334            _ => false,
335        }
336    }
337
338    fn value_to_string(&self, value: &PdfValue) -> String {
339        match value {
340            PdfValue::Name(n) => n.without_slash().to_string(),
341            PdfValue::String(s) => s.to_string_lossy(),
342            PdfValue::Integer(i) => i.to_string(),
343            PdfValue::Real(r) => r.to_string(),
344            PdfValue::Boolean(b) => b.to_string(),
345            _ => String::new(),
346        }
347    }
348
349    fn get_all_descendants(&self, node: NodeId) -> Vec<NodeId> {
350        let mut descendants = Vec::new();
351        let mut to_visit = vec![node];
352        let mut visited = HashSet::new();
353
354        while let Some(current) = to_visit.pop() {
355            if visited.insert(current) {
356                let children = self.graph.get_children(current);
357                descendants.extend(&children);
358                to_visit.extend(children);
359            }
360        }
361
362        descendants
363    }
364}
365
366pub struct QueryParser;
367
368impl QueryParser {
369    pub fn parse(query: &str) -> Result<QuerySelector, String> {
370        let query = query.trim();
371
372        if query.is_empty() {
373            return Err("Empty query".to_string());
374        }
375
376        // Simple parser for CSS-like selectors
377        if query.contains(" > ") {
378            let parts: Vec<&str> = query.split(" > ").collect();
379            if parts.len() == 2 {
380                let parent = Self::parse_simple(parts[0])?;
381                let child = Self::parse_simple(parts[1])?;
382                return Ok(QuerySelector::Child(Box::new(parent), Box::new(child)));
383            }
384        }
385
386        if query.contains(" ") {
387            let parts: Vec<&str> = query.split_whitespace().collect();
388            if parts.len() == 2 {
389                let ancestor = Self::parse_simple(parts[0])?;
390                let descendant = Self::parse_simple(parts[1])?;
391                return Ok(QuerySelector::Descendant(
392                    Box::new(ancestor),
393                    Box::new(descendant),
394                ));
395            }
396        }
397
398        if query.contains(',') {
399            let parts: Vec<&str> = query.split(',').map(|s| s.trim()).collect();
400            let selectors: Result<Vec<_>, _> = parts.into_iter().map(Self::parse_simple).collect();
401            return Ok(QuerySelector::Or(selectors?));
402        }
403
404        if query.starts_with(':') {
405            return Self::parse_pseudo(query);
406        }
407
408        if query.starts_with('[') && query.ends_with(']') {
409            return Self::parse_attribute(&query[1..query.len() - 1]);
410        }
411
412        Self::parse_simple(query)
413    }
414
415    fn parse_simple(query: &str) -> Result<QuerySelector, String> {
416        if let Some(id_str) = query.strip_prefix('#') {
417            // Parse object ID like #123.0
418            if let Some(dot_pos) = id_str.find('.') {
419                let num = id_str[..dot_pos]
420                    .parse::<u32>()
421                    .map_err(|_| "Invalid object number")?;
422                let gen = id_str[dot_pos + 1..]
423                    .parse::<u16>()
424                    .map_err(|_| "Invalid generation number")?;
425                return Ok(QuerySelector::ObjectId(num, gen));
426            }
427        }
428
429        // Parse node type
430        Ok(QuerySelector::NodeTypeName(query.to_lowercase()))
431    }
432
433    fn parse_pseudo(query: &str) -> Result<QuerySelector, String> {
434        match query {
435            ":first" | ":first-child" => Ok(QuerySelector::First),
436            ":last" | ":last-child" => Ok(QuerySelector::Last),
437            _ => {
438                if query.starts_with(":nth-child(") && query.ends_with(')') {
439                    let inner = &query[11..query.len() - 1];
440                    let idx = inner.parse::<usize>().map_err(|_| "Invalid index")?;
441                    Ok(QuerySelector::Index(idx))
442                } else {
443                    Err(format!("Unknown pseudo-selector: {}", query))
444                }
445            }
446        }
447    }
448
449    fn parse_attribute(attr: &str) -> Result<QuerySelector, String> {
450        if let Some(eq_pos) = attr.find('=') {
451            let prop = attr[..eq_pos].trim();
452            let value = attr[eq_pos + 1..]
453                .trim()
454                .trim_matches('"')
455                .trim_matches('\'');
456            Ok(QuerySelector::PropertyEquals(
457                prop.to_string(),
458                value.to_string(),
459            ))
460        } else {
461            Ok(QuerySelector::HasProperty(attr.trim().to_string()))
462        }
463    }
464}
465
466pub struct QueryBuilder {
467    selector: Option<QuerySelector>,
468}
469
470impl Default for QueryBuilder {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476impl QueryBuilder {
477    pub fn new() -> Self {
478        Self { selector: None }
479    }
480
481    pub fn node_type(mut self, node_type: NodeType) -> Self {
482        self.selector = Some(QuerySelector::NodeType(node_type));
483        self
484    }
485
486    pub fn object_id(mut self, num: u32, gen: u16) -> Self {
487        self.selector = Some(QuerySelector::ObjectId(num, gen));
488        self
489    }
490
491    pub fn has_property(mut self, prop: &str) -> Self {
492        let new_sel = QuerySelector::HasProperty(prop.to_string());
493        self.selector = Some(self.combine_with_and(new_sel));
494        self
495    }
496
497    pub fn property_equals(mut self, prop: &str, value: &str) -> Self {
498        let new_sel = QuerySelector::PropertyEquals(prop.to_string(), value.to_string());
499        self.selector = Some(self.combine_with_and(new_sel));
500        self
501    }
502
503    pub fn child_of(mut self, parent: QuerySelector) -> Self {
504        if let Some(current) = self.selector {
505            self.selector = Some(QuerySelector::Child(Box::new(parent), Box::new(current)));
506        }
507        self
508    }
509
510    pub fn descendant_of(mut self, ancestor: QuerySelector) -> Self {
511        if let Some(current) = self.selector {
512            self.selector = Some(QuerySelector::Descendant(
513                Box::new(ancestor),
514                Box::new(current),
515            ));
516        }
517        self
518    }
519
520    pub fn and(mut self, other: QuerySelector) -> Self {
521        self.selector = Some(self.combine_with_and(other));
522        self
523    }
524
525    pub fn or(mut self, other: QuerySelector) -> Self {
526        if let Some(current) = self.selector {
527            self.selector = Some(QuerySelector::Or(vec![current, other]));
528        } else {
529            self.selector = Some(other);
530        }
531        self
532    }
533
534    pub fn not(mut self, selector: QuerySelector) -> Self {
535        let new_sel = QuerySelector::Not(Box::new(selector));
536        self.selector = Some(self.combine_with_and(new_sel));
537        self
538    }
539
540    pub fn build(self) -> Result<QuerySelector, String> {
541        self.selector.ok_or_else(|| "Empty query".to_string())
542    }
543
544    fn combine_with_and(&self, new_sel: QuerySelector) -> QuerySelector {
545        if let Some(ref current) = self.selector {
546            QuerySelector::And(vec![current.clone(), new_sel])
547        } else {
548            new_sel
549        }
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_query_parser() {
559        let query = QueryParser::parse("page").unwrap();
560        assert!(matches!(query, QuerySelector::NodeTypeName(_)));
561
562        let query = QueryParser::parse("#123.0").unwrap();
563        assert!(matches!(query, QuerySelector::ObjectId(123, 0)));
564
565        let query = QueryParser::parse("[Type]").unwrap();
566        assert!(matches!(query, QuerySelector::HasProperty(_)));
567
568        let query = QueryParser::parse("[Type=Page]").unwrap();
569        assert!(matches!(query, QuerySelector::PropertyEquals(_, _)));
570
571        let query = QueryParser::parse("pages > page").unwrap();
572        assert!(matches!(query, QuerySelector::Child(_, _)));
573    }
574
575    #[test]
576    fn test_query_builder() {
577        let query = QueryBuilder::new()
578            .node_type(NodeType::Page)
579            .has_property("Resources")
580            .build()
581            .unwrap();
582
583        assert!(matches!(query, QuerySelector::And(_)));
584    }
585}