Skip to main content

polyfont_scope/
lib.rs

1use std::collections::{BTreeMap, HashMap};
2
3use polyfont_core::{FontAssignment, FontRule};
4
5#[cfg(test)]
6use polyfont_core::{FontSpec, FontStyle, FontWeight};
7
8pub mod constants {
9    pub const SCOPE_KEYWORD: &str = "keyword";
10    pub const SCOPE_COMMENT: &str = "comment";
11    pub const SCOPE_STRING: &str = "string";
12    pub const SCOPE_FUNCTION: &str = "entity.name.function";
13    pub const SCOPE_VARIABLE: &str = "variable";
14    pub const SCOPE_CONSTANT: &str = "constant";
15    pub const SCOPE_TYPE: &str = "entity.name.type";
16    pub const SCOPE_NUMBER: &str = "constant.numeric";
17    pub const SCOPE_OPERATOR: &str = "keyword.operator";
18    pub const SCOPE_PUNCTUATION: &str = "punctuation";
19    pub const SCOPE_TAG: &str = "entity.name.tag";
20    pub const SCOPE_ATTRIBUTE: &str = "entity.other.attribute-name";
21}
22
23pub use constants::*;
24
25#[derive(Debug, Clone)]
26pub struct ScopePattern {
27    segments: Vec<PatternSegment>,
28    negated: bool,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32enum PatternSegment {
33    Literal(String),
34    Wildcard,
35}
36
37impl ScopePattern {
38    #[allow(clippy::missing_errors_doc)]
39    pub fn parse(pattern: &str) -> Result<Self, ScopeError> {
40        let trimmed = pattern.trim();
41        if trimmed.is_empty() {
42            return Err(ScopeError::EmptyPattern);
43        }
44
45        let (negated, inner) = trimmed
46            .strip_prefix('-')
47            .map_or((false, trimmed), |rest| (true, rest.trim()));
48
49        if inner.is_empty() {
50            return Err(ScopeError::EmptyPattern);
51        }
52
53        let segments = inner
54            .split('.')
55            .map(|s| {
56                if s == "*" {
57                    PatternSegment::Wildcard
58                } else {
59                    PatternSegment::Literal(s.to_owned())
60                }
61            })
62            .collect();
63
64        Ok(Self { segments, negated })
65    }
66
67    #[must_use]
68    pub fn matches_scope(&self, scope: &str) -> bool {
69        self.matches_raw(scope)
70    }
71
72    #[must_use]
73    pub fn matches_raw(&self, scope: &str) -> bool {
74        let scope_parts: Vec<&str> = scope.split('.').collect();
75        if scope_parts.len() < self.segments.len() {
76            return false;
77        }
78
79        for (i, seg) in self.segments.iter().enumerate() {
80            match seg {
81                PatternSegment::Wildcard => {}
82                PatternSegment::Literal(lit) => {
83                    if scope_parts[i] != lit {
84                        return false;
85                    }
86                }
87            }
88        }
89
90        true
91    }
92
93    #[must_use]
94    pub fn specificity(&self) -> usize {
95        self.segments
96            .iter()
97            .filter(|s| **s != PatternSegment::Wildcard)
98            .count()
99    }
100}
101
102#[derive(Debug, Clone)]
103pub struct ScopeSelector {
104    patterns: Vec<ScopePattern>,
105}
106
107impl ScopeSelector {
108    #[allow(clippy::missing_errors_doc)]
109    pub fn parse(selector: &str) -> Result<Self, ScopeError> {
110        let patterns = selector
111            .split(',')
112            .filter_map(|s| {
113                let trimmed = s.trim();
114                if trimmed.is_empty() {
115                    None
116                } else {
117                    Some(ScopePattern::parse(trimmed))
118                }
119            })
120            .collect::<Result<Vec<_>, _>>()?;
121
122        if patterns.is_empty() {
123            return Err(ScopeError::EmptyPattern);
124        }
125
126        Ok(Self { patterns })
127    }
128
129    #[must_use]
130    pub fn matches(&self, scope: &str) -> bool {
131        let positive_matches: Vec<&ScopePattern> =
132            self.patterns.iter().filter(|p| !p.negated).collect();
133
134        let negative_patterns: Vec<&ScopePattern> =
135            self.patterns.iter().filter(|p| p.negated).collect();
136
137        for neg in &negative_patterns {
138            if neg.matches_raw(scope) {
139                return false;
140            }
141        }
142
143        if positive_matches.is_empty() && !negative_patterns.is_empty() {
144            return true;
145        }
146
147        positive_matches.iter().any(|p| p.matches_scope(scope))
148    }
149
150    #[must_use]
151    pub fn specificity(&self) -> usize {
152        self.patterns
153            .iter()
154            .filter(|p| !p.negated)
155            .map(ScopePattern::specificity)
156            .max()
157            .unwrap_or(0)
158    }
159}
160
161pub struct ScopeMatcher;
162
163impl ScopeMatcher {
164    #[allow(clippy::missing_errors_doc)]
165    pub fn matches(scope: &str, selector: &str) -> Result<bool, ScopeError> {
166        let sel = ScopeSelector::parse(selector)?;
167        Ok(sel.matches(scope))
168    }
169
170    #[allow(clippy::missing_errors_doc)]
171    pub fn matches_any(scope: &str, selectors: &[&str]) -> Result<bool, ScopeError> {
172        for selector in selectors {
173            let sel = ScopeSelector::parse(selector)?;
174            if sel.matches(scope) {
175                return Ok(true);
176            }
177        }
178        Ok(false)
179    }
180}
181
182#[derive(Debug, Clone)]
183pub struct ResolvedScope {
184    pub assignment: FontAssignment,
185    pub rule_index: usize,
186}
187
188pub struct ScopeResolver {
189    rules: Vec<(FontRule, usize)>,
190}
191
192impl ScopeResolver {
193    #[must_use]
194    pub const fn new() -> Self {
195        Self { rules: Vec::new() }
196    }
197
198    #[must_use]
199    pub fn from_rules(rules: Vec<FontRule>) -> Self {
200        let indexed: Vec<(FontRule, usize)> =
201            rules.into_iter().enumerate().map(|(i, r)| (r, i)).collect();
202        Self { rules: indexed }
203    }
204
205    pub fn add_rule(&mut self, rule: FontRule) {
206        let index = self.rules.len();
207        self.rules.push((rule, index));
208    }
209
210    #[must_use]
211    #[allow(clippy::missing_panics_doc)]
212    pub fn resolve(&self, scope: &str) -> Option<ResolvedScope> {
213        let mut best: Option<(&FontRule, usize, usize)> = None;
214
215        for (rule, rule_index) in &self.rules {
216            if let Ok(selector) = ScopeSelector::parse(&rule.scope)
217                && selector.matches(scope)
218            {
219                let specificity = selector.specificity();
220                let should_replace = match &best {
221                    None => true,
222                    Some((_, _, best_spec)) => {
223                        specificity > *best_spec
224                            || (specificity == *best_spec
225                                && *rule_index < best.expect("checked above").1)
226                    }
227                };
228                if should_replace {
229                    best = Some((rule, *rule_index, specificity));
230                }
231            }
232        }
233
234        best.map(|(rule, rule_index, specificity)| ResolvedScope {
235            assignment: FontAssignment {
236                scope: scope.to_owned(),
237                font: rule.font.clone(),
238                specificity,
239                is_active: true,
240            },
241            rule_index,
242        })
243    }
244
245    pub fn resolve_all<'a, I>(&self, scopes: I) -> Vec<Option<ResolvedScope>>
246    where
247        I: IntoIterator<Item = &'a str>,
248    {
249        scopes.into_iter().map(|s| self.resolve(s)).collect()
250    }
251
252    pub fn clear(&mut self) {
253        self.rules.clear();
254    }
255
256    #[must_use]
257    pub const fn rule_count(&self) -> usize {
258        self.rules.len()
259    }
260}
261
262impl Default for ScopeResolver {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268#[derive(Debug, Clone, Default)]
269pub struct ScopeTreeNode {
270    children: BTreeMap<String, Self>,
271    is_terminal: bool,
272}
273
274pub struct ScopeTree {
275    root: ScopeTreeNode,
276}
277
278impl ScopeTree {
279    #[must_use]
280    pub fn new() -> Self {
281        Self {
282            root: ScopeTreeNode::default(),
283        }
284    }
285
286    pub fn insert(&mut self, scope: &str) {
287        let mut node = &mut self.root;
288        for segment in scope.split('.') {
289            node = node.children.entry(segment.to_owned()).or_default();
290        }
291        node.is_terminal = true;
292    }
293
294    #[must_use]
295    pub fn contains(&self, scope: &str) -> bool {
296        let mut node = &self.root;
297        for segment in scope.split('.') {
298            match node.children.get(segment) {
299                Some(child) => node = child,
300                None => return false,
301            }
302        }
303        node.is_terminal
304    }
305
306    #[must_use]
307    pub fn has_prefix(&self, prefix: &str) -> bool {
308        let mut node = &self.root;
309        for segment in prefix.split('.') {
310            match node.children.get(segment) {
311                Some(child) => node = child,
312                None => return false,
313            }
314        }
315        true
316    }
317
318    #[must_use]
319    pub fn query_prefix(&self, prefix: &str) -> Vec<String> {
320        let mut node = &self.root;
321        for segment in prefix.split('.') {
322            match node.children.get(segment) {
323                Some(child) => node = child,
324                None => return Vec::new(),
325            }
326        }
327
328        let mut results = Vec::new();
329        collect_scopes(node, prefix, &mut results);
330        results
331    }
332
333    #[must_use]
334    pub fn len(&self) -> usize {
335        count_terminals(&self.root)
336    }
337
338    #[must_use]
339    pub fn is_empty(&self) -> bool {
340        !self.root.is_terminal && self.root.children.is_empty()
341    }
342}
343
344impl Default for ScopeTree {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350fn collect_scopes(node: &ScopeTreeNode, prefix: &str, results: &mut Vec<String>) {
351    if node.is_terminal {
352        results.push(prefix.to_owned());
353    }
354    for (name, child) in &node.children {
355        let child_path = if prefix.is_empty() {
356            name.clone()
357        } else {
358            format!("{prefix}.{name}")
359        };
360        collect_scopes(child, &child_path, results);
361    }
362}
363
364fn count_terminals(node: &ScopeTreeNode) -> usize {
365    let mut count = usize::from(node.is_terminal);
366    for child in node.children.values() {
367        count += count_terminals(child);
368    }
369    count
370}
371
372pub struct TrieScopeResolver {
373    root: TrieNode,
374    rule_count: usize,
375}
376
377#[derive(Default)]
378struct TrieNode {
379    children: HashMap<String, TrieNode>,
380    rule: Option<(FontRule, usize)>,
381}
382
383impl TrieScopeResolver {
384    #[must_use]
385    pub fn new() -> Self {
386        Self {
387            root: TrieNode::default(),
388            rule_count: 0,
389        }
390    }
391
392    #[must_use]
393    pub fn from_rules(rules: Vec<FontRule>) -> Self {
394        let mut resolver = Self::new();
395        for rule in rules {
396            resolver.add_rule(rule);
397        }
398        resolver
399    }
400
401    pub fn add_rule(&mut self, rule: FontRule) {
402        let index = self.rule_count;
403        let mut node = &mut self.root;
404        for segment in rule.scope.split('.') {
405            node = node.children.entry(segment.to_owned()).or_default();
406        }
407        node.rule = Some((rule, index));
408        self.rule_count += 1;
409    }
410
411    #[must_use]
412    pub fn resolve(&self, scope: &str) -> Option<ResolvedScope> {
413        let mut node = &self.root;
414        let mut best: Option<(&FontRule, usize, usize)> = None;
415
416        for segment in scope.split('.') {
417            let next = node
418                .children
419                .get(segment)
420                .or_else(|| node.children.get("*"));
421
422            let Some(next) = next else {
423                break;
424            };
425
426            node = next;
427            if let Some((rule, rule_index)) = &node.rule {
428                let specificity = rule.specificity();
429                let should_replace = match &best {
430                    None => true,
431                    Some((_, _, best_spec)) => {
432                        specificity > *best_spec
433                            || (specificity == *best_spec
434                                && *rule_index < best.expect("checked above").1)
435                    }
436                };
437                if should_replace {
438                    best = Some((rule, *rule_index, specificity));
439                }
440            }
441        }
442
443        best.map(|(rule, rule_index, specificity)| ResolvedScope {
444            assignment: FontAssignment {
445                scope: scope.to_owned(),
446                font: rule.font.clone(),
447                specificity,
448                is_active: true,
449            },
450            rule_index,
451        })
452    }
453
454    pub fn resolve_all(&self, scopes: &[&str]) -> Vec<Option<ResolvedScope>> {
455        scopes.iter().map(|s| self.resolve(s)).collect()
456    }
457
458    #[must_use]
459    pub const fn rule_count(&self) -> usize {
460        self.rule_count
461    }
462}
463
464impl Default for TrieScopeResolver {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470#[derive(Debug, thiserror::Error)]
471pub enum ScopeError {
472    #[error("empty scope pattern")]
473    EmptyPattern,
474    #[error("invalid scope pattern: {0}")]
475    InvalidPattern(String),
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_exact_match() {
484        let pattern = ScopePattern::parse("entity.name.function").unwrap();
485        assert!(pattern.matches_scope("entity.name.function"));
486        assert!(!pattern.matches_scope("entity.name"));
487        assert!(pattern.matches_scope("entity.name.function.call"));
488    }
489
490    #[test]
491    fn test_hierarchical_match() {
492        let pattern = ScopePattern::parse("entity.name").unwrap();
493        assert!(pattern.matches_scope("entity.name"));
494        assert!(pattern.matches_scope("entity.name.function"));
495        assert!(pattern.matches_scope("entity.name.function.call"));
496        assert!(!pattern.matches_scope("entity.type"));
497    }
498
499    #[test]
500    fn test_top_level_scope() {
501        let pattern = ScopePattern::parse("entity").unwrap();
502        assert!(pattern.matches_scope("entity"));
503        assert!(pattern.matches_scope("entity.name"));
504        assert!(pattern.matches_scope("entity.name.function"));
505    }
506
507    #[test]
508    fn test_wildcard_match() {
509        let pattern = ScopePattern::parse("entity.*").unwrap();
510        assert!(pattern.matches_scope("entity.name"));
511        assert!(pattern.matches_scope("entity.name.function"));
512        assert!(pattern.matches_scope("entity.type"));
513        assert!(!pattern.matches_scope("keyword"));
514    }
515
516    #[test]
517    fn test_negative_match() {
518        let selector = ScopeSelector::parse("keyword,-keyword.operator").unwrap();
519        assert!(selector.matches("keyword.control"));
520        assert!(!selector.matches("keyword.operator"));
521    }
522
523    #[test]
524    fn test_comma_separated_or() {
525        let selector = ScopeSelector::parse("keyword,storage.type").unwrap();
526        assert!(selector.matches("keyword"));
527        assert!(selector.matches("storage.type"));
528        assert!(selector.matches("storage.type.function"));
529        assert!(!selector.matches("comment"));
530    }
531
532    #[test]
533    fn test_scope_matcher_convenience() {
534        assert!(ScopeMatcher::matches("entity.name.function", "entity.name").unwrap());
535        assert!(ScopeMatcher::matches("comment.line", "comment").unwrap());
536        assert!(!ScopeMatcher::matches("string.quoted", "comment").unwrap());
537    }
538
539    #[test]
540    fn test_scope_resolver_specificity() {
541        let rules = vec![
542            FontRule {
543                scope: "entity".to_owned(),
544                font: FontSpec {
545                    family: "font-a".to_owned(),
546                    fallbacks: vec![],
547                    weight: FontWeight::default(),
548                    style: FontStyle::default(),
549                    size: None,
550                },
551            },
552            FontRule {
553                scope: "entity.name.function".to_owned(),
554                font: FontSpec {
555                    family: "font-b".to_owned(),
556                    fallbacks: vec![],
557                    weight: FontWeight::default(),
558                    style: FontStyle::default(),
559                    size: None,
560                },
561            },
562        ];
563
564        let resolver = ScopeResolver::from_rules(rules);
565        let result = resolver.resolve("entity.name.function").unwrap();
566        assert_eq!(result.assignment.font.family, "font-b");
567        assert_eq!(result.assignment.specificity, 3);
568    }
569
570    #[test]
571    fn test_scope_resolver_tiebreak_by_order() {
572        let rules = vec![
573            FontRule {
574                scope: "keyword".to_owned(),
575                font: FontSpec {
576                    family: "first".to_owned(),
577                    fallbacks: vec![],
578                    weight: FontWeight::default(),
579                    style: FontStyle::default(),
580                    size: None,
581                },
582            },
583            FontRule {
584                scope: "keyword".to_owned(),
585                font: FontSpec {
586                    family: "second".to_owned(),
587                    fallbacks: vec![],
588                    weight: FontWeight::default(),
589                    style: FontStyle::default(),
590                    size: None,
591                },
592            },
593        ];
594
595        let resolver = ScopeResolver::from_rules(rules);
596        let result = resolver.resolve("keyword").unwrap();
597        assert_eq!(result.assignment.font.family, "first");
598    }
599
600    #[test]
601    fn test_scope_tree_insert_and_query() {
602        let mut tree = ScopeTree::new();
603        tree.insert("entity.name.function");
604        tree.insert("entity.name.type");
605        tree.insert("entity.name");
606        tree.insert("keyword.control");
607
608        assert!(tree.contains("entity.name.function"));
609        assert!(tree.contains("entity.name.type"));
610        assert!(tree.contains("entity.name"));
611        assert!(tree.contains("keyword.control"));
612        assert!(!tree.contains("comment"));
613
614        assert!(tree.has_prefix("entity"));
615        assert!(tree.has_prefix("entity.name"));
616        assert!(!tree.has_prefix("string"));
617    }
618
619    #[test]
620    fn test_scope_tree_prefix_query() {
621        let mut tree = ScopeTree::new();
622        tree.insert("entity.name.function");
623        tree.insert("entity.name.type");
624        tree.insert("entity.other");
625
626        let results = tree.query_prefix("entity.name");
627        assert_eq!(results.len(), 2);
628        assert!(results.contains(&"entity.name.function".to_owned()));
629        assert!(results.contains(&"entity.name.type".to_owned()));
630    }
631
632    #[test]
633    fn test_scope_tree_len() {
634        let mut tree = ScopeTree::new();
635        assert!(tree.is_empty());
636
637        tree.insert("entity.name");
638        tree.insert("keyword");
639        assert_eq!(tree.len(), 2);
640    }
641
642    #[test]
643    fn test_empty_pattern_error() {
644        assert!(ScopePattern::parse("").is_err());
645        assert!(ScopePattern::parse("  ").is_err());
646        assert!(ScopePattern::parse("-").is_err());
647        assert!(ScopeSelector::parse("").is_err());
648    }
649
650    #[test]
651    fn test_matches_any() {
652        assert!(ScopeMatcher::matches_any("keyword.control", &["comment", "keyword"],).unwrap());
653        assert!(!ScopeMatcher::matches_any("string.quoted", &["comment", "keyword"],).unwrap());
654    }
655
656    #[test]
657    fn test_trie_insert_and_resolve() {
658        let mut resolver = TrieScopeResolver::new();
659        resolver.add_rule(FontRule {
660            scope: "keyword".to_owned(),
661            font: FontSpec::default_font("mono"),
662        });
663        let result = resolver.resolve("keyword").unwrap();
664        assert_eq!(result.assignment.font.family, "mono");
665    }
666
667    #[test]
668    fn test_trie_specificity() {
669        let mut resolver = TrieScopeResolver::new();
670        resolver.add_rule(FontRule {
671            scope: "entity".to_owned(),
672            font: FontSpec::default_font("font-a"),
673        });
674        resolver.add_rule(FontRule {
675            scope: "entity.name.function".to_owned(),
676            font: FontSpec::default_font("font-b"),
677        });
678        let result = resolver.resolve("entity.name.function").unwrap();
679        assert_eq!(result.assignment.font.family, "font-b");
680    }
681
682    #[test]
683    fn test_trie_partial_match() {
684        let mut resolver = TrieScopeResolver::new();
685        resolver.add_rule(FontRule {
686            scope: "entity.name".to_owned(),
687            font: FontSpec::default_font("font-a"),
688        });
689        let result = resolver.resolve("entity.name.function").unwrap();
690        assert_eq!(result.assignment.font.family, "font-a");
691    }
692
693    #[test]
694    fn test_trie_empty() {
695        let resolver = TrieScopeResolver::new();
696        assert!(resolver.resolve("anything").is_none());
697    }
698
699    #[test]
700    fn test_trie_wildcard() {
701        let mut resolver = TrieScopeResolver::new();
702        resolver.add_rule(FontRule {
703            scope: "entity.*".to_owned(),
704            font: FontSpec::default_font("wildcard-font"),
705        });
706        let result = resolver.resolve("entity.name").unwrap();
707        assert_eq!(result.assignment.font.family, "wildcard-font");
708    }
709
710    #[test]
711    fn test_trie_from_rules() {
712        let rules = vec![
713            FontRule {
714                scope: "keyword".to_owned(),
715                font: FontSpec::default_font("mono"),
716            },
717            FontRule {
718                scope: "string".to_owned(),
719                font: FontSpec::default_font("serif"),
720            },
721        ];
722        let resolver = TrieScopeResolver::from_rules(rules);
723        let result = resolver.resolve("keyword").unwrap();
724        assert_eq!(result.assignment.font.family, "mono");
725    }
726
727    #[test]
728    fn test_trie_resolve_all() {
729        let rules = vec![
730            FontRule {
731                scope: "keyword".to_owned(),
732                font: FontSpec::default_font("mono"),
733            },
734            FontRule {
735                scope: "string".to_owned(),
736                font: FontSpec::default_font("serif"),
737            },
738        ];
739        let resolver = TrieScopeResolver::from_rules(rules);
740        let results = resolver.resolve_all(&["keyword", "string", "comment"]);
741        assert!(results[0].is_some());
742        assert!(results[1].is_some());
743        assert!(results[2].is_none());
744    }
745}