use crate::cs::error::{Error, Result};
const PRIME: u64 = 16777619; const BASE: u64 = 256;
fn compute_pattern_hash(pattern: &[u8], m: usize) -> (u64, u64) {
let mut pattern_hash = 0;
let mut h = 1;
for _ in 0..m - 1 {
h = (h * BASE) % PRIME;
}
for &ch in pattern.iter().take(m) {
pattern_hash = (pattern_hash * BASE + ch as u64) % PRIME;
}
(pattern_hash, h)
}
pub fn find_all(text: impl AsRef<[u8]>, pattern: impl AsRef<[u8]>) -> Result<Vec<usize>> {
let text = text.as_ref();
let pattern = pattern.as_ref();
if pattern.is_empty() {
return Err(Error::empty_pattern());
}
if pattern.len() > text.len() {
return Err(Error::pattern_too_long(pattern.len(), text.len()));
}
let m = pattern.len();
let n = text.len();
let mut matches = Vec::new();
if n == 0 {
return Ok(matches);
}
let (pattern_hash, h) = compute_pattern_hash(pattern, m);
let mut text_hash = 0;
for &ch in text.iter().take(m) {
text_hash = (text_hash * BASE + ch as u64) % PRIME;
}
for i in 0..=n - m {
if pattern_hash == text_hash {
if text[i..i + m] == pattern[..] {
matches.push(i);
}
}
if i < n - m {
text_hash = (BASE * (text_hash + PRIME - (h * text[i] as u64) % PRIME)
+ text[i + m] as u64)
% PRIME;
}
}
Ok(matches)
}
pub fn find_first(text: impl AsRef<[u8]>, pattern: impl AsRef<[u8]>) -> Result<Option<usize>> {
let text = text.as_ref();
let pattern = pattern.as_ref();
if pattern.is_empty() {
return Err(Error::empty_pattern());
}
if pattern.len() > text.len() {
return Err(Error::pattern_too_long(pattern.len(), text.len()));
}
let m = pattern.len();
let n = text.len();
if n == 0 {
return Ok(None);
}
let (pattern_hash, h) = compute_pattern_hash(pattern, m);
let mut text_hash = 0;
for &ch in text.iter().take(m) {
text_hash = (text_hash * BASE + ch as u64) % PRIME;
}
for i in 0..=n - m {
if pattern_hash == text_hash {
if text[i..i + m] == pattern[..] {
return Ok(Some(i));
}
}
if i < n - m {
text_hash = (BASE * (text_hash + PRIME - (h * text[i] as u64) % PRIME)
+ text[i + m] as u64)
% PRIME;
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_pattern() {
let text = "hello";
let pattern = "";
assert!(matches!(find_all(text, pattern), Err(Error::EmptyPattern)));
}
#[test]
fn test_pattern_too_long() {
let text = "hi";
let pattern = "hello";
assert!(matches!(
find_all(text, pattern),
Err(Error::PatternTooLong { .. })
));
}
#[test]
fn test_pattern_not_found() {
let text = "hello world";
let pattern = "xyz";
assert_eq!(find_all(text, pattern).unwrap(), Vec::<usize>::new());
assert_eq!(find_first(text, pattern).unwrap(), None);
}
#[test]
fn test_single_match() {
let text = "hello world";
let pattern = "world";
assert_eq!(find_all(text, pattern).unwrap(), vec![6]);
assert_eq!(find_first(text, pattern).unwrap(), Some(6));
}
#[test]
fn test_multiple_matches() {
let text = "AABAACAADAABAAABAA";
let pattern = "AABA";
assert_eq!(find_all(text, pattern).unwrap(), vec![0, 9, 13]);
assert_eq!(find_first(text, pattern).unwrap(), Some(0));
}
#[test]
fn test_overlapping_matches() {
let text = "AAAAA";
let pattern = "AA";
assert_eq!(find_all(text, pattern).unwrap(), vec![0, 1, 2, 3]);
assert_eq!(find_first(text, pattern).unwrap(), Some(0));
}
#[test]
fn test_match_at_start() {
let text = "hello world";
let pattern = "hello";
assert_eq!(find_all(text, pattern).unwrap(), vec![0]);
assert_eq!(find_first(text, pattern).unwrap(), Some(0));
}
#[test]
fn test_match_at_end() {
let text = "hello world";
let pattern = "world";
assert_eq!(find_all(text, pattern).unwrap(), vec![6]);
assert_eq!(find_first(text, pattern).unwrap(), Some(6));
}
#[test]
fn test_unicode_text() {
let text = "Hello 世界!";
let pattern = "世界";
assert_eq!(
find_all(text.as_bytes(), pattern.as_bytes()).unwrap(),
vec![6]
);
assert_eq!(
find_first(text.as_bytes(), pattern.as_bytes()).unwrap(),
Some(6)
);
}
#[test]
fn test_empty_text() {
let text = "";
let pattern = "a";
assert!(matches!(
find_all(text, pattern),
Err(Error::PatternTooLong { .. })
));
}
#[test]
fn test_hash_collisions() {
let text = "abcdef";
let pattern = "abc";
assert_eq!(find_all(text, pattern).unwrap(), vec![0]);
assert_eq!(find_first(text, pattern).unwrap(), Some(0));
}
}