use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
use anyhow::{anyhow, bail, Context, Result};
use oxc::allocator::Allocator;
use oxc::ast::ast::{Argument, CallExpression, Expression, ImportDeclaration, ImportOrExportKind};
use oxc::ast_visit::{walk, Visit};
use oxc::parser::Parser;
use oxc::span::{SourceType, Span};
use crate::lint::Violation;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Origin {
FirstParty,
Builtin,
ThirdParty,
}
pub fn classify(specifier: &str) -> Origin {
if specifier.starts_with('.') || specifier.starts_with('/') {
return Origin::FirstParty;
}
if specifier.starts_with("node:") || is_node_builtin(specifier) {
return Origin::Builtin;
}
Origin::ThirdParty
}
fn is_node_builtin(specifier: &str) -> bool {
let head = specifier.split('/').next().unwrap_or(specifier);
NODE_BUILTINS.contains(&head)
}
const NODE_BUILTINS: &[&str] = &[
"assert",
"async_hooks",
"buffer",
"child_process",
"cluster",
"console",
"constants",
"crypto",
"dgram",
"diagnostics_channel",
"dns",
"domain",
"events",
"fs",
"http",
"http2",
"https",
"inspector",
"module",
"net",
"os",
"path",
"perf_hooks",
"process",
"punycode",
"querystring",
"readline",
"repl",
"stream",
"string_decoder",
"sys",
"timers",
"tls",
"trace_events",
"tty",
"url",
"util",
"v8",
"vm",
"wasi",
"worker_threads",
"zlib",
];
pub fn find_integration_violations(root: impl AsRef<Path>) -> Result<Vec<Violation>> {
let root = root.as_ref();
let mut files = Vec::new();
collect_ts_test_files(root, &mut files)?;
files.sort();
let mut violations = Vec::new();
for file in &files {
let source = std::fs::read_to_string(file)
.with_context(|| format!("reading test file `{}`", file.display()))?;
violations.extend(integration_violations_in(file, &source)?);
}
violations.sort_by(|a, b| a.file.cmp(&b.file).then(a.line.cmp(&b.line)));
Ok(violations)
}
pub fn find_unit_violations(root: impl AsRef<Path>) -> Result<Vec<Violation>> {
let root = root.as_ref();
let mut files = Vec::new();
collect_ts_test_files(root, &mut files)?;
files.sort();
let mut violations = Vec::new();
for file in &files {
let source = std::fs::read_to_string(file)
.with_context(|| format!("reading test file `{}`", file.display()))?;
violations.extend(unit_violations_in(file, &source)?);
}
violations.sort_by(|a, b| a.file.cmp(&b.file).then(a.line.cmp(&b.line)));
Ok(violations)
}
fn unit_violations_in(file: &Path, source: &str) -> Result<Vec<Violation>> {
let allocator = Allocator::default();
let source_type = SourceType::from_path(file).map_err(|err| {
anyhow!(
"unsupported TypeScript extension `{}`: {err}",
file.display()
)
})?;
let ret = Parser::new(&allocator, source, source_type).parse();
if ret.panicked || !ret.diagnostics.is_empty() {
let detail = ret
.diagnostics
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join("; ");
bail!("parsing `{}` failed: {detail}", file.display());
}
let mut collector = UnitCollector {
source,
imports: Vec::new(),
mocked: BTreeSet::new(),
untyped: Vec::new(),
};
collector.visit_program(&ret.program);
let unit = unit_under_test_specifier(file);
let mut violations = Vec::new();
for (spec, line) in &collector.imports {
if is_unit_under_test(spec, &unit)
|| is_test_runner(spec)
|| collector.mocked.contains(spec)
{
continue;
}
violations.push(Violation {
file: file.to_path_buf(),
line: *line,
rule: "unmocked-collaborator",
message: format!(
"unit test imports `{spec}` without mocking it — a unit test isolates the \
unit under test, so every collaborator must be `vi.mock()`-ed"
),
});
}
for (spec, line) in &collector.untyped {
violations.push(Violation {
file: file.to_path_buf(),
line: *line,
rule: "untyped-mock",
message: format!(
"`vi.mock('{spec}', …)` has an untyped factory — anchor it to the real module \
with `vi.importActual<typeof import('{spec}')>()` so the double can't drift \
from the source"
),
});
}
violations.sort_by_key(|v| v.line);
Ok(violations)
}
struct UnitCollector<'s> {
source: &'s str,
imports: Vec<(String, usize)>,
mocked: BTreeSet<String>,
untyped: Vec<(String, usize)>,
}
impl<'a> Visit<'a> for UnitCollector<'_> {
fn visit_import_declaration(&mut self, decl: &ImportDeclaration<'a>) {
if matches!(decl.import_kind, ImportOrExportKind::Type) {
return;
}
self.imports.push((
decl.source.value.to_string(),
line_of(self.source, decl.span.start),
));
}
fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
if let Some(spec) = vi_mock_target(call) {
if let Some(factory) = call.arguments.get(1) {
if !factory_is_typed(factory) {
self.untyped
.push((spec.clone(), line_of(self.source, call.span.start)));
}
}
self.mocked.insert(spec);
}
walk::walk_call_expression(self, call);
}
}
fn unit_under_test_specifier(file: &Path) -> String {
let name = file
.file_name()
.and_then(|n| n.to_str())
.unwrap_or_default();
let stem = name.split(".test.").next().unwrap_or(name);
format!("./{stem}")
}
fn is_unit_under_test(spec: &str, unit: &str) -> bool {
strip_module_ext(spec) == unit
}
fn strip_module_ext(spec: &str) -> &str {
for ext in [".js", ".mjs", ".cjs", ".jsx", ".ts", ".mts", ".cts", ".tsx"] {
if let Some(base) = spec.strip_suffix(ext) {
return base;
}
}
spec
}
fn is_test_runner(spec: &str) -> bool {
spec == "vitest" || spec.starts_with("vitest/") || spec.starts_with("@vitest/")
}
fn factory_is_typed(factory: &Argument) -> bool {
let mut finder = ImportActualFinder { typed: false };
finder.visit_argument(factory);
finder.typed
}
struct ImportActualFinder {
typed: bool,
}
impl<'a> Visit<'a> for ImportActualFinder {
fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
if is_typed_import_actual(call) {
self.typed = true;
}
walk::walk_call_expression(self, call);
}
}
fn is_typed_import_actual(call: &CallExpression) -> bool {
let Expression::StaticMemberExpression(member) = &call.callee else {
return false;
};
let is_vi = matches!(&member.object, Expression::Identifier(id) if id.name == "vi");
is_vi && member.property.name.as_str() == "importActual" && call.type_arguments.is_some()
}
fn integration_violations_in(file: &Path, source: &str) -> Result<Vec<Violation>> {
let allocator = Allocator::default();
let source_type = SourceType::from_path(file).map_err(|err| {
anyhow!(
"unsupported TypeScript extension `{}`: {err}",
file.display()
)
})?;
let ret = Parser::new(&allocator, source, source_type).parse();
if ret.panicked || !ret.diagnostics.is_empty() {
let detail = ret
.diagnostics
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join("; ");
bail!("parsing `{}` failed: {detail}", file.display());
}
let mut visitor = MockVisitor {
file,
source,
violations: Vec::new(),
};
visitor.visit_program(&ret.program);
Ok(visitor.violations)
}
struct MockVisitor<'s> {
file: &'s Path,
source: &'s str,
violations: Vec<Violation>,
}
impl MockVisitor<'_> {
fn report(&mut self, span: Span, spec: &str) {
self.violations.push(Violation {
file: self.file.to_path_buf(),
line: line_of(self.source, span.start),
rule: "no-first-party-mock",
message: format!(
"integration test mocks first-party module `{spec}` — an integration test \
runs first-party code for real; only third-party packages and Node built-ins \
may be mocked"
),
});
}
}
impl<'a> Visit<'a> for MockVisitor<'_> {
fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
if let Some(spec) = vi_mock_target(call) {
if classify(&spec) == Origin::FirstParty {
self.report(call.span, &spec);
}
}
walk::walk_call_expression(self, call);
}
}
fn vi_mock_target(call: &CallExpression) -> Option<String> {
let Expression::StaticMemberExpression(member) = &call.callee else {
return None;
};
let is_vi = matches!(&member.object, Expression::Identifier(id) if id.name == "vi");
if !is_vi {
return None;
}
let method = member.property.name.as_str();
if method != "mock" && method != "doMock" {
return None;
}
match call.arguments.first() {
Some(Argument::StringLiteral(lit)) => Some(lit.value.to_string()),
_ => None,
}
}
fn line_of(source: &str, offset: u32) -> usize {
let offset = (offset as usize).min(source.len());
source.as_bytes()[..offset]
.iter()
.filter(|&&byte| byte == b'\n')
.count()
+ 1
}
fn collect_ts_test_files(dir: &Path, out: &mut Vec<PathBuf>) -> Result<()> {
let entries =
std::fs::read_dir(dir).with_context(|| format!("reading directory `{}`", dir.display()))?;
for entry in entries {
let path = entry
.with_context(|| format!("reading an entry under `{}`", dir.display()))?
.path();
if path.is_dir() {
collect_ts_test_files(&path, out)?;
} else if is_ts_test_file(&path) {
out.push(path);
}
}
Ok(())
}
fn is_ts_test_file(path: &Path) -> bool {
let name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or_default();
name.ends_with(".test.ts")
|| name.ends_with(".test.tsx")
|| name.ends_with(".test.mts")
|| name.ends_with(".test.cts")
}
#[cfg(test)]
mod tests {
use super::*;
fn violations(name: &str, source: &str) -> Vec<Violation> {
integration_violations_in(Path::new(name), source).expect("source should parse")
}
fn unit_violations(name: &str, source: &str) -> Vec<Violation> {
unit_violations_in(Path::new(name), source).expect("source should parse")
}
#[test]
fn unit_flags_unmocked_first_party_and_external() {
let found = unit_violations(
"widget.test.ts",
"import { makeWidget } from './widget';\n\
import { format } from './formatter';\n\
import { chunk } from 'lodash';\n",
);
assert_eq!(found.len(), 2, "got: {found:?}");
assert!(found.iter().all(|v| v.rule == "unmocked-collaborator"));
assert!(found.iter().any(|v| v.message.contains("./formatter")));
assert!(found.iter().any(|v| v.message.contains("lodash")));
}
#[test]
fn unit_mocked_collaborator_is_clean() {
let found = unit_violations(
"widget.test.ts",
"import { format } from './formatter';\nvi.mock('./formatter');\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn unit_under_test_and_runner_are_not_flagged() {
let found = unit_violations(
"widget.test.ts",
"import { vi } from 'vitest';\n\
import { makeWidget } from './widget.js';\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn unit_type_only_import_is_not_flagged() {
let found = unit_violations(
"widget.test.ts",
"import type { Opts } from './opts';\nimport { x } from './x';\nvi.mock('./x');\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn unit_under_test_specifier_strips_test_suffix() {
assert_eq!(
unit_under_test_specifier(Path::new("pkg/widget.test.ts")),
"./widget"
);
assert_eq!(
unit_under_test_specifier(Path::new("button.test.tsx")),
"./button"
);
}
#[test]
fn strip_module_ext_drops_known_extensions_only() {
assert_eq!(strip_module_ext("./widget.js"), "./widget");
assert_eq!(strip_module_ext("./widget.mts"), "./widget");
assert_eq!(strip_module_ext("./widget"), "./widget");
assert_eq!(strip_module_ext("lodash"), "lodash");
}
#[test]
fn recognizes_the_test_runner() {
assert!(is_test_runner("vitest"));
assert!(is_test_runner("vitest/config"));
assert!(is_test_runner("@vitest/spy"));
assert!(!is_test_runner("./vitest-helpers"));
assert!(!is_test_runner("lodash"));
}
#[test]
fn unit_flags_untyped_factory_mock() {
let found = unit_violations(
"widget.test.ts",
"import { x } from './x';\nvi.mock('./x', () => ({ x: vi.fn() }));\n",
);
assert_eq!(found.len(), 1, "got: {found:?}");
assert_eq!(found[0].rule, "untyped-mock");
assert!(found[0].message.contains("./x"));
}
#[test]
fn unit_typed_factory_mock_is_clean() {
let found = unit_violations(
"widget.test.ts",
"import { x } from './x';\n\
vi.mock('./x', async () => {\n\
\x20 const actual = await vi.importActual<typeof import('./x')>('./x');\n\
\x20 return { ...actual, x: vi.fn() };\n\
});\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn unit_untyped_import_actual_is_still_untyped() {
let found = unit_violations(
"widget.test.ts",
"import { x } from './x';\n\
vi.mock('./x', async () => {\n\
\x20 const actual = await vi.importActual('./x');\n\
\x20 return { ...(actual as object), x: vi.fn() };\n\
});\n",
);
assert_eq!(found.len(), 1, "got: {found:?}");
assert_eq!(found[0].rule, "untyped-mock");
}
#[test]
fn classify_relative_is_first_party() {
assert_eq!(classify("./service"), Origin::FirstParty);
assert_eq!(classify("../pkg/util"), Origin::FirstParty);
assert_eq!(classify("/abs/path"), Origin::FirstParty);
}
#[test]
fn classify_node_builtins() {
assert_eq!(classify("fs"), Origin::Builtin);
assert_eq!(classify("node:fs"), Origin::Builtin);
assert_eq!(classify("fs/promises"), Origin::Builtin);
assert_eq!(classify("node:test"), Origin::Builtin);
assert_eq!(classify("child_process"), Origin::Builtin);
assert_eq!(classify("node:some-future-builtin"), Origin::Builtin);
}
#[test]
fn classify_third_party() {
assert_eq!(classify("lodash"), Origin::ThirdParty);
assert_eq!(classify("@scope/pkg"), Origin::ThirdParty);
assert_eq!(classify("stripe/lib/client"), Origin::ThirdParty);
assert_eq!(classify("test"), Origin::ThirdParty);
}
#[test]
fn recognizes_ts_test_files() {
assert!(is_ts_test_file(Path::new("widget.test.ts")));
assert!(is_ts_test_file(Path::new("pkg/button.test.tsx")));
assert!(is_ts_test_file(Path::new("service.test.mts")));
assert!(is_ts_test_file(Path::new("legacy.test.cts")));
assert!(!is_ts_test_file(Path::new("widget.ts")));
assert!(!is_ts_test_file(Path::new("types.d.ts")));
assert!(!is_ts_test_file(Path::new("README.md")));
}
#[test]
fn line_of_counts_newlines() {
let src = "a\nb\nc\n";
assert_eq!(line_of(src, 0), 1);
assert_eq!(line_of(src, 2), 2);
assert_eq!(line_of(src, 4), 3);
}
#[test]
fn flags_mock_of_relative_module() {
let found = violations("a.test.ts", "vi.mock('./service');\n");
assert_eq!(found.len(), 1);
assert_eq!(found[0].rule, "no-first-party-mock");
assert_eq!(found[0].line, 1);
}
#[test]
fn flags_mock_with_factory_and_parent_path() {
let found = violations(
"a.test.ts",
"import { x } from './x';\nvi.mock('../src/ledger', () => ({ record: vi.fn() }));\n",
);
assert_eq!(found.len(), 1);
assert!(found[0].message.contains("../src/ledger"));
}
#[test]
fn flags_domock_of_relative_module() {
let found = violations("a.test.mts", "vi.doMock('./mailer');\n");
assert_eq!(found.len(), 1);
}
#[test]
fn allows_mock_of_third_party_and_builtins() {
let found = violations(
"a.test.ts",
"vi.mock('stripe');\nvi.mock('node:fs');\nvi.mock('fs/promises');\nvi.mock('@scope/pkg');\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn ignores_non_vi_and_non_mock_calls() {
let found = violations(
"a.test.ts",
"describe('s', () => {});\nvi.fn();\nexpect(1).toBe(1);\nother.mock('./x');\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn ignores_dynamic_mock_target() {
let found = violations("a.test.ts", "const m = './x';\nvi.mock(m);\n");
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn finds_mocks_nested_in_blocks() {
let found = violations(
"a.test.ts",
"describe('s', () => {\n vi.mock('./inner');\n});\n",
);
assert_eq!(found.len(), 1);
assert_eq!(found[0].line, 2);
}
#[test]
fn parse_error_is_reported() {
let err = integration_violations_in(Path::new("bad.test.ts"), "const x = ;\n").unwrap_err();
assert!(err.to_string().contains("parsing"), "got: {err}");
}
#[test]
fn unsupported_extension_is_reported() {
let err = integration_violations_in(Path::new("weird.test.bogus"), "vi.mock('./x');\n")
.unwrap_err();
assert!(err.to_string().contains("unsupported"), "got: {err}");
}
}