Skip to main content

thread_language/
html.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use super::pre_process_pattern;
8use thread_ast_engine::Language;
9#[cfg(feature = "matching")]
10use thread_ast_engine::matcher::KindMatcher;
11#[cfg(feature = "matching")]
12use thread_ast_engine::matcher::{Pattern, PatternBuilder, PatternError};
13use thread_ast_engine::tree_sitter::{LanguageExt, TSLanguage};
14#[cfg(feature = "matching")]
15use thread_ast_engine::tree_sitter::{StrDoc, TSRange};
16#[cfg(feature = "matching")]
17use thread_ast_engine::{Doc, Node};
18#[cfg(feature = "html-embedded")]
19use thread_utilities::RapidMap;
20
21/// HTML language implementation with language injection capabilities.
22///
23/// Uses `z` as the expando character for metavariables since HTML attributes
24/// and tag names have specific character restrictions.
25///
26/// ## Language Injection
27///
28/// Automatically detects and extracts embedded languages:
29/// - **JavaScript** in `<script>` elements (default when no `lang` attribute)
30/// - **CSS** in `<style>` elements (default when no `lang` attribute)
31/// - **Other languages** when specified via `lang` attribute
32///
33/// ## Examples
34///
35/// ```rust
36/// use thread_language::Html;
37/// use thread_ast_engine::{Language, LanguageExt};
38///
39/// let html = Html;
40/// let source = r#"
41/// <script>console.log('hello');</script>
42/// <style>.class { color: red; }</style>
43/// <script lang="ts">const x: number = 42;</script>
44/// "#;
45///
46/// let tree = html.ast_grep(source);
47/// let injections = html.extract_injections(tree.root());
48/// // injections contains JavaScript, CSS, and TypeScript ranges
49/// ```
50///
51/// ## Note
52/// tree-sitter-html uses locale-dependent `iswalnum` for tag names.
53/// See: <https://github.com/tree-sitter/tree-sitter-html/blob/b5d9758e22b4d3d25704b72526670759a9e4d195/src/scanner.c#L194>
54#[derive(Clone, Copy, Debug)]
55pub struct Html;
56impl Language for Html {
57    fn expando_char(&self) -> char {
58        'z'
59    }
60    fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
61        pre_process_pattern(self.expando_char(), query)
62    }
63    fn kind_to_id(&self, kind: &str) -> u16 {
64        crate::parsers::language_html().id_for_node_kind(kind, true)
65    }
66    fn field_to_id(&self, field: &str) -> Option<u16> {
67        crate::parsers::language_html()
68            .field_id_for_name(field)
69            .map(|f| f.get())
70    }
71    #[cfg(feature = "matching")]
72    fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
73        builder.build(|src| StrDoc::try_new(src, *self))
74    }
75}
76impl LanguageExt for Html {
77    fn get_ts_language(&self) -> TSLanguage {
78        crate::parsers::language_html()
79    }
80    fn injectable_languages(&self) -> Option<&'static [&'static str]> {
81        Some(&["css", "js", "ts", "tsx", "scss", "less", "stylus", "coffee"])
82    }
83    #[cfg(feature = "html-embedded")]
84    fn extract_injections<L: LanguageExt>(
85        &self,
86        root: Node<StrDoc<L>>,
87    ) -> RapidMap<String, Vec<TSRange>> {
88        let lang = root.lang();
89        let mut map = RapidMap::default();
90
91        // Pre-allocate common language vectors to avoid repeated allocations
92        let mut js_ranges = Vec::new();
93        let mut css_ranges = Vec::new();
94        let mut other_ranges: RapidMap<String, Vec<TSRange>> = RapidMap::default();
95
96        // Process script elements
97        let script_matcher = KindMatcher::new("script_element", lang);
98        for script in root.find_all(script_matcher) {
99            if let Some(content) = script.children().find(|c| c.kind() == "raw_text") {
100                let range = node_to_range(&content);
101
102                // Fast path for common languages
103                match find_lang(&script) {
104                    Some(lang_name) => {
105                        if lang_name == "js" || lang_name == "javascript" {
106                            js_ranges.push(range);
107                        } else {
108                            other_ranges.entry(lang_name).or_default().push(range);
109                        }
110                    }
111                    None => js_ranges.push(range), // Default to JavaScript
112                }
113            }
114        }
115
116        // Process style elements
117        let style_matcher = KindMatcher::new("style_element", lang);
118        for style in root.find_all(style_matcher) {
119            if let Some(content) = style.children().find(|c| c.kind() == "raw_text") {
120                let range = node_to_range(&content);
121
122                // Fast path for CSS (most common)
123                match find_lang(&style) {
124                    Some(lang_name) => {
125                        if lang_name == "css" {
126                            css_ranges.push(range);
127                        } else {
128                            other_ranges.entry(lang_name).or_default().push(range);
129                        }
130                    }
131                    None => css_ranges.push(range), // Default to CSS
132                }
133            }
134        }
135
136        // Only insert non-empty vectors to reduce map size
137        if !js_ranges.is_empty() {
138            map.insert("js".to_string(), js_ranges);
139        }
140        if !css_ranges.is_empty() {
141            map.insert("css".to_string(), css_ranges);
142        }
143
144        // Merge other languages
145        for (lang_name, ranges) in other_ranges {
146            if !ranges.is_empty() {
147                map.insert(lang_name, ranges);
148            }
149        }
150
151        map
152    }
153}
154
155#[cfg(feature = "html-embedded")]
156fn find_lang<D: Doc>(node: &Node<D>) -> Option<String> {
157    let html = node.lang();
158    let attr_matcher = KindMatcher::new("attribute", html);
159    let name_matcher = KindMatcher::new("attribute_name", html);
160    let val_matcher = KindMatcher::new("attribute_value", html);
161    node.find_all(attr_matcher).find_map(|attr| {
162        let name = attr.find(&name_matcher)?;
163        if name.text() != "lang" {
164            return None;
165        }
166        let val = attr.find(&val_matcher)?;
167        Some(val.text().to_string())
168    })
169}
170#[cfg(feature = "matching")]
171fn node_to_range<D: Doc>(node: &Node<D>) -> TSRange {
172    let r = node.range();
173    let start = node.start_pos();
174    let sp = start.byte_point();
175    let sp = tree_sitter::Point::new(sp.0, sp.1);
176    let end = node.end_pos();
177    let ep = end.byte_point();
178    let ep = tree_sitter::Point::new(ep.0, ep.1);
179    TSRange {
180        start_byte: r.start,
181        end_byte: r.end,
182        start_point: sp,
183        end_point: ep,
184    }
185}
186
187#[cfg(test)]
188mod test {
189    use super::*;
190
191    fn test_match(query: &str, source: &str) {
192        use crate::test::test_match_lang;
193        test_match_lang(query, source, Html);
194    }
195
196    fn test_non_match(query: &str, source: &str) {
197        use crate::test::test_non_match_lang;
198        test_non_match_lang(query, source, Html);
199    }
200
201    #[test]
202    fn test_html_match() {
203        test_match("<input>", "<input>");
204        test_match("<$TAG>", "<input>");
205        test_match("<$TAG class='foo'>$$$</$TAG>", "<div class='foo'></div>");
206        test_match("<div>$$$</div>", "<div>123</div>");
207        test_non_match("<$TAG class='foo'>$$$</$TAG>", "<div></div>");
208        test_non_match("<div>$$$</div>", "<div class='foo'>123</div>");
209    }
210
211    fn test_replace(src: &str, pattern: &str, replacer: &str) -> String {
212        use crate::test::test_replace_lang;
213        test_replace_lang(src, pattern, replacer, Html)
214    }
215
216    #[test]
217    fn test_html_replace() {
218        let ret = test_replace(
219            r#"<div class='foo'>bar</div>"#,
220            r#"<$TAG class='foo'>$$$B</$TAG>"#,
221            r#"<$TAG class='$$$B'>foo</$TAG>"#,
222        );
223        assert_eq!(ret, r#"<div class='bar'>foo</div>"#);
224    }
225
226    fn extract(src: &str) -> RapidMap<String, Vec<TSRange>> {
227        let root = Html.ast_grep(src);
228        Html.extract_injections(root.root())
229    }
230
231    #[test]
232    fn test_html_extraction() {
233        let map = extract("<script>a</script><style>.a{}</style>");
234        assert!(map.contains_key("css"));
235        assert!(map.contains_key("js"));
236        assert_eq!(map["css"].len(), 1);
237        assert_eq!(map["js"].len(), 1);
238    }
239
240    #[test]
241    fn test_explicit_lang() {
242        let map = extract(
243            "<script lang='ts'>a</script><script lang=ts>.a{}</script><style lang=scss></style><style lang=\"scss\"></style>",
244        );
245        assert!(map.contains_key("ts"));
246        assert_eq!(map["ts"].len(), 2);
247        assert_eq!(map["scss"].len(), 2);
248    }
249}