use std::collections::HashMap;
use std::path::{Path, PathBuf};
use anyhow::Result;
use clap::Args;
use tldr_core::ast::ParserPool;
use tldr_core::walker::walk_project;
use tldr_core::Language;
use tree_sitter::{Node, Parser, Tree};
use tree_sitter_python::LANGUAGE as PYTHON_LANGUAGE;
use crate::output::{OutputFormat, OutputWriter};
use super::error::{ContractsError, ContractsResult};
use super::types::{
Confidence, ExceptionSpec, FunctionSpecs, InputOutputSpec,
OutputFormat as ContractsOutputFormat, PropertySpec, SpecsByType, SpecsReport, SpecsSummary,
};
use super::validation::{check_ast_depth, read_file_safe, validate_file_path};
const MAX_LITERAL_DEPTH: usize = 10;
const MAX_LITERAL_SIZE: usize = 10_000;
#[derive(Debug, Args)]
pub struct SpecsArgs {
#[arg(long = "from-tests", short = 't')]
pub from_tests: PathBuf,
#[arg(
long = "output-format",
short = 'o',
hide = true,
default_value = "json"
)]
pub output_format: ContractsOutputFormat,
#[arg(long)]
pub function: Option<String>,
#[arg(long)]
pub source: Option<PathBuf>,
}
impl SpecsArgs {
pub fn run(&self, format: OutputFormat, quiet: bool) -> Result<()> {
let writer = OutputWriter::new(format, quiet);
if !self.from_tests.exists() {
return Err(ContractsError::TestPathNotFound {
path: self.from_tests.clone(),
}
.into());
}
writer.progress(&format!(
"Extracting specs from {}...",
self.from_tests.display()
));
let report = run_specs(&self.from_tests, self.function.as_deref())?;
let use_text = matches!(self.output_format, ContractsOutputFormat::Text)
|| matches!(format, OutputFormat::Text);
if use_text {
let text = format_specs_text(&report);
writer.write_text(&text)?;
} else {
writer.write(&report)?;
}
Ok(())
}
}
pub fn run_specs(test_path: &Path, function_filter: Option<&str>) -> ContractsResult<SpecsReport> {
let mut all_specs: HashMap<String, FunctionSpecs> = HashMap::new();
let mut test_functions_scanned = 0u32;
let mut test_files_scanned = 0u32;
if test_path.is_file() {
let lang = super::test_recognizer::detect_language(test_path);
if matches!(lang, Some(Language::Python)) {
let file_report = extract_from_test_file(test_path)?;
test_files_scanned = 1;
test_functions_scanned = file_report.test_functions_scanned;
merge_specs(&mut all_specs, file_report.functions);
} else if let Some(language) = lang {
if let Ok(source) = std::fs::read_to_string(test_path) {
let info = super::test_recognizer::recognize(test_path, &source, language);
if info.is_test_file {
test_files_scanned = 1;
test_functions_scanned = info.test_function_count;
let extracted =
extract_generic_specs(test_path, &source, language);
merge_specs(&mut all_specs, extracted);
}
}
}
} else {
for entry in
walk_project(test_path).filter(|e| e.path().is_file())
{
let file_path = entry.path();
let language = match super::test_recognizer::detect_language(file_path) {
Some(l) => l,
None => continue,
};
if matches!(language, Language::Python) {
let name = match file_path.file_name().and_then(|n| n.to_str()) {
Some(n) => n,
None => continue,
};
if !((name.starts_with("test_") && name.ends_with(".py"))
|| name.ends_with("_test.py"))
{
continue;
}
match extract_from_test_file(file_path) {
Ok(file_report) => {
test_files_scanned += 1;
test_functions_scanned += file_report.test_functions_scanned;
merge_specs(&mut all_specs, file_report.functions);
}
Err(e) => {
eprintln!("Warning: Failed to parse {}: {}", file_path.display(), e);
}
}
continue;
}
let source = match std::fs::read_to_string(file_path) {
Ok(s) => s,
Err(_) => continue,
};
let info = super::test_recognizer::recognize(file_path, &source, language);
if info.is_test_file {
test_files_scanned += 1;
test_functions_scanned += info.test_function_count;
let extracted = extract_generic_specs(file_path, &source, language);
merge_specs(&mut all_specs, extracted);
}
}
}
let mut functions: Vec<FunctionSpecs> = all_specs.into_values().collect();
if let Some(filter) = function_filter {
functions.retain(|f| f.function_name == filter);
}
functions.sort_by(|a, b| a.function_name.cmp(&b.function_name));
let total_io = functions
.iter()
.map(|f| f.input_output_specs.len() as u32)
.sum();
let total_exc = functions
.iter()
.map(|f| f.exception_specs.len() as u32)
.sum();
let total_prop = functions
.iter()
.map(|f| f.property_specs.len() as u32)
.sum();
let total_specs = total_io + total_exc + total_prop;
let summary = SpecsSummary {
total_specs,
by_type: SpecsByType {
input_output: total_io,
exception: total_exc,
property: total_prop,
},
test_functions_scanned,
test_files_scanned,
functions_found: functions.len() as u32,
};
Ok(SpecsReport { functions, summary })
}
struct FileSpecReport {
functions: Vec<FunctionSpecs>,
test_functions_scanned: u32,
}
fn extract_from_test_file(path: &Path) -> ContractsResult<FileSpecReport> {
let canonical = validate_file_path(path)?;
let source = read_file_safe(&canonical)?;
if source.trim().is_empty() {
return Ok(FileSpecReport {
functions: vec![],
test_functions_scanned: 0,
});
}
let tree = parse_python(&source, &canonical)?;
let root = tree.root_node();
let mut specs: HashMap<String, FunctionSpecs> = HashMap::new();
let mut test_func_count = 0u32;
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
match child.kind() {
"function_definition" => {
if let Some(name_node) = child.child_by_field_name("name") {
let name = get_node_text(name_node, source.as_bytes());
if name.starts_with("test_") {
test_func_count += 1;
process_test_function(child, name, source.as_bytes(), &mut specs, 0)?;
}
}
}
"class_definition" => {
if let Some(name_node) = child.child_by_field_name("name") {
let class_name = get_node_text(name_node, source.as_bytes());
if class_name.starts_with("Test") {
if let Some(body) = child.child_by_field_name("body") {
let mut class_cursor = body.walk();
for method in body.children(&mut class_cursor) {
if method.kind() == "function_definition" {
if let Some(method_name) = method.child_by_field_name("name") {
let mname = get_node_text(method_name, source.as_bytes());
if mname.starts_with("test_") {
test_func_count += 1;
process_test_function(
method,
mname,
source.as_bytes(),
&mut specs,
0,
)?;
}
}
}
}
}
}
}
}
_ => {}
}
}
let functions: Vec<FunctionSpecs> = specs
.into_values()
.map(|mut fs| {
fs.summary = generate_summary(&fs);
fs
})
.collect();
Ok(FileSpecReport {
functions,
test_functions_scanned: test_func_count,
})
}
fn parse_python(source: &str, file: &Path) -> ContractsResult<Tree> {
let mut parser = Parser::new();
parser
.set_language(&PYTHON_LANGUAGE.into())
.map_err(|e| ContractsError::ParseError {
file: file.to_path_buf(),
message: format!("Failed to set Python language: {}", e),
})?;
parser
.parse(source, None)
.ok_or_else(|| ContractsError::ParseError {
file: file.to_path_buf(),
message: "Parsing returned None".to_string(),
})
}
fn get_node_text<'a>(node: Node<'a>, source: &'a [u8]) -> &'a str {
let start = node.start_byte();
let end = node.end_byte();
if end <= source.len() {
std::str::from_utf8(&source[start..end]).unwrap_or("")
} else {
""
}
}
fn process_test_function(
func: Node,
test_func_name: &str,
source: &[u8],
specs: &mut HashMap<String, FunctionSpecs>,
depth: usize,
) -> ContractsResult<()> {
check_ast_depth(depth, &PathBuf::from("<test>"))?;
let body = match func.child_by_field_name("body") {
Some(b) => b,
None => return Ok(()),
};
let mut cursor = body.walk();
for stmt in body.children(&mut cursor) {
match stmt.kind() {
"assert_statement" => {
extract_from_assert(stmt, test_func_name, source, specs)?;
}
"with_statement" => {
extract_from_with(stmt, test_func_name, source, specs)?;
}
"expression_statement" => {
let mut inner = stmt.walk();
for child in stmt.children(&mut inner) {
if child.kind() == "assert_statement" {
extract_from_assert(child, test_func_name, source, specs)?;
}
}
}
_ => {}
}
}
Ok(())
}
fn extract_from_assert(
assert_stmt: Node,
test_func_name: &str,
source: &[u8],
specs: &mut HashMap<String, FunctionSpecs>,
) -> ContractsResult<()> {
let line = assert_stmt.start_position().row as u32 + 1;
let mut cursor = assert_stmt.walk();
let mut test_expr = None;
for child in assert_stmt.children(&mut cursor) {
if child.kind() != "assert" {
test_expr = Some(child);
break;
}
}
let test_expr = match test_expr {
Some(e) => e,
None => return Ok(()),
};
if try_extract_isinstance_spec(test_expr, test_func_name, line, source, specs) {
return Ok(());
}
if try_extract_comparison_spec(test_expr, test_func_name, line, source, specs) {
return Ok(());
}
Ok(())
}
fn try_extract_isinstance_spec(
expr: Node,
test_func_name: &str,
line: u32,
source: &[u8],
specs: &mut HashMap<String, FunctionSpecs>,
) -> bool {
if expr.kind() != "call" {
return false;
}
let func_node = match expr.child_by_field_name("function") {
Some(f) => f,
None => return false,
};
let func_name = get_node_text(func_node, source);
if func_name != "isinstance" {
return false;
}
let args = match expr.child_by_field_name("arguments") {
Some(a) => a,
None => return false,
};
let mut arg_cursor = args.walk();
let mut first_arg = None;
let mut second_arg = None;
for child in args.children(&mut arg_cursor) {
let kind = child.kind();
if kind == "(" || kind == ")" || kind == "," {
continue;
}
if first_arg.is_none() {
first_arg = Some(child);
} else if second_arg.is_none() {
second_arg = Some(child);
break;
}
}
let (first_arg, second_arg) = match (first_arg, second_arg) {
(Some(f), Some(s)) => (f, s),
_ => return false,
};
if first_arg.kind() != "call" {
return false;
}
let (fname, _inputs) = match extract_call_info(first_arg, source) {
Some(info) => info,
None => return false,
};
let type_name = get_node_text(second_arg, source);
let constraint = format!("isinstance(result, {})", type_name);
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.property_specs.push(PropertySpec {
function: fname,
property_type: "type".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::High,
});
true
}
fn try_extract_comparison_spec(
expr: Node,
test_func_name: &str,
line: u32,
source: &[u8],
specs: &mut HashMap<String, FunctionSpecs>,
) -> bool {
if expr.kind() != "comparison_operator" {
return false;
}
let mut cursor = expr.walk();
let mut left = None;
let mut op: Option<&str> = None;
let mut right = None;
for child in expr.children(&mut cursor) {
let kind = child.kind();
match kind {
"==" | "!=" | "<" | ">" | "<=" | ">=" => {
op = Some(kind);
}
"in" | "not in" => {
op = Some(kind);
}
"is" | "is not" => {
op = Some(kind);
}
_ => {
if left.is_none() {
left = Some(child);
} else if right.is_none() {
right = Some(child);
}
}
}
}
let (left, op, right) = match (left, op, right) {
(Some(l), Some(o), Some(r)) => (l, o, r),
_ => return false,
};
if op == "in" && right.kind() == "call" {
if let Some((fname, _)) = extract_call_info(right, source) {
let key_text = get_node_text(left, source);
let constraint = format!("{} in result", key_text);
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.property_specs.push(PropertySpec {
function: fname,
property_type: "membership".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
return true;
}
}
if op == "==" {
if left.kind() == "call" {
let left_func = left
.child_by_field_name("function")
.map(|f| get_node_text(f, source));
if left_func == Some("len") {
if let Some(inner_args) = left.child_by_field_name("arguments") {
let mut inner_cursor = inner_args.walk();
for child in inner_args.children(&mut inner_cursor) {
if child.kind() == "call" {
if let Some((fname, _)) = extract_call_info(child, source) {
let len_val = get_node_text(right, source);
let constraint = format!("len(result) == {}", len_val);
let fs =
specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.property_specs.push(PropertySpec {
function: fname,
property_type: "length".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::High,
});
return true;
}
}
}
}
}
}
if left.kind() == "call" {
if let Some((fname, inputs)) = extract_call_info(left, source) {
let output = try_eval_literal(right, source);
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.input_output_specs.push(InputOutputSpec {
function: fname,
inputs,
output,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::High,
});
return true;
}
}
if right.kind() == "call" {
if let Some((fname, inputs)) = extract_call_info(right, source) {
let output = try_eval_literal(left, source);
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.input_output_specs.push(InputOutputSpec {
function: fname,
inputs,
output,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::High,
});
return true;
}
}
}
if matches!(op, "<" | ">" | "<=" | ">=") {
let (call_side, value_side) = if left.kind() == "call" {
(left, right)
} else if right.kind() == "call" {
(right, left)
} else {
return false;
};
let call_func_name = call_side
.child_by_field_name("function")
.map(|f| get_node_text(f, source));
if call_func_name == Some("len") {
if let Some(inner_args) = call_side.child_by_field_name("arguments") {
let mut inner_cursor = inner_args.walk();
for child in inner_args.children(&mut inner_cursor) {
if child.kind() == "call" {
if let Some((fname, _)) = extract_call_info(child, source) {
let val = get_node_text(value_side, source);
let constraint = format!("len(result) {} {}", op, val);
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.property_specs.push(PropertySpec {
function: fname,
property_type: "length".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
return true;
}
}
}
}
}
if let Some((fname, _)) = extract_call_info(call_side, source) {
let val = get_node_text(value_side, source);
let constraint = format!("result {} {}", op, val);
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.property_specs.push(PropertySpec {
function: fname,
property_type: "bounds".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
return true;
}
}
false
}
fn extract_from_with(
with_stmt: Node,
test_func_name: &str,
source: &[u8],
specs: &mut HashMap<String, FunctionSpecs>,
) -> ContractsResult<()> {
let line = with_stmt.start_position().row as u32 + 1;
let mut cursor = with_stmt.walk();
let mut is_raises = false;
let mut exception_type = String::new();
let mut match_pattern: Option<String> = None;
for child in with_stmt.children(&mut cursor) {
if child.kind() == "with_clause" {
let mut clause_cursor = child.walk();
for clause_child in child.children(&mut clause_cursor) {
if clause_child.kind() == "with_item" {
if let Some(ctx_expr) = clause_child.child(0) {
if ctx_expr.kind() == "call" {
let func_text = ctx_expr
.child_by_field_name("function")
.map(|f| get_node_text(f, source))
.unwrap_or("");
if func_text == "raises" || func_text.ends_with(".raises") {
is_raises = true;
if let Some(args) = ctx_expr.child_by_field_name("arguments") {
let mut arg_cursor = args.walk();
for arg in args.children(&mut arg_cursor) {
let kind = arg.kind();
if kind == "(" || kind == ")" || kind == "," {
continue;
}
if kind == "keyword_argument" {
if let Some(key) = arg.child_by_field_name("name") {
if get_node_text(key, source) == "match" {
if let Some(val) =
arg.child_by_field_name("value")
{
let val_text = get_node_text(val, source);
match_pattern = Some(
val_text
.trim_matches('"')
.trim_matches('\'')
.to_string(),
);
}
}
}
} else if exception_type.is_empty() {
exception_type = get_node_text(arg, source).to_string();
}
}
}
}
}
}
}
}
}
}
if !is_raises || exception_type.is_empty() {
return Ok(());
}
let body = match with_stmt.child_by_field_name("body") {
Some(b) => b,
None => return Ok(()),
};
find_calls_and_add_exception_specs(
body,
source,
specs,
&exception_type,
&match_pattern,
test_func_name,
line,
);
Ok(())
}
fn find_calls_and_add_exception_specs(
block: Node,
source: &[u8],
specs: &mut HashMap<String, FunctionSpecs>,
exception_type: &str,
match_pattern: &Option<String>,
test_func_name: &str,
line: u32,
) {
let mut cursor = block.walk();
for child in block.children(&mut cursor) {
if child.kind() == "call" {
if let Some((fname, inputs)) = extract_call_info(child, source) {
let fs = specs.entry(fname.clone()).or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
fs.exception_specs.push(ExceptionSpec {
function: fname,
inputs,
exception_type: exception_type.to_string(),
match_pattern: match_pattern.clone(),
test_function: test_func_name.to_string(),
line,
confidence: Confidence::High,
});
}
}
if child.child_count() > 0 {
find_calls_and_add_exception_specs(
child,
source,
specs,
exception_type,
match_pattern,
test_func_name,
line,
);
}
}
}
fn extract_call_info(call: Node, source: &[u8]) -> Option<(String, Vec<serde_json::Value>)> {
let func_node = call.child_by_field_name("function")?;
let func_name = match func_node.kind() {
"identifier" => get_node_text(func_node, source).to_string(),
"attribute" => {
func_node
.child_by_field_name("attribute")
.map(|a| get_node_text(a, source).to_string())?
}
_ => return None,
};
if matches!(
func_name.as_str(),
"len"
| "str"
| "int"
| "float"
| "bool"
| "list"
| "dict"
| "set"
| "tuple"
| "isinstance"
| "hasattr"
| "getattr"
| "print"
| "range"
| "type"
) {
return None;
}
let args_node = call.child_by_field_name("arguments")?;
let mut inputs = Vec::new();
let mut cursor = args_node.walk();
for child in args_node.children(&mut cursor) {
let kind = child.kind();
if kind == "(" || kind == ")" || kind == "," {
continue;
}
if kind == "keyword_argument" {
continue;
}
inputs.push(try_eval_literal(child, source));
}
Some((func_name, inputs))
}
fn try_eval_literal(node: Node, source: &[u8]) -> serde_json::Value {
try_eval_literal_inner(node, source, 0)
}
fn try_eval_literal_inner(node: Node, source: &[u8], depth: usize) -> serde_json::Value {
if depth > MAX_LITERAL_DEPTH {
return serde_json::Value::String(get_node_text(node, source).to_string());
}
let text = get_node_text(node, source);
if text.len() > MAX_LITERAL_SIZE {
return serde_json::Value::String("<large literal>".to_string());
}
match node.kind() {
"integer" => text
.parse::<i64>()
.map(serde_json::Value::from)
.unwrap_or_else(|_| serde_json::Value::String(text.to_string())),
"float" => text
.parse::<f64>()
.map(|f| serde_json::json!(f))
.unwrap_or_else(|_| serde_json::Value::String(text.to_string())),
"string" | "concatenated_string" => {
let unquoted = strip_string_quotes(text);
serde_json::Value::String(unquoted)
}
"true" | "True" => serde_json::Value::Bool(true),
"false" | "False" => serde_json::Value::Bool(false),
"none" | "None" => serde_json::Value::Null,
"identifier" => {
match text {
"True" => serde_json::Value::Bool(true),
"False" => serde_json::Value::Bool(false),
"None" => serde_json::Value::Null,
_ => serde_json::Value::String(text.to_string()),
}
}
"list" => {
let mut items = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let kind = child.kind();
if kind != "[" && kind != "]" && kind != "," {
items.push(try_eval_literal_inner(child, source, depth + 1));
}
}
serde_json::Value::Array(items)
}
"tuple" | "parenthesized_expression" => {
let mut items = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let kind = child.kind();
if kind != "(" && kind != ")" && kind != "," {
items.push(try_eval_literal_inner(child, source, depth + 1));
}
}
if items.len() == 1 && node.kind() == "parenthesized_expression" {
items.into_iter().next().unwrap_or(serde_json::Value::Null)
} else {
serde_json::Value::Array(items)
}
}
"dictionary" => {
let mut obj = serde_json::Map::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "pair" {
let key_node = child.child_by_field_name("key");
let value_node = child.child_by_field_name("value");
if let (Some(k), Some(v)) = (key_node, value_node) {
let key = match try_eval_literal_inner(k, source, depth + 1) {
serde_json::Value::String(s) => s,
other => other.to_string(),
};
let value = try_eval_literal_inner(v, source, depth + 1);
obj.insert(key, value);
}
}
}
serde_json::Value::Object(obj)
}
"unary_operator" => {
let mut cursor = node.walk();
let mut op = "";
let mut operand = None;
for child in node.children(&mut cursor) {
if child.kind() == "-" {
op = "-";
} else if child.kind() == "+" {
op = "+";
} else {
operand = Some(child);
}
}
if op == "-" {
if let Some(operand) = operand {
let val = try_eval_literal_inner(operand, source, depth + 1);
if let serde_json::Value::Number(n) = val {
if let Some(i) = n.as_i64() {
return serde_json::json!(-i);
}
if let Some(f) = n.as_f64() {
return serde_json::json!(-f);
}
}
}
}
serde_json::Value::String(text.to_string())
}
_ => {
serde_json::Value::String(text.to_string())
}
}
}
fn strip_string_quotes(s: &str) -> String {
let s = s.trim();
let s = s
.strip_prefix('r')
.or_else(|| s.strip_prefix('R'))
.unwrap_or(s);
let s = s
.strip_prefix('b')
.or_else(|| s.strip_prefix('B'))
.unwrap_or(s);
let s = s
.strip_prefix('f')
.or_else(|| s.strip_prefix('F'))
.unwrap_or(s);
if s.starts_with("\"\"\"") && s.ends_with("\"\"\"") && s.len() >= 6 {
return s[3..s.len() - 3].to_string();
}
if s.starts_with("'''") && s.ends_with("'''") && s.len() >= 6 {
return s[3..s.len() - 3].to_string();
}
if ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
&& s.len() >= 2
{
return s[1..s.len() - 1].to_string();
}
s.to_string()
}
fn extract_generic_specs(
path: &Path,
source: &str,
language: Language,
) -> Vec<FunctionSpecs> {
if source.trim().is_empty() {
return Vec::new();
}
let pool = ParserPool::new();
let tree = match pool.parse(source, language).ok() {
Some(t) => t,
None => return Vec::new(),
};
let mut specs: HashMap<String, FunctionSpecs> = HashMap::new();
let bytes = source.as_bytes();
walk_for_test_bodies(tree.root_node(), bytes, language, path, &mut specs);
specs
.into_values()
.map(|mut fs| {
fs.summary = generate_summary(&fs);
fs
})
.collect()
}
fn walk_for_test_bodies(
node: Node,
source: &[u8],
language: Language,
path: &Path,
specs: &mut HashMap<String, FunctionSpecs>,
) {
if super::test_recognizer::is_test_function_node(&node, source, language) {
let test_name = test_function_display_name(&node, source);
harvest_assertions_in(&node, source, language, &test_name, path, specs);
return;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_for_test_bodies(child, source, language, path, specs);
}
}
fn test_function_display_name(node: &Node, source: &[u8]) -> String {
if let Some(name) = node.child_by_field_name("name") {
return get_node_text(name, source).to_string();
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let kind = child.kind();
if kind == "identifier" || kind == "simple_identifier" || kind == "name" {
return get_node_text(child, source).to_string();
}
}
"<anonymous>".to_string()
}
fn harvest_assertions_in(
test_func: &Node,
source: &[u8],
language: Language,
test_func_name: &str,
_path: &Path,
specs: &mut HashMap<String, FunctionSpecs>,
) {
walk_for_assertion_calls(*test_func, source, language, test_func_name, specs);
}
fn walk_for_assertion_calls(
node: Node,
source: &[u8],
language: Language,
test_func_name: &str,
specs: &mut HashMap<String, FunctionSpecs>,
) {
if matches!(language, Language::Go) && node.kind() == "if_statement" {
if try_extract_go_if_t_assertion(&node, source, test_func_name, specs) {
}
}
if matches!(language, Language::Java)
&& (node.kind() == "method_invocation" || node.kind() == "invocation_expression")
{
try_extract_java_mockmvc_assertion(&node, source, test_func_name, specs);
}
let kind = node.kind();
let is_call = matches!(
kind,
"call_expression"
| "invocation_expression"
| "method_invocation"
| "macro_invocation"
| "call"
| "function_call"
| "function_call_statement"
| "member_call_expression"
| "function_call_expression"
| "scoped_call_expression"
| "nullsafe_member_call_expression"
);
if is_call {
if let Some(callee_text) = generic_callee_name(&node, source) {
let callee_tail = callee_text.rsplit('.').next().unwrap_or(&callee_text);
let callee_tail = callee_tail.split('<').next().unwrap_or(callee_tail);
classify_assertion_call(
&node,
source,
language,
callee_tail,
test_func_name,
specs,
);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_for_assertion_calls(child, source, language, test_func_name, specs);
}
}
fn try_extract_go_if_t_assertion(
if_node: &Node,
source: &[u8],
test_func_name: &str,
specs: &mut HashMap<String, FunctionSpecs>,
) -> bool {
let cond_node = match if_node.child_by_field_name("condition") {
Some(c) => c,
None => return false,
};
let consequence = match if_node.child_by_field_name("consequence") {
Some(c) => c,
None => return false,
};
let has_t_assertion = subtree_contains_go_test_failure_call(&consequence, source);
if !has_t_assertion {
return false;
}
let call_node = match first_callable_inside(cond_node) {
Some(c) => c,
None => return false,
};
let (fname, _inputs) = match generic_extract_call_info(call_node, source) {
Some(p) => p,
None => return false,
};
if is_known_assertion_callee(&fname) || is_go_t_failure_method(&fname) {
return false;
}
let line = if_node.start_position().row as u32 + 1;
let entry = specs
.entry(fname.clone())
.or_insert_with(|| FunctionSpecs {
function_name: fname.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
entry.property_specs.push(PropertySpec {
function: fname,
property_type: "go_if_assertion".to_string(),
constraint: "condition guards t.Errorf/t.Fatal".to_string(),
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
true
}
fn subtree_contains_go_test_failure_call(node: &Node, source: &[u8]) -> bool {
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let text = get_node_text(func, source);
let tail = text.rsplit('.').next().unwrap_or(text);
if is_go_t_failure_method(tail) {
return true;
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if subtree_contains_go_test_failure_call(&child, source) {
return true;
}
}
false
}
fn is_go_t_failure_method(name: &str) -> bool {
matches!(
name,
"Error"
| "Errorf"
| "Fatal"
| "Fatalf"
| "Fail"
| "FailNow"
| "Log"
| "Logf"
| "Skip"
| "Skipf"
| "Skipped"
)
}
fn try_extract_java_mockmvc_assertion(
call: &Node,
source: &[u8],
test_func_name: &str,
specs: &mut HashMap<String, FunctionSpecs>,
) -> bool {
let callee = match generic_callee_name(call, source) {
Some(c) => c,
None => return false,
};
let tail = callee
.rsplit('.')
.next()
.unwrap_or(&callee)
.split('<')
.next()
.unwrap_or(&callee)
.trim();
let is_mockmvc_verb = matches!(tail, "andExpect" | "andExpectAll" | "andDo");
if !is_mockmvc_verb {
return false;
}
let perform_call = match find_mockmvc_perform_call(*call, source) {
Some(p) => p,
None => return false,
};
let perform_args = collect_call_args(perform_call);
let endpoint_call_node = perform_args
.first()
.copied()
.and_then(first_callable_inside);
let (fut_name, fut_inputs): (String, Vec<serde_json::Value>) =
match endpoint_call_node.and_then(|c| generic_extract_call_info(c, source)) {
Some(info) => info,
None => {
("mockMvcRequest".to_string(), Vec::new())
}
};
let exp_args = collect_call_args(*call);
let constraint = exp_args
.first()
.copied()
.map(|n| classify_mockmvc_matcher(n, source))
.unwrap_or_else(|| "expectation".to_string());
let line = call.start_position().row as u32 + 1;
let entry = specs.entry(fut_name.clone()).or_insert_with(|| FunctionSpecs {
function_name: fut_name.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
let already_present = entry
.property_specs
.iter()
.any(|p| p.line == line && p.constraint == constraint && p.test_function == test_func_name);
if !already_present {
if entry
.input_output_specs
.iter()
.all(|io| io.test_function != test_func_name)
{
entry.input_output_specs.push(InputOutputSpec {
function: fut_name.clone(),
inputs: fut_inputs.clone(),
output: serde_json::Value::Null,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
}
entry.property_specs.push(PropertySpec {
function: fut_name.clone(),
property_type: "mockmvc_expectation".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
}
true
}
fn find_mockmvc_perform_call<'a>(call: Node<'a>, source: &[u8]) -> Option<Node<'a>> {
let mut current = call;
let mut hops = 0usize;
while hops < 32 {
let object = current
.child_by_field_name("object")
.or_else(|| current.child_by_field_name("expression"));
let object = match object {
Some(o) => o,
None => return None,
};
if matches!(
object.kind(),
"method_invocation" | "invocation_expression"
) {
if let Some(name) = generic_callee_name(&object, source) {
let tail = name.rsplit('.').next().unwrap_or(&name);
if tail == "perform" {
return Some(object);
}
}
current = object;
hops += 1;
continue;
}
return None;
}
None
}
fn classify_mockmvc_matcher(node: Node, source: &[u8]) -> String {
let text = std::str::from_utf8(&source[node.start_byte()..node.end_byte()]).unwrap_or("");
let trimmed = text.trim();
if trimmed.is_empty() {
return "expectation".to_string();
}
let head: String = trimmed
.chars()
.take_while(|c| c.is_alphanumeric() || *c == '_')
.collect();
let mut leaf: Option<&str> = None;
let bytes = trimmed.as_bytes();
for i in (0..bytes.len()).rev() {
if bytes[i] == b'(' && i > 0 {
let mut j = i;
while j > 0 {
let c = bytes[j - 1];
if c.is_ascii_alphanumeric() || c == b'_' {
j -= 1;
} else {
break;
}
}
if j < i {
let candidate = &trimmed[j..i];
if candidate != head {
leaf = Some(candidate);
break;
}
}
}
}
match (head.as_str(), leaf) {
("", None) => "expectation".to_string(),
(h, None) => h.to_string(),
("", Some(l)) => l.to_string(),
(h, Some(l)) => format!("{}:{}", h, l),
}
}
fn generic_callee_name(call: &Node, source: &[u8]) -> Option<String> {
if let Some(f) = call.child_by_field_name("function") {
return Some(get_node_text(f, source).to_string());
}
if let Some(f) = call.child_by_field_name("method") {
return Some(get_node_text(f, source).to_string());
}
if let Some(f) = call.child_by_field_name("name") {
return Some(get_node_text(f, source).to_string());
}
if let Some(f) = call.child_by_field_name("macro") {
return Some(get_node_text(f, source).to_string());
}
let mut cursor = call.walk();
for child in call.children(&mut cursor) {
match child.kind() {
"identifier"
| "simple_identifier"
| "field_access"
| "member_access_expression"
| "navigation_expression"
| "scoped_identifier" => {
return Some(get_node_text(child, source).to_string());
}
_ => {}
}
}
None
}
fn collect_rust_macro_args<'a>(call: Node<'a>) -> Vec<Node<'a>> {
let token_tree = {
let mut found = None;
let mut cursor = call.walk();
for child in call.children(&mut cursor) {
if child.kind() == "token_tree" {
found = Some(child);
break;
}
}
match found {
Some(t) => t,
None => return Vec::new(),
}
};
let mut groups: Vec<Vec<Node<'a>>> = vec![Vec::new()];
let mut depth = 0i32;
let mut cursor = token_tree.walk();
for child in token_tree.children(&mut cursor) {
let k = child.kind();
match k {
"(" | "[" | "{" => {
depth += 1;
if depth == 1 {
continue;
}
groups.last_mut().unwrap().push(child);
}
")" | "]" | "}" => {
depth -= 1;
if depth == 0 {
continue;
}
groups.last_mut().unwrap().push(child);
}
"," if depth == 1 => {
groups.push(Vec::new());
}
_ => {
groups.last_mut().unwrap().push(child);
}
}
}
groups
.into_iter()
.filter_map(|grp| {
if grp.is_empty() {
return None;
}
for n in &grp {
if matches!(n.kind(), "identifier" | "scoped_identifier") {
return Some(*n);
}
}
grp.into_iter().find(|n| {
!matches!(n.kind(), "(" | ")" | "[" | "]" | "{" | "}" | ",")
})
})
.collect()
}
fn collect_call_args<'a>(call: Node<'a>) -> Vec<Node<'a>> {
let mut out: Vec<Node<'a>> = Vec::new();
let arg_list = call
.child_by_field_name("arguments")
.or_else(|| {
let mut cursor = call.walk();
for child in call.children(&mut cursor) {
let k = child.kind();
if k == "argument_list"
|| k == "value_arguments"
|| k == "arguments"
|| k == "argument_list_no_paren"
|| k == "token_tree"
{
return Some(child);
}
}
None
})
.unwrap_or(call);
let mut cursor = arg_list.walk();
for child in arg_list.children(&mut cursor) {
let k = child.kind();
if k == "("
|| k == ")"
|| k == ","
|| k == ":"
|| k == "{"
|| k == "}"
|| k == "["
|| k == "]"
{
continue;
}
if k == "identifier" && arg_list.id() == call.id() {
continue;
}
if k == "argument" || k == "value_argument" {
let mut inner = child.walk();
for grand in child.children(&mut inner) {
let gk = grand.kind();
if gk == ":" || gk == "name" {
continue;
}
out.push(grand);
break;
}
continue;
}
out.push(child);
}
out
}
fn classify_assertion_call(
call: &Node,
source: &[u8],
language: Language,
callee_tail: &str,
test_func_name: &str,
specs: &mut HashMap<String, FunctionSpecs>,
) {
let line = call.start_position().row as u32 + 1;
fn ensure<'a>(
specs: &'a mut HashMap<String, FunctionSpecs>,
name: &str,
) -> &'a mut FunctionSpecs {
specs.entry(name.to_string()).or_insert_with(|| FunctionSpecs {
function_name: name.to_string(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
})
}
let is_equality = matches!(
callee_tail,
"assertEquals"
| "assertEqual"
| "assertSame"
| "AreEqual"
| "AreSame"
| "Equal"
| "assert_eq"
| "assert_equal"
| "should_eq"
| "shouldBe"
| "shouldEqual"
);
let is_inequality = matches!(
callee_tail,
"assertNotEquals"
| "assertNotEqual"
| "AreNotEqual"
| "NotEqual"
| "assert_ne"
| "assertNotSame"
);
let is_true = matches!(
callee_tail,
"assertTrue" | "IsTrue" | "True" | "assert" | "assert_true"
);
let is_false = matches!(
callee_tail,
"assertFalse" | "IsFalse" | "False" | "assert_false"
);
let is_not_null = matches!(
callee_tail,
"assertNotNull" | "IsNotNull" | "NotNull" | "assert_some"
);
let is_null = matches!(
callee_tail,
"assertNull" | "IsNull" | "Null" | "assert_none"
);
let is_throws = matches!(
callee_tail,
"assertThrows"
| "assertFails"
| "Throws"
| "ThrowsAsync"
| "Throws_"
| "should_panic"
| "expectThrows"
);
let args = if matches!(language, Language::Rust) && call.kind() == "macro_invocation" {
collect_rust_macro_args(*call)
} else {
collect_call_args(*call)
};
if args.is_empty() {
return;
}
if is_equality && args.len() >= 2 {
let rust_macro = matches!(language, Language::Rust)
&& call.kind() == "macro_invocation";
let (call_arg, value_arg) = if rust_macro {
let lhs_callish = is_rust_macro_call_token(args[0], source);
let rhs_callish = is_rust_macro_call_token(args[1], source);
match (rhs_callish, lhs_callish) {
(true, _) => (args[1], args[0]),
(false, true) => (args[0], args[1]),
_ => return,
}
} else {
match (looks_like_call(args[1]), looks_like_call(args[0])) {
(true, _) => (args[1], args[0]),
(false, true) => (args[0], args[1]),
_ => return,
}
};
if let Some((fname, inputs)) =
extract_call_info_for_lang(call_arg, source, language)
{
let output = try_eval_literal(value_arg, source);
let fs = ensure(specs, &fname);
fs.input_output_specs.push(InputOutputSpec {
function: fname,
inputs,
output,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::High,
});
}
return;
}
if is_inequality && args.len() >= 2 {
let rust_macro = matches!(language, Language::Rust)
&& call.kind() == "macro_invocation";
let call_arg = if rust_macro {
if is_rust_macro_call_token(args[1], source) {
args[1]
} else if is_rust_macro_call_token(args[0], source) {
args[0]
} else {
return;
}
} else if looks_like_call(args[1]) {
args[1]
} else if looks_like_call(args[0]) {
args[0]
} else {
return;
};
let other = if call_arg.id() == args[1].id() {
args[0]
} else {
args[1]
};
if let Some((fname, _inputs)) = extract_call_info_for_lang(call_arg, source, language) {
let val = std::str::from_utf8(
&source[other.start_byte()..other.end_byte()],
)
.unwrap_or("");
let constraint = format!("result != {}", val);
let fs = ensure(specs, &fname);
fs.property_specs.push(PropertySpec {
function: fname,
property_type: "inequality".to_string(),
constraint,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
}
return;
}
if (is_true || is_false) && !args.is_empty() {
let mut call_arg = first_callable_inside(args[0]);
if call_arg.is_none()
&& matches!(language, Language::Rust)
&& call.kind() == "macro_invocation"
{
call_arg = find_rust_macro_inline_call(*call, source);
}
if let Some(c) = call_arg {
let fname_inputs = if matches!(language, Language::Rust)
&& call.kind() == "macro_invocation"
&& !looks_like_call(c)
{
Some((
std::str::from_utf8(&source[c.start_byte()..c.end_byte()])
.unwrap_or("")
.trim()
.rsplit('.')
.next()
.unwrap_or("")
.rsplit("::")
.next()
.unwrap_or("")
.to_string(),
Vec::<serde_json::Value>::new(),
))
.filter(|(n, _)| !n.is_empty())
} else {
generic_extract_call_info(c, source)
};
if let Some((fname, _)) = fname_inputs {
let fs = ensure(specs, &fname);
fs.property_specs.push(PropertySpec {
function: fname,
property_type: if is_true {
"truthy".to_string()
} else {
"falsy".to_string()
},
constraint: if is_true {
"result is true".to_string()
} else {
"result is false".to_string()
},
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
}
}
return;
}
if (is_null || is_not_null) && !args.is_empty() {
let call_arg = first_callable_inside(args[0]);
if let Some(c) = call_arg {
if let Some((fname, _)) = generic_extract_call_info(c, source) {
let fs = ensure(specs, &fname);
fs.property_specs.push(PropertySpec {
function: fname,
property_type: if is_not_null {
"not_null".to_string()
} else {
"null".to_string()
},
constraint: if is_not_null {
"result != null".to_string()
} else {
"result == null".to_string()
},
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
}
}
return;
}
if is_throws {
for arg in &args {
if let Some(c) = first_callable_inside(*arg) {
if let Some((fname, inputs)) = generic_extract_call_info(c, source) {
let exc = guess_exception_type(call, source);
let fs = ensure(specs, &fname);
fs.exception_specs.push(ExceptionSpec {
function: fname,
exception_type: exc,
match_pattern: None,
inputs,
test_function: test_func_name.to_string(),
line,
confidence: Confidence::Medium,
});
return;
}
}
}
}
}
fn generic_extract_call_info(
call: Node,
source: &[u8],
) -> Option<(String, Vec<serde_json::Value>)> {
let raw_callee = generic_callee_name(&call, source)?;
let head = raw_callee.split('<').next().unwrap_or(&raw_callee);
let tail = head.rsplit('.').next().unwrap_or(head);
let tail = tail.rsplit("::").next().unwrap_or(tail);
let func_name = tail.trim().to_string();
if func_name.is_empty() {
return None;
}
if is_known_assertion_callee(&func_name) {
return None;
}
let args = collect_call_args(call);
let inputs: Vec<serde_json::Value> = args
.into_iter()
.map(|n| try_eval_literal(n, source))
.collect();
Some((func_name, inputs))
}
fn is_known_assertion_callee(name: &str) -> bool {
matches!(
name,
"assertEquals"
| "assertEqual"
| "assertSame"
| "assertTrue"
| "assertFalse"
| "assertNull"
| "assertNotNull"
| "assertNotEquals"
| "assertThrows"
| "assertFails"
| "AreEqual"
| "AreNotEqual"
| "AreSame"
| "IsTrue"
| "IsFalse"
| "IsNull"
| "IsNotNull"
| "Throws"
| "ThrowsAsync"
| "Equal"
| "NotEqual"
| "True"
| "False"
| "Null"
| "NotNull"
| "assert"
| "assert_eq"
| "assert_ne"
| "assert_true"
| "assert_false"
| "assert_some"
| "assert_none"
| "should_eq"
| "shouldBe"
| "shouldEqual"
)
}
fn is_rust_macro_call_token(n: Node, source: &[u8]) -> bool {
if !matches!(n.kind(), "identifier" | "scoped_identifier") {
return false;
}
let end = n.end_byte();
let mut i = end;
while i < source.len() {
let b = source[i];
if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
i += 1;
continue;
}
return b == b'(';
}
false
}
fn extract_call_info_for_lang(
node: Node,
source: &[u8],
language: Language,
) -> Option<(String, Vec<serde_json::Value>)> {
if matches!(language, Language::Rust)
&& matches!(node.kind(), "identifier" | "scoped_identifier")
{
let fname = std::str::from_utf8(&source[node.start_byte()..node.end_byte()])
.ok()?
.trim()
.to_string();
if fname.is_empty() {
return None;
}
return Some((fname, Vec::new()));
}
generic_extract_call_info(node, source)
}
fn looks_like_call(n: Node) -> bool {
matches!(
n.kind(),
"call_expression"
| "invocation_expression"
| "method_invocation"
| "call"
| "function_call"
| "function_call_statement"
| "macro_invocation"
| "member_call_expression"
| "function_call_expression"
| "scoped_call_expression"
| "nullsafe_member_call_expression"
)
}
fn first_callable_inside(n: Node) -> Option<Node> {
if looks_like_call(n) {
return Some(n);
}
let mut cursor = n.walk();
for child in n.children(&mut cursor) {
if let Some(found) = first_callable_inside(child) {
return Some(found);
}
}
None
}
fn find_rust_macro_inline_call<'a>(root: Node<'a>, source: &[u8]) -> Option<Node<'a>> {
let mut stack = vec![root];
while let Some(node) = stack.pop() {
let kind = node.kind();
if matches!(
kind,
"identifier" | "scoped_identifier" | "field_expression"
) && is_followed_by_open_paren(node, source)
{
return Some(node);
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
stack.push(cursor.node());
if !cursor.goto_next_sibling() {
break;
}
}
}
}
None
}
fn is_followed_by_open_paren(node: Node, source: &[u8]) -> bool {
let mut i = node.end_byte();
while i < source.len() {
let b = source[i];
if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
i += 1;
continue;
}
return b == b'(';
}
false
}
fn guess_exception_type(call: &Node, source: &[u8]) -> String {
let text = std::str::from_utf8(&source[call.start_byte()..call.end_byte()])
.unwrap_or("");
let mut acc = String::new();
for ch in text.chars() {
if acc.contains('(') || acc.contains('{') {
break;
}
acc.push(ch);
}
if let Some(start) = acc.find('<') {
if let Some(end) = acc[start + 1..].find('>') {
return acc[start + 1..start + 1 + end].trim().to_string();
}
}
if let Some(idx) = acc.find('(') {
let after = &acc[idx + 1..];
if let Some(dot) = after.find(".class") {
return after[..dot].trim().to_string();
}
}
"Exception".to_string()
}
fn merge_specs(all_specs: &mut HashMap<String, FunctionSpecs>, new_specs: Vec<FunctionSpecs>) {
for new_fs in new_specs {
let entry = all_specs
.entry(new_fs.function_name.clone())
.or_insert_with(|| FunctionSpecs {
function_name: new_fs.function_name.clone(),
summary: String::new(),
test_count: 0,
input_output_specs: vec![],
exception_specs: vec![],
property_specs: vec![],
});
entry.input_output_specs.extend(new_fs.input_output_specs);
entry.exception_specs.extend(new_fs.exception_specs);
entry.property_specs.extend(new_fs.property_specs);
entry.test_count += new_fs.test_count;
}
}
fn generate_summary(fs: &FunctionSpecs) -> String {
let io_count = fs.input_output_specs.len();
let exc_count = fs.exception_specs.len();
let prop_count = fs.property_specs.len();
let mut parts = Vec::new();
if io_count > 0 {
parts.push(format!("{} input/output", io_count));
}
if exc_count > 0 {
parts.push(format!("{} raises", exc_count));
}
if prop_count > 0 {
parts.push(format!("{} property", prop_count));
}
if parts.is_empty() {
"no specs".to_string()
} else {
parts.join(", ")
}
}
pub fn format_specs_text(report: &SpecsReport) -> String {
let mut output = String::new();
for func in &report.functions {
output.push_str(&format!("Function: {}\n", func.function_name));
for spec in &func.input_output_specs {
let inputs_str: Vec<String> = spec.inputs.iter().map(|v| format!("{}", v)).collect();
output.push_str(&format!(
" IO: {}({}) == {}\n",
func.function_name,
inputs_str.join(", "),
spec.output
));
}
for spec in &func.exception_specs {
if let Some(pattern) = &spec.match_pattern {
output.push_str(&format!(
" Raises: {} (match='{}')\n",
spec.exception_type, pattern
));
} else {
output.push_str(&format!(" Raises: {}\n", spec.exception_type));
}
}
for spec in &func.property_specs {
output.push_str(&format!(
" Property ({}): {}\n",
spec.property_type, spec.constraint
));
}
output.push('\n');
}
output.push_str(&format!("Total specs: {}\n", report.summary.total_specs));
output
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
const PYTHON_TEST_FILE: &str = r#"
import pytest
def test_add_basic():
assert add(1, 2) == 3
assert add(0, 0) == 0
assert add(-1, 1) == 0
def test_add_large():
assert add(100, 200) == 300
def test_divide_by_zero():
with pytest.raises(ZeroDivisionError):
divide(1, 0)
def test_validate_raises_with_match():
with pytest.raises(ValueError, match="invalid"):
validate(-1)
def test_result_type():
# Direct call pattern for type check
assert isinstance(multiply(2, 3), int)
def test_result_length():
# Direct call pattern for length check
assert len(get_items()) == 3
def test_result_bounds():
# Direct call pattern for bounds check
assert compute_value() > 0
def test_membership():
# Direct call pattern for membership check
assert "key" in get_config()
"#;
#[test]
fn test_specs_input_output_extraction() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let add_func = report.functions.iter().find(|f| f.function_name == "add");
assert!(add_func.is_some(), "Should find 'add' function");
let add = add_func.unwrap();
assert!(
add.input_output_specs.len() >= 3,
"Should extract at least 3 IO specs for add, got {}",
add.input_output_specs.len()
);
}
#[test]
fn test_specs_exception_extraction() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let divide_func = report
.functions
.iter()
.find(|f| f.function_name == "divide");
assert!(divide_func.is_some(), "Should find 'divide' function");
let divide = divide_func.unwrap();
assert!(
!divide.exception_specs.is_empty(),
"Should extract exception specs for divide"
);
assert_eq!(
divide.exception_specs[0].exception_type,
"ZeroDivisionError"
);
}
#[test]
fn test_specs_exception_with_match() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let validate_func = report
.functions
.iter()
.find(|f| f.function_name == "validate");
assert!(validate_func.is_some(), "Should find 'validate' function");
let validate = validate_func.unwrap();
assert!(!validate.exception_specs.is_empty());
assert!(validate.exception_specs[0].match_pattern.is_some());
assert_eq!(
validate.exception_specs[0].match_pattern.as_ref().unwrap(),
"invalid"
);
}
#[test]
fn test_specs_property_type_extraction() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let multiply_func = report
.functions
.iter()
.find(|f| f.function_name == "multiply");
assert!(multiply_func.is_some(), "Should find 'multiply' function");
let multiply = multiply_func.unwrap();
let type_prop = multiply
.property_specs
.iter()
.find(|p| p.property_type == "type");
assert!(type_prop.is_some(), "Should extract type property");
assert!(type_prop.unwrap().constraint.contains("isinstance"));
}
#[test]
fn test_specs_property_length_extraction() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let get_items = report
.functions
.iter()
.find(|f| f.function_name == "get_items");
assert!(get_items.is_some(), "Should find 'get_items' function");
let get_items = get_items.unwrap();
let len_prop = get_items
.property_specs
.iter()
.find(|p| p.property_type == "length");
assert!(len_prop.is_some(), "Should extract length property");
}
#[test]
fn test_specs_property_bounds_extraction() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let compute = report
.functions
.iter()
.find(|f| f.function_name == "compute_value");
assert!(compute.is_some(), "Should find 'compute_value' function");
let compute = compute.unwrap();
let bounds_prop = compute
.property_specs
.iter()
.find(|p| p.property_type == "bounds");
assert!(bounds_prop.is_some(), "Should extract bounds property");
}
#[test]
fn test_specs_property_membership_extraction() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let get_config = report
.functions
.iter()
.find(|f| f.function_name == "get_config");
assert!(get_config.is_some(), "Should find 'get_config' function");
let get_config = get_config.unwrap();
let member_prop = get_config
.property_specs
.iter()
.find(|p| p.property_type == "membership");
assert!(member_prop.is_some(), "Should extract membership property");
}
#[test]
fn test_specs_function_filter() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, Some("add")).unwrap();
assert_eq!(report.functions.len(), 1);
assert_eq!(report.functions[0].function_name, "add");
}
#[test]
fn test_specs_directory_scan() {
let temp = TempDir::new().unwrap();
let test1 = temp.path().join("test_one.py");
fs::write(&test1, "def test_foo():\n assert foo(1) == 2\n").unwrap();
let test2 = temp.path().join("test_two.py");
fs::write(&test2, "def test_bar():\n assert bar(3) == 4\n").unwrap();
let report = run_specs(temp.path(), None).unwrap();
assert_eq!(report.summary.test_files_scanned, 2);
assert!(report.functions.iter().any(|f| f.function_name == "foo"));
assert!(report.functions.iter().any(|f| f.function_name == "bar"));
}
#[test]
fn test_specs_json_output() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let json = serde_json::to_string(&report).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.get("functions").is_some());
assert!(parsed.get("summary").is_some());
}
#[test]
fn test_specs_text_output() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
let text = format_specs_text(&report);
assert!(text.contains("Function:"));
assert!(text.contains("Total specs:"));
}
#[test]
fn test_specs_test_path_not_found() {
let _args = SpecsArgs {
from_tests: PathBuf::from("/nonexistent/test_path"),
output_format: ContractsOutputFormat::Json,
function: None,
source: None,
};
let path = Path::new("/nonexistent/test_path");
assert!(!path.exists(), "Path should not exist for this test");
}
#[test]
fn test_specs_empty_directory() {
let temp = TempDir::new().unwrap();
let report = run_specs(temp.path(), None).unwrap();
assert_eq!(report.summary.test_files_scanned, 0);
assert_eq!(report.summary.total_specs, 0);
}
#[test]
fn test_specs_summary_counts() {
let temp = TempDir::new().unwrap();
let test_path = temp.path().join("test_module.py");
fs::write(&test_path, PYTHON_TEST_FILE).unwrap();
let report = run_specs(&test_path, None).unwrap();
assert!(report.summary.total_specs > 0);
assert!(report.summary.by_type.input_output > 0);
assert!(report.summary.test_functions_scanned > 0);
assert_eq!(report.summary.test_files_scanned, 1);
}
fn parse_and_get_expr(source: &str) -> (Tree, Vec<u8>) {
let mut parser = Parser::new();
parser.set_language(&PYTHON_LANGUAGE.into()).unwrap();
let tree = parser.parse(source, None).unwrap();
(tree, source.as_bytes().to_vec())
}
fn find_expr_node(node: Node) -> Node {
if node.kind() == "expression_statement" {
if let Some(child) = node.child(0) {
return child;
}
}
node
}
#[test]
fn test_literal_eval_integers() {
let (tree, source) = parse_and_get_expr("42");
let root = tree.root_node();
let expr = find_expr_node(root.child(0).unwrap());
let val = try_eval_literal(expr, &source);
assert_eq!(val, serde_json::json!(42));
}
#[test]
fn test_literal_eval_negative() {
let (tree, source) = parse_and_get_expr("-5");
let root = tree.root_node();
let expr = find_expr_node(root.child(0).unwrap());
let val = try_eval_literal(expr, &source);
assert_eq!(val, serde_json::json!(-5));
}
#[test]
fn test_literal_eval_string() {
let (tree, source) = parse_and_get_expr("\"hello\"");
let root = tree.root_node();
let expr = find_expr_node(root.child(0).unwrap());
let val = try_eval_literal(expr, &source);
assert_eq!(val, serde_json::json!("hello"));
}
#[test]
fn test_literal_eval_list() {
let (tree, source) = parse_and_get_expr("[1, 2, 3]");
let root = tree.root_node();
let expr = find_expr_node(root.child(0).unwrap());
let val = try_eval_literal(expr, &source);
assert_eq!(val, serde_json::json!([1, 2, 3]));
}
}