use crate::outliers::SimpleFileInfo;
use ndarray::Array2;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
struct DbscanContext<'a> {
epsilon: f64,
min_points: usize,
distances: &'a Array2<f64>,
labels: &'a mut [Option<usize>],
visited: &'a mut [bool],
}
#[derive(Debug, Error)]
pub enum ClusteringError {
#[error("Insufficient files for clustering: {0} < {1}")]
InsufficientFiles(usize, usize),
#[error("Invalid similarity threshold: {0} (must be 50-100)")]
InvalidSimilarity(u8),
#[error("DBSCAN failed: {0}")]
DbscanError(String),
#[error("Hash computation failed: {0}")]
HashError(String),
}
pub type ClusteringResult<T> = Result<T, ClusteringError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LargeFileCluster {
pub cluster_id: usize,
pub files: Vec<SimpleFileInfo>,
pub total_size: u64,
pub avg_similarity: f64,
pub density: f64,
}
pub fn similarity_to_distance(similarity: u8) -> f64 {
100.0 - f64::from(similarity)
}
pub fn build_distance_matrix(files: &[SimpleFileInfo]) -> Array2<f64> {
let n = files.len();
let mut distances = Array2::zeros((n, n));
let upper_triangle: Vec<_> = (0..n).into_par_iter()
.flat_map(|i| (i+1..n).into_par_iter().map(move |j| (i, j)))
.map(|(i, j)| {
let sim = calculate_similarity_safe(&files[i], &files[j]);
(i, j, similarity_to_distance(sim))
})
.collect();
for (i, j, dist) in upper_triangle {
distances[[i, j]] = dist;
distances[[j, i]] = dist;
}
distances
}
fn calculate_similarity_safe(a: &SimpleFileInfo, b: &SimpleFileInfo) -> u8 {
match (&a.ssdeep_hash, &b.ssdeep_hash) {
(Some(h1), Some(h2)) => {
ssdeep::compare(h1, h2).unwrap_or(0)
}
_ => 0,
}
}
pub fn detect_large_file_clusters(
files: &[SimpleFileInfo],
min_similarity: u8,
min_cluster_size: usize,
) -> ClusteringResult<Vec<LargeFileCluster>> {
if !(50..=100).contains(&min_similarity) {
return Err(ClusteringError::InvalidSimilarity(min_similarity));
}
if files.len() < min_cluster_size {
return Ok(vec![]);
}
let hashable_files: Vec<_> = files.iter()
.filter(|f| f.ssdeep_hash.is_some())
.cloned()
.collect();
if hashable_files.len() < min_cluster_size {
return Ok(vec![]);
}
let distances = build_distance_matrix(&hashable_files);
let epsilon = similarity_to_distance(min_similarity);
let cluster_labels = simple_dbscan(&distances, epsilon, min_cluster_size);
Ok(aggregate_clusters(&hashable_files, cluster_labels, &distances))
}
fn simple_dbscan(distances: &Array2<f64>, epsilon: f64, min_points: usize) -> Vec<Option<usize>> {
let n = distances.shape()[0];
let mut labels = vec![None; n];
let mut visited = vec![false; n];
let mut cluster_id = 0;
for i in 0..n {
if visited[i] {
continue;
}
visited[i] = true;
let neighbors: Vec<usize> = (0..n)
.filter(|&j| distances[[i, j]] <= epsilon)
.collect();
if neighbors.len() >= min_points {
let mut ctx = DbscanContext {
epsilon,
min_points,
distances,
labels: &mut labels,
visited: &mut visited,
};
expand_cluster(i, &neighbors, cluster_id, &mut ctx);
cluster_id += 1;
}
}
labels
}
fn expand_cluster(
point: usize,
neighbors: &[usize],
cluster_id: usize,
ctx: &mut DbscanContext,
) {
ctx.labels[point] = Some(cluster_id);
let mut seed_set = neighbors.to_vec();
let mut i = 0;
while i < seed_set.len() {
let q = seed_set[i];
if !ctx.visited[q] {
ctx.visited[q] = true;
let q_neighbors: Vec<usize> = (0..ctx.distances.shape()[0])
.filter(|&j| ctx.distances[[q, j]] <= ctx.epsilon)
.collect();
if q_neighbors.len() >= ctx.min_points {
for &neighbor in &q_neighbors {
if !seed_set.contains(&neighbor) {
seed_set.push(neighbor);
}
}
}
}
if ctx.labels[q].is_none() {
ctx.labels[q] = Some(cluster_id);
}
i += 1;
}
}
fn aggregate_clusters(
files: &[SimpleFileInfo],
cluster_labels: Vec<Option<usize>>,
distances: &Array2<f64>,
) -> Vec<LargeFileCluster> {
let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
for (idx, label) in cluster_labels.iter().enumerate() {
if let Some(cluster_id) = label {
cluster_map.entry(*cluster_id).or_default().push(idx);
}
}
cluster_map.into_iter()
.map(|(id, indices)| build_cluster_info(id, &indices, files, distances))
.collect()
}
fn build_cluster_info(
cluster_id: usize,
indices: &[usize],
files: &[SimpleFileInfo],
distances: &Array2<f64>,
) -> LargeFileCluster {
let cluster_files: Vec<_> = indices.iter()
.map(|&i| files[i].clone())
.collect();
let total_size = cluster_files.iter().map(|f| f.size_bytes).sum();
let avg_similarity = calculate_avg_similarity(indices, distances);
let density = calculate_density(indices, distances);
LargeFileCluster {
cluster_id,
files: cluster_files,
total_size,
avg_similarity,
density,
}
}
fn calculate_avg_similarity(indices: &[usize], distances: &Array2<f64>) -> f64 {
if indices.len() <= 1 {
return 100.0;
}
let mut total_similarity = 0.0;
let mut count = 0;
for i in 0..indices.len() {
for j in i+1..indices.len() {
let distance = distances[[indices[i], indices[j]]];
total_similarity += 100.0 - distance;
count += 1;
}
}
if count > 0 {
total_similarity / count as f64
} else {
100.0
}
}
fn calculate_density(indices: &[usize], distances: &Array2<f64>) -> f64 {
if indices.len() <= 1 {
return 1.0;
}
let max_edges = indices.len() * (indices.len() - 1) / 2;
let mut edges_within_epsilon = 0;
for i in 0..indices.len() {
for j in i+1..indices.len() {
if distances[[indices[i], indices[j]]] < 30.0 {
edges_within_epsilon += 1;
}
}
}
edges_within_epsilon as f64 / max_edges as f64
}
pub fn compute_lsh_buckets(files: &[SimpleFileInfo]) -> HashMap<u64, Vec<SimpleFileInfo>> {
let mut buckets = HashMap::new();
for file in files {
if let Some(hash) = &file.ssdeep_hash {
let parts: Vec<_> = hash.split(':').collect();
if parts.len() >= 3 {
let block_size = parts[0].parse::<u64>().unwrap_or(0);
let chunk = &parts[1][..parts[1].len().min(8)];
let bucket_key = hash_combine(block_size, chunk);
buckets.entry(bucket_key).or_insert_with(Vec::new).push(file.clone());
}
}
}
buckets
}
fn hash_combine(block_size: u64, chunk: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
block_size.hash(&mut hasher);
chunk.hash(&mut hasher);
hasher.finish()
}
pub fn detect_clusters_batched(
files: &[SimpleFileInfo],
min_similarity: u8,
min_cluster_size: usize,
batch_size: usize,
) -> ClusteringResult<Vec<LargeFileCluster>> {
if files.len() <= batch_size {
return detect_large_file_clusters(files, min_similarity, min_cluster_size);
}
let lsh_buckets = compute_lsh_buckets(files);
let mut all_clusters = Vec::new();
for bucket in lsh_buckets.values() {
if bucket.len() >= min_cluster_size {
let batch_clusters = detect_large_file_clusters(
bucket,
min_similarity,
min_cluster_size
)?;
all_clusters.extend(batch_clusters);
}
}
Ok(merge_overlapping_clusters(all_clusters))
}
fn merge_overlapping_clusters(clusters: Vec<LargeFileCluster>) -> Vec<LargeFileCluster> {
if clusters.is_empty() {
return vec![];
}
let mut merged = Vec::new();
let mut processed = HashSet::new();
for (i, cluster) in clusters.iter().enumerate() {
if processed.contains(&i) {
continue;
}
let mut merged_cluster = cluster.clone();
let mut file_paths: HashSet<_> = cluster.files.iter()
.map(|f| &f.path)
.collect();
for (j, other) in clusters.iter().enumerate().skip(i + 1) {
if processed.contains(&j) {
continue;
}
let overlap = other.files.iter()
.any(|f| file_paths.contains(&f.path));
if overlap {
for file in &other.files {
if file_paths.insert(&file.path) {
merged_cluster.files.push(file.clone());
}
}
merged_cluster.total_size += other.total_size;
processed.insert(j);
}
}
merged_cluster.total_size = merged_cluster.files.iter()
.map(|f| f.size_bytes)
.sum();
merged.push(merged_cluster);
processed.insert(i);
}
merged
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn create_test_file(path: &str, size: u64, hash: Option<String>) -> SimpleFileInfo {
SimpleFileInfo {
path: PathBuf::from(path),
size_bytes: size,
ssdeep_hash: hash,
}
}
#[test]
fn test_similarity_to_distance() {
assert_eq!(similarity_to_distance(100), 0.0);
assert_eq!(similarity_to_distance(70), 30.0);
assert_eq!(similarity_to_distance(0), 100.0);
}
#[test]
fn test_distance_matrix_symmetry() {
let files = vec![
create_test_file("a.txt", 1000, Some("3:abc:def".to_string())),
create_test_file("b.txt", 1000, Some("3:abc:ghi".to_string())),
create_test_file("c.txt", 1000, Some("3:xyz:123".to_string())),
];
let matrix = build_distance_matrix(&files);
for i in 0..files.len() {
for j in 0..files.len() {
assert_eq!(
matrix[[i, j]], matrix[[j, i]],
"Distance matrix must be symmetric"
);
}
}
}
#[test]
fn test_cluster_detection_edge_cases() {
assert_eq!(
detect_large_file_clusters(&[], 70, 2).unwrap().len(),
0
);
let single = vec![
create_test_file("single.txt", 1000, Some("3:abc:def".to_string()))
];
assert_eq!(
detect_large_file_clusters(&single, 70, 2).unwrap().len(),
0
);
let no_hashes = vec![
create_test_file("a.txt", 1000, None),
create_test_file("b.txt", 1000, None),
];
assert_eq!(
detect_large_file_clusters(&no_hashes, 70, 2).unwrap().len(),
0
);
}
#[test]
fn test_invalid_similarity_threshold() {
let files = vec![
create_test_file("a.txt", 1000, Some("3:abc:def".to_string()))
];
assert!(matches!(
detect_large_file_clusters(&files, 49, 2),
Err(ClusteringError::InvalidSimilarity(49))
));
assert!(matches!(
detect_large_file_clusters(&files, 101, 2),
Err(ClusteringError::InvalidSimilarity(101))
));
}
#[test]
fn test_lsh_buckets() {
let files = vec![
create_test_file("a.txt", 1000, Some("3:abc123:def456".to_string())),
create_test_file("b.txt", 1000, Some("3:abc123:ghi789".to_string())),
create_test_file("c.txt", 1000, Some("6:xyz123:123456".to_string())),
];
let buckets = compute_lsh_buckets(&files);
assert!(!buckets.is_empty());
}
#[test]
fn test_cluster_aggregation() {
let files = vec![
create_test_file("a.txt", 1000, Some("3:abc:def".to_string())),
create_test_file("b.txt", 2000, Some("3:abc:def".to_string())),
create_test_file("c.txt", 3000, Some("3:xyz:123".to_string())),
];
let distances = build_distance_matrix(&files);
let cluster_labels = vec![Some(0), Some(0), Some(1)];
let clusters = aggregate_clusters(&files, cluster_labels, &distances);
assert_eq!(clusters.len(), 2);
let mut sorted_clusters = clusters.clone();
sorted_clusters.sort_by_key(|c| c.files.len());
assert_eq!(sorted_clusters[1].files.len(), 2); assert_eq!(sorted_clusters[1].total_size, 3000); assert_eq!(sorted_clusters[0].files.len(), 1); assert_eq!(sorted_clusters[0].total_size, 3000); }
#[test]
fn test_dbscan_implementation() {
let mut distances = Array2::zeros((4, 4));
distances[[0, 1]] = 10.0;
distances[[1, 0]] = 10.0;
distances[[2, 3]] = 10.0;
distances[[3, 2]] = 10.0;
distances[[0, 2]] = 90.0;
distances[[0, 3]] = 90.0;
distances[[1, 2]] = 90.0;
distances[[1, 3]] = 90.0;
distances[[2, 0]] = 90.0;
distances[[3, 0]] = 90.0;
distances[[2, 1]] = 90.0;
distances[[3, 1]] = 90.0;
let labels = simple_dbscan(&distances, 20.0, 2);
assert!(labels[0].is_some());
assert!(labels[1].is_some());
assert!(labels[2].is_some());
assert!(labels[3].is_some());
assert_eq!(labels[0], labels[1]); assert_eq!(labels[2], labels[3]); assert_ne!(labels[0], labels[2]); }
#[test]
fn test_cluster_density_calculation() {
let indices = vec![0, 1, 2];
let mut distances = Array2::zeros((3, 3));
distances[[0, 1]] = 10.0;
distances[[1, 0]] = 10.0;
distances[[0, 2]] = 10.0;
distances[[2, 0]] = 10.0;
distances[[1, 2]] = 10.0;
distances[[2, 1]] = 10.0;
let density = calculate_density(&indices, &distances);
assert!(density > 0.9); }
#[test]
fn test_average_similarity_calculation() {
let indices = vec![0, 1, 2];
let mut distances = Array2::zeros((3, 3));
distances[[0, 1]] = 20.0; distances[[1, 0]] = 20.0;
distances[[0, 2]] = 30.0; distances[[2, 0]] = 30.0;
distances[[1, 2]] = 10.0; distances[[2, 1]] = 10.0;
let avg_sim = calculate_avg_similarity(&indices, &distances);
assert!((avg_sim - 80.0).abs() < 0.1);
}
#[test]
fn test_merge_overlapping_clusters() {
let cluster1 = LargeFileCluster {
cluster_id: 0,
files: vec![
create_test_file("a.txt", 1000, Some("3:abc:def".to_string())),
create_test_file("b.txt", 2000, Some("3:abc:def".to_string())),
],
total_size: 3000,
avg_similarity: 90.0,
density: 0.8,
};
let cluster2 = LargeFileCluster {
cluster_id: 1,
files: vec![
create_test_file("b.txt", 2000, Some("3:abc:def".to_string())), create_test_file("c.txt", 3000, Some("3:abc:ghi".to_string())),
],
total_size: 5000,
avg_similarity: 85.0,
density: 0.7,
};
let merged = merge_overlapping_clusters(vec![cluster1, cluster2]);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].files.len(), 3); assert_eq!(merged[0].total_size, 6000); }
#[test]
fn test_batch_clustering() {
let files: Vec<_> = (0..10)
.map(|i| create_test_file(&format!("file{}.txt", i), 1000 * (i as u64 + 1), Some(format!("3:abc{}:def", i % 3))))
.collect();
let result = detect_clusters_batched(&files, 70, 2, 5);
assert!(result.is_ok());
let clusters = result.unwrap();
assert!(!clusters.is_empty());
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
prop_compose! {
fn arb_file_set()(
count in 2..50usize,
base_hash in "[0-9a-f]{6}",
) -> Vec<SimpleFileInfo> {
(0..count).map(|i| {
SimpleFileInfo {
path: std::path::PathBuf::from(format!("file_{}.dat", i)),
size_bytes: 1024 * (i as u64 + 1),
ssdeep_hash: Some(format!("3:{}:{}", base_hash, i % 5)),
}
}).collect()
}
}
proptest! {
#[test]
fn prop_cluster_similarity_invariant(
files in arb_file_set(),
min_sim in 50..100u8,
) {
let result = detect_large_file_clusters(&files, min_sim, 2);
if let Ok(clusters) = result {
for cluster in clusters {
for i in 0..cluster.files.len() {
for j in i+1..cluster.files.len() {
if let (Some(h1), Some(h2)) = (&cluster.files[i].ssdeep_hash, &cluster.files[j].ssdeep_hash) {
if let Ok(sim) = ssdeep::compare(h1, h2) {
prop_assert!(sim >= min_sim,
"Files in cluster must have similarity {} >= {}",
sim, min_sim);
}
}
}
}
}
}
}
#[test]
fn prop_cluster_disjoint_invariant(
files in arb_file_set(),
min_sim in 50..100u8,
) {
let result = detect_large_file_clusters(&files, min_sim, 2);
if let Ok(clusters) = result {
let mut seen = HashSet::new();
for cluster in &clusters {
for file in &cluster.files {
prop_assert!(seen.insert(&file.path),
"File {:?} appears in multiple clusters", file.path);
}
}
}
}
#[test]
fn prop_minimum_cluster_size(
files in arb_file_set(),
min_sim in 50..100u8,
min_size in 2..10usize,
) {
let result = detect_large_file_clusters(&files, min_sim, min_size);
if let Ok(clusters) = result {
for cluster in clusters {
prop_assert!(cluster.files.len() >= min_size,
"Cluster size {} < minimum {}",
cluster.files.len(), min_size);
}
}
}
#[test]
fn prop_distance_matrix_symmetric(files in arb_file_set()) {
let matrix = build_distance_matrix(&files);
let n = files.len();
for i in 0..n {
for j in 0..n {
prop_assert_eq!(
matrix[[i, j]], matrix[[j, i]],
"Distance matrix must be symmetric at [{}, {}]", i, j
);
}
}
}
#[test]
fn prop_similarity_distance_monotonic(s1 in 0..=100u8, s2 in 0..=100u8) {
let d1 = similarity_to_distance(s1);
let d2 = similarity_to_distance(s2);
if s1 < s2 {
prop_assert!(d1 > d2, "Higher similarity must yield lower distance");
} else if s1 > s2 {
prop_assert!(d1 < d2, "Lower similarity must yield higher distance");
} else {
prop_assert_eq!(d1, d2, "Equal similarity must yield equal distance");
}
}
}
}