mod scan_state;
use scan_state::{BraceAction, ScanState};
use crate::EjectError;
#[derive(Debug)]
pub(crate) struct TestModuleRegion {
pub(crate) outer_start: usize,
pub(crate) outer_end: usize,
pub(crate) inner_start: usize,
pub(crate) inner_end: usize,
pub(crate) attrs_start: usize,
pub(crate) attrs_end: usize,
}
pub(crate) fn find_test_module_region(source: &str) -> Result<TestModuleRegion, EjectError> {
let cfg_test = "#[cfg(test)]";
let code_positions = find_cfg_test_in_code(source, cfg_test);
log::debug!(
"found {} #[cfg(test)] candidate(s) in code",
code_positions.len()
);
for cfg_pos in code_positions {
let after_cfg = cfg_pos + cfg_test.len();
let rest = source.get(after_cfg..).ok_or(EjectError::RegionNotFound)?;
if let Some(mod_offset) = find_mod_tests_after_attrs(rest) {
let mod_pos = after_cfg + mod_offset;
let after_kw = mod_pos + "mod tests".len();
let after_mod = source.get(after_kw..).ok_or(EjectError::RegionNotFound)?;
let trimmed = after_mod.trim_start();
if trimmed.starts_with('{') {
let ws_len = after_mod.len() - trimmed.len();
let open_brace = after_kw + ws_len;
let close_brace = find_matching_close_brace(source, open_brace)?;
let mut outer_end = close_brace + 1;
if source.get(outer_end..outer_end + 1) == Some("\n") {
outer_end += 1;
}
return Ok(TestModuleRegion {
outer_start: cfg_pos,
outer_end,
inner_start: open_brace + 1,
inner_end: close_brace,
attrs_start: after_cfg,
attrs_end: mod_pos,
});
} else if trimmed.starts_with(';') {
return Err(EjectError::AlreadyExternal);
}
}
}
Err(EjectError::NoTestModule)
}
fn find_cfg_test_in_code(source: &str, needle: &str) -> Vec<usize> {
let mut results = Vec::new();
let mut state = ScanState::Normal;
let bytes = source.as_bytes();
for (idx, ch) in source.char_indices() {
let is_normal = matches!(state, ScanState::Normal);
if is_normal && starts_with_at(bytes, needle.as_bytes(), idx) {
results.push(idx);
}
let action = state.advance(ch);
state = action.next;
}
results
}
fn starts_with_at(haystack: &[u8], needle: &[u8], offset: usize) -> bool {
let Some(slice) = haystack.get(offset..offset + needle.len()) else {
return false;
};
slice == needle
}
fn find_mod_tests_after_attrs(source: &str) -> Option<usize> {
let mut pos: usize = 0;
loop {
let rest = source.get(pos..)?;
let trimmed = rest.trim_start();
let ws_skipped = rest.len() - trimmed.len();
pos += ws_skipped;
if trimmed.starts_with("mod tests") {
let after = trimmed.get("mod tests".len()..)?;
let next_ch = after.chars().next();
match next_ch {
Some('{' | ';' | ' ' | '\t' | '\n' | '\r') | None => return Some(pos),
_ => return None,
}
} else if trimmed.starts_with("#[") {
let attr_body = trimmed.get(2..)?;
let close = find_attr_close(attr_body)?;
pos += 2 + close + 1;
} else {
return None;
}
}
}
pub(crate) fn find_attr_close(source: &str) -> Option<usize> {
let mut in_string = false;
let mut escaped = false;
for (idx, ch) in source.char_indices() {
if escaped {
escaped = false;
continue;
}
if in_string {
match ch {
'\\' => escaped = true,
'"' => in_string = false,
_ => {}
}
} else {
match ch {
'"' => in_string = true,
']' => return Some(idx),
_ => {}
}
}
}
None
}
fn find_matching_close_brace(source: &str, open_pos: usize) -> Result<usize, EjectError> {
let rest = source
.get(open_pos + 1..)
.ok_or(EjectError::RegionNotFound)?;
let base = open_pos + 1;
let mut depth: u32 = 1;
let mut state = ScanState::Normal;
for (offset, ch) in rest.char_indices() {
let action = state.advance(ch);
state = action.next;
match action.brace {
BraceAction::Open => depth += 1,
BraceAction::Close => {
depth -= 1;
if depth == 0 {
return Ok(base + offset);
}
}
BraceAction::None => {}
}
}
Err(EjectError::RegionNotFound)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_region() {
let src = "fn main() {}\n\n#[cfg(test)]\nmod tests {\n use super::*;\n}\n";
let rg = find_test_module_region(src).expect("should find region");
let outer = src.get(rg.outer_start..rg.outer_end).expect("valid range");
assert!(outer.starts_with("#[cfg(test)]"));
assert!(outer.contains("mod tests"));
}
#[test]
fn already_external() {
let src = "#[cfg(test)]\n#[path = \"t.rs\"]\nmod tests;\n";
let err = find_test_module_region(src).expect_err("should fail");
assert!(matches!(err, EjectError::AlreadyExternal));
}
#[test]
fn no_test_module() {
let src = "fn main() {}\n";
let err = find_test_module_region(src).expect_err("should fail");
assert!(matches!(err, EjectError::NoTestModule));
}
#[test]
fn braces_in_string() {
let src = concat!(
"#[cfg(test)]\nmod tests {\n",
" fn t() { let ss = \"}\"; }\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should handle string braces");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("let ss"));
}
#[test]
fn braces_in_comments() {
let src = concat!(
"#[cfg(test)]\nmod tests {\n",
" // }\n",
" /* } */\n",
" fn t() {}\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should handle comment braces");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("fn t()"));
}
#[test]
fn same_line_cfg() {
let src = "fn main() {}\n#[cfg(test)] mod tests {\n fn t() {}\n}\n";
let rg = find_test_module_region(src).expect("should find same-line cfg");
assert!(
src.get(rg.outer_start..rg.outer_end)
.expect("valid")
.starts_with("#[cfg(test)]")
);
}
#[test]
fn cfg_test_in_doc_comment_skipped() {
let src = concat!(
"/// No inline `#[cfg(test)] mod tests` here.\n",
"pub fn foo() {}\n",
"\n",
"#[cfg(test)]\n",
"mod tests {\n",
" fn real_test() {}\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should skip doc comment");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("real_test"));
}
#[test]
fn char_literal_with_quote() {
let src = concat!(
"fn foo() { let _c = '\"'; }\n",
"\n",
"#[cfg(test)]\n",
"mod tests {\n",
" fn real_test() {}\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should handle char literal with quote");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("real_test"));
}
#[test]
fn cfg_test_in_string_literal_skipped() {
let src = concat!(
"fn foo() { let _s = \"#[cfg(test)] mod tests { }\"; }\n",
"\n",
"#[cfg(test)]\n",
"mod tests {\n",
" fn real_test() {}\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should skip string literal");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("real_test"));
}
#[test]
fn braces_in_raw_string() {
let src = concat!(
"#[cfg(test)]\nmod tests {\n",
" fn t() { let ss = r#\"}\"#; }\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should handle raw string braces");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("let ss"));
}
#[test]
fn nested_block_comment_braces() {
let src = concat!(
"#[cfg(test)]\nmod tests {\n",
" /* /* } */ } */\n",
" fn t() {}\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should handle nested block comment");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("fn t()"));
}
#[test]
fn attr_with_bracket_in_string() {
let src = concat!(
"#[cfg(test)]\n",
"#[doc = \"contains ] bracket\"]\n",
"mod tests {\n",
" fn t() {}\n",
"}\n"
);
let rg = find_test_module_region(src).expect("should handle ] in attr string");
let inner = src.get(rg.inner_start..rg.inner_end).expect("valid range");
assert!(inner.contains("fn t()"));
}
}