use std::collections::HashSet;
use std::path::{Path, PathBuf};
use anyhow::Result;
use clap::Args;
use regex::Regex;
use tree_sitter::{Node, Parser, Tree};
use tldr_core::ast::ParserPool;
use tldr_core::Language;
use crate::output::{OutputFormat, OutputWriter};
use super::error::{ContractsError, ContractsResult};
use super::types::{Condition, ContractsReport, OutputFormat as ContractsOutputFormat};
#[cfg(test)]
use super::types::Confidence;
use super::validation::{
check_ast_depth, read_file_safe, validate_file_path, validate_function_name, MAX_AST_DEPTH,
MAX_CONDITIONS_PER_FUNCTION,
};
#[derive(Debug, Clone)]
struct LanguageConfig {
function_kinds: &'static [&'static str],
class_kinds: &'static [&'static str],
if_kinds: &'static [&'static str],
assert_kinds: &'static [&'static str],
throw_kinds: &'static [&'static str],
return_kinds: &'static [&'static str],
loop_kinds: &'static [&'static str],
assignment_kinds: &'static [&'static str],
func_name_field: &'static str,
func_body_field: &'static str,
if_condition_field: &'static str,
if_consequence_field: &'static str,
if_alternative_field: &'static str,
func_params_field: &'static str,
return_type_field: &'static str,
class_body_field: &'static str,
loop_body_field: &'static str,
negation_prefix: &'static str,
has_isinstance: bool,
typed_param_kinds: &'static [&'static str],
assert_is_macro: bool,
assert_call_names: &'static [&'static str],
error_call_names: &'static [&'static str],
call_kinds: &'static [&'static str],
}
impl LanguageConfig {
fn for_language(lang: Language) -> Self {
match lang {
Language::Python => Self::python(),
Language::Go => Self::go(),
Language::Rust => Self::rust(),
Language::Java => Self::java(),
Language::TypeScript | Language::JavaScript => Self::typescript(),
Language::C => Self::c(),
Language::Cpp => Self::cpp(),
Language::Ruby => Self::ruby(),
Language::CSharp => Self::csharp(),
Language::Scala => Self::scala(),
Language::Php => Self::php(),
Language::Lua => Self::lua(),
Language::Luau => Self::luau(),
Language::Elixir => Self::elixir(),
Language::Ocaml => Self::ocaml(),
Language::Kotlin => Self::kotlin(),
Language::Swift => Self::swift(),
}
}
fn python() -> Self {
Self {
function_kinds: &["function_definition"],
class_kinds: &["class_definition"],
if_kinds: &["if_statement"],
assert_kinds: &["assert_statement"],
throw_kinds: &["raise_statement"],
return_kinds: &["return_statement"],
loop_kinds: &["for_statement", "while_statement"],
assignment_kinds: &["assignment", "expression_statement"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "return_type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "not",
has_isinstance: true,
typed_param_kinds: &["typed_parameter", "typed_default_parameter"],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &[],
call_kinds: &["call"],
}
}
fn go() -> Self {
Self {
function_kinds: &["function_declaration", "method_declaration"],
class_kinds: &[],
if_kinds: &["if_statement"],
assert_kinds: &[], throw_kinds: &["return_statement"], return_kinds: &["return_statement"],
loop_kinds: &["for_statement"],
assignment_kinds: &["assignment_statement", "short_var_declaration"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "result",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter_declaration"],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &["panic"],
call_kinds: &["call_expression"],
}
}
fn rust() -> Self {
Self {
function_kinds: &["function_item"],
class_kinds: &["impl_item", "trait_item"],
if_kinds: &["if_expression"],
assert_kinds: &["macro_invocation"], throw_kinds: &["return_expression"], return_kinds: &["return_expression"],
loop_kinds: &["for_expression", "while_expression", "loop_expression"],
assignment_kinds: &[
"let_declaration",
"assignment_expression",
"expression_statement",
],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "return_type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter"],
assert_is_macro: true,
assert_call_names: &[],
error_call_names: &[],
call_kinds: &["call_expression"],
}
}
fn java() -> Self {
Self {
function_kinds: &["method_declaration", "constructor_declaration"],
class_kinds: &["class_declaration", "interface_declaration"],
if_kinds: &["if_statement"],
assert_kinds: &["assert_statement"],
throw_kinds: &["throw_statement"],
return_kinds: &["return_statement"],
loop_kinds: &["for_statement", "while_statement", "enhanced_for_statement"],
assignment_kinds: &[
"assignment_expression",
"local_variable_declaration",
"expression_statement",
],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["formal_parameter"],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &[],
call_kinds: &["method_invocation"],
}
}
fn typescript() -> Self {
Self {
function_kinds: &[
"function_declaration",
"method_definition",
"arrow_function",
],
class_kinds: &["class_declaration"],
if_kinds: &["if_statement"],
assert_kinds: &[], throw_kinds: &["throw_statement"],
return_kinds: &["return_statement"],
loop_kinds: &[
"for_statement",
"while_statement",
"for_in_statement",
"for_of_statement",
],
assignment_kinds: &[
"assignment_expression",
"variable_declaration",
"lexical_declaration",
"expression_statement",
],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "return_type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["required_parameter", "optional_parameter"],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &[],
call_kinds: &["call_expression"],
}
}
fn c() -> Self {
Self {
function_kinds: &["function_definition"],
class_kinds: &[],
if_kinds: &["if_statement"],
assert_kinds: &[], throw_kinds: &["return_statement"],
return_kinds: &["return_statement"],
loop_kinds: &["for_statement", "while_statement", "do_statement"],
assignment_kinds: &[
"assignment_expression",
"declaration",
"expression_statement",
],
func_name_field: "declarator",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "declarator",
return_type_field: "type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter_declaration"],
assert_is_macro: false,
assert_call_names: &["assert"],
error_call_names: &["abort", "exit"],
call_kinds: &["call_expression"],
}
}
fn cpp() -> Self {
Self {
function_kinds: &["function_definition"],
class_kinds: &["class_specifier", "struct_specifier"],
if_kinds: &["if_statement"],
assert_kinds: &[],
throw_kinds: &["throw_statement", "return_statement"],
return_kinds: &["return_statement"],
loop_kinds: &[
"for_statement",
"while_statement",
"do_statement",
"for_range_loop",
],
assignment_kinds: &[
"assignment_expression",
"declaration",
"expression_statement",
],
func_name_field: "declarator",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "declarator",
return_type_field: "type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter_declaration"],
assert_is_macro: false,
assert_call_names: &["assert"],
error_call_names: &["abort", "exit"],
call_kinds: &["call_expression"],
}
}
fn ruby() -> Self {
Self {
function_kinds: &["method"],
class_kinds: &["class", "module"],
if_kinds: &["if", "unless"],
assert_kinds: &[],
throw_kinds: &["raise", "call"], return_kinds: &["return"],
loop_kinds: &["while", "until", "for"],
assignment_kinds: &["assignment"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &[],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &["raise"],
call_kinds: &["call"],
}
}
fn csharp() -> Self {
Self {
function_kinds: &["method_declaration", "constructor_declaration"],
class_kinds: &[
"class_declaration",
"interface_declaration",
"struct_declaration",
],
if_kinds: &["if_statement"],
assert_kinds: &[],
throw_kinds: &["throw_statement"],
return_kinds: &["return_statement"],
loop_kinds: &["for_statement", "while_statement", "foreach_statement"],
assignment_kinds: &[
"assignment_expression",
"local_declaration_statement",
"expression_statement",
],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter"],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &[],
call_kinds: &["invocation_expression"],
}
}
fn scala() -> Self {
Self {
function_kinds: &["function_definition"],
class_kinds: &["class_definition", "object_definition", "trait_definition"],
if_kinds: &["if_expression"],
assert_kinds: &[], throw_kinds: &["throw_expression"],
return_kinds: &["return_expression"],
loop_kinds: &["while_expression", "for_expression"],
assignment_kinds: &["assignment_expression", "val_definition", "var_definition"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "return_type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter"],
assert_is_macro: false,
assert_call_names: &["assert", "require"],
error_call_names: &[],
call_kinds: &["call_expression"],
}
}
fn php() -> Self {
Self {
function_kinds: &["function_definition", "method_declaration"],
class_kinds: &["class_declaration"],
if_kinds: &["if_statement"],
assert_kinds: &[],
throw_kinds: &["throw_expression"],
return_kinds: &["return_statement"],
loop_kinds: &["for_statement", "while_statement", "foreach_statement"],
assignment_kinds: &["assignment_expression", "expression_statement"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "return_type",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["simple_parameter"],
assert_is_macro: false,
assert_call_names: &["assert"],
error_call_names: &[],
call_kinds: &["function_call_expression"],
}
}
fn lua() -> Self {
Self {
function_kinds: &["function_declaration", "function_definition_statement"],
class_kinds: &[],
if_kinds: &["if_statement"],
assert_kinds: &[], throw_kinds: &["return_statement"], return_kinds: &["return_statement"],
loop_kinds: &[
"for_statement",
"while_statement",
"repeat_statement",
"for_numeric_statement",
"for_generic_statement",
],
assignment_kinds: &["assignment_statement", "variable_declaration"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "not",
has_isinstance: false,
typed_param_kinds: &[],
assert_is_macro: false,
assert_call_names: &["assert"],
error_call_names: &["error"],
call_kinds: &["function_call"],
}
}
fn elixir() -> Self {
Self {
function_kinds: &["call"], class_kinds: &["call"], if_kinds: &["call"], assert_kinds: &[],
throw_kinds: &["call"], return_kinds: &[], loop_kinds: &["call"], assignment_kinds: &["binary_operator"], func_name_field: "target",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "arguments",
return_type_field: "",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &[],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &["raise"],
call_kinds: &["call"],
}
}
fn ocaml() -> Self {
Self {
function_kinds: &["let_binding", "value_definition"],
class_kinds: &["module_definition"],
if_kinds: &["if_expression"],
assert_kinds: &["assert_expression"],
throw_kinds: &["raise_expression"],
return_kinds: &[],
loop_kinds: &["while_expression", "for_expression"],
assignment_kinds: &["let_binding"],
func_name_field: "pattern",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameter",
return_type_field: "",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "not",
has_isinstance: false,
typed_param_kinds: &[],
assert_is_macro: false,
assert_call_names: &[],
error_call_names: &[],
call_kinds: &["application"],
}
}
fn kotlin() -> Self {
Self {
function_kinds: &["function_declaration"],
class_kinds: &["class_declaration", "object_declaration"],
if_kinds: &["if_expression"],
assert_kinds: &[], throw_kinds: &["throw_expression"],
return_kinds: &["return_expression"],
loop_kinds: &["for_statement", "while_statement"],
assignment_kinds: &["assignment", "property_declaration"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "function_value_parameters",
return_type_field: "type",
class_body_field: "class_body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter"],
assert_is_macro: false,
assert_call_names: &[
"assert",
"require",
"check",
"requireNotNull",
"checkNotNull",
],
error_call_names: &[],
call_kinds: &["call_expression"],
}
}
fn swift() -> Self {
Self {
function_kinds: &["function_declaration"],
class_kinds: &[
"class_declaration",
"struct_declaration",
"protocol_declaration",
],
if_kinds: &["if_statement", "guard_statement"],
assert_kinds: &[], throw_kinds: &["control_transfer_statement"], return_kinds: &["control_transfer_statement"],
loop_kinds: &["for_statement", "while_statement", "repeat_while_statement"],
assignment_kinds: &["assignment", "property_declaration"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "body", return_type_field: "return_type",
class_body_field: "class_body",
loop_body_field: "body",
negation_prefix: "!",
has_isinstance: false,
typed_param_kinds: &["parameter"],
assert_is_macro: false,
assert_call_names: &["precondition", "assert", "assertionFailure"],
error_call_names: &["fatalError", "preconditionFailure"],
call_kinds: &["call_expression"],
}
}
fn luau() -> Self {
Self {
function_kinds: &["function_declaration", "function_definition"],
class_kinds: &[],
if_kinds: &["if_statement"],
assert_kinds: &[], throw_kinds: &["return_statement"], return_kinds: &["return_statement"],
loop_kinds: &["for_statement", "while_statement"],
assignment_kinds: &["assignment_statement", "variable_declaration"],
func_name_field: "name",
func_body_field: "body",
if_condition_field: "condition",
if_consequence_field: "consequence",
if_alternative_field: "alternative",
func_params_field: "parameters",
return_type_field: "",
class_body_field: "body",
loop_body_field: "body",
negation_prefix: "not",
has_isinstance: false,
typed_param_kinds: &["parameter"],
assert_is_macro: false,
assert_call_names: &["assert"],
error_call_names: &["error"],
call_kinds: &["function_call"],
}
}
fn is_function(&self, kind: &str) -> bool {
self.function_kinds.contains(&kind)
}
fn is_class(&self, kind: &str) -> bool {
self.class_kinds.contains(&kind)
}
fn is_if(&self, kind: &str) -> bool {
self.if_kinds.contains(&kind)
}
fn is_assert(&self, kind: &str) -> bool {
self.assert_kinds.contains(&kind)
}
fn is_throw(&self, kind: &str) -> bool {
self.throw_kinds.contains(&kind)
}
fn is_loop(&self, kind: &str) -> bool {
self.loop_kinds.contains(&kind)
}
fn is_assignment(&self, kind: &str) -> bool {
self.assignment_kinds.contains(&kind)
}
fn is_call(&self, kind: &str) -> bool {
self.call_kinds.contains(&kind)
}
fn is_assert_call_name(&self, name: &str) -> bool {
self.assert_call_names.contains(&name)
}
fn is_error_call_name(&self, name: &str) -> bool {
self.error_call_names.contains(&name)
}
fn has_assert_calls(&self) -> bool {
!self.assert_call_names.is_empty()
}
fn has_error_calls(&self) -> bool {
!self.error_call_names.is_empty()
}
}
#[derive(Debug, Args)]
pub struct ContractsArgs {
pub file: PathBuf,
pub function: String,
#[arg(
long = "output-format",
short = 'o',
hide = true,
default_value = "json"
)]
pub output_format: ContractsOutputFormat,
#[arg(long, short = 'l')]
pub lang: Option<Language>,
#[arg(long, default_value = "100")]
pub limit: usize,
}
impl ContractsArgs {
pub fn run(&self, format: OutputFormat, quiet: bool) -> Result<()> {
let writer = OutputWriter::new(format, quiet);
let canonical_path = validate_file_path(&self.file)?;
validate_function_name(&self.function)?;
writer.progress(&format!(
"Analyzing contracts for {}::{}...",
self.file.display(),
self.function
));
let language = match self.lang {
Some(l) => l,
None => Language::from_path(&self.file).ok_or_else(|| ContractsError::ParseError {
file: self.file.clone(),
message: format!(
"Cannot determine language for '{}'. Use --lang to specify.",
self.file.display()
),
})?,
};
if ParserPool::get_ts_language(language).is_none() {
return Err(ContractsError::ParseError {
file: self.file.clone(),
message: format!("No tree-sitter grammar available for {:?}", language),
}
.into());
}
let report = run_contracts(&canonical_path, &self.function, language, self.limit)?;
let use_text = matches!(self.output_format, ContractsOutputFormat::Text)
|| matches!(format, OutputFormat::Text);
if use_text {
let text = format_contracts_text(&report);
writer.write_text(&text)?;
} else {
writer.write(&report)?;
}
Ok(())
}
}
pub fn run_contracts(
file: &Path,
function: &str,
language: Language,
limit: usize,
) -> ContractsResult<ContractsReport> {
let source = read_file_safe(file)?;
let config = LanguageConfig::for_language(language);
let _ = config.return_kinds;
let tree = parse_source(&source, language, file)?;
let root = tree.root_node();
let func_node =
find_function_node(root, function, source.as_bytes(), &config).ok_or_else(|| {
ContractsError::FunctionNotFound {
function: function.to_string(),
file: file.to_path_buf(),
}
})?;
let mut preconditions = Vec::new();
let mut postconditions = Vec::new();
let mut invariants = Vec::new();
let mut postcondition_lines = HashSet::new();
extract_postconditions(
func_node,
source.as_bytes(),
&mut postconditions,
0,
&config,
)?;
for cond in &postconditions {
postcondition_lines.insert(cond.source_line);
}
extract_preconditions(
func_node,
source.as_bytes(),
&mut preconditions,
&postcondition_lines,
0,
&config,
)?;
extract_type_annotation_preconditions(
func_node,
source.as_bytes(),
&mut preconditions,
&config,
)?;
extract_untyped_param_preconditions(func_node, source.as_bytes(), &mut preconditions, &config)?;
extract_return_type_postconditions(func_node, source.as_bytes(), &mut postconditions, &config)?;
extract_docstring_contracts(
func_node,
source.as_bytes(),
&mut preconditions,
&mut postconditions,
&config,
language,
)?;
extract_invariants(func_node, source.as_bytes(), &mut invariants, 0, &config)?;
preconditions = deduplicate_conditions(preconditions);
postconditions = deduplicate_conditions(postconditions);
invariants = deduplicate_conditions(invariants);
preconditions.truncate(limit.min(MAX_CONDITIONS_PER_FUNCTION));
postconditions.truncate(limit.min(MAX_CONDITIONS_PER_FUNCTION));
invariants.truncate(limit.min(MAX_CONDITIONS_PER_FUNCTION));
Ok(ContractsReport {
function: function.to_string(),
file: file.to_path_buf(),
preconditions,
postconditions,
invariants,
})
}
fn parse_source(source: &str, language: Language, file: &Path) -> ContractsResult<Tree> {
let ts_language =
ParserPool::get_ts_language(language).ok_or_else(|| ContractsError::ParseError {
file: file.to_path_buf(),
message: format!("No tree-sitter grammar for {:?}", language),
})?;
let mut parser = Parser::new();
parser
.set_language(&ts_language)
.map_err(|e| ContractsError::ParseError {
file: file.to_path_buf(),
message: format!("Failed to set {:?} language: {}", language, e),
})?;
parser
.parse(source, None)
.ok_or_else(|| ContractsError::ParseError {
file: file.to_path_buf(),
message: "Parsing returned None".to_string(),
})
}
fn find_function_node<'a>(
root: Node<'a>,
function_name: &str,
source: &[u8],
config: &LanguageConfig,
) -> Option<Node<'a>> {
find_function_recursive(root, function_name, source, config, 0)
}
fn find_function_recursive<'a>(
node: Node<'a>,
function_name: &str,
source: &[u8],
config: &LanguageConfig,
depth: usize,
) -> Option<Node<'a>> {
if depth > MAX_AST_DEPTH {
return None;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if config.is_function(child.kind()) {
if child.kind() == "call" {
if let Some(target) = child.child_by_field_name("target") {
let target_text = get_node_text(target, source);
if target_text == "def" || target_text == "defp" {
if let Some(args) = child.child_by_field_name("arguments") {
if let Some(first_arg) = args.child(0) {
let fname = if first_arg.kind() == "call" {
first_arg
.child_by_field_name("target")
.map(|n| get_node_text(n, source))
.unwrap_or_default()
} else {
get_node_text(first_arg, source)
};
if fname == function_name {
return Some(child);
}
}
}
}
}
}
if child.kind() == "value_definition" {
let mut inner = child.walk();
for sub in child.children(&mut inner) {
if sub.kind() == "let_binding" {
if let Some(pattern_node) = sub.child_by_field_name("pattern") {
let name = get_node_text(pattern_node, source);
if name == function_name {
return Some(child);
}
}
}
}
}
if let Some(name_node) = child.child_by_field_name(config.func_name_field) {
let name = get_node_text(name_node, source);
if name == function_name {
return Some(child);
}
if name.contains(function_name) {
if let Some(found) = find_identifier_match(name_node, function_name, source) {
let _ = found; return Some(child);
}
}
}
if find_name_in_children(child, function_name, source) {
return Some(child);
}
}
if matches!(child.kind(), "lexical_declaration" | "variable_declaration") {
let mut decl_cursor = child.walk();
for decl_child in child.children(&mut decl_cursor) {
if decl_child.kind() == "variable_declarator" {
if let Some(name_node) = decl_child.child_by_field_name("name") {
let var_name = get_node_text(name_node, source);
if var_name == function_name {
if let Some(value_node) = decl_child.child_by_field_name("value") {
if matches!(
value_node.kind(),
"arrow_function"
| "function"
| "function_expression"
| "generator_function"
) {
return Some(value_node);
}
}
}
}
}
}
}
if config.is_class(child.kind()) {
if let Some(body) = child.child_by_field_name(config.class_body_field) {
if let Some(found) =
find_function_recursive(body, function_name, source, config, depth + 1)
{
return Some(found);
}
}
if let Some(found) =
find_function_recursive(child, function_name, source, config, depth + 1)
{
return Some(found);
}
}
if child.kind() == "block"
|| child.kind() == "declaration_list"
|| child.kind() == "module"
|| child.kind() == "source_file"
|| child.kind() == "program"
|| child.kind() == "compound_statement"
|| child.kind() == "export_statement"
|| child.kind() == "export_default_declaration"
|| child.kind() == "decorated_definition"
|| child.kind() == "namespace_declaration"
|| child.kind() == "module_declaration"
|| child.kind() == "class_body"
|| child.kind() == "enum_body"
|| child.kind() == "translation_unit"
|| child.kind() == "do_block" || child.kind() == "stab_clause" || child.kind() == "body" || child.kind() == "arguments" || child.kind() == "structure" || child.kind() == "structure_item" || child.kind() == "module_definition" || child.kind() == "module_binding" || child.kind() == "functor"
{
if let Some(found) =
find_function_recursive(child, function_name, source, config, depth + 1)
{
return Some(found);
}
}
}
None
}
fn find_identifier_match<'a>(
node: Node<'a>,
function_name: &str,
source: &[u8],
) -> Option<Node<'a>> {
if node.kind() == "identifier" || node.kind() == "name" {
let text = get_node_text(node, source);
if text == function_name {
return Some(node);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = find_identifier_match(child, function_name, source) {
return Some(found);
}
}
None
}
fn find_name_in_children(func_node: Node, function_name: &str, source: &[u8]) -> bool {
let mut cursor = func_node.walk();
for child in func_node.children(&mut cursor) {
if (child.kind() == "identifier" || child.kind() == "name")
&& get_node_text(child, source) == function_name
{
return true;
}
}
false
}
fn find_first_identifier(node: Node) -> Option<Node> {
let mut cursor = node.walk();
let found = node
.children(&mut cursor)
.find(|&child| child.kind() == "identifier" || child.kind() == "name");
found
}
fn get_node_text<'a>(node: Node<'a>, source: &'a [u8]) -> &'a str {
let start = node.start_byte();
let end = node.end_byte();
if end <= source.len() {
std::str::from_utf8(&source[start..end]).unwrap_or("")
} else {
""
}
}
fn get_function_body<'a>(func: Node<'a>, config: &LanguageConfig) -> Option<Node<'a>> {
if let Some(body) = func.child_by_field_name(config.func_body_field) {
if let Some(block) = body.child_by_field_name("body") {
return Some(block);
}
return Some(body);
}
let mut cursor = func.walk();
for child in func.children(&mut cursor) {
let kind = child.kind();
if kind == "function_body" || kind == "block" || kind == "compound_statement" {
if kind == "function_body" {
let mut inner = child.walk();
for inner_child in child.children(&mut inner) {
if inner_child.kind() == "block" {
return Some(inner_child);
}
}
return Some(child);
}
return Some(child);
}
}
None
}
fn extract_preconditions(
func: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
skip_lines: &HashSet<u32>,
depth: usize,
config: &LanguageConfig,
) -> ContractsResult<()> {
check_ast_depth(depth, &PathBuf::from("<source>"))?;
let body = match get_function_body(func, config) {
Some(b) => b,
None => return Ok(()),
};
let mut cursor = body.walk();
for stmt in body.children(&mut cursor) {
let line = stmt.start_position().row as u32 + 1;
if skip_lines.contains(&line) {
continue;
}
let kind = stmt.kind();
if config.is_if(kind) {
if body_contains_throw(stmt, source, config) {
if let Some(cond) = precondition_from_guard(stmt, source, config) {
conditions.push(cond);
}
}
} else if config.is_assert(kind) {
if let Some(cond) = precondition_from_assert(stmt, source, config) {
conditions.push(cond);
}
} else if config.assert_is_macro && kind == "expression_statement" {
let mut inner = stmt.walk();
for child in stmt.children(&mut inner) {
if config.is_assert(child.kind()) {
if let Some(cond) = precondition_from_assert(child, source, config) {
conditions.push(cond);
}
} else if config.is_if(child.kind()) {
if body_contains_throw(child, source, config) {
if let Some(cond) = precondition_from_guard(child, source, config) {
conditions.push(cond);
}
}
}
}
} else if config.has_assert_calls() && config.is_call(kind) {
if let Some(cond) = precondition_from_assert_call(stmt, source, config) {
conditions.push(cond);
}
} else if config.has_assert_calls() && kind == "expression_statement" {
let mut inner = stmt.walk();
for child in stmt.children(&mut inner) {
if config.is_call(child.kind()) {
if let Some(cond) = precondition_from_assert_call(child, source, config) {
conditions.push(cond);
}
}
}
}
}
Ok(())
}
fn precondition_from_assert_call(
call_node: Node,
source: &[u8],
config: &LanguageConfig,
) -> Option<Condition> {
let call_name = extract_call_name(call_node, source)?;
if !config.is_assert_call_name(&call_name) {
return None;
}
let line = call_node.start_position().row as u32 + 1;
let first_arg = extract_first_call_argument(call_node, source)?;
let arg_text = first_arg.trim().to_string();
if arg_text.is_empty() {
return None;
}
Some(Condition::high(arg_text.clone(), arg_text, line))
}
fn extract_call_name(call_node: Node, source: &[u8]) -> Option<String> {
if let Some(func) = call_node.child_by_field_name("function") {
return Some(get_node_text(func, source).to_string());
}
if let Some(name) = call_node.child_by_field_name("name") {
return Some(get_node_text(name, source).to_string());
}
let mut cursor = call_node.walk();
for child in call_node.children(&mut cursor) {
let kind = child.kind();
if kind == "identifier" || kind == "simple_identifier" || kind == "name" {
return Some(get_node_text(child, source).to_string());
}
if kind == "navigation_expression" {
return Some(get_node_text(child, source).to_string());
}
}
None
}
fn extract_first_call_argument(call_node: Node, source: &[u8]) -> Option<String> {
if let Some(args) = call_node.child_by_field_name("arguments") {
let mut cursor = args.walk();
for arg in args.children(&mut cursor) {
let kind = arg.kind();
if kind != "(" && kind != ")" && kind != "," && kind != "{" && kind != "}" {
return Some(get_node_text(arg, source).to_string());
}
}
}
let mut cursor = call_node.walk();
for child in call_node.children(&mut cursor) {
let kind = child.kind();
if kind == "value_arguments" || kind == "call_suffix" || kind == "argument_list" {
let mut inner = child.walk();
for arg in child.children(&mut inner) {
let ak = arg.kind();
if ak != "("
&& ak != ")"
&& ak != ","
&& ak != "{"
&& ak != "}"
&& ak != "value_argument"
&& ak != "annotated_lambda"
{
return Some(get_node_text(arg, source).to_string());
}
if ak == "value_argument" {
let mut va_cursor = arg.walk();
for va_child in arg.children(&mut va_cursor) {
let vak = va_child.kind();
if vak != "value_argument_label" && vak != ":" {
return Some(get_node_text(va_child, source).to_string());
}
}
}
}
}
}
let text = get_node_text(call_node, source);
if let Some(start) = text.find('(') {
let rest = &text[start + 1..];
let end = rest
.find(',')
.or_else(|| rest.find(')'))
.unwrap_or(rest.len());
let arg = rest[..end].trim();
if !arg.is_empty() {
return Some(arg.to_string());
}
}
None
}
fn body_contains_throw(if_stmt: Node, source: &[u8], config: &LanguageConfig) -> bool {
if if_stmt.kind() == "guard_statement" {
let mut cursor = if_stmt.walk();
for child in if_stmt.children(&mut cursor) {
if (child.kind() == "statements" || child.kind() == "else")
&& node_tree_contains_throw(child, source, config)
{
return true;
}
}
return node_tree_contains_throw(if_stmt, source, config);
}
if let Some(consequence) = if_stmt.child_by_field_name(config.if_consequence_field) {
if node_tree_contains_throw(consequence, source, config) {
return true;
}
}
let mut cursor = if_stmt.walk();
for child in if_stmt.children(&mut cursor) {
if config.is_throw(child.kind()) {
return true;
}
if child.kind() == "block"
|| child.kind() == "compound_statement"
|| child.kind() == "function_body"
|| child.kind() == "statements"
&& node_tree_contains_throw(child, source, config)
{
return true;
}
}
if config.assert_is_macro {
if let Some(consequence) = if_stmt.child_by_field_name(config.if_consequence_field) {
if block_contains_panic_macro(consequence, source) {
return true;
}
}
}
false
}
fn node_tree_contains_throw(node: Node, source: &[u8], config: &LanguageConfig) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if config.is_throw(child.kind()) {
return true;
}
if config.assert_is_macro
&& child.kind() == "macro_invocation"
&& is_panic_macro(child, source)
{
return true;
}
if config.has_error_calls() && config.is_call(child.kind()) {
if let Some(name) = extract_call_name(child, source) {
if config.is_error_call_name(&name) {
return true;
}
}
}
if child.kind() == "expression_statement" {
let mut inner = child.walk();
for grandchild in child.children(&mut inner) {
if config.is_throw(grandchild.kind()) {
return true;
}
if config.assert_is_macro
&& grandchild.kind() == "macro_invocation"
&& is_panic_macro(grandchild, source)
{
return true;
}
if config.has_error_calls() && config.is_call(grandchild.kind()) {
if let Some(name) = extract_call_name(grandchild, source) {
if config.is_error_call_name(&name) {
return true;
}
}
}
}
}
if child.kind() == "block"
|| child.kind() == "compound_statement"
|| child.kind() == "function_body"
|| child.kind() == "statements"
&& node_tree_contains_throw(child, source, config)
{
return true;
}
}
false
}
fn block_contains_panic_macro(block: Node, source: &[u8]) -> bool {
let mut cursor = block.walk();
for child in block.children(&mut cursor) {
if child.kind() == "macro_invocation" && is_panic_macro(child, source) {
return true;
}
if child.kind() == "expression_statement" {
let mut inner = child.walk();
for grandchild in child.children(&mut inner) {
if grandchild.kind() == "macro_invocation" && is_panic_macro(grandchild, source) {
return true;
}
}
}
}
false
}
fn is_panic_macro(node: Node, source: &[u8]) -> bool {
if node.kind() != "macro_invocation" {
return false;
}
if let Some(macro_node) = node.child_by_field_name("macro") {
let name = get_node_text(macro_node, source);
return matches!(name, "panic" | "todo" | "unreachable" | "unimplemented");
}
let text = get_node_text(node, source);
text.starts_with("panic!") || text.starts_with("todo!") || text.starts_with("unreachable!")
}
fn is_assert_macro(node: Node, source: &[u8]) -> bool {
if node.kind() != "macro_invocation" {
return false;
}
if let Some(macro_node) = node.child_by_field_name("macro") {
let name = get_node_text(macro_node, source);
return name.starts_with("assert") || name.starts_with("debug_assert");
}
let text = get_node_text(node, source);
text.starts_with("assert") || text.starts_with("debug_assert")
}
fn precondition_from_guard(
if_stmt: Node,
source: &[u8],
config: &LanguageConfig,
) -> Option<Condition> {
if if_stmt.kind() == "guard_statement" {
let line = if_stmt.start_position().row as u32 + 1;
if let Some(condition_node) = if_stmt.child_by_field_name(config.if_condition_field) {
let condition_text = get_node_text(condition_node, source);
return Some(Condition::high(
condition_text.to_string(),
condition_text.to_string(),
line,
));
}
let full_text = get_node_text(if_stmt, source);
if let Some(else_pos) = full_text.find("else") {
let guard_cond = full_text[5..else_pos].trim(); if !guard_cond.is_empty() {
return Some(Condition::high(
guard_cond.to_string(),
guard_cond.to_string(),
line,
));
}
}
return None;
}
let condition_node = if_stmt.child_by_field_name(config.if_condition_field)?;
let line = if_stmt.start_position().row as u32 + 1;
let condition_text = get_node_text(condition_node, source);
match condition_node.kind() {
"comparison_operator" | "binary_expression" => {
if let Some(negated) = negate_comparison(condition_node, source) {
let var = extract_left_operand(condition_node, source);
return Some(Condition::high(var, negated, line));
}
let negated = format!("{} ({})", config.negation_prefix, condition_text);
return Some(Condition::medium(condition_text.to_string(), negated, line));
}
"not_operator" => {
if let Some(operand) = condition_node.child_by_field_name("argument") {
let operand_text = get_node_text(operand, source);
if operand.kind() == "call"
&& config.has_isinstance
&& is_isinstance_call(operand, source)
{
let var = extract_isinstance_var(operand, source);
return Some(Condition::high(var, operand_text.to_string(), line));
} else {
return Some(Condition::medium(
operand_text.to_string(),
operand_text.to_string(),
line,
));
}
}
}
"unary_expression" | "prefix_expression" => {
let mut cursor = condition_node.walk();
for child in condition_node.children(&mut cursor) {
if child.kind() != "!" && child.kind() != "not" {
let operand_text = get_node_text(child, source);
if !operand_text.is_empty() {
return Some(Condition::medium(
operand_text.to_string(),
operand_text.to_string(),
line,
));
}
}
}
if let Some(operand) = condition_node.child_by_field_name("operand") {
let operand_text = get_node_text(operand, source);
return Some(Condition::medium(
operand_text.to_string(),
operand_text.to_string(),
line,
));
}
}
"call" | "call_expression" => {
if config.has_isinstance && is_isinstance_call(condition_node, source) {
let negated = format!("{} ({})", config.negation_prefix, condition_text);
return Some(Condition::medium(condition_text.to_string(), negated, line));
}
let negated = format!("{} ({})", config.negation_prefix, condition_text);
return Some(Condition::medium(condition_text.to_string(), negated, line));
}
"parenthesized_expression" => {
let mut cursor = condition_node.walk();
for child in condition_node.children(&mut cursor) {
if child.kind() != "(" && child.kind() != ")" {
let inner_text = get_node_text(child, source);
if child.kind() == "comparison_operator" || child.kind() == "binary_expression"
{
if let Some(negated) = negate_comparison(child, source) {
let var = extract_left_operand(child, source);
return Some(Condition::high(var, negated, line));
}
}
let negated = format!("{} ({})", config.negation_prefix, inner_text);
return Some(Condition::medium(inner_text.to_string(), negated, line));
}
}
}
_ => {
let negated = format!("{} ({})", config.negation_prefix, condition_text);
return Some(Condition::medium(condition_text.to_string(), negated, line));
}
}
None
}
fn precondition_from_assert(
assert_stmt: Node,
source: &[u8],
config: &LanguageConfig,
) -> Option<Condition> {
let line = assert_stmt.start_position().row as u32 + 1;
if config.assert_is_macro {
if !is_assert_macro(assert_stmt, source) {
return None;
}
let condition_text = extract_macro_args(assert_stmt, source)?;
return Some(Condition::high(
condition_text.clone(),
condition_text,
line,
));
}
let condition_node = extract_assert_condition(assert_stmt, source)?;
let condition_text = get_node_text(condition_node, source);
match condition_node.kind() {
"call" if config.has_isinstance && is_isinstance_call(condition_node, source) => {
let var = extract_isinstance_var(condition_node, source);
Some(Condition::high(var, condition_text.to_string(), line))
}
"comparison_operator" | "binary_expression" => {
let var = extract_left_operand(condition_node, source);
Some(Condition::high(var, condition_text.to_string(), line))
}
_ => {
Some(Condition::medium(
condition_text.to_string(),
condition_text.to_string(),
line,
))
}
}
}
fn extract_assert_condition<'a>(assert_stmt: Node<'a>, _source: &[u8]) -> Option<Node<'a>> {
let mut cursor = assert_stmt.walk();
for child in assert_stmt.children(&mut cursor) {
let kind = child.kind();
if kind != "assert" && kind != "assert_keyword" && !kind.starts_with("assert") {
return Some(child);
}
}
None
}
fn extract_macro_args(macro_node: Node, source: &[u8]) -> Option<String> {
let mut cursor = macro_node.walk();
for child in macro_node.children(&mut cursor) {
if child.kind() == "token_tree" {
let text = get_node_text(child, source);
let trimmed = text.trim();
if trimmed.starts_with('(') && trimmed.ends_with(')') {
return Some(trimmed[1..trimmed.len() - 1].trim().to_string());
}
return Some(trimmed.to_string());
}
}
let text = get_node_text(macro_node, source);
if let Some(start) = text.find('(') {
if let Some(end) = text.rfind(')') {
return Some(text[start + 1..end].trim().to_string());
}
}
None
}
fn extract_postconditions(
func: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
depth: usize,
config: &LanguageConfig,
) -> ContractsResult<()> {
check_ast_depth(depth, &PathBuf::from("<source>"))?;
let body = match get_function_body(func, config) {
Some(b) => b,
None => return Ok(()),
};
let mut result_assigned = false;
let mut cursor = body.walk();
for stmt in body.children(&mut cursor) {
if config.is_assignment(stmt.kind()) && has_result_assignment(stmt, source, config) {
result_assigned = true;
}
if result_assigned {
if config.is_assert(stmt.kind()) {
if let Some(cond) = postcondition_from_assert(stmt, source, config) {
conditions.push(cond);
}
} else if config.assert_is_macro && stmt.kind() == "expression_statement" {
let mut inner = stmt.walk();
for child in stmt.children(&mut inner) {
if config.is_assert(child.kind()) && is_assert_macro(child, source) {
if let Some(cond) = postcondition_from_assert(child, source, config) {
conditions.push(cond);
}
}
}
}
}
}
Ok(())
}
fn has_result_assignment(stmt: Node, source: &[u8], config: &LanguageConfig) -> bool {
if let Some(left) = stmt.child_by_field_name("left") {
let text = get_node_text(left, source);
if text == "result" {
return true;
}
}
if let Some(pattern) = stmt.child_by_field_name("pattern") {
let text = get_node_text(pattern, source);
if text == "result" {
return true;
}
}
if let Some(name) = stmt.child_by_field_name("name") {
let text = get_node_text(name, source);
if text == "result" {
return true;
}
}
let mut cursor = stmt.walk();
for child in stmt.children(&mut cursor) {
if config.is_assignment(child.kind()) {
if let Some(left) = child.child_by_field_name("left") {
let text = get_node_text(left, source);
if text == "result" {
return true;
}
}
}
if child.kind() == "identifier" || child.kind() == "name" {
let text = get_node_text(child, source);
if text == "result" {
if let Some(next) = child.next_sibling() {
let next_text = get_node_text(next, source);
if next_text == "=" || next_text == ":=" {
return true;
}
}
}
}
}
false
}
fn postcondition_from_assert(
assert_stmt: Node,
source: &[u8],
config: &LanguageConfig,
) -> Option<Condition> {
let line = assert_stmt.start_position().row as u32 + 1;
if config.assert_is_macro {
if !is_assert_macro(assert_stmt, source) {
return None;
}
let condition_text = extract_macro_args(assert_stmt, source)?;
let var = if condition_text.contains("result") {
"result".to_string()
} else {
condition_text.clone()
};
return Some(Condition::high(var, condition_text, line));
}
let condition_node = extract_assert_condition(assert_stmt, source)?;
let condition_text = get_node_text(condition_node, source);
let var = find_result_var(condition_node, source);
Some(Condition::high(var, condition_text.to_string(), line))
}
fn find_result_var(node: Node, source: &[u8]) -> String {
if node.kind() == "identifier" {
let text = get_node_text(node, source);
if text.contains("result") {
return text.to_string();
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let found = find_result_var(child, source);
if found.contains("result") {
return found;
}
}
get_node_text(node, source).to_string()
}
fn extract_type_annotation_preconditions(
func: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
config: &LanguageConfig,
) -> ContractsResult<()> {
let params = match func.child_by_field_name(config.func_params_field) {
Some(p) => p,
None => return Ok(()),
};
let line = func.start_position().row as u32 + 1;
extract_typed_params_recursive(params, source, conditions, config, line);
Ok(())
}
fn extract_typed_params_recursive(
node: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
config: &LanguageConfig,
line: u32,
) {
let mut cursor = node.walk();
for param in node.children(&mut cursor) {
let kind = param.kind();
if config.typed_param_kinds.contains(&kind) {
let name_node = param
.child_by_field_name("name")
.or_else(|| param.child_by_field_name("pattern"))
.or_else(|| {
find_first_identifier(param)
});
let type_node = param
.child_by_field_name("type")
.or_else(|| param.child_by_field_name("type_annotation"));
if let (Some(name_node), Some(type_node)) = (name_node, type_node) {
let name = get_node_text(name_node, source);
let type_str = get_node_text(type_node, source);
if name == "self" || name == "cls" || name == "this" {
continue;
}
let constraint = if config.has_isinstance {
format!("isinstance({}, {})", name, type_str)
} else {
format!("{}: {}", name, type_str)
};
conditions.push(Condition::low(name.to_string(), constraint, line));
}
}
if kind == "function_declarator" || kind == "parameter_list" {
extract_typed_params_recursive(param, source, conditions, config, line);
}
}
}
fn extract_return_type_postconditions(
func: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
config: &LanguageConfig,
) -> ContractsResult<()> {
if config.return_type_field.is_empty() {
return Ok(());
}
let return_type = match func.child_by_field_name(config.return_type_field) {
Some(rt) => rt,
None => return Ok(()),
};
let line = func.start_position().row as u32 + 1;
let type_str = get_node_text(return_type, source);
if type_str == "None" || type_str == "void" || type_str == "()" || type_str.is_empty() {
return Ok(());
}
let constraint = if config.has_isinstance {
format!("isinstance(return, {})", type_str)
} else {
format!("return: {}", type_str)
};
conditions.push(Condition::low("return".to_string(), constraint, line));
Ok(())
}
fn extract_untyped_param_preconditions(
func: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
config: &LanguageConfig,
) -> ContractsResult<()> {
let params = match func.child_by_field_name(config.func_params_field) {
Some(p) => p,
None => return Ok(()),
};
let line = func.start_position().row as u32 + 1;
let existing_vars: HashSet<String> = conditions.iter().map(|c| c.variable.clone()).collect();
extract_untyped_params_recursive(params, source, conditions, config, line, &existing_vars);
Ok(())
}
fn extract_untyped_params_recursive(
node: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
config: &LanguageConfig,
line: u32,
existing_vars: &HashSet<String>,
) {
let mut cursor = node.walk();
for param in node.children(&mut cursor) {
let kind = param.kind();
if config.typed_param_kinds.contains(&kind) {
continue;
}
if kind == "(" || kind == ")" || kind == "," || kind == ":" {
continue;
}
match kind {
"identifier" | "name" => {
let name = get_node_text(param, source);
if name == "self" || name == "cls" || name == "this" {
continue;
}
if existing_vars.contains(name) {
continue;
}
let constraint = format!("parameter {} is required", name);
conditions.push(Condition::low(name.to_string(), constraint, line));
}
"default_parameter" => {
if let Some(name_node) = param.child_by_field_name("name") {
let name = get_node_text(name_node, source);
if name == "self" || name == "cls" || name == "this" {
continue;
}
if existing_vars.contains(name) {
continue;
}
let default_val = param
.child_by_field_name("value")
.map(|v| get_node_text(v, source))
.unwrap_or("?");
let constraint = format!("parameter {} (default: {})", name, default_val);
conditions.push(Condition::low(name.to_string(), constraint, line));
}
}
"list_splat_pattern" => {
let mut inner = param.walk();
for child in param.children(&mut inner) {
if child.kind() == "identifier" {
let name = get_node_text(child, source);
if !existing_vars.contains(name) {
let constraint = format!("parameter *{} (variadic positional)", name);
conditions.push(Condition::low(name.to_string(), constraint, line));
}
break;
}
}
}
"dictionary_splat_pattern" => {
let mut inner = param.walk();
for child in param.children(&mut inner) {
if child.kind() == "identifier" {
let name = get_node_text(child, source);
if !existing_vars.contains(name) {
let constraint = format!("parameter **{} (variadic keyword)", name);
conditions.push(Condition::low(name.to_string(), constraint, line));
}
break;
}
}
}
"rest_pattern" | "rest_element" => {
let mut inner = param.walk();
for child in param.children(&mut inner) {
if child.kind() == "identifier" {
let name = get_node_text(child, source);
if !existing_vars.contains(name) {
let constraint = format!("parameter ...{} (rest)", name);
conditions.push(Condition::low(name.to_string(), constraint, line));
}
break;
}
}
}
"formal_parameters" | "parameters" | "parameter_list" => {
extract_untyped_params_recursive(
param,
source,
conditions,
config,
line,
existing_vars,
);
}
_ => {}
}
}
}
fn extract_docstring_contracts(
func: Node,
source: &[u8],
preconditions: &mut Vec<Condition>,
postconditions: &mut Vec<Condition>,
config: &LanguageConfig,
language: Language,
) -> ContractsResult<()> {
let line = func.start_position().row as u32 + 1;
let existing_pre_vars: HashSet<String> =
preconditions.iter().map(|c| c.variable.clone()).collect();
let existing_post_vars: HashSet<String> =
postconditions.iter().map(|c| c.variable.clone()).collect();
match language {
Language::Python => {
let docstring = extract_python_docstring(func, source, config);
if let Some(doc_text) = docstring {
extract_sphinx_params(
&doc_text,
preconditions,
postconditions,
line,
&existing_pre_vars,
&existing_post_vars,
);
}
}
Language::JavaScript | Language::TypeScript => {
let jsdoc = extract_jsdoc_comment(func, source);
if let Some(doc_text) = jsdoc {
extract_jsdoc_params(
&doc_text,
preconditions,
postconditions,
line,
&existing_pre_vars,
&existing_post_vars,
);
}
}
_ => {}
}
Ok(())
}
fn extract_python_docstring<'a>(
func: Node<'a>,
source: &'a [u8],
config: &LanguageConfig,
) -> Option<String> {
let body = func.child_by_field_name(config.func_body_field)?;
let mut cursor = body.walk();
if let Some(stmt) = body.children(&mut cursor).next() {
if stmt.kind() == "expression_statement" {
let mut inner = stmt.walk();
for child in stmt.children(&mut inner) {
if child.kind() == "string" || child.kind() == "concatenated_string" {
let text = get_node_text(child, source);
let stripped = text
.trim_start_matches("\"\"\"")
.trim_start_matches("'''")
.trim_start_matches("r\"\"\"")
.trim_start_matches("r'''")
.trim_end_matches("\"\"\"")
.trim_end_matches("'''");
return Some(stripped.to_string());
}
}
}
}
None
}
fn extract_jsdoc_comment(func: Node, source: &[u8]) -> Option<String> {
let mut node = func;
if let Some(parent) = func.parent() {
if parent.kind() == "export_statement"
|| parent.kind() == "export_default_declaration"
|| parent.kind() == "decorated_definition"
{
node = parent;
}
}
let mut prev = node.prev_sibling();
if let Some(sibling) = prev {
if sibling.kind() == "comment" {
let text = get_node_text(sibling, source);
if text.starts_with("/**") {
let stripped = text
.trim_start_matches("/**")
.trim_end_matches("*/")
.lines()
.map(|l| l.trim().trim_start_matches('*').trim())
.collect::<Vec<_>>()
.join("\n");
return Some(stripped);
}
}
}
prev = node.prev_sibling();
while let Some(sibling) = prev {
let kind = sibling.kind();
if kind == "comment" {
let text = get_node_text(sibling, source);
if text.starts_with("/**") {
let stripped = text
.trim_start_matches("/**")
.trim_end_matches("*/")
.lines()
.map(|l| l.trim().trim_start_matches('*').trim())
.collect::<Vec<_>>()
.join("\n");
return Some(stripped);
}
}
prev = sibling.prev_sibling();
}
None
}
fn extract_sphinx_params(
docstring: &str,
preconditions: &mut Vec<Condition>,
postconditions: &mut Vec<Condition>,
line: u32,
existing_pre_vars: &HashSet<String>,
existing_post_vars: &HashSet<String>,
) {
let param_re = Regex::new(r":param\s+(\w+)\s*:(.*)").unwrap();
let type_re = Regex::new(r":type\s+(\w+)\s*:\s*(.+)").unwrap();
let return_re = Regex::new(r":returns?\s*:(.*)").unwrap();
let rtype_re = Regex::new(r":rtype\s*:\s*(.+)").unwrap();
let mut param_types: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
for cap in type_re.captures_iter(docstring) {
let name = cap[1].trim().to_string();
let type_str = cap[2].trim().to_string();
param_types.insert(name, type_str);
}
for cap in param_re.captures_iter(docstring) {
let name = cap[1].trim().to_string();
let desc = cap[2].trim().to_string();
if existing_pre_vars.contains(&name) {
continue;
}
let constraint = if let Some(type_str) = param_types.get(&name) {
format!("{}: {} ({})", name, type_str, desc)
} else {
format!("{}: {}", name, desc)
};
preconditions.push(Condition::low(name, constraint, line));
}
for cap in rtype_re.captures_iter(docstring) {
let type_str = cap[1].trim().to_string();
if !type_str.is_empty() && !existing_post_vars.contains("return") {
let constraint = format!("isinstance(return, {})", type_str);
postconditions.push(Condition::low("return".to_string(), constraint, line));
}
}
let has_rtype = postconditions.iter().any(|c| c.variable == "return");
if !has_rtype {
for cap in return_re.captures_iter(docstring) {
let desc = cap[1].trim().to_string();
if !desc.is_empty() && !existing_post_vars.contains("return") {
let constraint = format!("returns: {}", desc);
postconditions.push(Condition::low("return".to_string(), constraint, line));
break; }
}
}
}
fn extract_jsdoc_params(
jsdoc: &str,
preconditions: &mut Vec<Condition>,
postconditions: &mut Vec<Condition>,
line: u32,
existing_pre_vars: &HashSet<String>,
existing_post_vars: &HashSet<String>,
) {
let param_with_type_re = Regex::new(r"@param\s+\{([^}]+)\}\s+(\w+)\s*[-:]?\s*(.*)").unwrap();
let param_no_type_re = Regex::new(r"@param\s+(\w+)\s*[-:]?\s*(.*)").unwrap();
let returns_with_type_re = Regex::new(r"@returns?\s+\{([^}]+)\}\s*(.*)").unwrap();
let returns_no_type_re = Regex::new(r"@returns?\s+(.+)").unwrap();
let mut seen_params: HashSet<String> = HashSet::new();
for cap in param_with_type_re.captures_iter(jsdoc) {
let type_str = cap[1].trim().to_string();
let name = cap[2].trim().to_string();
let desc = cap[3].trim().to_string();
if existing_pre_vars.contains(&name) || seen_params.contains(&name) {
continue;
}
seen_params.insert(name.clone());
let constraint = if desc.is_empty() {
format!("{}: {}", name, type_str)
} else {
format!("{}: {} ({})", name, type_str, desc)
};
preconditions.push(Condition::low(name, constraint, line));
}
for cap in param_no_type_re.captures_iter(jsdoc) {
let name = cap[1].trim().to_string();
let desc = cap[2].trim().to_string();
if existing_pre_vars.contains(&name) || seen_params.contains(&name) {
continue;
}
seen_params.insert(name.clone());
let constraint = if desc.is_empty() {
format!("parameter {} is required", name)
} else {
format!("{}: {}", name, desc)
};
preconditions.push(Condition::low(name, constraint, line));
}
let mut has_return = existing_post_vars.contains("return");
for cap in returns_with_type_re.captures_iter(jsdoc) {
if has_return {
break;
}
let type_str = cap[1].trim().to_string();
let desc = cap[2].trim().to_string();
let constraint = if desc.is_empty() {
format!("return: {}", type_str)
} else {
format!("return: {} ({})", type_str, desc)
};
postconditions.push(Condition::low("return".to_string(), constraint, line));
has_return = true;
}
if !has_return {
for cap in returns_no_type_re.captures_iter(jsdoc) {
let desc = cap[1].trim().to_string();
if desc.starts_with('{') {
continue;
}
if !desc.is_empty() {
let constraint = format!("returns: {}", desc);
postconditions.push(Condition::low("return".to_string(), constraint, line));
break;
}
}
}
}
fn extract_invariants(
func: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
depth: usize,
config: &LanguageConfig,
) -> ContractsResult<()> {
check_ast_depth(depth, &PathBuf::from("<source>"))?;
let body = match get_function_body(func, config) {
Some(b) => b,
None => return Ok(()),
};
extract_invariants_from_block(body, source, conditions, depth + 1, config)?;
Ok(())
}
fn extract_invariants_from_block(
block: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
depth: usize,
config: &LanguageConfig,
) -> ContractsResult<()> {
if depth > MAX_AST_DEPTH {
return Ok(());
}
let mut cursor = block.walk();
for stmt in block.children(&mut cursor) {
let kind = stmt.kind();
if config.is_loop(kind) {
if let Some(loop_body) = stmt.child_by_field_name(config.loop_body_field) {
extract_asserts_as_invariants(loop_body, source, conditions, config)?;
}
} else if config.is_if(kind) {
if let Some(consequence) = stmt.child_by_field_name(config.if_consequence_field) {
extract_invariants_from_block(consequence, source, conditions, depth + 1, config)?;
}
if let Some(alternative) = stmt.child_by_field_name(config.if_alternative_field) {
extract_invariants_from_block(alternative, source, conditions, depth + 1, config)?;
}
}
}
Ok(())
}
fn extract_asserts_as_invariants(
block: Node,
source: &[u8],
conditions: &mut Vec<Condition>,
config: &LanguageConfig,
) -> ContractsResult<()> {
let mut cursor = block.walk();
for stmt in block.children(&mut cursor) {
let kind = stmt.kind();
if config.is_assert(kind) {
if config.assert_is_macro {
if is_assert_macro(stmt, source) {
if let Some(args) = extract_macro_args(stmt, source) {
let line = stmt.start_position().row as u32 + 1;
conditions.push(Condition::medium(args.clone(), args, line));
}
}
} else {
let mut inner_cursor = stmt.walk();
for child in stmt.children(&mut inner_cursor) {
if child.kind() != "assert" && child.kind() != "assert_keyword" {
let constraint = get_node_text(child, source);
let line = stmt.start_position().row as u32 + 1;
conditions.push(Condition::medium(
constraint.to_string(),
constraint.to_string(),
line,
));
break;
}
}
}
} else if config.assert_is_macro && kind == "expression_statement" {
let mut inner = stmt.walk();
for child in stmt.children(&mut inner) {
if config.is_assert(child.kind()) && is_assert_macro(child, source) {
if let Some(args) = extract_macro_args(child, source) {
let line = stmt.start_position().row as u32 + 1;
conditions.push(Condition::medium(args.clone(), args, line));
}
}
}
}
}
Ok(())
}
fn negate_comparison(node: Node, source: &[u8]) -> Option<String> {
let mut cursor = node.walk();
let mut left = None;
let mut op = None;
let mut right = None;
for child in node.children(&mut cursor) {
let kind = child.kind();
match kind {
"<" | ">" | "<=" | ">=" | "==" | "!=" => {
op = Some(get_node_text(child, source));
}
_ => {
if left.is_none() {
left = Some(child);
} else if right.is_none() {
right = Some(child);
}
}
}
}
let left_node = left?;
let right_node = right?;
let op_text = op?;
let left_text = get_node_text(left_node, source);
let right_text = get_node_text(right_node, source);
let negated_op = negate_operator(op_text)?;
Some(format!("{} {} {}", left_text, negated_op, right_text))
}
fn negate_operator(op: &str) -> Option<&'static str> {
match op {
"<" => Some(">="),
"<=" => Some(">"),
">" => Some("<="),
">=" => Some("<"),
"==" => Some("!="),
"!=" => Some("=="),
"is" => Some("is not"),
"is not" => Some("is"),
"in" => Some("not in"),
"not in" => Some("in"),
_ => None,
}
}
fn is_isinstance_call(node: Node, source: &[u8]) -> bool {
if node.kind() != "call" {
return false;
}
if let Some(func) = node.child_by_field_name("function") {
let func_name = get_node_text(func, source);
return func_name == "isinstance";
}
false
}
fn extract_isinstance_var(node: Node, source: &[u8]) -> String {
if let Some(args) = node.child_by_field_name("arguments") {
let mut cursor = args.walk();
for arg in args.children(&mut cursor) {
let kind = arg.kind();
if kind != "(" && kind != ")" && kind != "," {
return get_node_text(arg, source).to_string();
}
}
}
"?".to_string()
}
fn extract_left_operand(node: Node, source: &[u8]) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let kind = child.kind();
match kind {
"<" | ">" | "<=" | ">=" | "==" | "!=" | "is" | "is not" | "in" | "not in" => continue,
_ => return get_node_text(child, source).to_string(),
}
}
get_node_text(node, source).to_string()
}
fn deduplicate_conditions(mut conditions: Vec<Condition>) -> Vec<Condition> {
let mut seen = HashSet::new();
conditions.retain(|c| {
let key = (c.variable.clone(), c.constraint.clone());
seen.insert(key)
});
conditions
}
pub fn format_contracts_text(report: &ContractsReport) -> String {
let mut output = String::new();
output.push_str(&format!("Function: {}\n", report.function));
let mut any_contracts = false;
for (label, conds) in [
("Preconditions", &report.preconditions),
("Postconditions", &report.postconditions),
("Invariants", &report.invariants),
] {
if !conds.is_empty() {
any_contracts = true;
output.push_str(&format!(" {}:\n", label));
for c in conds {
output.push_str(&format!(
" - {} ({}, line {}, {})\n",
c.constraint, c.variable, c.source_line, c.confidence
));
}
}
}
if !any_contracts {
output.push_str(" (none detected)\n");
}
output
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
const PYTHON_GUARD_CLAUSES: &str = r#"
def process_data(x, data):
if x < 0:
raise ValueError("x must be non-negative")
if not isinstance(data, list):
raise TypeError("data must be a list")
result = sum(data) + x
return result
"#;
const PYTHON_ASSERTS: &str = r#"
def calculate(a, b):
assert a > 0
assert isinstance(b, int)
result = a * b
return result
"#;
const PYTHON_POSTCONDITIONS: &str = r#"
def divide(a, b):
if b == 0:
raise ZeroDivisionError("Cannot divide by zero")
result = a / b
assert result is not None
return result
"#;
#[test]
fn test_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("guards.py");
fs::write(&file_path, PYTHON_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "process_data", Language::Python, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"Should detect preconditions"
);
let has_x_precond = report
.preconditions
.iter()
.any(|p| p.variable.contains("x") && p.constraint.contains(">="));
assert!(has_x_precond, "Should detect x >= 0 precondition");
}
#[test]
fn test_assert_extraction() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("asserts.py");
fs::write(&file_path, PYTHON_ASSERTS).unwrap();
let report = run_contracts(&file_path, "calculate", Language::Python, 100).unwrap();
assert!(
report.preconditions.len() >= 2,
"Should detect at least 2 preconditions"
);
let has_a_precond = report
.preconditions
.iter()
.any(|p| p.constraint.contains("a > 0") || p.constraint.contains("a>0"));
assert!(has_a_precond, "Should detect a > 0 precondition");
}
#[test]
fn test_postcondition_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("postcond.py");
fs::write(&file_path, PYTHON_POSTCONDITIONS).unwrap();
let report = run_contracts(&file_path, "divide", Language::Python, 100).unwrap();
assert!(
!report.postconditions.is_empty(),
"Should detect postconditions"
);
let has_result_postcond = report
.postconditions
.iter()
.any(|p| p.variable.contains("result") && p.constraint.contains("None"));
assert!(
has_result_postcond,
"Should detect result is not None postcondition"
);
}
#[test]
fn test_confidence_scoring() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("guards.py");
fs::write(&file_path, PYTHON_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "process_data", Language::Python, 100).unwrap();
for precond in &report.preconditions {
if precond.constraint.contains(">=") || precond.constraint.contains("isinstance") {
assert_eq!(
precond.confidence,
Confidence::High,
"Guard clause should have High confidence"
);
}
}
}
#[test]
fn test_function_not_found() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("test.py");
fs::write(&file_path, "def foo(): pass").unwrap();
let result = run_contracts(&file_path, "nonexistent", Language::Python, 100);
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::FunctionNotFound { function, .. } => {
assert_eq!(function, "nonexistent");
}
e => panic!("Expected FunctionNotFound, got {:?}", e),
}
}
#[test]
fn test_empty_function() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("empty.py");
fs::write(&file_path, "def empty(): pass").unwrap();
let report = run_contracts(&file_path, "empty", Language::Python, 100).unwrap();
assert!(report.preconditions.is_empty());
assert!(report.postconditions.is_empty());
assert!(report.invariants.is_empty());
}
#[test]
fn test_deduplicate_conditions() {
let conditions = vec![
Condition::high("x", "x > 0", 1),
Condition::high("x", "x > 0", 2), Condition::high("y", "y > 0", 3),
];
let deduped = deduplicate_conditions(conditions);
assert_eq!(deduped.len(), 2);
}
#[test]
fn test_negate_operator() {
assert_eq!(negate_operator("<"), Some(">="));
assert_eq!(negate_operator("<="), Some(">"));
assert_eq!(negate_operator(">"), Some("<="));
assert_eq!(negate_operator(">="), Some("<"));
assert_eq!(negate_operator("=="), Some("!="));
assert_eq!(negate_operator("!="), Some("=="));
assert_eq!(negate_operator("is"), Some("is not"));
assert_eq!(negate_operator("is not"), Some("is"));
}
const RUST_GUARD_CLAUSES: &str = r#"
fn process_data(x: i32, data: &[i32]) -> i32 {
if x < 0 {
panic!("x must be non-negative");
}
assert!(data.len() > 0);
let result = data.iter().sum::<i32>() + x;
assert!(result >= 0);
result
}
"#;
#[test]
fn test_rust_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("guards.rs");
fs::write(&file_path, RUST_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "process_data", Language::Rust, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"Rust: Should detect preconditions, got: {:?}",
report.preconditions
);
}
#[test]
fn test_rust_assert_macro() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("asserts.rs");
fs::write(&file_path, RUST_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "process_data", Language::Rust, 100).unwrap();
let has_assert = report
.preconditions
.iter()
.any(|p| p.constraint.contains("data.len() > 0") || p.constraint.contains("len"));
assert!(
has_assert,
"Rust: Should detect assert!(data.len() > 0), got: {:?}",
report.preconditions
);
}
#[test]
fn test_rust_postcondition() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("postcond.rs");
fs::write(&file_path, RUST_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "process_data", Language::Rust, 100).unwrap();
assert!(
!report.postconditions.is_empty(),
"Rust: Should detect postconditions after result assignment, got: {:?}",
report.postconditions
);
}
const GO_GUARD_CLAUSES: &str = r#"
package main
import "errors"
func processData(x int, data []int) (int, error) {
if x < 0 {
return 0, errors.New("x must be non-negative")
}
if len(data) == 0 {
return 0, errors.New("data must not be empty")
}
result := 0
for _, v := range data {
result += v
}
return result + x, nil
}
"#;
#[test]
fn test_go_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("guards.go");
fs::write(&file_path, GO_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "processData", Language::Go, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"Go: Should detect preconditions from guard clauses, got: {:?}",
report.preconditions
);
let has_x_precond = report
.preconditions
.iter()
.any(|p| p.constraint.contains(">=") || p.constraint.contains("x"));
assert!(
has_x_precond,
"Go: Should detect x >= 0 precondition, got: {:?}",
report.preconditions
);
}
const JAVA_GUARD_CLAUSES: &str = r#"
public class Processor {
public int processData(int x, int[] data) {
if (x < 0) {
throw new IllegalArgumentException("x must be non-negative");
}
assert data.length > 0;
int result = 0;
for (int v : data) {
result += v;
}
return result + x;
}
}
"#;
#[test]
fn test_java_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("Processor.java");
fs::write(&file_path, JAVA_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "processData", Language::Java, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"Java: Should detect preconditions, got: {:?}",
report.preconditions
);
}
#[test]
fn test_java_assert_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("Assert.java");
fs::write(&file_path, JAVA_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "processData", Language::Java, 100).unwrap();
let has_assert = report
.preconditions
.iter()
.any(|p| p.constraint.contains("data.length") || p.constraint.contains("length"));
assert!(
has_assert,
"Java: Should detect assert data.length > 0, got: {:?}",
report.preconditions
);
}
const TS_GUARD_CLAUSES: &str = r#"
function processData(x: number, data: number[]): number {
if (x < 0) {
throw new Error("x must be non-negative");
}
if (data.length === 0) {
throw new Error("data must not be empty");
}
let result = data.reduce((a, b) => a + b, 0) + x;
return result;
}
"#;
#[test]
fn test_ts_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("guards.ts");
fs::write(&file_path, TS_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "processData", Language::TypeScript, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"TypeScript: Should detect preconditions from throw guards, got: {:?}",
report.preconditions
);
}
#[test]
fn test_ts_type_annotations() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("typed.ts");
fs::write(&file_path, TS_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "processData", Language::TypeScript, 100).unwrap();
let has_type_precond = report
.preconditions
.iter()
.any(|p| p.confidence == Confidence::Low && p.constraint.contains("number"));
assert!(
has_type_precond,
"TypeScript: Should detect type annotation preconditions, got: {:?}",
report.preconditions
);
}
const CPP_GUARD_CLAUSES: &str = r#"
#include <stdexcept>
#include <vector>
int processData(int x, const std::vector<int>& data) {
if (x < 0) {
throw std::invalid_argument("x must be non-negative");
}
int result = 0;
for (int v : data) {
result += v;
}
return result + x;
}
"#;
#[test]
fn test_cpp_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("guards.cpp");
fs::write(&file_path, CPP_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "processData", Language::Cpp, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"C++: Should detect preconditions from throw guards, got: {:?}",
report.preconditions
);
}
const CSHARP_GUARD_CLAUSES: &str = r#"
public class Processor {
public int ProcessData(int x, int[] data) {
if (x < 0) {
throw new ArgumentException("x must be non-negative");
}
if (data.Length == 0) {
throw new ArgumentException("data must not be empty");
}
int result = 0;
foreach (int v in data) {
result += v;
}
return result + x;
}
}
"#;
#[test]
fn test_csharp_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("Processor.cs");
fs::write(&file_path, CSHARP_GUARD_CLAUSES).unwrap();
let report = run_contracts(&file_path, "ProcessData", Language::CSharp, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"C#: Should detect preconditions from throw guards, got: {:?}",
report.preconditions
);
}
#[test]
fn test_multi_language_no_panic() {
let temp = TempDir::new().unwrap();
let test_cases: Vec<(&str, Language, &str, &str)> = vec![
("test.py", Language::Python, "foo", "def foo(x):\n if x < 0:\n raise ValueError('bad')\n return x\n"),
("test.go", Language::Go, "foo", "package main\nfunc foo(x int) int {\n if x < 0 {\n return -1\n }\n return x\n}\n"),
("test.rs", Language::Rust, "foo", "fn foo(x: i32) -> i32 {\n if x < 0 {\n panic!(\"bad\");\n }\n x\n}\n"),
("test.java", Language::Java, "foo", "class T {\n int foo(int x) {\n if (x < 0) {\n throw new RuntimeException(\"bad\");\n }\n return x;\n }\n}\n"),
("test.ts", Language::TypeScript, "foo", "function foo(x: number): number {\n if (x < 0) {\n throw new Error('bad');\n }\n return x;\n}\n"),
("test.c", Language::C, "foo", "int foo(int x) {\n if (x < 0) {\n return -1;\n }\n return x;\n}\n"),
("test.cpp", Language::Cpp, "foo", "int foo(int x) {\n if (x < 0) {\n throw -1;\n }\n return x;\n}\n"),
("test.cs", Language::CSharp, "Foo", "class T {\n int Foo(int x) {\n if (x < 0) {\n throw new Exception(\"bad\");\n }\n return x;\n }\n}\n"),
];
for (filename, lang, func_name, source) in test_cases {
let file_path = temp.path().join(filename);
fs::write(&file_path, source).unwrap();
let result = run_contracts(&file_path, func_name, lang, 100);
assert!(
result.is_ok(),
"{:?}: Should parse without error, got: {:?}",
lang,
result.err()
);
let report = result.unwrap();
assert!(
!report.preconditions.is_empty(),
"{:?}: Should detect at least one precondition from guard clause, got: {:?}",
lang,
report
);
}
}
#[test]
fn test_find_ts_arrow_function_contracts() {
let ts_source = r#"
const getDuration = (start: Date, end: Date): number => {
if (!start || !end) {
throw new Error("invalid arguments");
}
return end.getTime() - start.getTime();
};
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("arrow.ts");
fs::write(&file_path, ts_source).unwrap();
let result = run_contracts(&file_path, "getDuration", Language::TypeScript, 100);
assert!(
result.is_ok(),
"Should find TS arrow function 'getDuration' for contracts analysis, got: {:?}",
result.err()
);
}
#[test]
fn test_python_untyped_params_produce_preconditions() {
let source = r#"
def request(method, url, **kwargs):
"""Sends a request."""
with sessions.Session() as session:
return session.request(method=method, url=url, **kwargs)
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("api.py");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "request", Language::Python, 100).unwrap();
let has_method_param = report.preconditions.iter().any(|p| p.variable == "method");
let has_url_param = report.preconditions.iter().any(|p| p.variable == "url");
assert!(
has_method_param,
"Should detect 'method' parameter as precondition, got: {:?}",
report.preconditions
);
assert!(
has_url_param,
"Should detect 'url' parameter as precondition, got: {:?}",
report.preconditions
);
}
#[test]
fn test_python_typed_params_no_guards() {
let source = r#"
def greet(name: str, count: int = 1) -> str:
return (name + "! ") * count
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("greet.py");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "greet", Language::Python, 100).unwrap();
let has_name_type = report
.preconditions
.iter()
.any(|p| p.variable == "name" && p.constraint.contains("str"));
let has_count_type = report
.preconditions
.iter()
.any(|p| p.variable == "count" && p.constraint.contains("int"));
let has_return_type = report
.postconditions
.iter()
.any(|p| p.constraint.contains("str"));
assert!(
has_name_type,
"Should detect name: str type precondition, got: {:?}",
report.preconditions
);
assert!(
has_count_type,
"Should detect count: int type precondition, got: {:?}",
report.preconditions
);
assert!(
has_return_type,
"Should detect -> str return type postcondition, got: {:?}",
report.postconditions
);
}
#[test]
fn test_python_docstring_param_extraction() {
let source = r#"
def request(method, url, **kwargs):
"""Constructs and sends a Request.
:param method: method for the new Request object.
:param url: URL for the new Request object.
:param params: (optional) Dictionary to send in the query string.
:return: Response object
:rtype: requests.Response
"""
with sessions.Session() as session:
return session.request(method=method, url=url, **kwargs)
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("api.py");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "request", Language::Python, 100).unwrap();
let has_method = report.preconditions.iter().any(|p| p.variable == "method");
let has_url = report.preconditions.iter().any(|p| p.variable == "url");
assert!(
has_method,
"Should extract :param method from docstring, got: {:?}",
report.preconditions
);
assert!(
has_url,
"Should extract :param url from docstring, got: {:?}",
report.preconditions
);
let has_return = report.postconditions.iter().any(|p| {
p.constraint.contains("Response") || p.constraint.contains("requests.Response")
});
assert!(
has_return,
"Should extract :rtype from docstring as postcondition, got: {:?}",
report.postconditions
);
}
#[test]
fn test_python_kwargs_parameter() {
let source = r#"
def request(method, url, **kwargs):
pass
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("kwargs.py");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "request", Language::Python, 100).unwrap();
let has_kwargs = report
.preconditions
.iter()
.any(|p| p.variable == "kwargs" || p.variable == "**kwargs");
assert!(
has_kwargs,
"Should detect **kwargs parameter, got: {:?}",
report.preconditions
);
}
#[test]
fn test_typescript_typed_params_no_guards() {
let source = r#"
function processData(x: number, data: string[]): string {
return data.join(", ") + x.toString();
}
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("process.ts");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "processData", Language::TypeScript, 100).unwrap();
let has_x_type = report
.preconditions
.iter()
.any(|p| p.variable == "x" && p.constraint.contains("number"));
let has_data_type = report
.preconditions
.iter()
.any(|p| p.variable == "data" && p.constraint.contains("string"));
assert!(
has_x_type,
"Should detect x: number type precondition, got: {:?}",
report.preconditions
);
assert!(
has_data_type,
"Should detect data: string[] type precondition, got: {:?}",
report.preconditions
);
let has_return = report
.postconditions
.iter()
.any(|p| p.constraint.contains("string"));
assert!(
has_return,
"Should detect return type string postcondition, got: {:?}",
report.postconditions
);
}
#[test]
fn test_jsdoc_param_extraction() {
let source = r#"
/**
* Sends a request to the server.
* @param {string} method - The HTTP method.
* @param {string} url - The request URL.
* @returns {Promise<Response>} The server response.
*/
function request(method, url) {
return fetch(url, { method: method });
}
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("api.js");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "request", Language::JavaScript, 100).unwrap();
let has_method = report.preconditions.iter().any(|p| p.variable == "method");
let has_url = report.preconditions.iter().any(|p| p.variable == "url");
assert!(
has_method,
"Should extract @param method from JSDoc, got: {:?}",
report.preconditions
);
assert!(
has_url,
"Should extract @param url from JSDoc, got: {:?}",
report.preconditions
);
let has_return = report
.postconditions
.iter()
.any(|p| p.constraint.contains("Promise") || p.constraint.contains("Response"));
assert!(
has_return,
"Should extract @returns from JSDoc as postcondition, got: {:?}",
report.postconditions
);
}
const RUST_RETURN_ERR_GUARDS: &str = r#"
fn transfer(amount: f64) -> Result<(), String> {
if amount <= 0.0 {
return Err("Amount must be positive".to_string());
}
if amount > 10000.0 {
return Err("Amount exceeds limit".to_string());
}
Ok(())
}
"#;
#[test]
fn test_rust_return_err_guard_clause_detection() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("transfer.rs");
fs::write(&file_path, RUST_RETURN_ERR_GUARDS).unwrap();
let report = run_contracts(&file_path, "transfer", Language::Rust, 100).unwrap();
assert!(
!report.preconditions.is_empty(),
"Rust: Should detect preconditions from `return Err(...)` guard clauses, got: {:?}",
report.preconditions
);
let has_amount_pos = report
.preconditions
.iter()
.any(|p| p.variable.contains("amount") && p.constraint.contains(">"));
assert!(
has_amount_pos,
"Rust: Should detect `amount > 0.0` precondition (negation of `amount <= 0.0`), got: {:?}",
report.preconditions
);
let has_amount_limit = report
.preconditions
.iter()
.any(|p| p.variable.contains("amount") && p.constraint.contains("<="));
assert!(
has_amount_limit,
"Rust: Should detect `amount <= 10000.0` precondition (negation of `amount > 10000.0`), got: {:?}",
report.preconditions
);
for precond in &report.preconditions {
if precond.constraint.contains(">") || precond.constraint.contains("<=") {
assert_eq!(
precond.confidence,
Confidence::High,
"Rust `return Err(...)` guard should produce High confidence, got: {:?}",
precond
);
}
}
}
#[test]
fn test_go_typed_params_no_guards() {
let source = r#"
package main
func add(x int, y int) int {
return x + y
}
"#;
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("add.go");
fs::write(&file_path, source).unwrap();
let report = run_contracts(&file_path, "add", Language::Go, 100).unwrap();
let has_x_type = report
.preconditions
.iter()
.any(|p| p.variable == "x" && p.constraint.contains("int"));
let has_y_type = report
.preconditions
.iter()
.any(|p| p.variable == "y" && p.constraint.contains("int"));
assert!(
has_x_type,
"Should detect x: int parameter type, got: {:?}",
report.preconditions
);
assert!(
has_y_type,
"Should detect y: int parameter type, got: {:?}",
report.preconditions
);
let has_return = report
.postconditions
.iter()
.any(|p| p.constraint.contains("int"));
assert!(
has_return,
"Should detect int return type postcondition, got: {:?}",
report.postconditions
);
}
}