Skip to main content

meta_language/
mixed_regions.rs

1use crate::configuration::RegionDetectionPolicy;
2use crate::source::{ByteRange, Point, SourceSpan};
3
4/// Embedded region discovered inside a mixed-language document.
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct EmbeddedRegion {
7    language: String,
8    span: SourceSpan,
9}
10
11impl EmbeddedRegion {
12    pub(crate) const fn new(language: String, span: SourceSpan) -> Self {
13        Self { language, span }
14    }
15
16    /// Language detected for the embedded region.
17    #[must_use]
18    pub fn language(&self) -> &str {
19        &self.language
20    }
21
22    /// Source span covered by the embedded region.
23    #[must_use]
24    pub const fn span(&self) -> SourceSpan {
25        self.span
26    }
27}
28
29pub(crate) fn detect_embedded_regions(
30    text: &str,
31    language: &str,
32    policy: RegionDetectionPolicy,
33) -> Vec<EmbeddedRegion> {
34    let mut regions = Vec::new();
35    match language.to_ascii_lowercase().as_str() {
36        "markdown" => {
37            regions.extend(detect_markdown_fenced_regions(text, policy));
38            regions.extend(detect_markdown_html_regions(text));
39        }
40        "html" => {
41            regions.extend(detect_html_element_regions(text, "script", "JavaScript"));
42            regions.extend(detect_html_element_regions(text, "style", "CSS"));
43            regions.extend(detect_html_style_attributes(text));
44        }
45        _ => {}
46    }
47    regions
48}
49
50fn detect_markdown_fenced_regions(
51    text: &str,
52    policy: RegionDetectionPolicy,
53) -> Vec<EmbeddedRegion> {
54    let mut regions = Vec::new();
55    let mut offset = 0;
56    let mut open_fence: Option<(String, usize)> = None;
57
58    for line in text.split_inclusive('\n') {
59        let trimmed = line.trim_end_matches(['\r', '\n']).trim_start();
60        if let Some((language_tag, content_start)) = open_fence.take() {
61            if trimmed.starts_with("```") {
62                if let Some(language) = region_language_from_tag_or_content(
63                    &language_tag,
64                    &text[content_start..offset],
65                    policy,
66                ) {
67                    regions.push(region_for(text, language, content_start, offset));
68                }
69            } else {
70                open_fence = Some((language_tag, content_start));
71            }
72        } else if let Some(rest) = trimmed.strip_prefix("```") {
73            let language_tag = rest
74                .split_whitespace()
75                .next()
76                .unwrap_or_default()
77                .to_string();
78            open_fence = Some((language_tag, offset + line.len()));
79        }
80        offset += line.len();
81    }
82
83    if let Some((language_tag, content_start)) = open_fence {
84        if let Some(language) =
85            region_language_from_tag_or_content(&language_tag, &text[content_start..], policy)
86        {
87            regions.push(region_for(text, language, content_start, text.len()));
88        }
89    }
90
91    regions
92}
93
94fn region_language_from_tag_or_content(
95    language_tag: &str,
96    content: &str,
97    policy: RegionDetectionPolicy,
98) -> Option<String> {
99    match policy {
100        RegionDetectionPolicy::NameDriven => {
101            (!language_tag.is_empty()).then(|| language_tag.to_string())
102        }
103        RegionDetectionPolicy::ContentDriven => sniff_language(content).map(str::to_string),
104        RegionDetectionPolicy::Both => {
105            if language_tag.is_empty() {
106                sniff_language(content).map(str::to_string)
107            } else {
108                Some(language_tag.to_string())
109            }
110        }
111    }
112}
113
114fn detect_markdown_html_regions(text: &str) -> Vec<EmbeddedRegion> {
115    let mut regions = Vec::new();
116    let mut search_start = 0;
117
118    while let Some(relative_start) = text[search_start..].find('<') {
119        let start = search_start + relative_start;
120        let Some(next) = text[start + 1..].chars().next() else {
121            break;
122        };
123        if !next.is_ascii_alphabetic() {
124            search_start = start + 1;
125            continue;
126        }
127
128        let Some(close) = text[start..].find('>') else {
129            break;
130        };
131        let first_tag_end = start + close + 1;
132        let tag_name = text[start + 1..first_tag_end - 1]
133            .split_whitespace()
134            .next()
135            .unwrap_or_default()
136            .trim_matches('/')
137            .to_ascii_lowercase();
138        if tag_name.is_empty() {
139            search_start = first_tag_end;
140            continue;
141        }
142
143        let closing_tag = format!("</{tag_name}>");
144        let region_end = text[first_tag_end..]
145            .to_ascii_lowercase()
146            .find(&closing_tag)
147            .map_or(first_tag_end, |relative_end| {
148                first_tag_end + relative_end + closing_tag.len()
149            });
150        regions.push(region_for(text, "HTML".to_string(), start, region_end));
151        search_start = region_end;
152    }
153
154    regions
155}
156
157fn detect_html_element_regions(text: &str, element: &str, language: &str) -> Vec<EmbeddedRegion> {
158    let mut regions = Vec::new();
159    let lower = text.to_ascii_lowercase();
160    let open = format!("<{element}");
161    let close = format!("</{element}>");
162    let mut search_start = 0;
163
164    while let Some(relative_start) = lower[search_start..].find(&open) {
165        let start = search_start + relative_start;
166        let Some(open_end_relative) = lower[start..].find('>') else {
167            break;
168        };
169        let content_start = start + open_end_relative + 1;
170        let Some(close_relative) = lower[content_start..].find(&close) else {
171            break;
172        };
173        let content_end = content_start + close_relative;
174        regions.push(region_for(
175            text,
176            language.to_string(),
177            content_start,
178            content_end,
179        ));
180        search_start = content_end + close.len();
181    }
182
183    regions
184}
185
186fn detect_html_style_attributes(text: &str) -> Vec<EmbeddedRegion> {
187    let mut regions = Vec::new();
188    let lower = text.to_ascii_lowercase();
189    let mut search_start = 0;
190
191    while let Some(relative_start) = lower[search_start..].find("style=\"") {
192        let value_start = search_start + relative_start + "style=\"".len();
193        let Some(value_end_relative) = text[value_start..].find('"') else {
194            break;
195        };
196        let value_end = value_start + value_end_relative;
197        regions.push(region_for(text, "CSS".to_string(), value_start, value_end));
198        search_start = value_end + 1;
199    }
200
201    regions
202}
203
204fn sniff_language(content: &str) -> Option<&'static str> {
205    let trimmed = content.trim_start();
206    let upper = trimmed.to_ascii_uppercase();
207
208    if trimmed.contains("fn main") {
209        Some("rust")
210    } else if trimmed.starts_with("def ") {
211        Some("Python")
212    } else if trimmed.starts_with('<') {
213        Some("HTML")
214    } else if trimmed.contains("function ")
215        || trimmed.contains("const ")
216        || trimmed.contains("let ")
217    {
218        Some("JavaScript")
219    } else if upper.starts_with("SELECT ") {
220        Some("SQL")
221    } else {
222        None
223    }
224}
225
226fn region_for(text: &str, language: String, start: usize, end: usize) -> EmbeddedRegion {
227    EmbeddedRegion::new(
228        language,
229        SourceSpan::new(
230            ByteRange::new(start, end),
231            point_at_byte(text, start),
232            point_at_byte(text, end),
233        ),
234    )
235}
236
237fn point_at_byte(text: &str, byte: usize) -> Point {
238    let mut row = 0;
239    let mut column = 0;
240
241    for (index, character) in text.char_indices() {
242        if index >= byte {
243            break;
244        }
245        if character == '\n' {
246            row += 1;
247            column = 0;
248        } else {
249            column += 1;
250        }
251    }
252
253    Point::new(row, column)
254}