use std::sync::Arc;
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use crate::dsl::Field;
use crate::segment::SegmentReader;
pub struct LazyGlobalStats {
segments: Vec<Arc<SegmentReader>>,
total_docs: u64,
sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
sparse_total_vectors_cache: RwLock<FxHashMap<u32, u64>>,
text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
}
impl LazyGlobalStats {
pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
Self {
segments,
total_docs,
sparse_idf_cache: RwLock::new(FxHashMap::default()),
sparse_total_vectors_cache: RwLock::new(FxHashMap::default()),
text_idf_cache: RwLock::new(FxHashMap::default()),
avg_field_len_cache: RwLock::new(FxHashMap::default()),
}
}
#[inline]
pub fn total_docs(&self) -> u64 {
self.total_docs
}
pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
{
let cache = self.sparse_idf_cache.read();
if let Some(field_cache) = cache.get(&field.0)
&& let Some(&idf) = field_cache.get(&dim_id)
{
return idf;
}
}
let df = self.compute_sparse_df(field, dim_id);
let n = self.cached_sparse_n(field);
let idf = if df > 0 && n > 0 {
(n as f32 / df as f32).ln().max(0.0)
} else {
0.0
};
{
let mut cache = self.sparse_idf_cache.write();
cache.entry(field.0).or_default().insert(dim_id, idf);
}
idf
}
pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
let mut result = vec![0.0f32; dim_ids.len()];
let mut misses: Vec<usize> = Vec::new();
{
let cache = self.sparse_idf_cache.read();
if let Some(field_cache) = cache.get(&field.0) {
for (i, &dim_id) in dim_ids.iter().enumerate() {
if let Some(&idf) = field_cache.get(&dim_id) {
result[i] = idf;
} else {
misses.push(i);
}
}
} else {
misses.extend(0..dim_ids.len());
}
}
if misses.is_empty() {
return result;
}
let n = self.cached_sparse_n(field);
let mut new_entries: Vec<(u32, f32)> = Vec::with_capacity(misses.len());
for &i in &misses {
let dim_id = dim_ids[i];
let df = self.compute_sparse_df(field, dim_id);
let idf = if df > 0 && n > 0 {
(n as f32 / df as f32).ln().max(0.0)
} else {
0.0
};
result[i] = idf;
new_entries.push((dim_id, idf));
}
{
let mut cache = self.sparse_idf_cache.write();
let field_cache = cache.entry(field.0).or_default();
for (dim_id, idf) in new_entries {
field_cache.insert(dim_id, idf);
}
}
result
}
fn cached_sparse_n(&self, field: Field) -> u64 {
{
let cache = self.sparse_total_vectors_cache.read();
if let Some(&tv) = cache.get(&field.0) {
return tv.max(self.total_docs);
}
}
let tv = self.compute_sparse_total_vectors(field);
self.sparse_total_vectors_cache.write().insert(field.0, tv);
tv.max(self.total_docs)
}
pub fn text_idf(&self, field: Field, term: &str) -> f32 {
{
let cache = self.text_idf_cache.read();
if let Some(field_cache) = cache.get(&field.0)
&& let Some(&idf) = field_cache.get(term)
{
return idf;
}
}
let df = self.compute_text_df(field, term);
let n = self.total_docs as f32;
let df_f = df as f32;
let idf = if df > 0 {
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
} else {
0.0
};
{
let mut cache = self.text_idf_cache.write();
cache
.entry(field.0)
.or_default()
.insert(term.to_string(), idf);
}
idf
}
pub fn avg_field_len(&self, field: Field) -> f32 {
{
let cache = self.avg_field_len_cache.read();
if let Some(&avg) = cache.get(&field.0) {
return avg;
}
}
let mut weighted_sum = 0.0f64;
let mut total_weight = 0u64;
for segment in &self.segments {
let avg_len = segment.avg_field_len(field);
let doc_count = segment.num_docs() as u64;
if avg_len > 0.0 && doc_count > 0 {
weighted_sum += avg_len as f64 * doc_count as f64;
total_weight += doc_count;
}
}
let avg = if total_weight > 0 {
(weighted_sum / total_weight as f64) as f32
} else {
1.0
};
{
let mut cache = self.avg_field_len_cache.write();
cache.insert(field.0, avg);
}
avg
}
fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
let mut df = 0u64;
for segment in &self.segments {
if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
df += sparse_index.doc_count(dim_id) as u64;
}
}
df
}
fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
let mut total = 0u64;
for segment in &self.segments {
if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
total += sparse_index.total_vectors as u64;
}
}
total
}
fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
0
}
pub fn num_segments(&self) -> usize {
self.segments.len()
}
}
impl std::fmt::Debug for LazyGlobalStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LazyGlobalStats")
.field("total_docs", &self.total_docs)
.field("num_segments", &self.segments.len())
.field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
.field("text_cache_fields", &self.text_idf_cache.read().len())
.finish()
}
}
#[derive(Debug)]
pub struct GlobalStats {
total_docs: u64,
sparse_stats: FxHashMap<u32, SparseFieldStats>,
text_stats: FxHashMap<u32, TextFieldStats>,
generation: u64,
}
#[derive(Debug, Default)]
pub struct SparseFieldStats {
pub doc_freqs: FxHashMap<u32, u64>,
}
#[derive(Debug, Default)]
pub struct TextFieldStats {
pub doc_freqs: FxHashMap<String, u64>,
pub avg_field_len: f32,
}
impl GlobalStats {
pub fn new() -> Self {
Self {
total_docs: 0,
sparse_stats: FxHashMap::default(),
text_stats: FxHashMap::default(),
generation: 0,
}
}
#[inline]
pub fn total_docs(&self) -> u64 {
self.total_docs
}
#[inline]
pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
if let Some(stats) = self.sparse_stats.get(&field.0)
&& let Some(&df) = stats.doc_freqs.get(&dim_id)
&& df > 0
{
return (self.total_docs as f32 / df as f32).ln();
}
0.0
}
pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
}
#[inline]
pub fn text_idf(&self, field: Field, term: &str) -> f32 {
if let Some(stats) = self.text_stats.get(&field.0)
&& let Some(&df) = stats.doc_freqs.get(term)
{
let n = self.total_docs as f32;
let df = df as f32;
return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
}
0.0
}
#[inline]
pub fn avg_field_len(&self, field: Field) -> f32 {
self.text_stats
.get(&field.0)
.map(|s| s.avg_field_len)
.unwrap_or(1.0)
}
#[inline]
pub fn generation(&self) -> u64 {
self.generation
}
}
impl Default for GlobalStats {
fn default() -> Self {
Self::new()
}
}
pub struct GlobalStatsBuilder {
pub total_docs: u64,
sparse_stats: FxHashMap<u32, SparseFieldStats>,
text_stats: FxHashMap<u32, TextFieldStats>,
}
impl GlobalStatsBuilder {
pub fn new() -> Self {
Self {
total_docs: 0,
sparse_stats: FxHashMap::default(),
text_stats: FxHashMap::default(),
}
}
pub fn add_segment(&mut self, reader: &SegmentReader) {
self.total_docs += reader.num_docs() as u64;
}
pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
let stats = self.sparse_stats.entry(field.0).or_default();
*stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
}
pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
let stats = self.text_stats.entry(field.0).or_default();
*stats.doc_freqs.entry(term).or_insert(0) += doc_count;
}
pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
let stats = self.text_stats.entry(field.0).or_default();
stats.avg_field_len = avg_len;
}
pub fn build(self, generation: u64) -> GlobalStats {
GlobalStats {
total_docs: self.total_docs,
sparse_stats: self.sparse_stats,
text_stats: self.text_stats,
generation,
}
}
}
impl Default for GlobalStatsBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct GlobalStatsCache {
stats: RwLock<Option<Arc<GlobalStats>>>,
generation: RwLock<u64>,
}
impl GlobalStatsCache {
pub fn new() -> Self {
Self {
stats: RwLock::new(None),
generation: RwLock::new(0),
}
}
pub fn invalidate(&self) {
let mut current_gen = self.generation.write();
*current_gen += 1;
let mut stats = self.stats.write();
*stats = None;
}
pub fn generation(&self) -> u64 {
*self.generation.read()
}
pub fn get(&self) -> Option<Arc<GlobalStats>> {
self.stats.read().clone()
}
pub fn set(&self, stats: GlobalStats) {
let mut cached = self.stats.write();
*cached = Some(Arc::new(stats));
}
pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
where
F: FnOnce(&mut GlobalStatsBuilder),
{
if let Some(stats) = self.get() {
return stats;
}
let current_gen = self.generation();
let mut builder = GlobalStatsBuilder::new();
compute(&mut builder);
let stats = Arc::new(builder.build(current_gen));
let mut cached = self.stats.write();
*cached = Some(Arc::clone(&stats));
stats
}
pub fn needs_rebuild(&self) -> bool {
self.stats.read().is_none()
}
pub fn set_stats(&self, stats: GlobalStats) {
let mut cached = self.stats.write();
*cached = Some(Arc::new(stats));
}
}
impl Default for GlobalStatsCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_idf_computation() {
let mut builder = GlobalStatsBuilder::new();
builder.total_docs = 1000;
builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10);
let stats = builder.build(1);
let idf_42 = stats.sparse_idf(Field(0), 42);
let idf_43 = stats.sparse_idf(Field(0), 43);
assert!(idf_43 > idf_42);
assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
}
#[test]
fn test_text_idf_computation() {
let mut builder = GlobalStatsBuilder::new();
builder.total_docs = 10000;
builder.add_text_df(Field(0), "common".to_string(), 5000);
builder.add_text_df(Field(0), "rare".to_string(), 10);
let stats = builder.build(1);
let idf_common = stats.text_idf(Field(0), "common");
let idf_rare = stats.text_idf(Field(0), "rare");
assert!(idf_rare > idf_common);
}
#[test]
fn test_cache_invalidation() {
let cache = GlobalStatsCache::new();
assert!(cache.get().is_none());
let stats = cache.get_or_compute(|builder| {
builder.total_docs = 100;
});
assert_eq!(stats.total_docs(), 100);
assert!(cache.get().is_some());
cache.invalidate();
assert!(cache.get().is_none());
}
}