use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use walkdir::WalkDir;
use crate::ast::extract::extract_file;
use crate::types::{FunctionInfo, Language};
use crate::TldrResult;
const SIGNATURE_WEIGHT: f64 = 0.3;
const COMPLEXITY_WEIGHT: f64 = 0.2;
const CALL_PATTERN_WEIGHT: f64 = 0.3;
const LOC_WEIGHT: f64 = 0.2;
const _: () = {
let sum = SIGNATURE_WEIGHT + COMPLEXITY_WEIGHT + CALL_PATTERN_WEIGHT + LOC_WEIGHT;
assert!((sum - 1.0).abs() < 0.0001);
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SimilarityReason {
SameSignature,
SimilarComplexity,
SimilarCallPattern,
SimilarLoc,
}
impl SimilarityReason {
pub fn description(&self) -> &'static str {
match self {
SimilarityReason::SameSignature => "same parameter count",
SimilarityReason::SimilarComplexity => "similar complexity",
SimilarityReason::SimilarCallPattern => "similar call pattern",
SimilarityReason::SimilarLoc => "similar lines of code",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionRef {
pub name: String,
pub file: PathBuf,
pub line: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarPair {
pub func_a: FunctionRef,
pub func_b: FunctionRef,
pub score: f64,
pub reasons: Vec<SimilarityReason>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityReport {
pub functions_analyzed: usize,
pub pairs_compared: usize,
pub similar_pairs_count: usize,
pub threshold: f64,
pub similar_pairs: Vec<SimilarPair>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncated: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_pairs: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub shown_pairs: Option<usize>,
}
impl Default for SimilarityReport {
fn default() -> Self {
Self {
functions_analyzed: 0,
pairs_compared: 0,
similar_pairs_count: 0,
threshold: 0.7,
similar_pairs: Vec::new(),
truncated: None,
total_pairs: None,
shown_pairs: None,
}
}
}
#[derive(Debug, Clone)]
pub struct SimilarityOptions {
pub threshold: f64,
pub max_functions: usize,
pub max_pairs: usize,
}
impl Default for SimilarityOptions {
fn default() -> Self {
Self {
threshold: 0.7,
max_functions: 500,
max_pairs: 50,
}
}
}
#[derive(Debug, Clone)]
struct FunctionData {
func_ref: FunctionRef,
param_count: usize,
has_return_type: bool,
complexity: usize,
loc: usize,
callees: HashSet<String>,
}
pub fn find_similar(
path: &Path,
language: Option<Language>,
threshold: f64,
max_functions: Option<usize>,
) -> TldrResult<SimilarityReport> {
let options = SimilarityOptions {
threshold,
max_functions: max_functions.unwrap_or(500),
..Default::default()
};
find_similar_with_options(path, language, &options)
}
pub fn find_similar_with_options(
path: &Path,
language: Option<Language>,
options: &SimilarityOptions,
) -> TldrResult<SimilarityReport> {
let lang = language.unwrap_or_else(|| detect_dominant_language(path));
let mut functions = extract_all_functions(path, lang)?;
if functions.len() > options.max_functions {
functions.sort_by(|a, b| b.complexity.cmp(&a.complexity));
functions.truncate(options.max_functions);
}
let function_count = functions.len();
if function_count < 2 {
return Ok(SimilarityReport {
functions_analyzed: function_count,
pairs_compared: 0,
similar_pairs_count: 0,
threshold: options.threshold,
similar_pairs: Vec::new(),
truncated: None,
total_pairs: None,
shown_pairs: None,
});
}
let pairs_count = function_count * (function_count - 1) / 2;
let pair_indices: Vec<(usize, usize)> = (0..function_count)
.flat_map(|i| ((i + 1)..function_count).map(move |j| (i, j)))
.collect();
let threshold = options.threshold;
let similar_pairs: Vec<SimilarPair> = pair_indices
.par_iter()
.filter_map(|(i, j)| {
let (score, reasons) = calculate_similarity(&functions[*i], &functions[*j]);
if score >= threshold {
Some(SimilarPair {
func_a: functions[*i].func_ref.clone(),
func_b: functions[*j].func_ref.clone(),
score,
reasons,
})
} else {
None
}
})
.collect();
let mut similar_pairs = similar_pairs;
similar_pairs.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let similar_count = similar_pairs.len();
let shown_pairs = similar_pairs.len().min(options.max_pairs);
let was_truncated = similar_pairs.len() > options.max_pairs;
similar_pairs.truncate(options.max_pairs);
Ok(SimilarityReport {
functions_analyzed: function_count,
pairs_compared: pairs_count,
similar_pairs_count: similar_count,
threshold: options.threshold,
similar_pairs,
truncated: if was_truncated { Some(true) } else { None },
total_pairs: if was_truncated {
Some(similar_count)
} else {
None
},
shown_pairs: if was_truncated {
Some(shown_pairs)
} else {
None
},
})
}
fn detect_dominant_language(path: &Path) -> Language {
let mut counts: HashMap<Language, usize> = HashMap::new();
for entry in WalkDir::new(path)
.follow_links(false)
.into_iter()
.filter_map(|e| e.ok())
{
if let Some(lang) = Language::from_path(entry.path()) {
*counts.entry(lang).or_insert(0) += 1;
}
}
counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(lang, _)| lang)
.unwrap_or(Language::Python)
}
fn extract_all_functions(path: &Path, language: Language) -> TldrResult<Vec<FunctionData>> {
let mut functions = Vec::new();
let extensions: HashSet<String> = language
.extensions()
.iter()
.map(|s| s.to_string())
.collect();
for entry in WalkDir::new(path)
.follow_links(false)
.into_iter()
.filter_map(|e| e.ok())
{
let entry_path = entry.path();
if !entry_path.is_file() {
continue;
}
if let Some(ext) = entry_path.extension().and_then(|e| e.to_str()) {
let ext_with_dot = format!(".{}", ext);
if !extensions.contains(&ext_with_dot) {
continue;
}
} else {
continue;
}
match extract_file(entry_path, Some(path)) {
Ok(info) => {
for func in &info.functions {
if let Some(func_data) =
function_info_to_data(func, entry_path, &info.call_graph.calls)
{
functions.push(func_data);
}
}
for class in &info.classes {
for method in &class.methods {
if method.name.starts_with("__") && method.name.ends_with("__") {
continue;
}
let qualified_name = format!("{}.{}", class.name, method.name);
if let Some(mut func_data) =
function_info_to_data(method, entry_path, &info.call_graph.calls)
{
func_data.func_ref.name = qualified_name;
functions.push(func_data);
}
}
}
}
Err(_) => {
continue;
}
}
}
Ok(functions)
}
fn function_info_to_data(
func: &FunctionInfo,
file: &Path,
call_graph: &HashMap<String, Vec<String>>,
) -> Option<FunctionData> {
let callees: HashSet<String> = call_graph
.get(&func.name)
.map(|v| v.iter().cloned().collect())
.unwrap_or_default();
let loc = func.params.len() * 2 + 5;
Some(FunctionData {
func_ref: FunctionRef {
name: func.name.clone(),
file: file.to_path_buf(),
line: func.line_number as usize,
},
param_count: func.params.len(),
has_return_type: func.return_type.is_some(),
complexity: 1, loc,
callees,
})
}
fn calculate_similarity(a: &FunctionData, b: &FunctionData) -> (f64, Vec<SimilarityReason>) {
let mut reasons = Vec::new();
let signature_sim = calculate_signature_similarity(a, b);
if signature_sim > 0.8 {
reasons.push(SimilarityReason::SameSignature);
}
let complexity_sim = calculate_complexity_similarity(a, b);
if complexity_sim > 0.8 {
reasons.push(SimilarityReason::SimilarComplexity);
}
let call_pattern_sim = calculate_call_pattern_similarity(a, b);
if call_pattern_sim > 0.8 {
reasons.push(SimilarityReason::SimilarCallPattern);
}
let loc_sim = calculate_loc_similarity(a, b);
if loc_sim > 0.8 {
reasons.push(SimilarityReason::SimilarLoc);
}
let score = SIGNATURE_WEIGHT * signature_sim
+ COMPLEXITY_WEIGHT * complexity_sim
+ CALL_PATTERN_WEIGHT * call_pattern_sim
+ LOC_WEIGHT * loc_sim;
(score, reasons)
}
fn calculate_signature_similarity(a: &FunctionData, b: &FunctionData) -> f64 {
let max_params = a.param_count.max(b.param_count);
let param_sim = if max_params == 0 {
1.0
} else {
let diff = (a.param_count as i32 - b.param_count as i32).unsigned_abs() as usize;
1.0 - (diff as f64 / max_params as f64)
};
let return_sim = if a.has_return_type == b.has_return_type {
1.0
} else {
0.5
};
(param_sim + return_sim) / 2.0
}
fn calculate_complexity_similarity(a: &FunctionData, b: &FunctionData) -> f64 {
let max_complexity = a.complexity.max(b.complexity);
if max_complexity == 0 {
return 1.0;
}
let diff = (a.complexity as i32 - b.complexity as i32).unsigned_abs() as usize;
1.0 - (diff as f64 / max_complexity as f64).min(1.0)
}
fn calculate_call_pattern_similarity(a: &FunctionData, b: &FunctionData) -> f64 {
if a.callees.is_empty() && b.callees.is_empty() {
return 1.0; }
let intersection = a.callees.intersection(&b.callees).count();
let union = a.callees.union(&b.callees).count();
if union == 0 {
1.0
} else {
intersection as f64 / union as f64
}
}
fn calculate_loc_similarity(a: &FunctionData, b: &FunctionData) -> f64 {
let max_loc = a.loc.max(b.loc);
if max_loc == 0 {
return 1.0;
}
let diff = (a.loc as i32 - b.loc as i32).unsigned_abs() as usize;
1.0 - (diff as f64 / max_loc as f64).min(1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weights_sum_to_one() {
let sum = SIGNATURE_WEIGHT + COMPLEXITY_WEIGHT + CALL_PATTERN_WEIGHT + LOC_WEIGHT;
assert!((sum - 1.0).abs() < 0.0001);
}
#[test]
fn test_signature_similarity_same_params() {
let a = FunctionData {
func_ref: FunctionRef {
name: "a".to_string(),
file: PathBuf::from("a.py"),
line: 1,
},
param_count: 3,
has_return_type: true,
complexity: 1,
loc: 10,
callees: HashSet::new(),
};
let b = FunctionData {
func_ref: FunctionRef {
name: "b".to_string(),
file: PathBuf::from("b.py"),
line: 1,
},
param_count: 3,
has_return_type: true,
complexity: 1,
loc: 10,
callees: HashSet::new(),
};
let sim = calculate_signature_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.0001);
}
#[test]
fn test_call_pattern_jaccard() {
let mut callees_a = HashSet::new();
callees_a.insert("foo".to_string());
callees_a.insert("bar".to_string());
let mut callees_b = HashSet::new();
callees_b.insert("foo".to_string());
callees_b.insert("baz".to_string());
let a = FunctionData {
func_ref: FunctionRef {
name: "a".to_string(),
file: PathBuf::from("a.py"),
line: 1,
},
param_count: 0,
has_return_type: false,
complexity: 1,
loc: 10,
callees: callees_a,
};
let b = FunctionData {
func_ref: FunctionRef {
name: "b".to_string(),
file: PathBuf::from("b.py"),
line: 1,
},
param_count: 0,
has_return_type: false,
complexity: 1,
loc: 10,
callees: callees_b,
};
let sim = calculate_call_pattern_similarity(&a, &b);
assert!((sim - 1.0 / 3.0).abs() < 0.0001);
}
#[test]
fn test_similarity_report_default() {
let report = SimilarityReport::default();
assert_eq!(report.functions_analyzed, 0);
assert_eq!(report.pairs_compared, 0);
assert!(report.similar_pairs.is_empty());
}
#[test]
fn test_similarity_reason_description() {
assert_eq!(
SimilarityReason::SameSignature.description(),
"same parameter count"
);
assert_eq!(
SimilarityReason::SimilarComplexity.description(),
"similar complexity"
);
}
}