use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct StringPool {
strings: RwLock<HashMap<String, Arc<str>>>,
usage_counts: RwLock<HashMap<Arc<str>, usize>>,
max_size: usize,
}
impl StringPool {
pub fn new(max_size: usize) -> Self {
Self {
strings: RwLock::new(HashMap::with_capacity(max_size / 2)),
usage_counts: RwLock::new(HashMap::with_capacity(max_size / 2)),
max_size,
}
}
pub async fn intern(&self, s: &str) -> Arc<str> {
{
let strings = self.strings.read().await;
if let Some(interned) = strings.get(s) {
let mut usage_counts = self.usage_counts.write().await;
*usage_counts.entry(Arc::clone(interned)).or_insert(0) += 1;
return Arc::clone(interned);
}
}
let mut strings = self.strings.write().await;
let mut usage_counts = self.usage_counts.write().await;
if let Some(interned) = strings.get(s) {
*usage_counts.entry(Arc::clone(interned)).or_insert(0) += 1;
return Arc::clone(interned);
}
if strings.len() >= self.max_size {
self.evict_lru(&mut strings, &mut usage_counts);
}
let arc_str: Arc<str> = Arc::from(s);
strings.insert(s.to_string(), Arc::clone(&arc_str));
usage_counts.insert(Arc::clone(&arc_str), 1);
arc_str
}
pub async fn intern_batch(&self, strings: &[&str]) -> Vec<Arc<str>> {
let mut result = Vec::with_capacity(strings.len());
{
let pool_strings = self.strings.read().await;
let mut found = Vec::new();
let mut missing = Vec::new();
for (i, s) in strings.iter().enumerate() {
if let Some(interned) = pool_strings.get(*s) {
found.push((i, Arc::clone(interned)));
} else {
missing.push((i, *s));
}
}
result.resize(strings.len(), Arc::from(""));
if !found.is_empty() {
let mut usage_counts = self.usage_counts.write().await;
for (i, interned) in found {
*usage_counts.entry(Arc::clone(&interned)).or_insert(0) += 1;
result[i] = interned;
}
}
if missing.is_empty() {
return result;
}
drop(pool_strings); }
let mut pool_strings = self.strings.write().await;
let mut usage_counts = self.usage_counts.write().await;
for s in strings.iter() {
if result.iter().any(|r| r.as_ref() == *s && !r.is_empty()) {
continue; }
let index = strings.iter().position(|&x| x == *s).unwrap();
if let Some(interned) = pool_strings.get(*s) {
*usage_counts.entry(Arc::clone(interned)).or_insert(0) += 1;
result[index] = Arc::clone(interned);
continue;
}
if pool_strings.len() >= self.max_size {
self.evict_lru(&mut pool_strings, &mut usage_counts);
}
let arc_str: Arc<str> = Arc::from(*s);
pool_strings.insert(s.to_string(), Arc::clone(&arc_str));
usage_counts.insert(Arc::clone(&arc_str), 1);
result[index] = arc_str;
}
result
}
fn evict_lru(
&self,
strings: &mut HashMap<String, Arc<str>>,
usage_counts: &mut HashMap<Arc<str>, usize>,
) {
let target_size = self.max_size * 3 / 4; let evict_count = strings.len().saturating_sub(target_size);
if evict_count == 0 {
return;
}
let mut candidates: Vec<_> = usage_counts
.iter()
.map(|(arc_str, &count)| (count, Arc::clone(arc_str)))
.collect();
candidates.sort_by_key(|(count, _)| *count);
for (_, arc_str) in candidates.into_iter().take(evict_count) {
let key = arc_str.as_ref().to_string();
strings.remove(&key);
usage_counts.remove(&arc_str);
}
}
pub async fn stats(&self) -> StringPoolStats {
let strings = self.strings.read().await;
let usage_counts = self.usage_counts.read().await;
let total_usage: usize = usage_counts.values().sum();
let unique_strings = strings.len();
let total_memory = strings
.keys()
.map(|s| s.len())
.sum::<usize>()
+ unique_strings * std::mem::size_of::<String>()
+ usage_counts.len() * std::mem::size_of::<(Arc<str>, usize)>();
StringPoolStats {
unique_strings,
total_usage,
memory_bytes: total_memory,
hit_rate: if total_usage > unique_strings {
(total_usage - unique_strings) as f64 / total_usage as f64
} else {
0.0
},
}
}
pub async fn clear(&self) {
self.strings.write().await.clear();
self.usage_counts.write().await.clear();
}
}
impl Default for StringPool {
fn default() -> Self {
Self::new(10000) }
}
#[derive(Debug, Clone)]
pub struct StringPoolStats {
pub unique_strings: usize,
pub total_usage: usize,
pub memory_bytes: usize,
pub hit_rate: f64,
}
pub struct ZeroCopyStrings;
impl ZeroCopyStrings {
pub fn substring_cow(s: &str, start: usize, len: usize) -> Cow<'_, str> {
if start == 0 && len >= s.len() {
Cow::Borrowed(s)
} else if start + len <= s.len() {
let byte_start = s
.char_indices()
.nth(start)
.map(|(i, _)| i)
.unwrap_or(s.len());
let byte_end = s
.char_indices()
.nth(start + len)
.map(|(i, _)| i)
.unwrap_or(s.len());
Cow::Borrowed(&s[byte_start..byte_end])
} else {
Cow::Owned(
s.chars()
.skip(start)
.take(len)
.collect::<String>()
)
}
}
pub fn split_no_alloc(s: &str, delimiter: char) -> impl Iterator<Item = &str> {
s.split(delimiter)
}
pub fn join_with_capacity<'a>(
strings: impl IntoIterator<Item = &'a str>,
delimiter: &str
) -> String {
let iter = strings.into_iter();
let (size_hint, _) = iter.size_hint();
let mut capacity = size_hint.saturating_sub(1) * delimiter.len();
capacity += iter.clone().map(|s| s.len()).sum::<usize>();
let mut result = String::with_capacity(capacity);
let mut first = true;
for s in iter {
if !first {
result.push_str(delimiter);
}
result.push_str(s);
first = false;
}
result
}
pub fn sanitize_query_single_pass(query: &str) -> Cow<'_, str> {
let needs_escaping = query.chars().any(|c| {
matches!(c, '\\' | '(' | ')' | '[' | ']' | '{' | '}' | '^' | '~' | ':')
});
if !needs_escaping {
return Cow::Borrowed(query);
}
let mut sanitized = String::with_capacity(query.len() * 2);
for ch in query.chars() {
match ch {
'\\' => sanitized.push_str("\\\\"),
'(' => sanitized.push_str("\\("),
')' => sanitized.push_str("\\)"),
'[' => sanitized.push_str("\\["),
']' => sanitized.push_str("\\]"),
'{' => sanitized.push_str("\\{"),
'}' => sanitized.push_str("\\}"),
'^' => sanitized.push_str("\\^"),
'~' => sanitized.push_str("\\~"),
':' => sanitized.push_str("\\:"),
_ => sanitized.push(ch),
}
}
Cow::Owned(sanitized)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_test;
#[tokio::test]
async fn test_string_pool_basic() {
let pool = StringPool::new(100);
let s1 = pool.intern("test").await;
let s2 = pool.intern("test").await;
assert_eq!(s1.as_ref(), "test");
assert_eq!(s2.as_ref(), "test");
assert!(Arc::ptr_eq(&s1, &s2)); }
#[tokio::test]
async fn test_string_pool_batch() {
let pool = StringPool::new(100);
let strings = ["test1", "test2", "test1", "test3"];
let interned = pool.intern_batch(&strings).await;
assert_eq!(interned.len(), 4);
assert_eq!(interned[0].as_ref(), "test1");
assert_eq!(interned[2].as_ref(), "test1");
assert!(Arc::ptr_eq(&interned[0], &interned[2])); }
#[tokio::test]
async fn test_string_pool_eviction() {
let pool = StringPool::new(3);
let s1 = pool.intern("string1").await;
let s2 = pool.intern("string2").await;
let s3 = pool.intern("string3").await;
let s4 = pool.intern("string4").await;
let stats = pool.stats().await;
assert!(stats.unique_strings <= 3);
let s1_again = pool.intern("string1").await;
assert_eq!(s1_again.as_ref(), "string1");
}
#[tokio::test]
async fn test_string_pool_stats() {
let pool = StringPool::new(100);
pool.intern("test1").await;
pool.intern("test2").await;
pool.intern("test1").await;
let stats = pool.stats().await;
assert_eq!(stats.unique_strings, 2);
assert_eq!(stats.total_usage, 3);
assert!(stats.hit_rate > 0.0);
}
#[test]
fn test_zero_copy_substring() {
let s = "hello world";
let sub1 = ZeroCopyStrings::substring_cow(s, 0, 20);
assert!(matches!(sub1, Cow::Borrowed(_)));
assert_eq!(sub1, "hello world");
let sub2 = ZeroCopyStrings::substring_cow(s, 6, 5);
assert!(matches!(sub2, Cow::Borrowed(_)));
assert_eq!(sub2, "world");
}
#[test]
fn test_zero_copy_join() {
let strings = vec!["hello", "world", "test"];
let joined = ZeroCopyStrings::join_with_capacity(strings.iter().copied(), " ");
assert_eq!(joined, "hello world test");
}
#[test]
fn test_sanitize_query_no_escaping() {
let query = "simple query";
let sanitized = ZeroCopyStrings::sanitize_query_single_pass(query);
assert!(matches!(sanitized, Cow::Borrowed(_)));
assert_eq!(sanitized, "simple query");
}
#[test]
fn test_sanitize_query_with_escaping() {
let query = "query with (parens) and :colons";
let sanitized = ZeroCopyStrings::sanitize_query_single_pass(query);
assert!(matches!(sanitized, Cow::Owned(_)));
assert_eq!(sanitized, "query with \\(parens\\) and \\:colons");
}
#[test]
fn test_split_no_alloc() {
let text = "a,b,c,d";
let parts: Vec<_> = ZeroCopyStrings::split_no_alloc(text, ',').collect();
assert_eq!(parts, vec!["a", "b", "c", "d"]);
}
#[test]
fn test_unicode_safety() {
let unicode_text = "Hello 👋 World 🌍!";
let sub = ZeroCopyStrings::substring_cow(unicode_text, 6, 2);
assert_eq!(sub, "👋 ");
let unicode_query = "search 👋 (test)";
let sanitized = ZeroCopyStrings::sanitize_query_single_pass(unicode_query);
assert_eq!(sanitized, "search 👋 \\(test\\)");
}
}