use crate::error::SurgeonError;
use std::sync::Arc;
use tree_sitter::{Parser, Point, Range, Tree};
#[derive(Debug, Clone)]
pub struct VueZoneRange {
pub start_byte: usize,
pub end_byte: usize,
pub start_point: Point,
pub end_point: Point,
}
#[derive(Debug, Clone, Default)]
pub struct VueZones {
pub script: Option<VueZoneRange>,
pub template: Option<VueZoneRange>,
pub style: Option<VueZoneRange>,
}
#[derive(Debug, Clone)]
pub struct MultiZoneTree {
pub script_tree: Option<Tree>,
pub template_tree: Option<Tree>,
pub style_tree: Option<Tree>,
pub zones: VueZones,
pub source: Arc<[u8]>,
pub degraded: bool,
}
fn byte_to_point(source: &[u8], byte: usize) -> Point {
let safe = byte.min(source.len());
let prefix = &source[..safe];
#[allow(clippy::naive_bytecount)]
let row = prefix.iter().filter(|&&b| b == b'\n').count();
let col = prefix
.iter()
.rposition(|&b| b == b'\n')
.map_or(safe, |nl| safe - nl - 1);
Point { row, column: col }
}
fn find_zone(source: &[u8], tag: &str) -> Option<VueZoneRange> {
let text = std::str::from_utf8(source).ok()?;
let bytes = source;
let open_prefix = format!("<{tag}");
let close_tag = format!("</{tag}>");
let mut search_from = 0usize;
loop {
let rel = text[search_from..].find(open_prefix.as_str())?;
let open_pos = search_from + rel;
let after_name = open_pos + open_prefix.len();
let next_byte = bytes.get(after_name).copied().unwrap_or(b'>');
if !matches!(next_byte, b'>' | b' ' | b'\t' | b'\n' | b'\r') {
search_from = open_pos + 1;
continue;
}
if bytes.get(open_pos + 1).copied() == Some(b'/') {
search_from = open_pos + 1;
continue;
}
let gt_rel = text[open_pos..].find('>')?;
let content_start = open_pos + gt_rel + 1;
let close_rel = text[content_start..].find(close_tag.as_str())?;
let content_end = content_start + close_rel;
return Some(VueZoneRange {
start_byte: content_start,
end_byte: content_end,
start_point: byte_to_point(source, content_start),
end_point: byte_to_point(source, content_end),
});
}
}
#[must_use]
pub fn scan_vue_zones(source: &[u8]) -> VueZones {
VueZones {
script: find_zone(source, "script"),
template: find_zone(source, "template"),
style: find_zone(source, "style"),
}
}
pub fn parse_vue_multizone(source: &[u8]) -> Result<MultiZoneTree, SurgeonError> {
let zones = scan_vue_zones(source);
let mut degraded = false;
let script_tree = parse_zone_with_grammar(
source,
zones.script.as_ref(),
&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
"<vue-script>",
true, &mut degraded,
)?;
let template_tree = parse_zone_with_grammar(
source,
zones.template.as_ref(),
&tree_sitter_html::LANGUAGE.into(),
"<vue-template>",
false, &mut degraded,
)?;
let style_tree = parse_zone_with_grammar(
source,
zones.style.as_ref(),
&tree_sitter_css::LANGUAGE.into(),
"<vue-style>",
false, &mut degraded,
)?;
Ok(MultiZoneTree {
script_tree,
template_tree,
style_tree,
zones,
source: Arc::from(source),
degraded,
})
}
fn parse_zone_with_grammar(
source: &[u8],
zone: Option<&VueZoneRange>,
grammar: &tree_sitter::Language,
zone_label: &str,
fatal: bool,
degraded: &mut bool,
) -> Result<Option<Tree>, SurgeonError> {
let Some(z) = zone else {
return Ok(None);
};
let ts_range = Range {
start_byte: z.start_byte,
end_byte: z.end_byte,
start_point: z.start_point,
end_point: z.end_point,
};
let mut parser = Parser::new();
let set_lang_result = parser.set_language(grammar);
let set_ranges_result = set_lang_result
.ok()
.and_then(|()| parser.set_included_ranges(&[ts_range]).ok());
if set_ranges_result.is_none() {
if fatal {
return Err(SurgeonError::ParseError {
path: std::path::PathBuf::from(zone_label),
reason: "failed to configure tree-sitter grammar or ranges".into(),
});
}
*degraded = true;
return Ok(None);
}
let tree = parser.parse(source, None);
if tree.is_none() && fatal {
return Err(SurgeonError::ParseError {
path: std::path::PathBuf::from(zone_label),
reason: "tree-sitter parse returned None".into(),
});
}
if tree.is_none() {
*degraded = true;
}
Ok(tree)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
const BASIC_SFC: &[u8] = br#"<template>
<div class="app">
<MyButton @click="doThing">Click me</MyButton>
<router-view />
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
const count = ref(0)
function doThing() { count.value++ }
</script>
<style scoped>
.app { color: red; }
#main { font-size: 16px; }
@media (max-width: 768px) { .app { display: none; } }
</style>"#;
#[test]
fn test_scan_vue_zones_all_three_present() {
let zones = scan_vue_zones(BASIC_SFC);
assert!(zones.template.is_some(), "template zone should be found");
assert!(zones.script.is_some(), "script zone should be found");
assert!(zones.style.is_some(), "style zone should be found");
}
#[test]
fn test_scan_vue_zones_content_bytes_correct() {
let zones = scan_vue_zones(BASIC_SFC);
let sfc_str = std::str::from_utf8(BASIC_SFC).unwrap();
let script = zones.script.unwrap();
let script_content = &sfc_str[script.start_byte..script.end_byte];
assert!(
script_content.contains("const count = ref(0)"),
"script content must include TS code"
);
let template = zones.template.unwrap();
let tmpl_content = &sfc_str[template.start_byte..template.end_byte];
assert!(
tmpl_content.contains("MyButton"),
"template content must include component tag"
);
let style = zones.style.unwrap();
let style_content = &sfc_str[style.start_byte..style.end_byte];
assert!(
style_content.contains(".app"),
"style content must include CSS class"
);
}
#[test]
fn test_scan_vue_zones_template_only() {
let sfc = b"<template><div>Hello</div></template>\n";
let zones = scan_vue_zones(sfc);
assert!(zones.template.is_some());
assert!(zones.script.is_none());
assert!(zones.style.is_none());
}
#[test]
fn test_scan_vue_zones_does_not_match_partial_tag() {
let sfc = b"<template><script-runner /></template>\n";
let zones = scan_vue_zones(sfc);
assert!(
zones.script.is_none(),
"script-runner must not match <script>"
);
assert!(zones.template.is_some());
}
#[test]
fn test_scan_vue_zones_byte_to_point_newline_accuracy() {
let sfc = b"<template>\n<div/>\n</template>\n";
let zones = scan_vue_zones(sfc);
let tmpl = zones.template.unwrap();
assert_eq!(tmpl.start_byte, 10, "content starts after '<template>'");
assert_eq!(tmpl.start_point.row, 0, "should be on row 0");
assert_eq!(tmpl.start_point.column, 10, "should be at column 10");
}
#[test]
fn test_scan_vue_zones_empty_source() {
let zones = scan_vue_zones(b"");
assert!(zones.script.is_none());
assert!(zones.template.is_none());
assert!(zones.style.is_none());
}
#[test]
fn test_parse_vue_multizone_produces_all_trees() {
let result = parse_vue_multizone(BASIC_SFC).unwrap();
assert!(result.script_tree.is_some(), "script tree should parse");
assert!(result.template_tree.is_some(), "template tree should parse");
assert!(result.style_tree.is_some(), "style tree should parse");
assert!(!result.degraded, "should not be degraded");
}
#[test]
fn test_parse_vue_multizone_script_root_is_program() {
let result = parse_vue_multizone(BASIC_SFC).unwrap();
let tree = result.script_tree.unwrap();
assert_eq!(
tree.root_node().kind(),
"program",
"TypeScript root node should be 'program'"
);
}
#[test]
fn test_parse_vue_multizone_no_script_block() {
let sfc = b"<template><div>Hello</div></template>\n";
let result = parse_vue_multizone(sfc).unwrap();
assert!(result.script_tree.is_none());
assert!(result.template_tree.is_some());
assert!(result.style_tree.is_none());
}
#[test]
fn test_parse_vue_multizone_source_preserved() {
let result = parse_vue_multizone(BASIC_SFC).unwrap();
assert_eq!(
result.source,
Arc::from(BASIC_SFC),
"source bytes should be preserved unchanged"
);
}
#[test]
fn test_parse_vue_multizone_script_node_has_correct_global_offset() {
let result = parse_vue_multizone(BASIC_SFC).unwrap();
let tree = result.script_tree.unwrap();
let zones = result.zones;
let script_start = zones.script.unwrap().start_byte;
let root_start = tree.root_node().start_byte();
assert!(
root_start >= script_start,
"script root start_byte ({root_start}) should be >= zone start ({script_start})"
);
}
#[test]
fn test_byte_to_point_start_of_file() {
let p = byte_to_point(b"hello", 0);
assert_eq!(p.row, 0);
assert_eq!(p.column, 0);
}
#[test]
fn test_byte_to_point_second_line() {
let p = byte_to_point(b"hello\nworld", 6);
assert_eq!(p.row, 1);
assert_eq!(p.column, 0);
}
#[test]
fn test_byte_to_point_mid_line() {
let p = byte_to_point(b"abc\nde", 5);
assert_eq!(p.row, 1);
assert_eq!(p.column, 1);
}
}