1mod error;
2mod selector;
3
4use fnv::FnvHashMap;
5use serde_derive::{self, Deserialize, Serialize};
6use std::{cmp, collections::HashMap, convert::TryFrom};
7
8use error::Result;
9use tree_sitter::Language;
10
11use crate::selector::{map_node_kind_names, Selector};
12
13pub use crate::selector::SelectorNodeId;
14
15#[derive(Clone, Debug, Deserialize, Serialize)]
16pub struct HighlightRules {
17 name: String,
18 node_id_to_selector_id: FnvHashMap<u16, SelectorNodeId>,
19
20 #[serde(default)]
21 rules: Vec<HighlightRule>,
22}
23
24#[derive(Clone, Debug, Deserialize, Serialize)]
25pub struct HighlightRule {
26 selectors: Vec<Selector>,
27 scope: ScopePattern,
28}
29
30impl HighlightRules {
31 #[inline]
32 pub fn get_selector_node_id(&self, node_kind_id: u16) -> SelectorNodeId {
33 self.node_id_to_selector_id
34 .get(&node_kind_id)
35 .copied()
36 .unwrap_or_else(|| {
37 SelectorNodeId(u16::try_from(self.node_id_to_selector_id.len()).unwrap())
38 })
39 }
40
41 #[inline]
42 pub fn matches(
43 &self,
44 node_stack: &[SelectorNodeId],
45 nth_children: &[u16],
46 content: &str,
47 ) -> Option<&Scope> {
48 if node_stack.is_empty() {
49 return None;
50 }
51
52 let mut distance_to_match = std::usize::MAX;
53 let mut num_nodes_match = 0;
54 let mut scope_pattern = None;
55 for rule in self.rules.iter() {
56 let rule_scope = match rule.scope.matches(content) {
57 Some(scope) => scope,
58 None => continue,
59 };
60
61 for selector in rule.selectors.iter() {
62 let selector_node_kinds = selector.node_kinds();
63 let selector_nth_children = selector.nth_children();
64
65 assert!(!selector_node_kinds.is_empty());
69 if selector_node_kinds.len() > node_stack.len() {
70 continue;
71 }
72
73 for start in 0..=cmp::min(
75 node_stack.len().saturating_sub(selector_node_kinds.len()),
76 distance_to_match,
77 ) {
78 let span_range = || start..start + selector_node_kinds.len();
79
80 if selector_node_kinds
82 != &node_stack[start..(start + selector_node_kinds.len())]
83 {
84 continue;
85 }
86
87 let nth_child_not_satisfied = selector_nth_children
89 .iter()
90 .zip(nth_children[span_range()].iter())
91 .any(|(&nth_child_selector, &node_sibling_index)| {
92 nth_child_selector >= 0
93 && nth_child_selector as u16 != node_sibling_index
94 });
95 if nth_child_not_satisfied {
96 continue;
97 }
98
99 if start == distance_to_match && num_nodes_match > selector_node_kinds.len() {
102 break;
103 }
104
105 assert!(start <= distance_to_match);
106 distance_to_match = start;
115 num_nodes_match = selector_node_kinds.len();
116 scope_pattern = Some(rule_scope);
117 break;
118 }
119 }
120 }
121
122 scope_pattern
123 }
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct RawHighlightRules {
128 name: String,
129
130 #[serde(default)]
131 pub scopes: HashMap<String, ScopePattern>,
132}
133
134impl RawHighlightRules {
135 fn compile(self, language: Language) -> Result<HighlightRules> {
136 let (node_name_to_selector_id, node_id_to_selector_id) =
137 build_node_to_selector_id_maps(language);
138 let RawHighlightRules { name, scopes } = self;
139
140 scopes
141 .into_iter()
142 .map(|(selector_str, scope)| {
143 let selectors = selector::parse(&selector_str)?;
144 let selectors = selectors
145 .into_iter()
146 .map(|selector| map_node_kind_names(&node_name_to_selector_id, selector))
147 .collect::<Result<Vec<_>>>()?;
148 Ok(HighlightRule { selectors, scope })
149 })
150 .collect::<Result<Vec<_>>>()
151 .map(|rules| HighlightRules {
152 name,
153 rules,
154 node_id_to_selector_id,
155 })
156 }
157}
158
159fn build_node_to_selector_id_maps(
160 language: Language,
161) -> (
162 FnvHashMap<&'static str, SelectorNodeId>,
163 FnvHashMap<u16, SelectorNodeId>,
164) {
165 let mut node_name_to_selector_id =
166 FnvHashMap::with_capacity_and_hasher(language.node_kind_count(), Default::default());
167 let mut node_id_to_selector_id =
168 FnvHashMap::with_capacity_and_hasher(language.node_kind_count(), Default::default());
169
170 let node_id_range =
171 0..u16::try_from(language.node_kind_count()).expect("node_kind_count() should fit in u16");
172 for node_id in node_id_range {
173 let node_name = language
174 .node_kind_for_id(node_id)
175 .expect("node kind available for node_id in range");
176 let next_selector_id =
177 SelectorNodeId(u16::try_from(node_name_to_selector_id.len()).unwrap());
178 let selector_id = node_name_to_selector_id
179 .entry(node_name)
180 .or_insert_with(|| next_selector_id);
181 node_id_to_selector_id.insert(node_id, *selector_id);
182 }
183
184 (node_name_to_selector_id, node_id_to_selector_id)
192}
193
194#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
195#[serde(untagged)]
196pub enum ScopePattern {
197 All(Scope),
198 Exact {
199 exact: String,
200 scopes: Scope,
201 },
202 Regex {
203 #[serde(rename = "match")]
204 regex: Regex,
205 scopes: Scope,
206 },
207 Vec(Vec<ScopePattern>),
208}
209
210#[derive(Clone, Debug, Deserialize, Serialize)]
211pub struct Regex(#[serde(with = "serde_regex")] regex::Regex);
212
213impl Regex {
214 fn is_match(&self, text: &str) -> bool {
215 self.0.is_match(text)
216 }
217}
218
219impl PartialEq for Regex {
220 fn eq(&self, other: &Self) -> bool {
221 self.0.as_str() == other.0.as_str()
222 }
223}
224
225impl ScopePattern {
226 fn matches(&self, content: &str) -> Option<&Scope> {
227 match self {
228 ScopePattern::All(ref scopes) => Some(scopes),
229 ScopePattern::Exact {
230 ref exact,
231 ref scopes,
232 } if exact.as_str() == content => Some(scopes),
233 ScopePattern::Regex {
234 ref regex,
235 ref scopes,
236 } if regex.is_match(content) => Some(scopes),
237 ScopePattern::Vec(ref scope_patterns) => {
238 for scope_pattern in scope_patterns.iter() {
239 let maybe_scope = scope_pattern.matches(content);
240 if maybe_scope.is_some() {
241 return maybe_scope;
242 }
243 }
244 None
245 }
246 _ => None,
247 }
248 }
249}
250
251#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
252pub struct Scope(pub String);
253
254pub fn parse_rules_unwrap(language: Language, source: &str) -> HighlightRules {
255 let raw_rules =
256 serde_json::from_str::<RawHighlightRules>(source).expect("valid json file for rules");
257 let name = format!("valid rules for {}", raw_rules.name);
258 raw_rules.compile(language).expect(&name)
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use maplit::hashmap;
265
266 #[test]
267 fn deserialize_no_scopes() {
268 let style_str = r#"{"name": "Rust"}"#;
269 let expected = RawHighlightRules {
270 name: "Rust".into(),
271 scopes: Default::default(),
272 };
273 let actual: RawHighlightRules = serde_json::from_str(style_str).expect("valid json");
274 assert_eq!(expected.name, actual.name);
275 }
276
277 #[test]
278 fn deserialize_all_scope_types() {
279 let style_str = r#"{
280 "name": "Rust",
281 "scopes": {
282 "type_identifier": "support.type",
283 "\"let\"": {"exact": "let", "scopes": "keyword.control" }
284 }
285 }"#;
286 let expected = RawHighlightRules {
287 name: "Rust".into(),
288 scopes: hashmap! {
289 "type_identifier".into() => ScopePattern::All(Scope("support.type".into())),
290 "\"let\"".into() => ScopePattern::Exact {
291 exact: "let".into(),
292 scopes: Scope("keyword.control".into())
293 },
294 },
295 };
296 let actual: RawHighlightRules = serde_json::from_str(style_str).expect("valid json");
297 assert_eq!(expected.name, actual.name);
298 assert_eq!(expected.scopes, actual.scopes);
299 }
300}