Skip to main content

polyfont_scope/
lib.rs

1use std::collections::{BTreeMap, HashMap};
2
3use polyfont_core::{FontAssignment, FontRule};
4
5#[cfg(test)]
6use polyfont_core::{FontSpec, FontStyle, FontWeight};
7
8pub mod constants {
9    pub const SCOPE_KEYWORD: &str = "keyword";
10    pub const SCOPE_COMMENT: &str = "comment";
11    pub const SCOPE_STRING: &str = "string";
12    pub const SCOPE_FUNCTION: &str = "entity.name.function";
13    pub const SCOPE_VARIABLE: &str = "variable";
14    pub const SCOPE_CONSTANT: &str = "constant";
15    pub const SCOPE_TYPE: &str = "entity.name.type";
16    pub const SCOPE_NUMBER: &str = "constant.numeric";
17    pub const SCOPE_OPERATOR: &str = "keyword.operator";
18    pub const SCOPE_PUNCTUATION: &str = "punctuation";
19    pub const SCOPE_TAG: &str = "entity.name.tag";
20    pub const SCOPE_ATTRIBUTE: &str = "entity.other.attribute-name";
21}
22
23pub use constants::*;
24
25#[derive(Debug, Clone)]
26pub struct ScopePattern {
27    segments: Vec<PatternSegment>,
28    negated: bool,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32enum PatternSegment {
33    Literal(String),
34    Wildcard,
35}
36
37impl ScopePattern {
38    #[allow(clippy::missing_errors_doc)]
39    pub fn parse(pattern: &str) -> Result<Self, ScopeError> {
40        let trimmed = pattern.trim();
41        if trimmed.is_empty() {
42            return Err(ScopeError::EmptyPattern);
43        }
44
45        let (negated, inner) = trimmed
46            .strip_prefix('-')
47            .map_or((false, trimmed), |rest| (true, rest.trim()));
48
49        if inner.is_empty() {
50            return Err(ScopeError::EmptyPattern);
51        }
52
53        let segments = inner
54            .split('.')
55            .map(|s| {
56                if s == "*" {
57                    PatternSegment::Wildcard
58                } else {
59                    PatternSegment::Literal(s.to_owned())
60                }
61            })
62            .collect();
63
64        Ok(Self { segments, negated })
65    }
66
67    #[must_use]
68    pub fn matches_scope(&self, scope: &str) -> bool {
69        self.matches_raw(scope)
70    }
71
72    #[must_use]
73    pub fn matches_raw(&self, scope: &str) -> bool {
74        let scope_parts: Vec<&str> = scope.split('.').collect();
75        if scope_parts.len() < self.segments.len() {
76            return false;
77        }
78
79        for (i, seg) in self.segments.iter().enumerate() {
80            match seg {
81                PatternSegment::Wildcard => {}
82                PatternSegment::Literal(lit) => {
83                    if scope_parts[i] != lit {
84                        return false;
85                    }
86                }
87            }
88        }
89
90        true
91    }
92
93    #[must_use]
94    pub fn specificity(&self) -> usize {
95        self.segments
96            .iter()
97            .filter(|s| **s != PatternSegment::Wildcard)
98            .count()
99    }
100}
101
102#[derive(Debug, Clone)]
103pub struct ScopeSelector {
104    patterns: Vec<ScopePattern>,
105}
106
107impl ScopeSelector {
108    #[allow(clippy::missing_errors_doc)]
109    pub fn parse(selector: &str) -> Result<Self, ScopeError> {
110        let patterns = selector
111            .split(',')
112            .filter_map(|s| {
113                let trimmed = s.trim();
114                if trimmed.is_empty() {
115                    None
116                } else {
117                    Some(ScopePattern::parse(trimmed))
118                }
119            })
120            .collect::<Result<Vec<_>, _>>()?;
121
122        if patterns.is_empty() {
123            return Err(ScopeError::EmptyPattern);
124        }
125
126        Ok(Self { patterns })
127    }
128
129    #[must_use]
130    pub fn matches(&self, scope: &str) -> bool {
131        let positive_matches: Vec<&ScopePattern> =
132            self.patterns.iter().filter(|p| !p.negated).collect();
133
134        let negative_patterns: Vec<&ScopePattern> =
135            self.patterns.iter().filter(|p| p.negated).collect();
136
137        for neg in &negative_patterns {
138            if neg.matches_raw(scope) {
139                return false;
140            }
141        }
142
143        if positive_matches.is_empty() && !negative_patterns.is_empty() {
144            return true;
145        }
146
147        positive_matches.iter().any(|p| p.matches_scope(scope))
148    }
149
150    #[must_use]
151    pub fn specificity(&self) -> usize {
152        self.patterns
153            .iter()
154            .filter(|p| !p.negated)
155            .map(ScopePattern::specificity)
156            .max()
157            .unwrap_or(0)
158    }
159}
160
161pub struct ScopeMatcher;
162
163impl ScopeMatcher {
164    #[allow(clippy::missing_errors_doc)]
165    pub fn matches(scope: &str, selector: &str) -> Result<bool, ScopeError> {
166        let sel = ScopeSelector::parse(selector)?;
167        Ok(sel.matches(scope))
168    }
169
170    #[allow(clippy::missing_errors_doc)]
171    pub fn matches_any(scope: &str, selectors: &[&str]) -> Result<bool, ScopeError> {
172        for selector in selectors {
173            let sel = ScopeSelector::parse(selector)?;
174            if sel.matches(scope) {
175                return Ok(true);
176            }
177        }
178        Ok(false)
179    }
180}
181
182#[derive(Debug, Clone)]
183pub struct ResolvedScope {
184    pub assignment: FontAssignment,
185    pub rule_index: usize,
186}
187
188pub struct ScopeResolver {
189    rules: Vec<(FontRule, usize)>,
190}
191
192impl ScopeResolver {
193    #[must_use]
194    pub const fn new() -> Self {
195        Self { rules: Vec::new() }
196    }
197
198    #[must_use]
199    pub fn from_rules(rules: Vec<FontRule>) -> Self {
200        let indexed: Vec<(FontRule, usize)> =
201            rules.into_iter().enumerate().map(|(i, r)| (r, i)).collect();
202        Self { rules: indexed }
203    }
204
205    pub fn add_rule(&mut self, rule: FontRule) {
206        let index = self.rules.len();
207        self.rules.push((rule, index));
208    }
209
210    #[must_use]
211    #[allow(clippy::missing_panics_doc)]
212    pub fn resolve(&self, scope: &str) -> Option<ResolvedScope> {
213        let mut best: Option<(&FontRule, usize, usize)> = None;
214
215        for (rule, rule_index) in &self.rules {
216            if let Ok(selector) = ScopeSelector::parse(&rule.scope)
217                && selector.matches(scope)
218            {
219                let specificity = selector.specificity();
220                let should_replace = match &best {
221                    None => true,
222                    Some((_, _, best_spec)) => {
223                        specificity > *best_spec
224                            || (specificity == *best_spec
225                                && *rule_index < best.expect("checked above").1)
226                    }
227                };
228                if should_replace {
229                    best = Some((rule, *rule_index, specificity));
230                }
231            }
232        }
233
234        best.map(|(rule, rule_index, specificity)| ResolvedScope {
235            assignment: FontAssignment {
236                scope: scope.to_owned(),
237                font: rule.font.clone(),
238                specificity,
239                is_active: true,
240            },
241            rule_index,
242        })
243    }
244
245    pub fn resolve_all<'a, I>(&self, scopes: I) -> Vec<Option<ResolvedScope>>
246    where
247        I: IntoIterator<Item = &'a str>,
248    {
249        scopes.into_iter().map(|s| self.resolve(s)).collect()
250    }
251
252    pub fn clear(&mut self) {
253        self.rules.clear();
254    }
255
256    #[must_use]
257    pub const fn rule_count(&self) -> usize {
258        self.rules.len()
259    }
260}
261
262impl Default for ScopeResolver {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268#[derive(Debug, Clone, Default)]
269pub struct ScopeTreeNode {
270    children: BTreeMap<String, Self>,
271    is_terminal: bool,
272}
273
274pub struct ScopeTree {
275    root: ScopeTreeNode,
276}
277
278impl ScopeTree {
279    #[must_use]
280    pub fn new() -> Self {
281        Self {
282            root: ScopeTreeNode::default(),
283        }
284    }
285
286    pub fn insert(&mut self, scope: &str) {
287        let mut node = &mut self.root;
288        for segment in scope.split('.') {
289            node = node.children.entry(segment.to_owned()).or_default();
290        }
291        node.is_terminal = true;
292    }
293
294    #[must_use]
295    pub fn contains(&self, scope: &str) -> bool {
296        let mut node = &self.root;
297        for segment in scope.split('.') {
298            match node.children.get(segment) {
299                Some(child) => node = child,
300                None => return false,
301            }
302        }
303        node.is_terminal
304    }
305
306    #[must_use]
307    pub fn has_prefix(&self, prefix: &str) -> bool {
308        let mut node = &self.root;
309        for segment in prefix.split('.') {
310            match node.children.get(segment) {
311                Some(child) => node = child,
312                None => return false,
313            }
314        }
315        true
316    }
317
318    #[must_use]
319    pub fn query_prefix(&self, prefix: &str) -> Vec<String> {
320        let mut node = &self.root;
321        for segment in prefix.split('.') {
322            match node.children.get(segment) {
323                Some(child) => node = child,
324                None => return Vec::new(),
325            }
326        }
327
328        let mut results = Vec::new();
329        collect_scopes(node, prefix, &mut results);
330        results
331    }
332
333    #[must_use]
334    pub fn len(&self) -> usize {
335        count_terminals(&self.root)
336    }
337
338    #[must_use]
339    pub fn is_empty(&self) -> bool {
340        !self.root.is_terminal && self.root.children.is_empty()
341    }
342}
343
344impl Default for ScopeTree {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350fn collect_scopes(node: &ScopeTreeNode, prefix: &str, results: &mut Vec<String>) {
351    if node.is_terminal {
352        results.push(prefix.to_owned());
353    }
354    for (name, child) in &node.children {
355        let child_path = if prefix.is_empty() {
356            name.clone()
357        } else {
358            format!("{prefix}.{name}")
359        };
360        collect_scopes(child, &child_path, results);
361    }
362}
363
364fn count_terminals(node: &ScopeTreeNode) -> usize {
365    let mut count = usize::from(node.is_terminal);
366    for child in node.children.values() {
367        count += count_terminals(child);
368    }
369    count
370}
371
372pub struct TrieScopeResolver {
373    root: TrieNode,
374    rule_count: usize,
375}
376
377#[derive(Default)]
378struct TrieNode {
379    children: HashMap<String, TrieNode>,
380    rule: Option<(FontRule, usize)>,
381}
382
383impl TrieScopeResolver {
384    #[must_use]
385    pub fn new() -> Self {
386        Self {
387            root: TrieNode::default(),
388            rule_count: 0,
389        }
390    }
391
392    #[must_use]
393    pub fn from_rules(rules: Vec<FontRule>) -> Self {
394        let mut resolver = Self::new();
395        for rule in rules {
396            resolver.add_rule(rule);
397        }
398        resolver
399    }
400
401    pub fn add_rule(&mut self, rule: FontRule) {
402        let index = self.rule_count;
403        let mut node = &mut self.root;
404        for segment in rule.scope.split('.') {
405            node = node.children.entry(segment.to_owned()).or_default();
406        }
407        node.rule = Some((rule, index));
408        self.rule_count += 1;
409    }
410
411    #[must_use]
412    pub fn resolve(&self, scope: &str) -> Option<ResolvedScope> {
413        let mut node = &self.root;
414        let mut best: Option<(&FontRule, usize, usize)> = None;
415
416        for segment in scope.split('.') {
417            let next = node
418                .children
419                .get(segment)
420                .or_else(|| node.children.get("*"));
421
422            let Some(next) = next else {
423                break;
424            };
425
426            node = next;
427            if let Some((rule, rule_index)) = &node.rule {
428                let specificity = rule.specificity();
429                let should_replace = match &best {
430                    None => true,
431                    Some((_, _, best_spec)) => {
432                        specificity > *best_spec
433                            || (specificity == *best_spec
434                                && *rule_index < best.expect("checked above").1)
435                    }
436                };
437                if should_replace {
438                    best = Some((rule, *rule_index, specificity));
439                }
440            }
441        }
442
443        best.map(|(rule, rule_index, specificity)| ResolvedScope {
444            assignment: FontAssignment {
445                scope: scope.to_owned(),
446                font: rule.font.clone(),
447                specificity,
448                is_active: true,
449            },
450            rule_index,
451        })
452    }
453
454    pub fn resolve_all(&self, scopes: &[&str]) -> Vec<Option<ResolvedScope>> {
455        scopes.iter().map(|s| self.resolve(s)).collect()
456    }
457
458    #[must_use]
459    pub const fn rule_count(&self) -> usize {
460        self.rule_count
461    }
462}
463
464impl Default for TrieScopeResolver {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470#[derive(Debug, thiserror::Error)]
471pub enum ScopeError {
472    #[error("empty scope pattern")]
473    EmptyPattern,
474    #[error("invalid scope pattern: {0}")]
475    InvalidPattern(String),
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_exact_match() {
484        let pattern = ScopePattern::parse("entity.name.function").unwrap();
485        assert!(pattern.matches_scope("entity.name.function"));
486        assert!(!pattern.matches_scope("entity.name"));
487        assert!(pattern.matches_scope("entity.name.function.call"));
488    }
489
490    #[test]
491    fn test_hierarchical_match() {
492        let pattern = ScopePattern::parse("entity.name").unwrap();
493        assert!(pattern.matches_scope("entity.name"));
494        assert!(pattern.matches_scope("entity.name.function"));
495        assert!(pattern.matches_scope("entity.name.function.call"));
496        assert!(!pattern.matches_scope("entity.type"));
497    }
498
499    #[test]
500    fn test_top_level_scope() {
501        let pattern = ScopePattern::parse("entity").unwrap();
502        assert!(pattern.matches_scope("entity"));
503        assert!(pattern.matches_scope("entity.name"));
504        assert!(pattern.matches_scope("entity.name.function"));
505    }
506
507    #[test]
508    fn test_wildcard_match() {
509        let pattern = ScopePattern::parse("entity.*").unwrap();
510        assert!(pattern.matches_scope("entity.name"));
511        assert!(pattern.matches_scope("entity.name.function"));
512        assert!(pattern.matches_scope("entity.type"));
513        assert!(!pattern.matches_scope("keyword"));
514    }
515
516    #[test]
517    fn test_negative_match() {
518        let selector = ScopeSelector::parse("keyword,-keyword.operator").unwrap();
519        assert!(selector.matches("keyword.control"));
520        assert!(!selector.matches("keyword.operator"));
521    }
522
523    #[test]
524    fn test_comma_separated_or() {
525        let selector = ScopeSelector::parse("keyword,storage.type").unwrap();
526        assert!(selector.matches("keyword"));
527        assert!(selector.matches("storage.type"));
528        assert!(selector.matches("storage.type.function"));
529        assert!(!selector.matches("comment"));
530    }
531
532    #[test]
533    fn test_scope_matcher_convenience() {
534        assert!(ScopeMatcher::matches("entity.name.function", "entity.name").unwrap());
535        assert!(ScopeMatcher::matches("comment.line", "comment").unwrap());
536        assert!(!ScopeMatcher::matches("string.quoted", "comment").unwrap());
537    }
538
539    #[test]
540    fn test_scope_resolver_specificity() {
541        let rules = vec![
542            FontRule {
543                scope: "entity".to_owned(),
544                font: FontSpec {
545                    family: "font-a".to_owned(),
546                    fallbacks: vec![],
547                    weight: FontWeight::default(),
548                    style: FontStyle::default(),
549                    size: None,
550                    axes: vec![],
551                },
552            },
553            FontRule {
554                scope: "entity.name.function".to_owned(),
555                font: FontSpec {
556                    family: "font-b".to_owned(),
557                    fallbacks: vec![],
558                    weight: FontWeight::default(),
559                    style: FontStyle::default(),
560                    size: None,
561                    axes: vec![],
562                },
563            },
564        ];
565
566        let resolver = ScopeResolver::from_rules(rules);
567        let result = resolver.resolve("entity.name.function").unwrap();
568        assert_eq!(result.assignment.font.family, "font-b");
569        assert_eq!(result.assignment.specificity, 3);
570    }
571
572    #[test]
573    fn test_scope_resolver_tiebreak_by_order() {
574        let rules = vec![
575            FontRule {
576                scope: "keyword".to_owned(),
577                font: FontSpec {
578                    family: "first".to_owned(),
579                    fallbacks: vec![],
580                    weight: FontWeight::default(),
581                    style: FontStyle::default(),
582                    size: None,
583                    axes: vec![],
584                },
585            },
586            FontRule {
587                scope: "keyword".to_owned(),
588                font: FontSpec {
589                    family: "second".to_owned(),
590                    fallbacks: vec![],
591                    weight: FontWeight::default(),
592                    style: FontStyle::default(),
593                    size: None,
594                    axes: vec![],
595                },
596            },
597        ];
598
599        let resolver = ScopeResolver::from_rules(rules);
600        let result = resolver.resolve("keyword").unwrap();
601        assert_eq!(result.assignment.font.family, "first");
602    }
603
604    #[test]
605    fn test_scope_tree_insert_and_query() {
606        let mut tree = ScopeTree::new();
607        tree.insert("entity.name.function");
608        tree.insert("entity.name.type");
609        tree.insert("entity.name");
610        tree.insert("keyword.control");
611
612        assert!(tree.contains("entity.name.function"));
613        assert!(tree.contains("entity.name.type"));
614        assert!(tree.contains("entity.name"));
615        assert!(tree.contains("keyword.control"));
616        assert!(!tree.contains("comment"));
617
618        assert!(tree.has_prefix("entity"));
619        assert!(tree.has_prefix("entity.name"));
620        assert!(!tree.has_prefix("string"));
621    }
622
623    #[test]
624    fn test_scope_tree_prefix_query() {
625        let mut tree = ScopeTree::new();
626        tree.insert("entity.name.function");
627        tree.insert("entity.name.type");
628        tree.insert("entity.other");
629
630        let results = tree.query_prefix("entity.name");
631        assert_eq!(results.len(), 2);
632        assert!(results.contains(&"entity.name.function".to_owned()));
633        assert!(results.contains(&"entity.name.type".to_owned()));
634    }
635
636    #[test]
637    fn test_scope_tree_len() {
638        let mut tree = ScopeTree::new();
639        assert!(tree.is_empty());
640
641        tree.insert("entity.name");
642        tree.insert("keyword");
643        assert_eq!(tree.len(), 2);
644    }
645
646    #[test]
647    fn test_empty_pattern_error() {
648        assert!(ScopePattern::parse("").is_err());
649        assert!(ScopePattern::parse("  ").is_err());
650        assert!(ScopePattern::parse("-").is_err());
651        assert!(ScopeSelector::parse("").is_err());
652    }
653
654    #[test]
655    fn test_matches_any() {
656        assert!(ScopeMatcher::matches_any("keyword.control", &["comment", "keyword"],).unwrap());
657        assert!(!ScopeMatcher::matches_any("string.quoted", &["comment", "keyword"],).unwrap());
658    }
659
660    #[test]
661    fn test_trie_insert_and_resolve() {
662        let mut resolver = TrieScopeResolver::new();
663        resolver.add_rule(FontRule {
664            scope: "keyword".to_owned(),
665            font: FontSpec::default_font("mono"),
666        });
667        let result = resolver.resolve("keyword").unwrap();
668        assert_eq!(result.assignment.font.family, "mono");
669    }
670
671    #[test]
672    fn test_trie_specificity() {
673        let mut resolver = TrieScopeResolver::new();
674        resolver.add_rule(FontRule {
675            scope: "entity".to_owned(),
676            font: FontSpec::default_font("font-a"),
677        });
678        resolver.add_rule(FontRule {
679            scope: "entity.name.function".to_owned(),
680            font: FontSpec::default_font("font-b"),
681        });
682        let result = resolver.resolve("entity.name.function").unwrap();
683        assert_eq!(result.assignment.font.family, "font-b");
684    }
685
686    #[test]
687    fn test_trie_partial_match() {
688        let mut resolver = TrieScopeResolver::new();
689        resolver.add_rule(FontRule {
690            scope: "entity.name".to_owned(),
691            font: FontSpec::default_font("font-a"),
692        });
693        let result = resolver.resolve("entity.name.function").unwrap();
694        assert_eq!(result.assignment.font.family, "font-a");
695    }
696
697    #[test]
698    fn test_trie_empty() {
699        let resolver = TrieScopeResolver::new();
700        assert!(resolver.resolve("anything").is_none());
701    }
702
703    #[test]
704    fn test_trie_wildcard() {
705        let mut resolver = TrieScopeResolver::new();
706        resolver.add_rule(FontRule {
707            scope: "entity.*".to_owned(),
708            font: FontSpec::default_font("wildcard-font"),
709        });
710        let result = resolver.resolve("entity.name").unwrap();
711        assert_eq!(result.assignment.font.family, "wildcard-font");
712    }
713
714    #[test]
715    fn test_trie_from_rules() {
716        let rules = vec![
717            FontRule {
718                scope: "keyword".to_owned(),
719                font: FontSpec::default_font("mono"),
720            },
721            FontRule {
722                scope: "string".to_owned(),
723                font: FontSpec::default_font("serif"),
724            },
725        ];
726        let resolver = TrieScopeResolver::from_rules(rules);
727        let result = resolver.resolve("keyword").unwrap();
728        assert_eq!(result.assignment.font.family, "mono");
729    }
730
731    #[test]
732    fn test_trie_resolve_all() {
733        let rules = vec![
734            FontRule {
735                scope: "keyword".to_owned(),
736                font: FontSpec::default_font("mono"),
737            },
738            FontRule {
739                scope: "string".to_owned(),
740                font: FontSpec::default_font("serif"),
741            },
742        ];
743        let resolver = TrieScopeResolver::from_rules(rules);
744        let results = resolver.resolve_all(&["keyword", "string", "comment"]);
745        assert!(results[0].is_some());
746        assert!(results[1].is_some());
747        assert!(results[2].is_none());
748    }
749}