use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use rayon::prelude::*;
pub mod gestalt;
pub use gestalt::gestalt_ratio;
const B2J_WORK_FACTOR: u64 = 34;
#[must_use]
pub fn ratio(a: &str, b: &str) -> f64 {
let av: Vec<char> = a.chars().collect();
let bv: Vec<char> = b.chars().collect();
ratio_chars(&av, &bv)
}
fn ascii_counts(s: &[char]) -> ([u32; 128], u32) {
let mut c = [0u32; 128];
let mut other = 0u32;
for &ch in s {
let u = ch as u32;
if u < 128 {
c[u as usize] += 1;
} else {
other += 1;
}
}
(c, other)
}
fn work_factor() -> u64 {
use std::sync::OnceLock;
static F: OnceLock<u64> = OnceLock::new();
*F.get_or_init(|| std::env::var("DF_WORK_FACTOR").ok().and_then(|s| s.parse().ok()).unwrap_or(B2J_WORK_FACTOR))
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
fn ratio_chars(a: &[char], b: &[char]) -> f64 {
let total = a.len() + b.len();
if total == 0 {
return 1.0;
}
let (ca, oa) = ascii_counts(a);
let (cb, ob) = ascii_counts(b);
if oa > 0 || ob > 0 {
return gestalt::gestalt_ratio_chars(a, b);
}
let mut w = 0u64;
for i in 0..128 {
w += u64::from(ca[i]) * u64::from(cb[i]);
}
if w <= work_factor() * total as u64 {
ratio_b2j_chars(a, b, &cb)
} else {
gestalt::gestalt_ratio_chars(a, b)
}
}
#[must_use]
pub fn ratio_b2j(a: &str, b: &str) -> f64 {
let av: Vec<char> = a.chars().collect();
let bv: Vec<char> = b.chars().collect();
let (cb, ob) = ascii_counts(&bv);
if ob > 0 {
return gestalt::gestalt_ratio_chars(&av, &bv); }
ratio_b2j_chars(&av, &bv, &cb)
}
#[must_use]
pub fn ratio_many(pairs: &[(String, String)]) -> Vec<f64> {
pairs.par_iter().map(|(a, b)| ratio(a, b)).collect()
}
fn build_b2j(b: &[char]) -> HashMap<char, Vec<usize>> {
let mut b2j: HashMap<char, Vec<usize>> = HashMap::new();
for (j, &c) in b.iter().enumerate() {
b2j.entry(c).or_default().push(j);
}
b2j
}
#[allow(clippy::similar_names)]
fn find_longest(a: &[char], b2j: &HashMap<char, Vec<usize>>, alo: usize, ahi: usize, blo: usize, bhi: usize) -> (usize, usize, usize) {
let mut besti = alo;
let mut bestj = blo;
let mut bestsize = 0usize;
let mut j2_prev: HashMap<usize, usize> = HashMap::new();
for (i, ch) in a.iter().enumerate().take(ahi).skip(alo) {
let mut j2_cur: HashMap<usize, usize> = HashMap::new();
if let Some(positions) = b2j.get(ch) {
for &j in positions {
if j < blo {
continue;
}
if j >= bhi {
break;
}
let prev = if j > blo { *j2_prev.get(&(j - 1)).unwrap_or(&0) } else { 0 };
let k = prev + 1;
j2_cur.insert(j, k);
if k > bestsize {
besti = i + 1 - k;
bestj = j + 1 - k;
bestsize = k;
}
}
}
j2_prev = j2_cur;
}
(besti, bestj, bestsize)
}
#[allow(clippy::many_single_char_names)]
fn matching_count(a: &[char], b: &[char], b2j: &HashMap<char, Vec<usize>>) -> usize {
let mut total = 0usize;
let mut stack: Vec<(usize, usize, usize, usize)> = vec![(0, a.len(), 0, b.len())];
while let Some((alo, ahi, blo, bhi)) = stack.pop() {
let (i, j, k) = find_longest(a, b2j, alo, ahi, blo, bhi);
if k > 0 {
total += k;
if alo < i && blo < j {
stack.push((alo, i, blo, j));
}
if i + k < ahi && j + k < bhi {
stack.push((i + k, ahi, j + k, bhi));
}
}
}
total
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn ratio_reference(a: &str, b: &str) -> f64 {
let av: Vec<char> = a.chars().collect();
let bv: Vec<char> = b.chars().collect();
let total = av.len() + bv.len();
if total == 0 {
return 1.0;
}
let b2j = build_b2j(&bv);
2.0 * (matching_count(&av, &bv, &b2j) as f64) / (total as f64)
}
#[derive(Default)]
struct B2jScratch {
offsets: Vec<u32>, positions: Vec<u32>, cursor: Vec<u32>, j2len: Vec<u32>,
erase: Vec<(u32, u32)>, affect: Vec<(u32, u32)>, stack: Vec<(usize, usize, usize, usize)>,
}
thread_local! {
static B2J: RefCell<B2jScratch> = RefCell::new(B2jScratch::default());
}
#[must_use]
pub fn b2j_work(a: &[char], b: &[char]) -> u64 {
let (ca, oa) = ascii_counts(a);
let (cb, ob) = ascii_counts(b);
let mut w = u64::from(oa) * u64::from(ob);
for i in 0..128 {
w += u64::from(ca[i]) * u64::from(cb[i]);
}
w
}
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
fn ratio_b2j_chars(a: &[char], b: &[char], cb: &[u32; 128]) -> f64 {
let total = a.len() + b.len();
if total == 0 {
return 1.0;
}
if a.is_empty() || b.is_empty() {
return 0.0;
}
B2J.with_borrow_mut(|s| {
let B2jScratch { offsets, positions, cursor, j2len, erase, affect, stack } = s;
offsets.clear();
offsets.resize(129, 0);
for c in 0..128 {
offsets[c + 1] = offsets[c] + cb[c];
}
cursor.clear();
cursor.extend_from_slice(&offsets[..128]);
positions.clear();
positions.reserve(b.len());
#[allow(clippy::uninit_vec)]
unsafe {
positions.set_len(b.len());
}
for (j, &ch) in b.iter().enumerate() {
let c = ch as usize; positions[cursor[c] as usize] = j as u32;
cursor[c] += 1;
}
j2len.clear();
j2len.resize(b.len() + 1, 0);
erase.reserve(b.len() + 1);
affect.reserve(b.len() + 1);
stack.clear();
stack.push((0, a.len(), 0, b.len()));
let mut m = 0usize;
while let Some((alo, ahi, blo, bhi)) = stack.pop() {
let (bi, bj, bk) = find_longest_b2j(a, offsets, positions, j2len, erase, affect, alo, ahi, blo, bhi);
if bk > 0 {
m += bk;
if alo < bi && blo < bj {
stack.push((alo, bi, blo, bj));
}
if bi + bk < ahi && bj + bk < bhi {
stack.push((bi + bk, ahi, bj + bk, bhi));
}
}
}
2.0 * m as f64 / total as f64
})
}
#[allow(clippy::too_many_arguments, clippy::cast_possible_truncation, clippy::similar_names)]
fn find_longest_b2j(
a: &[char],
offsets: &[u32],
positions: &[u32],
j2len: &mut [u32],
erase: &mut Vec<(u32, u32)>,
affect: &mut Vec<(u32, u32)>,
alo: usize,
ahi: usize,
blo: usize,
bhi: usize,
) -> (usize, usize, usize) {
let (mut bi, mut bj, mut bk) = (alo, blo, 0usize);
erase.clear();
#[allow(clippy::undocumented_unsafe_blocks)]
unsafe {
for i in alo..ahi {
affect.clear();
let c = *a.get_unchecked(i) as usize;
if c < 128 {
let lo = *offsets.get_unchecked(c) as usize;
let hi = *offsets.get_unchecked(c + 1) as usize;
for &jj in positions.get_unchecked(lo..hi) {
let j = jj as usize;
if j < blo {
continue;
}
if j >= bhi {
break;
}
let k = *j2len.get_unchecked(j) as usize + 1;
affect.push((j as u32 + 1, k as u32));
if k > bk {
bi = i + 1 - k;
bj = j + 1 - k;
bk = k;
}
}
}
for &(p, _) in erase.iter() {
*j2len.get_unchecked_mut(p as usize) = 0;
}
for &(p, v) in affect.iter() {
*j2len.get_unchecked_mut(p as usize) = v;
}
std::mem::swap(erase, affect);
}
for &(p, _) in erase.iter() {
*j2len.get_unchecked_mut(p as usize) = 0;
}
}
(bi, bj, bk)
}
#[allow(clippy::cast_precision_loss)]
fn real_quick_ratio(a: &[char], b: &[char]) -> f64 {
let total = a.len() + b.len();
if total == 0 {
return 1.0;
}
2.0 * (a.len().min(b.len()) as f64) / (total as f64)
}
fn char_counts(a: &[char]) -> Vec<(char, u32)> {
let mut v = a.to_vec();
v.sort_unstable();
let mut out: Vec<(char, u32)> = Vec::new();
for c in v {
match out.last_mut() {
Some(last) if last.0 == c => last.1 += 1,
_ => out.push((c, 1)),
}
}
out
}
#[allow(clippy::cast_precision_loss)]
fn quick_ratio_counts(ca: &[(char, u32)], cb: &[(char, u32)], total: usize) -> f64 {
if total == 0 {
return 1.0;
}
let (mut x, mut y, mut matches) = (0usize, 0usize, 0u32);
while x < ca.len() && y < cb.len() {
match ca[x].0.cmp(&cb[y].0) {
std::cmp::Ordering::Less => x += 1,
std::cmp::Ordering::Greater => y += 1,
std::cmp::Ordering::Equal => {
matches += ca[x].1.min(cb[y].1);
x += 1;
y += 1;
}
}
}
2.0 * f64::from(matches) / total as f64
}
fn uf_find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
fn progress_on() -> bool {
std::env::var_os("DIFFLIB_FAST_PROGRESS").is_some()
}
#[allow(clippy::cast_precision_loss, clippy::many_single_char_names)]
fn qualifying_pairs(chars: &[Vec<char>], sams: &[gestalt::Sam], threshold: f64) -> Vec<(usize, usize, f64)> {
use std::sync::atomic::{AtomicUsize, Ordering};
let n = chars.len();
let rows = AtomicUsize::new(0);
std::thread::scope(|scope| {
if progress_on() {
let rows = &rows;
scope.spawn(move || {
while rows.load(Ordering::Relaxed) < n {
std::thread::sleep(std::time::Duration::from_secs(1));
let done = rows.load(Ordering::Relaxed);
eprintln!(" [difflib-fast] qualifying_pairs: row {done}/{n} ({:.0}%)", done as f64 / n as f64 * 100.0);
}
});
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| chars[i].len());
let counts: Vec<Vec<(char, u32)>> = chars.par_iter().map(|c| char_counts(c)).collect();
let pairs = (0..n)
.into_par_iter()
.flat_map_iter(|p| {
let i = order[p];
let a = &chars[i];
let mut local: Vec<(usize, usize, f64)> = Vec::new();
for &j in &order[p + 1..] {
let b = &chars[j];
if real_quick_ratio(a, b) < threshold {
break; }
if quick_ratio_counts(&counts[i], &counts[j], a.len() + b.len()) < threshold {
continue;
}
let (lo, hi) = if i < j { (i, j) } else { (j, i) };
if let Some(r) = gestalt::gestalt_edge(&chars[lo], &chars[hi], &sams[hi], threshold) {
local.push((lo, hi, r));
}
}
rows.fetch_add(1, Ordering::Relaxed);
local
})
.collect();
rows.store(n, Ordering::Relaxed); pairs
})
}
fn cluster_min_sim(members: &[usize], chars: &[Vec<char>], sams: &[gestalt::Sam], ratios: &HashMap<(usize, usize), f64>) -> f64 {
members
.par_iter()
.enumerate()
.map(|(pos, &i)| {
let mut local = 1.0_f64;
for &j in &members[pos + 1..] {
let key = if i < j { (i, j) } else { (j, i) };
let r = match ratios.get(&key) {
Some(&r) => r, None => gestalt::gestalt_ratio_capped(&chars[key.0], &chars[key.1], &sams[key.1], local),
};
local = local.min(r);
}
local
})
.reduce(|| 1.0_f64, f64::min)
}
fn assemble(n: usize, pairs: Vec<(usize, usize, f64)>, chars: &[Vec<char>], sams: &[gestalt::Sam]) -> Vec<(Vec<usize>, f64)> {
let mut parent: Vec<usize> = (0..n).collect();
let mut ratios: HashMap<(usize, usize), f64> = HashMap::with_capacity(pairs.len());
for (i, j, r) in pairs {
ratios.insert((i, j), r);
let (ri, rj) = (uf_find(&mut parent, i), uf_find(&mut parent, j));
if ri != rj {
parent[ri] = rj;
}
}
let mut comps: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = uf_find(&mut parent, i);
comps.entry(root).or_default().push(i);
}
let mut out: Vec<(Vec<usize>, f64)> = Vec::new();
for members in comps.values() {
if members.len() < 2 {
continue;
}
let min_sim = cluster_min_sim(members, chars, sams, &ratios);
let mut sorted = members.clone();
sorted.sort_unstable();
out.push((sorted, min_sim));
}
out.sort_by(|a, b| a.0[0].cmp(&b.0[0]));
out
}
#[must_use]
pub fn cluster_canonicals_chars(chars: &[Vec<char>], threshold: f64) -> Vec<(Vec<usize>, f64)> {
let n = chars.len();
let sams: Vec<gestalt::Sam> = chars.par_iter().map(|c| gestalt::build_sam(c)).collect();
let pairs = qualifying_pairs(chars, &sams, threshold);
assemble(n, pairs, chars, &sams)
}
#[must_use]
pub fn cluster_canonicals(canonicals: &[String], threshold: f64) -> Vec<(Vec<usize>, f64)> {
let chars: Vec<Vec<char>> = canonicals.iter().map(|s| s.chars().collect()).collect();
cluster_canonicals_chars(&chars, threshold)
}
const SHINGLE_K: usize = 9;
fn fnv1a_bytes(data: &[u8]) -> u64 {
let mut h = 0xcbf2_9ce4_8422_2325u64;
for &b in data {
h ^= u64::from(b);
h = h.wrapping_mul(0x0000_0100_0000_01b3);
}
h
}
fn fnv1a_u64s(values: &[u64]) -> u64 {
let mut h = 0xcbf2_9ce4_8422_2325u64;
for &v in values {
h ^= v;
h = h.wrapping_mul(0x0000_0100_0000_01b3);
}
h
}
fn shingle_hashes(s: &str) -> Vec<u64> {
let bytes = s.as_bytes();
if bytes.len() <= SHINGLE_K {
return vec![fnv1a_bytes(bytes)];
}
let mut set: HashSet<u64> = HashSet::new();
for window in bytes.windows(SHINGLE_K) {
set.insert(fnv1a_bytes(window));
}
set.into_iter().collect()
}
fn make_perms(num: usize) -> Vec<(u64, u64)> {
let mut state = 0x9e37_79b9_7f4a_7c15u64;
let mut next = move || {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
state
};
(0..num).map(|_| (next() | 1, next())).collect()
}
fn minhash(shingles: &[u64], perms: &[(u64, u64)]) -> Vec<u64> {
perms
.iter()
.map(|&(a, b)| shingles.iter().map(|&h| a.wrapping_mul(h).wrapping_add(b)).min().unwrap_or(u64::MAX))
.collect()
}
fn lsh_candidates(sigs: &[Vec<u64>], band_rows: usize) -> HashSet<(usize, usize)> {
let bands = sigs.first().map_or(0, Vec::len).checked_div(band_rows).unwrap_or(0);
let mut candidates: HashSet<(usize, usize)> = HashSet::new();
for band in 0..bands {
let lo = band * band_rows;
let mut buckets: HashMap<u64, Vec<usize>> = HashMap::new();
for (d, sig) in sigs.iter().enumerate() {
buckets.entry(fnv1a_u64s(&sig[lo..lo + band_rows])).or_default().push(d);
}
for docs in buckets.values() {
for a in 0..docs.len() {
for b in (a + 1)..docs.len() {
candidates.insert((docs[a].min(docs[b]), docs[a].max(docs[b])));
}
}
}
}
candidates
}
#[must_use]
pub fn cluster_canonicals_lsh(canonicals: &[String], threshold: f64, num_perm: usize, band_rows: usize) -> Vec<(Vec<usize>, f64)> {
let debug = progress_on();
let start = std::time::Instant::now();
let chars: Vec<Vec<char>> = canonicals.iter().map(|s| s.chars().collect()).collect();
let n = chars.len();
let perms = make_perms(num_perm);
let sigs: Vec<Vec<u64>> = canonicals.par_iter().map(|s| minhash(&shingle_hashes(s), &perms)).collect();
if debug {
eprintln!(" [difflib-fast] lsh: {n} signatures in {:.2}s", start.elapsed().as_secs_f64());
}
let candidates = lsh_candidates(&sigs, band_rows);
if debug {
eprintln!(" [difflib-fast] lsh: {} candidate pairs in {:.2}s", candidates.len(), start.elapsed().as_secs_f64());
}
let sams: Vec<gestalt::Sam> = chars.par_iter().map(|c| gestalt::build_sam(c)).collect();
let cand: Vec<(usize, usize)> = candidates.into_iter().collect();
let pairs: Vec<(usize, usize, f64)> = cand
.par_iter()
.filter_map(|&(i, j)| {
let (a, b) = if i < j { (i, j) } else { (j, i) };
gestalt::gestalt_edge(&chars[a], &chars[b], &sams[b], threshold).map(|r| (a, b, r))
})
.collect();
if debug {
eprintln!(" [difflib-fast] lsh: {} verified pairs in {:.2}s", pairs.len(), start.elapsed().as_secs_f64());
}
assemble(n, pairs, &chars, &sams)
}
#[cfg(feature = "python")]
mod python {
use pyo3::prelude::*;
fn run_on_threads<T: Send>(threads: usize, f: impl FnOnce() -> T + Send) -> T {
if threads == 0 {
return f();
}
match rayon::ThreadPoolBuilder::new().num_threads(threads).build() {
Ok(pool) => pool.install(f),
Err(_) => f(),
}
}
#[pyfunction]
fn ratio(py: Python<'_>, a: &str, b: &str) -> f64 {
let (a, b) = (a.to_owned(), b.to_owned());
py.detach(|| super::ratio(&a, &b))
}
#[pyfunction]
#[pyo3(signature = (pairs, threads=0))]
#[allow(clippy::needless_pass_by_value)]
fn ratio_many(py: Python<'_>, pairs: Vec<(String, String)>, threads: usize) -> Vec<f64> {
py.detach(|| run_on_threads(threads, || super::ratio_many(&pairs)))
}
#[pyfunction]
#[pyo3(signature = (canonicals, threshold, threads=0))]
#[allow(clippy::needless_pass_by_value)]
fn cluster_canonicals(py: Python<'_>, canonicals: Vec<String>, threshold: f64, threads: usize) -> Vec<(Vec<usize>, f64)> {
py.detach(|| run_on_threads(threads, || super::cluster_canonicals(&canonicals, threshold)))
}
#[pyfunction]
#[pyo3(signature = (canonicals, threshold, num_perm, band_rows, threads=0))]
#[allow(clippy::needless_pass_by_value)]
fn cluster_canonicals_lsh(py: Python<'_>, canonicals: Vec<String>, threshold: f64, num_perm: usize, band_rows: usize, threads: usize) -> Vec<(Vec<usize>, f64)> {
py.detach(|| run_on_threads(threads, || super::cluster_canonicals_lsh(&canonicals, threshold, num_perm, band_rows)))
}
#[pymodule]
fn _difflib_fast(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(ratio, m)?)?;
m.add_function(wrap_pyfunction!(ratio_many, m)?)?;
m.add_function(wrap_pyfunction!(cluster_canonicals, m)?)?;
m.add_function(wrap_pyfunction!(cluster_canonicals_lsh, m)?)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::float_cmp, clippy::unreadable_literal)]
use super::{cluster_canonicals, gestalt_ratio, ratio, ratio_b2j, ratio_reference};
#[test]
fn matches_known_difflib_values() {
assert_eq!(gestalt_ratio("", ""), 1.0);
assert_eq!(gestalt_ratio("", "x"), 0.0);
assert_eq!(gestalt_ratio("abc", "abc"), 1.0);
assert_eq!(gestalt_ratio("abc", "abd"), 0.6666666666666666);
assert_eq!(gestalt_ratio("the quick brown fox", "the quick brown dog"), 0.8947368421052632);
assert_eq!(gestalt_ratio("ПриветМир", "ПриветМирЪ"), 0.9473684210526315);
}
#[test]
fn fast_matches_reference() {
let mut s: u64 = 0x1234_5678_9abc_def1;
let mut next = || {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
s
};
for _ in 0..2000 {
let mk = |n: usize, rng: &mut dyn FnMut() -> u64| -> String {
(0..n).map(|_| char::from(b'a' + (rng() % 5) as u8)).collect()
};
let (la, lb) = ((next() % 50) as usize, (next() % 50) as usize);
let a = mk(la, &mut next);
let b = mk(lb, &mut next);
let r = ratio_reference(&a, &b);
assert_eq!(gestalt_ratio(&a, &b), r, "SAM a={a:?} b={b:?}");
assert_eq!(ratio_b2j(&a, &b), r, "b2j a={a:?} b={b:?}");
}
}
#[test]
fn long_strings_all_paths_agree() {
let mut s: u64 = 0xdead_beef_cafe_1234;
let mut next = || {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
s
};
for _ in 0..40 {
let mk = |n: usize, rng: &mut dyn FnMut() -> u64| -> String {
(0..n).map(|_| char::from(b'a' + (rng() % 6) as u8)).collect()
};
let a = mk(1400 + (next() % 600) as usize, &mut next); let b = mk(1400 + (next() % 600) as usize, &mut next);
let r = ratio_reference(&a, &b);
assert_eq!(gestalt_ratio(&a, &b), r);
assert_eq!(ratio_b2j(&a, &b), r);
assert_eq!(ratio(&a, &b), r); }
}
#[test]
fn clusters_obvious_duplicates() {
let corpus: Vec<String> = vec![
"def add(a, b): return a + b".into(),
"def add(x, y): return x + y".into(),
"completely unrelated text here".into(),
];
let clusters = cluster_canonicals(&corpus, 0.5);
assert_eq!(clusters.len(), 1, "the two add() variants should cluster");
assert_eq!(clusters[0].0, vec![0, 1]);
assert!(clusters[0].1 >= 0.5);
}
}