use alloc::vec::Vec;
use core::str;
use crate::classical::StreamingFingerprinter;
use crate::error::{Error, Result};
use crate::tokenize::Tokenizer;
use super::fingerprinter::MinHashFingerprinter;
use super::sig::MinHashSig;
pub const DEFAULT_MAX_BUFFER_BYTES: usize = 16 * 1024 * 1024;
pub struct MinHashStreaming<T: Tokenizer, const H: usize> {
inner: MinHashFingerprinter<T, H>,
buffer: Vec<u8>,
carry: Vec<u8>,
max_bytes: usize,
}
impl<T: Tokenizer, const H: usize> MinHashStreaming<T, H> {
pub fn new(inner: MinHashFingerprinter<T, H>) -> Self {
Self {
inner,
buffer: Vec::new(),
carry: Vec::with_capacity(4),
max_bytes: DEFAULT_MAX_BUFFER_BYTES,
}
}
#[must_use]
pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
self.max_bytes = max_bytes;
self
}
pub fn buffered_bytes(&self) -> usize {
self.buffer.len()
}
}
impl<T: Tokenizer, const H: usize> StreamingFingerprinter for MinHashStreaming<T, H> {
type Output = MinHashSig<H>;
fn update(&mut self, chunk: &[u8]) -> Result<()> {
if self.buffer.len().saturating_add(chunk.len()) > self.max_bytes {
return Err(Error::InvalidInput("streaming buffer exceeded cap".into()));
}
let mut combined = core::mem::take(&mut self.carry);
combined.reserve(chunk.len());
combined.extend_from_slice(chunk);
let valid_up_to = match str::from_utf8(&combined) {
Ok(_) => combined.len(),
Err(e) => {
if let Some(_invalid) = e.error_len() {
return Err(Error::InvalidInput("invalid UTF-8 in stream".into()));
}
e.valid_up_to()
}
};
self.buffer.extend_from_slice(&combined[..valid_up_to]);
self.carry.clear();
self.carry.extend_from_slice(&combined[valid_up_to..]);
Ok(())
}
fn finalize(self) -> Result<Self::Output> {
if !self.carry.is_empty() {
return Err(Error::InvalidInput("trailing incomplete UTF-8".into()));
}
if self.buffer.is_empty() {
return Err(Error::InvalidInput("empty document".into()));
}
let s = str::from_utf8(&self.buffer)
.map_err(|e| Error::InvalidInput(alloc::format!("internal UTF-8: {e}")))?;
let canonical = self.inner.canonicalizer().canonicalize(s);
self.inner.sketch_canonical(&canonical)
}
fn reset(&mut self) {
self.buffer.clear();
self.carry.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonical::Canonicalizer;
use crate::classical::Fingerprinter;
use crate::tokenize::{ShingleTokenizer, WordTokenizer};
fn make() -> MinHashStreaming<ShingleTokenizer<WordTokenizer>, 64> {
MinHashStreaming::new(MinHashFingerprinter::new(
Canonicalizer::default(),
ShingleTokenizer {
k: 3,
inner: WordTokenizer,
},
))
}
#[test]
fn streaming_matches_offline_for_single_chunk() {
let txt = "the quick brown fox jumps over the lazy dog";
let mut s = make();
s.update(txt.as_bytes()).unwrap();
let stream_sig = s.finalize().unwrap();
let offline = MinHashFingerprinter::<_, 64>::new(
Canonicalizer::default(),
ShingleTokenizer {
k: 3,
inner: WordTokenizer,
},
);
let offline_sig = offline.fingerprint(txt).unwrap();
assert_eq!(stream_sig, offline_sig);
}
#[test]
fn streaming_matches_offline_across_chunks() {
let txt = "the quick brown fox jumps over the lazy dog";
let mut s = make();
for chunk in txt.as_bytes().chunks(7) {
s.update(chunk).unwrap();
}
let stream_sig = s.finalize().unwrap();
let offline = MinHashFingerprinter::<_, 64>::new(
Canonicalizer::default(),
ShingleTokenizer {
k: 3,
inner: WordTokenizer,
},
);
let offline_sig = offline.fingerprint(txt).unwrap();
assert_eq!(stream_sig, offline_sig);
}
#[test]
fn empty_finalize_errors() {
let s = make();
assert!(matches!(s.finalize(), Err(Error::InvalidInput(_))));
}
#[test]
fn invalid_utf8_errors() {
let mut s = make();
let r = s.update(&[0x80]);
assert!(matches!(r, Err(Error::InvalidInput(_))));
}
#[test]
fn split_multibyte_works() {
let mut s = make();
s.update(&[0xC3]).unwrap();
s.update(&[0xA9]).unwrap();
s.update(b" cafe noir cafe noir cafe noir").unwrap();
let sig = s.finalize().unwrap();
assert_ne!(sig.hashes[0], u64::MAX);
}
#[test]
fn reset_clears_buffer() {
let mut s = make();
s.update(b"hello world").unwrap();
s.reset();
assert_eq!(s.buffered_bytes(), 0);
assert!(matches!(s.finalize(), Err(Error::InvalidInput(_))));
}
#[test]
fn buffer_cap_enforced() {
let mut s = make().with_max_bytes(16);
s.update(b"0123456789ABCDEF").unwrap(); let r = s.update(b"!");
assert!(matches!(r, Err(Error::InvalidInput(_))));
}
#[test]
fn trailing_incomplete_multibyte_errors_on_finalize() {
let mut s = make();
s.update(&[0xC3]).unwrap(); let r = s.finalize();
assert!(matches!(r, Err(Error::InvalidInput(_))));
}
}