use std::hash::{DefaultHasher, Hash, Hasher};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BlockHash(u64);
impl BlockHash {
pub fn value(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BlockHashWithGroupId {
pub block_hash: BlockHash,
pub group_id: u32,
}
#[derive(Debug, Clone, Hash)]
pub enum ExtraHashKey {
MultiModalHash(String),
#[allow(dead_code)]
LoraName(String),
#[allow(dead_code)]
CacheSalt(String),
}
#[derive(Debug, Clone)]
pub struct MultiModalFeature {
pub identifier: String,
pub offset: usize,
pub length: usize,
}
const NONE_HASH_SEED: u64 = 0;
pub fn hash_block_tokens(
parent_hash: Option<BlockHash>,
block_tokens: &[u32],
extra_keys: Option<&[ExtraHashKey]>,
) -> BlockHash {
let mut hasher = DefaultHasher::new();
match parent_hash {
Some(parent) => parent.0.hash(&mut hasher),
None => NONE_HASH_SEED.hash(&mut hasher),
}
block_tokens.hash(&mut hasher);
if let Some(keys) = extra_keys {
for key in keys {
key.hash(&mut hasher);
}
}
BlockHash(hasher.finish())
}
const MAX_MM_EXTRA_KEYS_PER_BLOCK: usize = 32;
pub fn generate_mm_extra_keys(
block_start_token: usize,
block_size: usize,
mm_features: &[MultiModalFeature],
) -> Vec<ExtraHashKey> {
let block_end_token = block_start_token + block_size;
let mut extra_keys = Vec::new();
for feature in mm_features {
let feature_end = feature.offset + feature.length;
if feature.offset < block_end_token && feature_end > block_start_token {
extra_keys.push(ExtraHashKey::MultiModalHash(feature.identifier.clone()));
if extra_keys.len() >= MAX_MM_EXTRA_KEYS_PER_BLOCK {
tracing::warn!(
"Block at token offset {block_start_token} has more than \
{MAX_MM_EXTRA_KEYS_PER_BLOCK} overlapping multimodal features; \
capping extra keys"
);
break;
}
}
}
extra_keys
}
pub fn compute_block_hashes(
tokens: &[u32],
block_size: usize,
mm_features: &[MultiModalFeature],
extra_keys_base: &[ExtraHashKey],
) -> Vec<BlockHash> {
let num_full_blocks = tokens.len() / block_size;
let mut hashes = Vec::with_capacity(num_full_blocks);
let mut parent_hash = None;
for block_idx in 0..num_full_blocks {
let start = block_idx * block_size;
let block_tokens = &tokens[start..start + block_size];
let mut extra_keys = extra_keys_base.to_vec();
let mm_keys = generate_mm_extra_keys(start, block_size, mm_features);
extra_keys.extend(mm_keys);
let extra = if extra_keys.is_empty() {
None
} else {
Some(extra_keys.as_slice())
};
let hash = hash_block_tokens(parent_hash, block_tokens, extra);
hashes.push(hash);
parent_hash = Some(hash);
}
hashes
}
pub fn compute_new_block_hashes(
tokens: &[u32],
block_size: usize,
existing_hashes: &[BlockHash],
mm_features: &[MultiModalFeature],
extra_keys_base: &[ExtraHashKey],
) -> Vec<BlockHash> {
let num_full_blocks = tokens.len() / block_size;
if num_full_blocks <= existing_hashes.len() {
return Vec::new();
}
let mut new_hashes = Vec::new();
let parent_hash = existing_hashes.last().copied();
let start_block = existing_hashes.len();
let mut prev_hash = parent_hash;
for block_idx in start_block..num_full_blocks {
let start = block_idx * block_size;
let block_tokens = &tokens[start..start + block_size];
let mut extra_keys = extra_keys_base.to_vec();
let mm_keys = generate_mm_extra_keys(start, block_size, mm_features);
extra_keys.extend(mm_keys);
let extra = if extra_keys.is_empty() {
None
} else {
Some(extra_keys.as_slice())
};
let hash = hash_block_tokens(prev_hash, block_tokens, extra);
new_hashes.push(hash);
prev_hash = Some(hash);
}
new_hashes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_consistency() {
let tokens = vec![1, 2, 3, 4];
let h1 = hash_block_tokens(None, &tokens, None);
let h2 = hash_block_tokens(None, &tokens, None);
assert_eq!(h1, h2);
}
#[test]
fn test_different_tokens_different_hash() {
let h1 = hash_block_tokens(None, &[1, 2, 3, 4], None);
let h2 = hash_block_tokens(None, &[1, 2, 3, 5], None);
assert_ne!(h1, h2);
}
#[test]
fn test_chain_hashing() {
let h1 = hash_block_tokens(None, &[5, 6, 7, 8], None);
let h2_with_parent = hash_block_tokens(Some(h1), &[9, 10, 11, 12], None);
let h2_without_parent = hash_block_tokens(None, &[9, 10, 11, 12], None);
assert_ne!(h2_with_parent, h2_without_parent);
}
#[test]
fn test_extra_keys_affect_hash() {
let tokens = vec![1, 2, 3, 4];
let h1 = hash_block_tokens(None, &tokens, None);
let extra = vec![ExtraHashKey::MultiModalHash("image_abc".to_string())];
let h2 = hash_block_tokens(None, &tokens, Some(&extra));
assert_ne!(h1, h2);
}
#[test]
fn test_different_mm_hashes_different_block_hash() {
let tokens = vec![1, 2, 3, 4];
let extra1 = vec![ExtraHashKey::MultiModalHash("image_1".to_string())];
let extra2 = vec![ExtraHashKey::MultiModalHash("image_2".to_string())];
let h1 = hash_block_tokens(None, &tokens, Some(&extra1));
let h2 = hash_block_tokens(None, &tokens, Some(&extra2));
assert_ne!(h1, h2);
}
#[test]
fn test_compute_block_hashes() {
let tokens: Vec<u32> = (0..16).collect();
let hashes = compute_block_hashes(&tokens, 4, &[], &[]);
assert_eq!(hashes.len(), 4);
let hashes2 = compute_block_hashes(&tokens, 4, &[], &[]);
assert_eq!(hashes, hashes2);
}
#[test]
fn test_compute_block_hashes_partial_block_ignored() {
let tokens: Vec<u32> = (0..10).collect();
let hashes = compute_block_hashes(&tokens, 4, &[], &[]);
assert_eq!(hashes.len(), 2);
}
#[test]
fn test_incremental_hashing() {
let tokens: Vec<u32> = (0..16).collect();
let all_hashes = compute_block_hashes(&tokens, 4, &[], &[]);
let first_8: Vec<u32> = (0..8).collect();
let initial = compute_block_hashes(&first_8, 4, &[], &[]);
assert_eq!(initial.len(), 2);
let new = compute_new_block_hashes(&tokens, 4, &initial, &[], &[]);
assert_eq!(new.len(), 2);
let mut combined = initial;
combined.extend(new);
assert_eq!(combined, all_hashes);
}
#[test]
fn test_mm_extra_keys_overlap() {
let feature = MultiModalFeature {
identifier: "img_hash_123".to_string(),
offset: 2,
length: 6,
};
let keys = generate_mm_extra_keys(0, 4, std::slice::from_ref(&feature));
assert_eq!(keys.len(), 1);
let keys = generate_mm_extra_keys(4, 4, std::slice::from_ref(&feature));
assert_eq!(keys.len(), 1);
let keys = generate_mm_extra_keys(8, 4, &[feature]);
assert_eq!(keys.len(), 0);
}
#[test]
fn test_mm_extra_keys_multiple_features() {
let features = vec![
MultiModalFeature {
identifier: "image_1".to_string(),
offset: 0,
length: 4,
},
MultiModalFeature {
identifier: "image_2".to_string(),
offset: 8,
length: 4,
},
];
let keys = generate_mm_extra_keys(0, 4, &features);
assert_eq!(keys.len(), 1);
let keys = generate_mm_extra_keys(4, 4, &features);
assert_eq!(keys.len(), 0);
let keys = generate_mm_extra_keys(8, 4, &features);
assert_eq!(keys.len(), 1);
}
#[test]
fn test_block_hash_with_group_id() {
let hash = hash_block_tokens(None, &[1, 2, 3, 4], None);
let g0 = BlockHashWithGroupId {
block_hash: hash,
group_id: 0,
};
let g1 = BlockHashWithGroupId {
block_hash: hash,
group_id: 1,
};
assert_ne!(g0, g1);
}
#[test]
fn test_mm_feature_fully_within_prefix() {
let features = [
MultiModalFeature {
identifier: "img_a".to_string(),
offset: 0,
length: 4,
},
MultiModalFeature {
identifier: "img_b".to_string(),
offset: 6,
length: 4, },
];
let prefix_len = 8;
let fully_cached = features
.iter()
.filter(|f| f.offset + f.length <= prefix_len)
.count();
assert_eq!(fully_cached, 1);
let buggy_count = features.iter().filter(|f| f.offset < prefix_len).count();
assert_eq!(buggy_count, 2);
assert_ne!(
fully_cached, buggy_count,
"The correct filter should NOT match partially-overlapping features"
);
}
}