use std::cmp::Ordering;
use std::collections::BinaryHeap;
use serde::Serialize;
use crate::store::StoredChunk;
use super::cosine_similarity;
fn chunk_key(c: &DuplicateChunk) -> (&str, &str, u32) {
(&c.repo, &c.file_path, c.line_start)
}
fn pair_rank(a: &DuplicatePair, b: &DuplicatePair) -> Ordering {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(Ordering::Equal)
.then_with(|| chunk_key(&a.a).cmp(&chunk_key(&b.a)))
.then_with(|| chunk_key(&a.b).cmp(&chunk_key(&b.b)))
}
struct HeapEntry(DuplicatePair);
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
pair_rank(&self.0, &other.0)
}
}
#[derive(Debug, Clone)]
pub struct LabeledChunk<'a> {
pub repo: &'a str,
pub chunk: &'a StoredChunk,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct DuplicateChunk {
pub repo: String,
pub file_path: String,
pub line_start: u32,
pub line_end: u32,
pub name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct DuplicatePair {
pub a: DuplicateChunk,
pub b: DuplicateChunk,
pub similarity: f32,
}
pub fn find_duplicates(
chunks: &[LabeledChunk<'_>],
min_similarity: f32,
limit: usize,
) -> Vec<DuplicatePair> {
if chunks.len() < 2 || limit == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(limit + 1);
for i in 0..chunks.len() {
for j in (i + 1)..chunks.len() {
let a = &chunks[i];
let b = &chunks[j];
if a.repo == b.repo && a.chunk.file_path == b.chunk.file_path {
continue;
}
let sim = cosine_similarity(&a.chunk.vector, &b.chunk.vector);
if sim < min_similarity {
continue;
}
if heap.len() == limit
&& let Some(min_entry) = heap.peek()
&& sim
.partial_cmp(&min_entry.0.similarity)
.unwrap_or(Ordering::Equal)
== Ordering::Less
{
continue;
}
let candidate = DuplicatePair {
a: DuplicateChunk {
repo: a.repo.to_owned(),
file_path: a.chunk.file_path.clone(),
line_start: a.chunk.line_start,
line_end: a.chunk.line_end,
name: a.chunk.name.clone(),
},
b: DuplicateChunk {
repo: b.repo.to_owned(),
file_path: b.chunk.file_path.clone(),
line_start: b.chunk.line_start,
line_end: b.chunk.line_end,
name: b.chunk.name.clone(),
},
similarity: sim,
};
if heap.len() == limit {
if let Some(min_entry) = heap.peek()
&& pair_rank(&candidate, &min_entry.0) != Ordering::Less
{
continue;
}
heap.pop();
}
heap.push(HeapEntry(candidate));
}
}
let mut pairs: Vec<DuplicatePair> = heap.into_iter().map(|e| e.0).collect();
pairs.sort_by(pair_rank);
pairs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::StoredChunk;
fn stored(file_path: &str, name: &str, vector: Vec<f32>) -> StoredChunk {
StoredChunk {
chunk_id: 0,
file_path: file_path.to_owned(),
language: "rust".into(),
kind: "function".into(),
name: Some(name.to_owned()),
line_start: 1,
line_end: 10,
byte_start: 0,
byte_end: 100,
file_hash: [0u8; 16],
content: format!("pub fn {name}() {{}}"),
vector,
}
}
fn labeled<'a>(repo: &'a str, chunk: &'a StoredChunk) -> LabeledChunk<'a> {
LabeledChunk { repo, chunk }
}
#[test]
fn identical_vectors_produce_a_pair() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let a = stored("src/a.rs", "foo", v.clone());
let b = stored("src/b.rs", "bar", v.clone());
let chunks = [labeled("/repo", &a), labeled("/repo", &b)];
let pairs = find_duplicates(&chunks, 0.9, 50);
assert_eq!(pairs.len(), 1);
assert!(pairs[0].similarity > 0.99);
assert_eq!(pairs[0].a.file_path, "src/a.rs");
assert_eq!(pairs[0].b.file_path, "src/b.rs");
}
#[test]
fn no_pairs_when_below_threshold() {
let a = stored("src/a.rs", "foo", vec![1.0, 0.0, 0.0, 0.0]);
let b = stored("src/b.rs", "bar", vec![0.0, 1.0, 0.0, 0.0]);
let chunks = [labeled("/repo", &a), labeled("/repo", &b)];
let pairs = find_duplicates(&chunks, 0.85, 50);
assert!(pairs.is_empty());
}
#[test]
fn same_file_pairs_excluded() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let a = stored("src/a.rs", "foo", v.clone());
let b = stored("src/a.rs", "bar", v.clone());
let chunks = [labeled("/repo", &a), labeled("/repo", &b)];
let pairs = find_duplicates(&chunks, 0.0, 50);
assert!(pairs.is_empty(), "same-file pairs must not be reported");
}
#[test]
fn limit_caps_output_not_input() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let a = stored("src/a.rs", "foo", v.clone());
let b = stored("src/b.rs", "bar", v.clone());
let c = stored("src/c.rs", "baz", v.clone());
let chunks = [
labeled("/repo", &a),
labeled("/repo", &b),
labeled("/repo", &c),
];
let pairs = find_duplicates(&chunks, 0.0, 1);
assert_eq!(pairs.len(), 1);
}
#[test]
fn cross_repo_pair_when_repos_differ() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let a = stored("src/a.rs", "foo", v.clone());
let b = stored("src/a.rs", "foo", v.clone());
let chunks = [labeled("/repo-a", &a), labeled("/repo-b", &b)];
let pairs = find_duplicates(&chunks, 0.9, 50);
assert_eq!(pairs.len(), 1);
assert_ne!(pairs[0].a.repo, pairs[0].b.repo);
}
#[test]
fn sorted_by_similarity_descending() {
let high = vec![1.0_f32, 0.0, 0.0, 0.0];
let mid = vec![0.9_f32, 0.1_f32.sqrt(), 0.0, 0.0];
let a = stored("src/a.rs", "foo", high.clone());
let b = stored("src/b.rs", "bar", high.clone());
let c = stored("src/c.rs", "baz", mid.clone());
let d = stored("src/d.rs", "qux", high.clone());
let chunks = [
labeled("/repo", &a),
labeled("/repo", &b),
labeled("/repo", &c),
labeled("/repo", &d),
];
let pairs = find_duplicates(&chunks, 0.0, 50);
for window in pairs.windows(2) {
assert!(
window[0].similarity >= window[1].similarity,
"pairs not sorted descending"
);
}
}
#[test]
fn empty_input_returns_empty() {
let pairs = find_duplicates(&[], 0.0, 50);
assert!(pairs.is_empty());
}
#[test]
fn single_chunk_returns_empty() {
let a = stored("src/a.rs", "foo", vec![1.0, 0.0]);
let chunks = [labeled("/repo", &a)];
let pairs = find_duplicates(&chunks, 0.0, 50);
assert!(pairs.is_empty());
}
#[test]
fn heap_cap_keeps_highest_similarity_pairs() {
let a = stored("src/a.rs", "foo", vec![1.0, 0.0, 0.0, 0.0]);
let b = stored("src/b.rs", "bar", vec![1.0, 0.0, 0.0, 0.0]);
let c = stored(
"src/c.rs",
"baz",
vec![0.95_f32, 0.31225_f32, 0.0, 0.0], );
let chunks = [
labeled("/repo", &a),
labeled("/repo", &b),
labeled("/repo", &c),
];
let pairs = find_duplicates(&chunks, 0.0, 2);
assert_eq!(pairs.len(), 2, "limit=2 must yield exactly 2 pairs");
assert!(
pairs[0].similarity >= pairs[1].similarity,
"output must be sorted descending"
);
assert!(
pairs[0].similarity > 0.99,
"a↔b (identical vectors) must be the top pair"
);
let min_sim = pairs[1].similarity;
assert!(
min_sim > 0.89,
"pair with similarity ≈ 0.90 should be dropped, kept pair sim={min_sim}"
);
}
#[test]
fn tied_similarity_yields_a_stable_set() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let a = stored("src/a.rs", "a", v.clone());
let b = stored("src/b.rs", "b", v.clone());
let c = stored("src/c.rs", "c", v.clone());
let d = stored("src/d.rs", "d", v.clone());
let chunks = [
labeled("/repo", &a),
labeled("/repo", &b),
labeled("/repo", &c),
labeled("/repo", &d),
];
let key = |p: &DuplicatePair| {
(
p.a.file_path.clone(),
p.a.line_start,
p.b.file_path.clone(),
p.b.line_start,
)
};
let baseline: Vec<_> = find_duplicates(&chunks, 0.0, 3).iter().map(key).collect();
assert_eq!(baseline.len(), 3);
for _ in 0..16 {
let again: Vec<_> = find_duplicates(&chunks, 0.0, 3).iter().map(key).collect();
assert_eq!(again, baseline, "tied-similarity survivors must be stable");
}
}
#[test]
fn tie_similarity_output_is_sorted() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let a = stored("src/a.rs", "foo", v.clone());
let b = stored("src/b.rs", "bar", v.clone());
let c = stored("src/c.rs", "baz", v.clone());
let d = stored("src/d.rs", "qux", v.clone());
let chunks = [
labeled("/repo", &a),
labeled("/repo", &b),
labeled("/repo", &c),
labeled("/repo", &d),
];
let pairs = find_duplicates(&chunks, 0.0, 3);
assert_eq!(pairs.len(), 3);
for window in pairs.windows(2) {
assert!(
window[0].similarity >= window[1].similarity,
"pairs not sorted descending on ties"
);
}
}
}