use std::path::{Path, PathBuf};
use anyhow::{anyhow, bail, Context, Result};
use oxc::allocator::Allocator;
use oxc::ast::ast::{Argument, CallExpression, Expression};
use oxc::ast_visit::{walk, Visit};
use oxc::parser::Parser;
use oxc::span::{SourceType, Span};
use crate::lint::Violation;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Origin {
FirstParty,
Builtin,
ThirdParty,
}
pub fn classify(specifier: &str) -> Origin {
if specifier.starts_with('.') || specifier.starts_with('/') {
return Origin::FirstParty;
}
if specifier.starts_with("node:") || is_node_builtin(specifier) {
return Origin::Builtin;
}
Origin::ThirdParty
}
fn is_node_builtin(specifier: &str) -> bool {
let head = specifier.split('/').next().unwrap_or(specifier);
NODE_BUILTINS.contains(&head)
}
const NODE_BUILTINS: &[&str] = &[
"assert",
"async_hooks",
"buffer",
"child_process",
"cluster",
"console",
"constants",
"crypto",
"dgram",
"diagnostics_channel",
"dns",
"domain",
"events",
"fs",
"http",
"http2",
"https",
"inspector",
"module",
"net",
"os",
"path",
"perf_hooks",
"process",
"punycode",
"querystring",
"readline",
"repl",
"stream",
"string_decoder",
"sys",
"timers",
"tls",
"trace_events",
"tty",
"url",
"util",
"v8",
"vm",
"wasi",
"worker_threads",
"zlib",
];
pub fn find_integration_violations(root: impl AsRef<Path>) -> Result<Vec<Violation>> {
let root = root.as_ref();
let mut files = Vec::new();
collect_ts_test_files(root, &mut files)?;
files.sort();
let mut violations = Vec::new();
for file in &files {
let source = std::fs::read_to_string(file)
.with_context(|| format!("reading test file `{}`", file.display()))?;
violations.extend(integration_violations_in(file, &source)?);
}
violations.sort_by(|a, b| a.file.cmp(&b.file).then(a.line.cmp(&b.line)));
Ok(violations)
}
fn integration_violations_in(file: &Path, source: &str) -> Result<Vec<Violation>> {
let allocator = Allocator::default();
let source_type = SourceType::from_path(file).map_err(|err| {
anyhow!(
"unsupported TypeScript extension `{}`: {err}",
file.display()
)
})?;
let ret = Parser::new(&allocator, source, source_type).parse();
if ret.panicked || !ret.diagnostics.is_empty() {
let detail = ret
.diagnostics
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join("; ");
bail!("parsing `{}` failed: {detail}", file.display());
}
let mut visitor = MockVisitor {
file,
source,
violations: Vec::new(),
};
visitor.visit_program(&ret.program);
Ok(visitor.violations)
}
struct MockVisitor<'s> {
file: &'s Path,
source: &'s str,
violations: Vec<Violation>,
}
impl MockVisitor<'_> {
fn report(&mut self, span: Span, spec: &str) {
self.violations.push(Violation {
file: self.file.to_path_buf(),
line: line_of(self.source, span.start),
rule: "no-first-party-mock",
message: format!(
"integration test mocks first-party module `{spec}` — an integration test \
runs first-party code for real; only third-party packages and Node built-ins \
may be mocked"
),
});
}
}
impl<'a> Visit<'a> for MockVisitor<'_> {
fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
if let Some(spec) = vi_mock_target(call) {
if classify(&spec) == Origin::FirstParty {
self.report(call.span, &spec);
}
}
walk::walk_call_expression(self, call);
}
}
fn vi_mock_target(call: &CallExpression) -> Option<String> {
let Expression::StaticMemberExpression(member) = &call.callee else {
return None;
};
let is_vi = matches!(&member.object, Expression::Identifier(id) if id.name == "vi");
if !is_vi {
return None;
}
let method = member.property.name.as_str();
if method != "mock" && method != "doMock" {
return None;
}
match call.arguments.first() {
Some(Argument::StringLiteral(lit)) => Some(lit.value.to_string()),
_ => None,
}
}
fn line_of(source: &str, offset: u32) -> usize {
let offset = (offset as usize).min(source.len());
source.as_bytes()[..offset]
.iter()
.filter(|&&byte| byte == b'\n')
.count()
+ 1
}
fn collect_ts_test_files(dir: &Path, out: &mut Vec<PathBuf>) -> Result<()> {
let entries =
std::fs::read_dir(dir).with_context(|| format!("reading directory `{}`", dir.display()))?;
for entry in entries {
let path = entry
.with_context(|| format!("reading an entry under `{}`", dir.display()))?
.path();
if path.is_dir() {
collect_ts_test_files(&path, out)?;
} else if is_ts_test_file(&path) {
out.push(path);
}
}
Ok(())
}
fn is_ts_test_file(path: &Path) -> bool {
let name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or_default();
name.ends_with(".test.ts")
|| name.ends_with(".test.tsx")
|| name.ends_with(".test.mts")
|| name.ends_with(".test.cts")
}
#[cfg(test)]
mod tests {
use super::*;
fn violations(name: &str, source: &str) -> Vec<Violation> {
integration_violations_in(Path::new(name), source).expect("source should parse")
}
#[test]
fn classify_relative_is_first_party() {
assert_eq!(classify("./service"), Origin::FirstParty);
assert_eq!(classify("../pkg/util"), Origin::FirstParty);
assert_eq!(classify("/abs/path"), Origin::FirstParty);
}
#[test]
fn classify_node_builtins() {
assert_eq!(classify("fs"), Origin::Builtin);
assert_eq!(classify("node:fs"), Origin::Builtin);
assert_eq!(classify("fs/promises"), Origin::Builtin);
assert_eq!(classify("node:test"), Origin::Builtin);
assert_eq!(classify("child_process"), Origin::Builtin);
assert_eq!(classify("node:some-future-builtin"), Origin::Builtin);
}
#[test]
fn classify_third_party() {
assert_eq!(classify("lodash"), Origin::ThirdParty);
assert_eq!(classify("@scope/pkg"), Origin::ThirdParty);
assert_eq!(classify("stripe/lib/client"), Origin::ThirdParty);
assert_eq!(classify("test"), Origin::ThirdParty);
}
#[test]
fn recognizes_ts_test_files() {
assert!(is_ts_test_file(Path::new("widget.test.ts")));
assert!(is_ts_test_file(Path::new("pkg/button.test.tsx")));
assert!(is_ts_test_file(Path::new("service.test.mts")));
assert!(is_ts_test_file(Path::new("legacy.test.cts")));
assert!(!is_ts_test_file(Path::new("widget.ts")));
assert!(!is_ts_test_file(Path::new("types.d.ts")));
assert!(!is_ts_test_file(Path::new("README.md")));
}
#[test]
fn line_of_counts_newlines() {
let src = "a\nb\nc\n";
assert_eq!(line_of(src, 0), 1);
assert_eq!(line_of(src, 2), 2);
assert_eq!(line_of(src, 4), 3);
}
#[test]
fn flags_mock_of_relative_module() {
let found = violations("a.test.ts", "vi.mock('./service');\n");
assert_eq!(found.len(), 1);
assert_eq!(found[0].rule, "no-first-party-mock");
assert_eq!(found[0].line, 1);
}
#[test]
fn flags_mock_with_factory_and_parent_path() {
let found = violations(
"a.test.ts",
"import { x } from './x';\nvi.mock('../src/ledger', () => ({ record: vi.fn() }));\n",
);
assert_eq!(found.len(), 1);
assert!(found[0].message.contains("../src/ledger"));
}
#[test]
fn flags_domock_of_relative_module() {
let found = violations("a.test.mts", "vi.doMock('./mailer');\n");
assert_eq!(found.len(), 1);
}
#[test]
fn allows_mock_of_third_party_and_builtins() {
let found = violations(
"a.test.ts",
"vi.mock('stripe');\nvi.mock('node:fs');\nvi.mock('fs/promises');\nvi.mock('@scope/pkg');\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn ignores_non_vi_and_non_mock_calls() {
let found = violations(
"a.test.ts",
"describe('s', () => {});\nvi.fn();\nexpect(1).toBe(1);\nother.mock('./x');\n",
);
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn ignores_dynamic_mock_target() {
let found = violations("a.test.ts", "const m = './x';\nvi.mock(m);\n");
assert!(found.is_empty(), "got: {found:?}");
}
#[test]
fn finds_mocks_nested_in_blocks() {
let found = violations(
"a.test.ts",
"describe('s', () => {\n vi.mock('./inner');\n});\n",
);
assert_eq!(found.len(), 1);
assert_eq!(found[0].line, 2);
}
#[test]
fn parse_error_is_reported() {
let err = integration_violations_in(Path::new("bad.test.ts"), "const x = ;\n").unwrap_err();
assert!(err.to_string().contains("parsing"), "got: {err}");
}
#[test]
fn unsupported_extension_is_reported() {
let err = integration_violations_in(Path::new("weird.test.bogus"), "vi.mock('./x');\n")
.unwrap_err();
assert!(err.to_string().contains("unsupported"), "got: {err}");
}
}