mod classify;
mod scanner;
#[cfg(feature = "cli")]
mod cli;
pub use classify::{Classification, classify_source};
#[cfg(feature = "cli")]
pub use cli::{
FileResult, OutputFormat, Report, apply_path, check_path, render_apply, render_check,
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum EjectError {
#[error("no inline #[cfg(test)] mod tests block found")]
NoTestModule,
#[error("tests already extracted to external file")]
AlreadyExternal,
#[error("could not locate test module boundaries in source")]
RegionNotFound,
#[error("generated output failed to parse: {reason}")]
ValidationFailed {
reason: String,
},
}
pub struct EjectResult {
pub modified_source: String,
pub test_content: String,
pub test_file_name: String,
}
pub fn eject_tests(source: &str, file_stem: &str) -> Result<EjectResult, EjectError> {
log::debug!("scanning {file_stem} ({} bytes)", source.len());
let region = scanner::find_test_module_region(source)?;
log::debug!(
"found test module at bytes {}..{} (inner {}..{})",
region.outer_start,
region.outer_end,
region.inner_start,
region.inner_end,
);
let inner = source
.get(region.inner_start..region.inner_end)
.ok_or(EjectError::RegionNotFound)?;
let attrs_region = source
.get(region.attrs_start..region.attrs_end)
.ok_or(EjectError::RegionNotFound)?;
let inner_attrs = collect_inner_attrs(attrs_region);
let body = dedent(inner);
let test_content = if inner_attrs.is_empty() {
body
} else {
format!("{inner_attrs}{}", body.trim_start_matches('\n'))
};
let test_file_name = format!("{file_stem}_tests.rs");
let replacement = format!("#[cfg(test)]\n#[path = \"{test_file_name}\"]\nmod tests;\n");
let prefix = source
.get(..region.outer_start)
.ok_or(EjectError::RegionNotFound)?;
let suffix = source
.get(region.outer_end..)
.ok_or(EjectError::RegionNotFound)?;
let modified_source = normalize_trailing_newlines(&format!("{prefix}{replacement}{suffix}"));
#[cfg(feature = "validate")]
syn::parse_file(&modified_source).map_err(|err| EjectError::ValidationFailed {
reason: err.to_string(),
})?;
Ok(EjectResult {
modified_source,
test_content,
test_file_name,
})
}
fn collect_inner_attrs(text: &str) -> String {
let mut out = String::new();
let mut rest = text.trim_start();
while let Some(after_open) = rest.strip_prefix("#[") {
let Some(close) = scanner::find_attr_close(after_open) else {
break;
};
let attr_body = after_open.get(..close).unwrap_or("");
out.push_str("#![");
out.push_str(attr_body);
out.push_str("]\n");
rest = after_open.get(close + 1..).unwrap_or("").trim_start();
}
out
}
fn normalize_trailing_newlines(source: &str) -> String {
let trimmed = source.trim_end();
let mut result = trimmed.to_owned();
result.push('\n');
result
}
fn dedent(text: &str) -> String {
let lines: Vec<&str> = text.lines().collect();
let min_indent = lines
.iter()
.filter(|line| !line.trim().is_empty())
.map(|line| line.len() - line.trim_start().len())
.min()
.unwrap_or(0);
if min_indent == 0 {
return text.to_owned();
}
let mut result: String = lines
.iter()
.map(|line| {
if line.trim().is_empty() {
""
} else {
line.get(min_indent..).unwrap_or("")
}
})
.collect::<Vec<_>>()
.join("\n");
if text.ends_with('\n') {
result.push('\n');
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_extraction() {
let source = concat!(
"use std::collections::HashMap;\n",
"\n",
"pub fn add(aa: i32, bb: i32) -> i32 {\n",
" aa + bb\n",
"}\n",
"\n",
"#[cfg(test)]\n",
"mod tests {\n",
" use super::*;\n",
"\n",
" #[test]\n",
" fn test_add() {\n",
" assert_eq!(add(1, 2), 3);\n",
" }\n",
"}\n",
);
let result = eject_tests(source, "math").expect("should succeed");
assert!(
result
.modified_source
.contains("#[path = \"math_tests.rs\"]")
);
assert!(result.modified_source.contains("mod tests;"));
assert!(!result.modified_source.contains("fn test_add"));
assert!(result.test_content.contains("fn test_add"));
assert!(result.test_content.contains("use super::*;"));
assert_eq!(result.test_file_name, "math_tests.rs");
}
#[test]
fn no_test_module() {
let source = "pub fn add(aa: i32, bb: i32) -> i32 { aa + bb }\n";
let result = eject_tests(source, "math");
assert!(matches!(result, Err(EjectError::NoTestModule)));
}
#[test]
fn already_external() {
let source = "#[cfg(test)]\n#[path = \"math_tests.rs\"]\nmod tests;\n";
let result = eject_tests(source, "math");
assert!(matches!(result, Err(EjectError::AlreadyExternal)));
}
#[test]
fn dedent_basic() {
let input = " use super::*;\n\n #[test]\n fn test_foo() {}\n";
let result = dedent(input);
assert!(result.starts_with("use super::*;"));
assert!(result.contains("#[test]\nfn test_foo()"));
}
#[test]
fn dedent_no_indent() {
let input = "use super::*;\nfn test_foo() {}\n";
let result = dedent(input);
assert_eq!(result, input);
}
#[test]
fn preserves_code_before_tests() {
let source = concat!(
"pub struct Foo;\n",
"\n",
"impl Foo {\n",
" pub fn bar(&self) -> i32 { 42 }\n",
"}\n",
"\n",
"#[cfg(test)]\n",
"mod tests {\n",
" use super::*;\n",
" #[test]\n",
" fn test_bar() {\n",
" assert_eq!(Foo.bar(), 42);\n",
" }\n",
"}\n",
);
let result = eject_tests(source, "foo").expect("should succeed");
assert!(result.modified_source.contains("pub struct Foo;"));
assert!(result.modified_source.contains("impl Foo"));
assert!(result.modified_source.contains("fn bar"));
}
#[test]
fn preserves_allow_attrs_as_inner() {
let source = concat!(
"pub fn first(arr: &[i32]) -> i32 {\n",
" arr[0]\n",
"}\n",
"\n",
"#[cfg(test)]\n",
"#[allow(clippy::unwrap_used, clippy::indexing_slicing)]\n",
"mod tests {\n",
" use super::*;\n",
" #[test]\n",
" fn test_first() {\n",
" assert_eq!(first(&[1, 2, 3]), 1);\n",
" }\n",
"}\n",
);
let result = eject_tests(source, "lift").expect("should succeed");
assert!(
result
.test_content
.starts_with("#![allow(clippy::unwrap_used, clippy::indexing_slicing)]\n")
);
assert!(
result.test_content.contains(
"#![allow(clippy::unwrap_used, clippy::indexing_slicing)]\nuse super::*;"
)
);
assert!(result.modified_source.contains("#[cfg(test)]"));
assert!(!result.modified_source.contains("#[allow"));
assert!(!result.test_content.contains("#[cfg(test)]"));
}
#[test]
fn preserves_multiple_outer_attrs() {
let source = concat!(
"#[cfg(test)]\n",
"#[allow(clippy::unwrap_used)]\n",
"#[allow(clippy::indexing_slicing)]\n",
"mod tests {\n",
" use super::*;\n",
" #[test]\n",
" fn test_foo() {}\n",
"}\n",
);
let result = eject_tests(source, "foo").expect("should succeed");
assert!(
result.test_content.starts_with(
"#![allow(clippy::unwrap_used)]\n#![allow(clippy::indexing_slicing)]\n"
)
);
}
#[test]
fn no_extra_attrs_keeps_plain_body() {
let source = concat!(
"#[cfg(test)]\n",
"mod tests {\n",
" use super::*;\n",
" #[test]\n",
" fn test_foo() {}\n",
"}\n",
);
let result = eject_tests(source, "foo").expect("should succeed");
assert!(!result.test_content.contains("#!["));
}
#[test]
fn collect_inner_attrs_translates_outer_to_inner() {
let text = "\n#[allow(clippy::unwrap_used)]\n";
assert_eq!(
collect_inner_attrs(text),
"#![allow(clippy::unwrap_used)]\n"
);
}
#[test]
fn collect_inner_attrs_empty_when_blank() {
assert_eq!(collect_inner_attrs("\n \n"), "");
}
#[test]
fn no_trailing_blank_lines() {
let source = concat!(
"pub fn add(aa: i32, bb: i32) -> i32 {\n",
" aa + bb\n",
"}\n",
"\n",
"\n",
"#[cfg(test)]\n",
"mod tests {\n",
" use super::*;\n",
" #[test]\n",
" fn test_add() {\n",
" assert_eq!(add(1, 2), 3);\n",
" }\n",
"}\n",
);
let result = eject_tests(source, "math").expect("should succeed");
assert!(result.modified_source.ends_with("mod tests;\n"));
assert!(!result.modified_source.ends_with("\n\n"));
}
}