use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use clap::Args;
use tree_sitter::{Node, Parser};
use tldr_core::callgraph::{
build_project_call_graph_v2, extract_calls_for_language, BuildConfig, CallSite,
};
use tldr_core::types::Language;
use crate::output::OutputFormat as GlobalOutputFormat;
use super::error::{PatternsError, PatternsResult};
use super::types::{
OutputFormat, TemporalConstraint, TemporalExample, TemporalMetadata, TemporalReport, Trigram,
};
use super::validation::{
check_directory_file_count, read_file_safe, validate_directory_path, validate_file_path,
validate_file_path_in_project, MAX_TRIGRAMS,
};
#[derive(Debug, Args)]
pub struct TemporalArgs {
pub path: PathBuf,
#[arg(long, default_value = "2")]
pub min_support: u32,
#[arg(long, default_value = "0.5")]
pub min_confidence: f64,
#[arg(long)]
pub query: Option<String>,
#[arg(long = "source-lang", default_value = "python")]
pub source_lang: String,
#[arg(long, default_value = "1000")]
pub max_files: u32,
#[arg(long)]
pub include_trigrams: bool,
#[arg(long, default_value = "3")]
pub include_examples: u32,
#[arg(
long = "output",
short = 'o',
hide = true,
default_value = "json",
value_enum
)]
pub output_format: OutputFormat,
#[arg(long, default_value = "60")]
pub timeout: u64,
#[arg(long)]
pub project_root: Option<PathBuf>,
#[arg(long, short = 'l')]
pub lang: Option<Language>,
}
impl TemporalArgs {
pub fn run(&self, global_format: GlobalOutputFormat) -> anyhow::Result<()> {
run(self.clone(), global_format)
}
}
impl Clone for TemporalArgs {
fn clone(&self) -> Self {
Self {
path: self.path.clone(),
min_support: self.min_support,
min_confidence: self.min_confidence,
query: self.query.clone(),
source_lang: self.source_lang.clone(),
max_files: self.max_files,
include_trigrams: self.include_trigrams,
include_examples: self.include_examples,
output_format: self.output_format,
timeout: self.timeout,
project_root: self.project_root.clone(),
lang: self.lang,
}
}
}
#[derive(Debug, Default)]
pub struct SequenceExtractor {
current_function: String,
sequences: HashMap<String, Vec<String>>,
var_assignments: HashMap<String, String>,
current_line: u32,
}
impl SequenceExtractor {
pub fn new() -> Self {
Self::default()
}
pub fn extract_function(&mut self, func_node: Node, source: &[u8]) {
let func_name = self.get_function_name(func_node, source);
if func_name.is_empty() {
return;
}
self.current_function = func_name;
self.var_assignments.clear();
self.extract_calls_recursive(func_node, source, 0);
}
fn extract_calls_recursive(&mut self, node: Node, source: &[u8], depth: usize) {
if depth > 100 {
return;
}
self.current_line = node.start_position().row as u32 + 1;
match node.kind() {
"assignment" => {
self.handle_assignment(node, source);
}
"call" => {
self.handle_call(node, source);
}
"with_statement" => {
self.handle_with_statement(node, source);
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.extract_calls_recursive(child, source, depth + 1);
}
}
fn handle_assignment(&mut self, node: Node, source: &[u8]) {
let var_name = if let Some(left) = node.child_by_field_name("left") {
self.node_text(left, source).to_string()
} else {
let mut var = String::new();
for child in node.children(&mut node.walk()) {
if child.kind() == "identifier" {
var = self.node_text(child, source).to_string();
break;
}
}
var
};
if var_name.is_empty() {
return;
}
if let Some(right) = node.child_by_field_name("right") {
if right.kind() == "call" {
let call_name = self.extract_call_name(right, source);
if !call_name.is_empty() {
self.var_assignments
.insert(var_name.clone(), call_name.clone());
let key = format!("{}:{}", self.current_function, var_name);
self.sequences.entry(key).or_default().push(call_name);
}
}
}
}
fn handle_call(&mut self, node: Node, source: &[u8]) {
if let Some(func) = node.child_by_field_name("function") {
if func.kind() == "attribute" {
if let Some(obj) = func.child_by_field_name("object") {
let obj_name = self.node_text(obj, source).to_string();
if let Some(method) = func.child_by_field_name("attribute") {
let method_name = self.node_text(method, source).to_string();
let key = format!("{}:{}", self.current_function, obj_name);
self.sequences.entry(key).or_default().push(method_name);
}
}
}
}
}
fn handle_with_statement(&mut self, node: Node, source: &[u8]) {
for child in node.children(&mut node.walk()) {
if child.kind() == "with_clause" {
for item in child.children(&mut child.walk()) {
if item.kind() == "with_item" {
let mut call_name = String::new();
let mut var_name = String::new();
for part in item.children(&mut item.walk()) {
if part.kind() == "call" {
call_name = self.extract_call_name(part, source);
} else if part.kind() == "as_pattern" || part.kind() == "identifier" {
if part.kind() == "identifier" {
var_name = self.node_text(part, source).to_string();
} else {
for as_child in part.children(&mut part.walk()) {
if as_child.kind() == "identifier" {
var_name = self.node_text(as_child, source).to_string();
break;
}
}
}
}
}
if !call_name.is_empty() && !var_name.is_empty() {
let key = format!("{}:{}", self.current_function, var_name);
self.sequences
.entry(key.clone())
.or_default()
.push(call_name);
self.sequences
.entry(key)
.or_default()
.push("__exit__".to_string());
}
}
}
}
}
}
fn extract_call_name(&self, node: Node, source: &[u8]) -> String {
if let Some(func) = node.child_by_field_name("function") {
return self.extract_name_from_expr(func, source);
}
for child in node.children(&mut node.walk()) {
match child.kind() {
"identifier" => return self.node_text(child, source).to_string(),
"attribute" => return self.extract_name_from_expr(child, source),
_ => continue,
}
}
String::new()
}
fn extract_name_from_expr(&self, node: Node, source: &[u8]) -> String {
match node.kind() {
"identifier" => self.node_text(node, source).to_string(),
"attribute" => {
if let Some(attr) = node.child_by_field_name("attribute") {
self.node_text(attr, source).to_string()
} else {
String::new()
}
}
_ => self.node_text(node, source).to_string(),
}
}
fn get_function_name(&self, node: Node, source: &[u8]) -> String {
for child in node.children(&mut node.walk()) {
if child.kind() == "identifier" {
return self.node_text(child, source).to_string();
}
}
String::new()
}
fn node_text<'a>(&self, node: Node, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
pub fn get_sequences(&self) -> &HashMap<String, Vec<String>> {
&self.sequences
}
}
pub fn extract_sequences(source: &str) -> HashMap<String, Vec<String>> {
let mut extractor = SequenceExtractor::new();
let mut parser = match get_python_parser() {
Ok(p) => p,
Err(_) => return HashMap::new(),
};
let tree = match parser.parse(source, None) {
Some(t) => t,
None => return HashMap::new(),
};
let root = tree.root_node();
let source_bytes = source.as_bytes();
extract_functions_recursive(root, source_bytes, &mut extractor);
extractor.sequences
}
fn extract_functions_recursive(node: Node, source: &[u8], extractor: &mut SequenceExtractor) {
match node.kind() {
"function_definition" | "async_function_definition" => {
extractor.extract_function(node, source);
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_functions_recursive(child, source, extractor);
}
}
fn sequences_from_callsite_map(
file_key: &str,
calls_by_func: &HashMap<String, Vec<CallSite>>,
) -> HashMap<String, Vec<String>> {
let mut out: HashMap<String, Vec<String>> = HashMap::new();
for (caller, sites) in calls_by_func {
if sites.is_empty() {
continue;
}
let mut ordered = sites.clone();
ordered.sort_by_key(|s| s.line.unwrap_or(u32::MAX));
let names: Vec<String> = ordered
.into_iter()
.map(|s| s.target)
.filter(|t| !t.is_empty())
.collect();
if names.is_empty() {
continue;
}
let key = format!("{}::{}", file_key, caller);
out.insert(key, names);
}
out
}
struct FileSequences {
sequences: HashMap<String, Vec<String>>,
first_line: HashMap<(String, String, String), u32>,
}
fn extract_sequences_for_file(
path: &Path,
source: &str,
language: Language,
) -> PatternsResult<FileSequences> {
let file_key = path.to_string_lossy().to_string();
let mut sequences: HashMap<String, Vec<String>> = HashMap::new();
let mut first_line: HashMap<(String, String, String), u32> = HashMap::new();
if language == Language::Python {
let legacy = extract_sequences(source);
for (k, v) in legacy {
sequences.entry(k).or_default().extend(v);
}
}
let lang_str = language.as_str();
let calls_by_func = match extract_calls_for_language(lang_str, path, source) {
Ok(map) => map,
Err(_) => {
return Ok(FileSequences {
sequences,
first_line,
});
}
};
let scoped = sequences_from_callsite_map(&file_key, &calls_by_func);
for (k, v) in scoped {
sequences.entry(k).or_default().extend(v);
}
for (caller, sites) in &calls_by_func {
let mut ordered = sites.clone();
ordered.sort_by_key(|s| s.line.unwrap_or(u32::MAX));
for pair in ordered.windows(2) {
let before = pair[0].target.clone();
let after = pair[1].target.clone();
if before.is_empty() || after.is_empty() || before == after {
continue;
}
let line = pair[1].line.unwrap_or(1);
first_line
.entry((caller.clone(), before, after))
.or_insert(line);
}
}
Ok(FileSequences {
sequences,
first_line,
})
}
fn resolve_directory_language(path: &Path, args: &TemporalArgs) -> Option<Language> {
if let Some(lang) = args.lang {
return Some(lang);
}
Language::from_directory(path)
}
fn per_caller_first_line(
calls_by_func: &HashMap<String, Vec<CallSite>>,
) -> HashMap<(String, String, String), u32> {
let mut first_line: HashMap<(String, String, String), u32> = HashMap::new();
for (caller, sites) in calls_by_func {
let mut ordered = sites.clone();
ordered.sort_by_key(|s| s.line.unwrap_or(u32::MAX));
for pair in ordered.windows(2) {
let before = pair[0].target.clone();
let after = pair[1].target.clone();
if before.is_empty() || after.is_empty() || before == after {
continue;
}
let line = pair[1].line.unwrap_or(1);
first_line
.entry((caller.clone(), before, after))
.or_insert(line);
}
}
first_line
}
#[allow(clippy::too_many_arguments)]
fn aggregate_file_sequences(
file_sequences: &HashMap<String, Vec<String>>,
file_path_str: &str,
first_line: &HashMap<(String, String, String), u32>,
all_sequences: &mut HashMap<String, Vec<String>>,
bigram_counts: &mut HashMap<(String, String), u32>,
before_counts: &mut HashMap<String, u32>,
all_examples: &mut HashMap<(String, String), Vec<TemporalExample>>,
args: &TemporalArgs,
) {
for (key, calls) in file_sequences {
all_sequences
.entry(key.clone())
.or_default()
.extend(calls.clone());
let caller_for_lookup = key
.rsplit_once("::")
.map(|(_, c)| c.to_string())
.unwrap_or_default();
for i in 0..calls.len().saturating_sub(1) {
let before = &calls[i];
let after = &calls[i + 1];
if before == after {
continue;
}
let pair = (before.clone(), after.clone());
*bigram_counts.entry(pair.clone()).or_default() += 1;
*before_counts.entry(before.clone()).or_default() += 1;
let examples = all_examples.entry(pair).or_default();
if examples.len() < args.include_examples as usize {
let line = first_line
.get(&(caller_for_lookup.clone(), before.clone(), after.clone()))
.copied()
.unwrap_or(1);
examples.push(TemporalExample {
file: file_path_str.to_string(),
line,
});
}
}
}
}
#[derive(Debug, Default)]
pub struct BigramCounter {
pub counts: HashMap<(String, String), u32>,
pub before_counts: HashMap<String, u32>,
pub examples: HashMap<(String, String), Vec<TemporalExample>>,
}
impl BigramCounter {
pub fn new() -> Self {
Self::default()
}
pub fn add_sequences(&mut self, sequences: &HashMap<String, Vec<String>>, file: &str) {
for calls in sequences.values() {
let line = 1u32;
for i in 0..calls.len().saturating_sub(1) {
let before = &calls[i];
let after = &calls[i + 1];
if before == after {
continue;
}
let pair = (before.clone(), after.clone());
*self.counts.entry(pair.clone()).or_default() += 1;
*self.before_counts.entry(before.clone()).or_default() += 1;
self.examples
.entry(pair)
.or_default()
.push(TemporalExample {
file: file.to_string(),
line,
});
}
}
}
}
pub fn mine_bigrams(
sequences: &HashMap<String, Vec<String>>,
file: &str,
args: &TemporalArgs,
) -> (BigramCounter, Vec<TemporalConstraint>) {
let mut counter = BigramCounter::new();
counter.add_sequences(sequences, file);
let mut constraints = Vec::new();
for ((before, after), count) in &counter.counts {
if *count < args.min_support {
continue;
}
let before_total = *counter.before_counts.get(before).unwrap_or(&1);
let confidence = (*count as f64) / (before_total as f64);
if confidence < args.min_confidence {
continue;
}
let examples = counter
.examples
.get(&(before.clone(), after.clone()))
.map(|ex| {
ex.iter()
.take(args.include_examples as usize)
.cloned()
.collect()
})
.unwrap_or_default();
constraints.push(TemporalConstraint {
before: before.clone(),
after: after.clone(),
support: *count,
confidence,
examples,
});
}
constraints.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.support.cmp(&a.support))
});
(counter, constraints)
}
pub fn mine_trigrams(
sequences: &HashMap<String, Vec<String>>,
args: &TemporalArgs,
) -> Vec<Trigram> {
let mut trigram_counts: HashMap<(String, String, String), u32> = HashMap::new();
let mut bigram_follows: HashMap<(String, String), u32> = HashMap::new();
for calls in sequences.values() {
for i in 0..calls.len().saturating_sub(2) {
let a = &calls[i];
let b = &calls[i + 1];
let c = &calls[i + 2];
if a == b || b == c {
continue;
}
*trigram_counts
.entry((a.clone(), b.clone(), c.clone()))
.or_default() += 1;
if a != b {
*bigram_follows.entry((a.clone(), b.clone())).or_default() += 1;
}
}
}
let mut heap: BinaryHeap<Reverse<(u32, String, String, String)>> = BinaryHeap::new();
for ((a, b, c), count) in &trigram_counts {
if *count < args.min_support {
continue;
}
let bigram_total = *bigram_follows.get(&(a.clone(), b.clone())).unwrap_or(&1);
let confidence = (*count as f64) / (bigram_total as f64);
if confidence < args.min_confidence {
continue;
}
if heap.len() < MAX_TRIGRAMS {
heap.push(Reverse((*count, a.clone(), b.clone(), c.clone())));
} else if let Some(&Reverse((min_support, _, _, _))) = heap.peek() {
if *count > min_support {
heap.pop();
heap.push(Reverse((*count, a.clone(), b.clone(), c.clone())));
}
}
}
let mut trigrams: Vec<Trigram> = heap
.into_iter()
.map(|Reverse((support, a, b, c))| {
let bigram_total = *bigram_follows.get(&(a.clone(), b.clone())).unwrap_or(&1);
let confidence = (support as f64) / (bigram_total as f64);
Trigram {
sequence: [a, b, c],
support,
confidence,
}
})
.collect();
trigrams.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.support.cmp(&a.support))
});
trigrams
}
pub fn filter_by_query(
constraints: Vec<TemporalConstraint>,
query: &str,
) -> Vec<TemporalConstraint> {
constraints
.into_iter()
.filter(|c| c.before.contains(query) || c.after.contains(query))
.collect()
}
pub fn filter_trigrams_by_query(trigrams: Vec<Trigram>, query: &str) -> Vec<Trigram> {
trigrams
.into_iter()
.filter(|t| t.sequence.iter().any(|s| s.contains(query)))
.collect()
}
fn get_python_parser() -> PatternsResult<Parser> {
let mut parser = Parser::new();
let language = tree_sitter_python::LANGUAGE;
parser.set_language(&language.into()).map_err(|e| {
PatternsError::parse_error(PathBuf::new(), format!("Failed to set language: {}", e))
})?;
Ok(parser)
}
type TemporalFileAnalysis = (HashMap<String, Vec<String>>, Vec<TemporalConstraint>);
fn analyze_temporal_file(path: &Path, args: &TemporalArgs) -> PatternsResult<TemporalFileAnalysis> {
let canonical = if let Some(ref root) = args.project_root {
validate_file_path_in_project(path, root)?
} else {
validate_file_path(path)?
};
let source = read_file_safe(&canonical)?;
let file_path_str = canonical.to_string_lossy().to_string();
let language = args
.lang
.or_else(|| Language::from_path(&canonical))
.unwrap_or(Language::Python);
let file_seqs = extract_sequences_for_file(&canonical, &source, language)?;
let sequences = file_seqs.sequences;
let (_, constraints) = mine_bigrams(&sequences, &file_path_str, args);
Ok((sequences, constraints))
}
fn analyze_temporal_directory(
path: &Path,
args: &TemporalArgs,
start_time: Instant,
) -> PatternsResult<TemporalReport> {
let canonical = validate_directory_path(path)?;
let timeout = Duration::from_secs(args.timeout);
let mut all_sequences: HashMap<String, Vec<String>> = HashMap::new();
let mut all_examples: HashMap<(String, String), Vec<TemporalExample>> = HashMap::new();
let mut bigram_counts: HashMap<(String, String), u32> = HashMap::new();
let mut before_counts: HashMap<String, u32> = HashMap::new();
let mut files_analyzed = 0u32;
let resolved_lang = resolve_directory_language(&canonical, args);
let use_project_builder = matches!(resolved_lang, Some(Language::Ocaml));
if use_project_builder {
let lang = resolved_lang.expect("checked above");
let mut config = BuildConfig {
language: lang.as_str().to_string(),
respect_ignore: false,
..Default::default()
};
config.use_type_resolution = false;
match build_project_call_graph_v2(&canonical, config) {
Ok(ir) => {
for (file_path, file_ir) in &ir.files {
if start_time.elapsed() > timeout {
break;
}
files_analyzed += 1;
if files_analyzed > args.max_files {
break;
}
check_directory_file_count(files_analyzed as usize)?;
let abs_path = if file_path.is_absolute() {
file_path.clone()
} else {
canonical.join(file_path)
};
let file_key = abs_path.to_string_lossy().to_string();
let scoped = sequences_from_callsite_map(&file_key, &file_ir.calls);
aggregate_file_sequences(
&scoped,
&file_key,
&per_caller_first_line(&file_ir.calls),
&mut all_sequences,
&mut bigram_counts,
&mut before_counts,
&mut all_examples,
args,
);
}
}
Err(_) => {
}
}
} else {
for entry in tldr_core::walker::walk_project(&canonical) {
if start_time.elapsed() > timeout {
break;
}
let entry_path = entry.path();
let entry_lang = match Language::from_path(entry_path) {
Some(lang) => lang,
None => continue,
};
if let Some(forced) = args.lang {
if forced != entry_lang {
continue;
}
} else if let Some(project_lang) = resolved_lang {
if project_lang != entry_lang {
continue;
}
}
files_analyzed += 1;
if files_analyzed > args.max_files {
break;
}
check_directory_file_count(files_analyzed as usize)?;
let file_path_str = entry_path.to_string_lossy().to_string();
if let Ok(source) = read_file_safe(entry_path) {
let file_seqs = match extract_sequences_for_file(entry_path, &source, entry_lang) {
Ok(s) => s,
Err(_) => continue,
};
aggregate_file_sequences(
&file_seqs.sequences,
&file_path_str,
&file_seqs.first_line,
&mut all_sequences,
&mut bigram_counts,
&mut before_counts,
&mut all_examples,
args,
);
}
}
}
let mut constraints = Vec::new();
for ((before, after), count) in &bigram_counts {
if *count < args.min_support {
continue;
}
let before_total = *before_counts.get(before).unwrap_or(&1);
let confidence = (*count as f64) / (before_total as f64);
if confidence < args.min_confidence {
continue;
}
let examples = all_examples
.get(&(before.clone(), after.clone()))
.cloned()
.unwrap_or_default();
constraints.push(TemporalConstraint {
before: before.clone(),
after: after.clone(),
support: *count,
confidence,
examples,
});
}
constraints.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.support.cmp(&a.support))
});
if let Some(ref query) = args.query {
constraints = filter_by_query(constraints, query);
}
let trigrams = if args.include_trigrams {
let mut trigrams = mine_trigrams(&all_sequences, args);
if let Some(ref query) = args.query {
trigrams = filter_trigrams_by_query(trigrams, query);
}
trigrams
} else {
Vec::new()
};
let sequences_extracted: u32 = all_sequences.values().map(|v| v.len() as u32).sum();
Ok(TemporalReport {
constraints,
trigrams,
metadata: TemporalMetadata {
files_analyzed,
sequences_extracted,
min_support: args.min_support,
min_confidence: args.min_confidence,
},
})
}
pub fn format_temporal_text(report: &TemporalReport) -> String {
let mut lines = Vec::new();
lines.push("Temporal Constraints".to_string());
lines.push("=".repeat(40));
lines.push(String::new());
if report.constraints.is_empty() {
lines.push("No constraints found matching criteria.".to_string());
} else {
lines.push(format!("Found {} constraints:", report.constraints.len()));
lines.push(String::new());
for constraint in &report.constraints {
lines.push(format!(" {} -> {}", constraint.before, constraint.after));
lines.push(format!(
" support: {}, confidence: {:.2}",
constraint.support, constraint.confidence
));
if !constraint.examples.is_empty() {
lines.push(" examples:".to_string());
for example in &constraint.examples {
lines.push(format!(" - {}:{}", example.file, example.line));
}
}
lines.push(String::new());
}
}
if !report.trigrams.is_empty() {
lines.push(String::new());
lines.push("Trigrams".to_string());
lines.push("-".repeat(40));
lines.push(String::new());
for trigram in &report.trigrams {
lines.push(format!(
" {} -> {} -> {}",
trigram.sequence[0], trigram.sequence[1], trigram.sequence[2]
));
lines.push(format!(
" support: {}, confidence: {:.2}",
trigram.support, trigram.confidence
));
lines.push(String::new());
}
}
lines.push(String::new());
lines.push("Metadata".to_string());
lines.push("-".repeat(40));
lines.push(format!(
" Files analyzed: {}",
report.metadata.files_analyzed
));
lines.push(format!(
" Sequences extracted: {}",
report.metadata.sequences_extracted
));
lines.push(format!(" Min support: {}", report.metadata.min_support));
lines.push(format!(
" Min confidence: {:.2}",
report.metadata.min_confidence
));
lines.join("\n")
}
pub fn run(args: TemporalArgs, global_format: GlobalOutputFormat) -> anyhow::Result<()> {
let start_time = Instant::now();
let path = &args.path;
let source_lang_norm = args.source_lang.to_lowercase();
if source_lang_norm != "auto" && source_lang_norm.parse::<Language>().is_err() {
return Err(PatternsError::UnsupportedLanguage {
language: args.source_lang.clone(),
}
.into());
}
let report = if path.is_dir() {
analyze_temporal_directory(path, &args, start_time)?
} else {
let (sequences, mut constraints) = analyze_temporal_file(path, &args)?;
if let Some(ref query) = args.query {
constraints = filter_by_query(constraints, query);
}
let trigrams = if args.include_trigrams {
let mut trigrams = mine_trigrams(&sequences, &args);
if let Some(ref query) = args.query {
trigrams = filter_trigrams_by_query(trigrams, query);
}
trigrams
} else {
Vec::new()
};
let sequences_extracted: u32 = sequences.values().map(|v| v.len() as u32).sum();
TemporalReport {
constraints,
trigrams,
metadata: TemporalMetadata {
files_analyzed: 1,
sequences_extracted,
min_support: args.min_support,
min_confidence: args.min_confidence,
},
}
};
let use_text = matches!(global_format, GlobalOutputFormat::Text)
|| matches!(args.output_format, OutputFormat::Text);
if report.constraints.is_empty() && report.trigrams.is_empty() {
if use_text {
println!("{}", format_temporal_text(&report));
} else {
let json = serde_json::to_string_pretty(&report)?;
println!("{}", json);
}
std::process::exit(2);
}
if use_text {
println!("{}", format_temporal_text(&report));
} else {
let json = serde_json::to_string_pretty(&report)?;
println!("{}", json);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_sequences_simple() {
let code = r#"
def read_config(path):
f = open(path)
content = f.read()
f.close()
return content
"#;
let sequences = extract_sequences(code);
let has_f_sequence = sequences.keys().any(|k| k.contains(":f"));
assert!(has_f_sequence, "Should extract sequence for variable f");
}
#[test]
fn test_bigram_counter() {
let mut sequences = HashMap::new();
sequences.insert(
"func:f".to_string(),
vec!["open".to_string(), "read".to_string(), "close".to_string()],
);
let mut counter = BigramCounter::new();
counter.add_sequences(&sequences, "test.py");
assert_eq!(
counter
.counts
.get(&("open".to_string(), "read".to_string())),
Some(&1)
);
assert_eq!(
counter
.counts
.get(&("read".to_string(), "close".to_string())),
Some(&1)
);
}
#[test]
fn test_mine_bigrams_filter() {
let mut sequences = HashMap::new();
sequences.insert(
"func:f".to_string(),
vec!["open".to_string(), "read".to_string(), "close".to_string()],
);
let args = TemporalArgs {
path: PathBuf::new(),
min_support: 1,
min_confidence: 0.0,
query: None,
source_lang: "python".to_string(),
max_files: 1000,
include_trigrams: false,
include_examples: 3,
output_format: OutputFormat::Json,
timeout: 60,
project_root: None,
lang: None,
};
let (_, constraints) = mine_bigrams(&sequences, "test.py", &args);
assert!(!constraints.is_empty(), "Should find bigram constraints");
}
#[test]
fn test_filter_by_query() {
let constraints = vec![
TemporalConstraint {
before: "open".to_string(),
after: "read".to_string(),
support: 5,
confidence: 0.8,
examples: vec![],
},
TemporalConstraint {
before: "acquire".to_string(),
after: "release".to_string(),
support: 3,
confidence: 0.9,
examples: vec![],
},
];
let filtered = filter_by_query(constraints, "open");
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].before, "open");
}
#[test]
fn test_mine_trigrams_limit() {
let mut sequences = HashMap::new();
let calls: Vec<String> = (0..100).map(|i| format!("method{}", i)).collect();
sequences.insert("func:obj".to_string(), calls);
let args = TemporalArgs {
path: PathBuf::new(),
min_support: 1,
min_confidence: 0.0,
query: None,
source_lang: "python".to_string(),
max_files: 1000,
include_trigrams: true,
include_examples: 3,
output_format: OutputFormat::Json,
timeout: 60,
project_root: None,
lang: None,
};
let trigrams = mine_trigrams(&sequences, &args);
assert!(trigrams.len() <= MAX_TRIGRAMS);
}
#[test]
fn test_format_temporal_text() {
let report = TemporalReport {
constraints: vec![TemporalConstraint {
before: "open".to_string(),
after: "close".to_string(),
support: 10,
confidence: 0.95,
examples: vec![TemporalExample {
file: "test.py".to_string(),
line: 5,
}],
}],
trigrams: vec![],
metadata: TemporalMetadata {
files_analyzed: 1,
sequences_extracted: 5,
min_support: 2,
min_confidence: 0.5,
},
};
let text = format_temporal_text(&report);
assert!(text.contains("open -> close"));
assert!(text.contains("support: 10"));
assert!(text.contains("confidence: 0.95"));
}
#[test]
fn test_temporal_args_lang_flag() {
use tldr_core::types::Language;
let args = TemporalArgs {
path: PathBuf::from("src/"),
min_support: 2,
min_confidence: 0.5,
query: None,
source_lang: "python".to_string(),
max_files: 1000,
include_trigrams: false,
include_examples: 3,
output_format: OutputFormat::Json,
timeout: 60,
project_root: None,
lang: Some(Language::Python),
};
assert_eq!(args.lang, Some(Language::Python));
let args_auto = TemporalArgs {
path: PathBuf::from("src/"),
min_support: 2,
min_confidence: 0.5,
query: None,
source_lang: "python".to_string(),
max_files: 1000,
include_trigrams: false,
include_examples: 3,
output_format: OutputFormat::Json,
timeout: 60,
project_root: None,
lang: None,
};
assert_eq!(args_auto.lang, None);
}
use std::io::Write;
fn extract_for_lang(extension: &str, source: &str, language: Language) -> Vec<Vec<String>> {
let mut tmp = tempfile::Builder::new()
.suffix(&format!(".{}", extension))
.tempfile()
.expect("tempfile");
tmp.write_all(source.as_bytes()).expect("write source");
let path = tmp.path().to_path_buf();
let file_seqs = extract_sequences_for_file(&path, source, language)
.expect("extract_sequences_for_file");
file_seqs.sequences.into_values().collect()
}
fn assert_helper_then_b_util(seqs: &[Vec<String>], language_label: &str) {
let found = seqs
.iter()
.any(|seq| seq.windows(2).any(|w| w[0] == "helper" && w[1] == "b_util"));
assert!(
found,
"[{}] expected `helper -> b_util` bigram, got: {:?}",
language_label, seqs
);
}
#[test]
fn test_extract_sequences_typescript() {
let source = "\
function helper(): number { return 1; }
function b_util(): number { return 2; }
function main(): void {
helper();
b_util();
}
";
let seqs = extract_for_lang("ts", source, Language::TypeScript);
assert_helper_then_b_util(&seqs, "typescript");
}
#[test]
fn test_extract_sequences_java() {
let source = "\
class Main {
public static int helper() { return 1; }
public static int bUtil() { return 2; }
public static void main(String[] args) {
helper();
bUtil();
}
}
";
let mut tmp = tempfile::Builder::new().suffix(".java").tempfile().unwrap();
tmp.write_all(source.as_bytes()).unwrap();
let path = tmp.path().to_path_buf();
let file_seqs = extract_sequences_for_file(&path, source, Language::Java).expect("extract");
let seqs: Vec<Vec<String>> = file_seqs.sequences.into_values().collect();
let found = seqs
.iter()
.any(|seq| seq.windows(2).any(|w| w[0] == "helper" && w[1] == "bUtil"));
assert!(
found,
"[java] expected `helper -> bUtil` bigram, got: {:?}",
seqs
);
}
#[test]
fn test_extract_sequences_go() {
let source = "\
package main
func helper() int { return 1 }
func b_util() int { return 2 }
func main() {
helper()
b_util()
}
";
let seqs = extract_for_lang("go", source, Language::Go);
assert_helper_then_b_util(&seqs, "go");
}
#[test]
fn test_extract_sequences_rust() {
let source = "\
fn helper() -> i32 { 1 }
fn b_util() -> i32 { 2 }
fn main() {
let _ = helper();
let _ = b_util();
}
";
let seqs = extract_for_lang("rs", source, Language::Rust);
assert_helper_then_b_util(&seqs, "rust");
}
#[test]
fn test_extract_sequences_python_via_generalized_path() {
let source = "\
def helper():
return 1
def b_util():
return 2
def main():
helper()
b_util()
";
let seqs = extract_for_lang("py", source, Language::Python);
assert_helper_then_b_util(&seqs, "python");
}
#[test]
fn test_extract_sequences_python_legacy_receiver_aware() {
let source = "\
def read_config(path):
f = open(path)
content = f.read()
f.close()
return content
";
let mut tmp = tempfile::Builder::new().suffix(".py").tempfile().unwrap();
tmp.write_all(source.as_bytes()).unwrap();
let path = tmp.path().to_path_buf();
let file_seqs =
extract_sequences_for_file(&path, source, Language::Python).expect("extract");
let has_open_read = file_seqs
.sequences
.values()
.any(|seq| seq.windows(2).any(|w| w[0] == "open" && w[1] == "read"));
assert!(
has_open_read,
"python legacy: expected `open -> read` bigram for receiver f, got: {:?}",
file_seqs.sequences
);
}
#[test]
fn test_sequences_from_callsite_map_orders_by_line() {
use tldr_core::callgraph::CallSite;
let mut calls: HashMap<String, Vec<CallSite>> = HashMap::new();
calls.insert(
"main".to_string(),
vec![
CallSite::direct("main".to_string(), "b_util".to_string(), Some(8)),
CallSite::direct("main".to_string(), "helper".to_string(), Some(7)),
],
);
let out = sequences_from_callsite_map("/tmp/foo", &calls);
let main_seq = out.get("/tmp/foo::main").expect("main sequence");
assert_eq!(
main_seq,
&vec!["helper".to_string(), "b_util".to_string()],
"calls must be ordered by line ascending (sequences_from_callsite_map)"
);
}
}