use crate::core::GLOBAL_INTERNER;
use std::path::Path;
use super::error::{IndexPersistenceError, Result};
use super::formats::StringInternerData;
use super::{INTERNER_MAGIC, MANIFEST_VERSION};
pub fn save_string_interner(path: &Path) -> Result<()> {
let strings = GLOBAL_INTERNER.get_all_strings();
let data = StringInternerData {
magic: INTERNER_MAGIC,
version: MANIFEST_VERSION,
string_count: strings.len() as u64,
strings,
};
super::common::save_encoded_with_crc(&data, path)
}
pub fn load_string_interner(path: &Path) -> Result<StringInternerData> {
let data: StringInternerData = super::common::load_encoded_with_crc(
path,
super::MAX_STRING_INTERNER_FILE_SIZE,
"String interner",
)?;
if data.magic != INTERNER_MAGIC {
return Err(IndexPersistenceError::InvalidMagic {
path: path.to_path_buf(),
expected: INTERNER_MAGIC,
got: data.magic,
});
}
if data.version > MANIFEST_VERSION {
return Err(IndexPersistenceError::UnsupportedVersion {
found: data.version,
supported: MANIFEST_VERSION,
});
}
if data.string_count > super::MAX_STRING_COUNT {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"String count {} exceeds maximum allowed count {}",
data.string_count,
super::MAX_STRING_COUNT
),
});
}
for s in &data.strings {
if s.len() > super::MAX_STRING_LENGTH {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"String length {} exceeds maximum allowed length {}",
s.len(),
super::MAX_STRING_LENGTH
),
});
}
}
Ok(data)
}
pub fn restore_string_interner(data: &StringInternerData) -> Result<()> {
for (idx, s) in data.strings.iter().enumerate() {
let interned_id = GLOBAL_INTERNER.intern(s).map_err(|e| {
IndexPersistenceError::Serialization(format!("Failed to intern string: {}", e))
})?;
if interned_id.as_u32() != idx as u32 {
if s.is_empty() {
continue;
}
return Err(IndexPersistenceError::InternerMismatch {
expected: idx as u32,
got: interned_id.as_u32(),
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crc32fast::Hasher;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_string_interner_round_trip() {
let dir = tempdir().unwrap();
let path = dir.path().join("interner.idx");
let _idx1 = GLOBAL_INTERNER.intern("test_string_1").unwrap();
let _idx2 = GLOBAL_INTERNER.intern("test_string_2").unwrap();
let _idx3 = GLOBAL_INTERNER.intern("test_string_3").unwrap();
save_string_interner(&path).unwrap();
let loaded = load_string_interner(&path).unwrap();
assert_eq!(loaded.magic, INTERNER_MAGIC);
assert!(loaded.strings.contains(&"test_string_1".to_string()));
assert!(loaded.strings.contains(&"test_string_2".to_string()));
assert!(loaded.strings.contains(&"test_string_3".to_string()));
}
#[test]
fn test_invalid_magic_rejected() {
let dir = tempdir().unwrap();
let path = dir.path().join("bad.idx");
let bad_data = StringInternerData {
magic: *b"BAAD",
version: 1,
string_count: 0,
strings: vec![],
};
let encoded = bitcode::encode(&bad_data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let mut data_with_checksum = encoded;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
fs::write(&path, data_with_checksum).unwrap();
let result = load_string_interner(&path);
assert!(matches!(
result,
Err(IndexPersistenceError::InvalidMagic { .. })
));
}
#[test]
fn test_crc_corruption_detected() {
let dir = tempdir().unwrap();
let path = dir.path().join("interner.idx");
let _idx1 = GLOBAL_INTERNER.intern("corruption_test").unwrap();
save_string_interner(&path).unwrap();
let mut bytes = fs::read(&path).unwrap();
bytes[10] ^= 0xFF; fs::write(&path, bytes).unwrap();
let result = load_string_interner(&path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Index file corrupted"));
}
#[test]
fn test_truncated_file_detected() {
let dir = tempdir().unwrap();
let path = dir.path().join("interner.idx");
fs::write(&path, b"ab").unwrap();
let result = load_string_interner(&path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Index file corrupted"));
}
#[test]
fn test_string_count_limit_dos_protection() {
let dir = tempdir().unwrap();
let path = dir.path().join("interner.idx");
let bad_data = StringInternerData {
magic: INTERNER_MAGIC,
version: MANIFEST_VERSION,
string_count: super::super::MAX_STRING_COUNT + 1,
strings: vec!["test".to_string()],
};
let encoded = bitcode::encode(&bad_data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let mut data_with_checksum = encoded;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
fs::write(&path, data_with_checksum).unwrap();
let result = load_string_interner(&path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Size limit exceeded"));
assert!(err.to_string().contains("String count"));
}
#[test]
fn test_string_length_limit_dos_protection() {
let dir = tempdir().unwrap();
let path = dir.path().join("interner.idx");
let oversized_string = "x".repeat(super::super::MAX_STRING_LENGTH + 1);
assert!(
oversized_string.len() > super::super::MAX_STRING_LENGTH,
"String should exceed MAX_STRING_LENGTH"
);
assert!(
oversized_string.len() < super::super::MAX_STRING_INTERNER_FILE_SIZE as usize,
"String should be within file size limit to test string length check"
);
let bad_data = StringInternerData {
magic: INTERNER_MAGIC,
version: MANIFEST_VERSION,
string_count: 1,
strings: vec![oversized_string],
};
let encoded = bitcode::encode(&bad_data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let mut data_with_checksum = encoded;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
fs::write(&path, data_with_checksum).unwrap();
let result = load_string_interner(&path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Size limit exceeded"));
assert!(err.to_string().contains("String length"));
}
#[test]
fn test_restore_string_interner_mismatch() {
let data = StringInternerData {
magic: INTERNER_MAGIC,
version: MANIFEST_VERSION,
string_count: 1,
strings: vec!["type".to_string()],
};
let result = restore_string_interner(&data);
assert!(result.is_err());
match result.unwrap_err() {
IndexPersistenceError::InternerMismatch { expected, got } => {
assert_eq!(expected, 0, "Expected index 0 (from data position)");
assert_ne!(got, 0, "Got index should not be 0");
}
err => panic!("Expected InternerMismatch, got: {:?}", err),
}
}
#[test]
fn test_restore_string_interner_mismatch_with_new_string() {
let unique_string = format!(
"unique_string_{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
);
let data = StringInternerData {
magic: INTERNER_MAGIC,
version: MANIFEST_VERSION,
string_count: 1,
strings: vec![unique_string],
};
let result = restore_string_interner(&data);
assert!(result.is_err());
match result.unwrap_err() {
IndexPersistenceError::InternerMismatch { expected, got } => {
assert_eq!(expected, 0, "Expected index 0 (from data position)");
assert!(
got > 0,
"Got index should be > 0 (because interner is pre-warmed)"
);
}
err => panic!("Expected InternerMismatch, got: {:?}", err),
}
}
}