#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::sync::Arc;
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
use rayon::prelude::*;
use crate::error::Result;
use crate::hyperdim::HVec10240;
use crate::singularity::{Singularity, unix_now_ns};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RetrievalStats {
pub candidate_count: usize,
pub scored_count: usize,
pub fell_back_to_exact_scan: bool,
pub candidate_ns: u64,
pub scoring_ns: u64,
pub selectivity_ratio: f32,
pub filter_strategy: Option<FilterStrategy>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CandidateSource {
Metadata,
Graph,
Bucket,
ExactFallback,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FilterStrategy {
Pre,
BucketPost,
ScanPost,
}
pub(crate) struct ScoredCandidateParams<'a> {
pub query: &'a HVec10240,
pub top_k: usize,
pub candidates: Vec<usize>,
pub start_ns: u64,
pub cand_ns: u64,
pub source: CandidateSource,
pub bypass_cache: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalConfig {
pub max_candidates: usize,
pub candidate_ratio_fallback: f32,
pub graph_depth: u8,
pub graph_fanout: usize,
pub bucket_probe_width: usize,
pub enable_graph_candidates: bool,
pub enable_bucket_candidates: bool,
}
impl RetrievalConfig {
pub fn validate(&self) -> Result<()> {
crate::framework::ChaoticSemanticFramework::validate_retrieval_config(self)
}
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
max_candidates: 1000,
candidate_ratio_fallback: 0.5,
graph_depth: 2,
graph_fanout: 10,
bucket_probe_width: 2,
enable_graph_candidates: false,
enable_bucket_candidates: false,
}
}
}
impl Singularity {
pub fn set_retrieval_config(&mut self, config: RetrievalConfig) -> Result<()> {
config.validate()?;
self._retrieval_config = config;
Ok(())
}
pub fn last_retrieval_stats(&self, ns: &str) -> RetrievalStats {
self.get_namespace(ns)
.and_then(|n| n.last_retrieval_stats.read().ok())
.map(|s| s.clone())
.unwrap_or_default()
}
pub(crate) fn generate_graph_candidates(&self, ns: &str, query: &HVec10240) -> Vec<usize> {
let Some(ns_state) = self.get_namespace(ns) else {
return Vec::new();
};
let mut candidates = std::collections::HashSet::new();
let results = self.exact_similarity_scan(ns, query, 1, unix_now_ns(), true);
if let Some((seed_id, _)) = results.first() {
let mut queue = VecDeque::new();
queue.push_back((seed_id.clone(), 0u8));
candidates.insert(seed_id.clone());
while let Some((id, depth)) = queue.pop_front() {
if depth >= self._retrieval_config.graph_depth {
continue;
}
if let Some(links) = ns_state.associations.get(&id) {
let mut sorted_links: Vec<_> = links.iter().collect();
sorted_links.sort_by(|a, b| b.1.total_cmp(a.1));
for (neighbor_id, _) in sorted_links
.into_iter()
.take(self._retrieval_config.graph_fanout)
{
if !candidates.contains(neighbor_id) {
candidates.insert(neighbor_id.clone());
queue.push_back((neighbor_id.clone(), depth + 1));
}
}
}
}
}
candidates
.into_iter()
.filter_map(|id| ns_state.id_to_index.get(&id).copied())
.collect()
}
pub(crate) fn generate_bucket_candidates(&self, ns: &str, query: &HVec10240) -> Vec<usize> {
let Some(ns_state) = self.get_namespace(ns) else {
return Vec::new();
};
debug_assert!(self._retrieval_config.bucket_probe_width <= 127);
let bucket_mask = (1u128 << self._retrieval_config.bucket_probe_width) - 1;
let query_bucket = query.data[0] & bucket_mask;
let filter = |(idx, vec): (usize, &HVec10240)| {
if (vec.data[0] & bucket_mask) == query_bucket {
Some(idx)
} else {
None
}
};
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
{
ns_state
.concept_vectors
.par_iter()
.enumerate()
.filter_map(filter)
.collect()
}
#[cfg(any(target_arch = "wasm32", not(feature = "parallel")))]
{
ns_state
.concept_vectors
.iter()
.enumerate()
.filter_map(filter)
.collect()
}
}
pub(crate) fn exact_similarity_scan(
&self,
ns: &str,
query: &HVec10240,
top_k: usize,
start_ns: u64,
bypass_cache: bool,
) -> Arc<[(String, f32)]> {
let Some(ns_state) = self.get_namespace(ns) else {
return Arc::from(Vec::new());
};
let scoring_start = unix_now_ns();
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
let mut scores: Vec<(usize, u32)> = ns_state
.concept_vectors
.par_iter()
.enumerate()
.with_min_len(128)
.map(|(idx, v)| (idx, query.hamming_distance(v)))
.collect();
#[cfg(any(target_arch = "wasm32", not(feature = "parallel")))]
let mut scores: Vec<(usize, u32)> = ns_state
.concept_vectors
.iter()
.enumerate()
.map(|(idx, v)| (idx, query.hamming_distance(v)))
.collect();
let scoring_ns = unix_now_ns().saturating_sub(scoring_start);
let scored_count = scores.len();
if scored_count <= top_k {
scores.sort_unstable_by_key(|&(_, dist)| dist);
} else {
scores.select_nth_unstable_by(top_k - 1, |a, b| a.1.cmp(&b.1));
scores.truncate(top_k);
scores.sort_unstable_by_key(|&(_, dist)| dist);
}
let results: Vec<(String, f32)> = scores
.into_iter()
.map(|(idx, dist)| {
let similarity = 1.0 - (dist as f32 / 5120.0);
(ns_state.concept_indices[idx].clone(), similarity)
})
.collect();
let results_arc = Arc::from(results);
if !bypass_cache {
if let Ok(mut cache) = ns_state.query_cache.write() {
let cache_key = crate::singularity::similarity_cache_key(query, top_k);
if cache.put(cache_key, Arc::clone(&results_arc)) {
ns_state
.cache_metrics
.evictions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
}
self.update_stats(
ns,
scored_count,
scored_count,
true,
scoring_start.saturating_sub(start_ns),
scoring_ns,
1.0, None, );
results_arc
}
pub(crate) fn scored_candidate_retrieval(
&self,
ns: &str,
params: ScoredCandidateParams,
) -> Arc<[(String, f32)]> {
let Some(ns_state) = self.get_namespace(ns) else {
return Arc::from(Vec::new());
};
let ScoredCandidateParams {
query,
top_k,
candidates,
start_ns: _start_ns,
cand_ns,
source: _source,
bypass_cache,
} = params;
let scoring_start = unix_now_ns();
let candidate_count = candidates.len();
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
let mut scores: Vec<(usize, u32)> = candidates
.into_par_iter()
.map(|idx| (idx, query.hamming_distance(&ns_state.concept_vectors[idx])))
.collect();
#[cfg(any(target_arch = "wasm32", not(feature = "parallel")))]
let mut scores: Vec<(usize, u32)> = candidates
.into_iter()
.map(|idx| (idx, query.hamming_distance(&ns_state.concept_vectors[idx])))
.collect();
let scoring_ns = unix_now_ns().saturating_sub(scoring_start);
let scored_count = scores.len();
if scores.len() <= top_k {
scores.sort_unstable_by_key(|&(_, dist)| dist);
} else {
scores.select_nth_unstable_by(top_k - 1, |a, b| a.1.cmp(&b.1));
scores.truncate(top_k);
scores.sort_unstable_by_key(|&(_, dist)| dist);
}
let results: Vec<(String, f32)> = scores
.into_iter()
.map(|(idx, dist)| {
let similarity = 1.0 - (dist as f32 / 5120.0);
(ns_state.concept_indices[idx].clone(), similarity)
})
.collect();
let results_arc = Arc::from(results);
if !bypass_cache {
if let Ok(mut cache) = ns_state.query_cache.write() {
let cache_key = crate::singularity::similarity_cache_key(query, top_k);
if cache.put(cache_key, Arc::clone(&results_arc)) {
ns_state
.cache_metrics
.evictions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
}
self.update_stats(
ns,
candidate_count,
scored_count,
false,
cand_ns,
scoring_ns,
0.0,
None,
);
results_arc
}
#[allow(clippy::too_many_arguments)]
fn update_stats(
&self,
ns: &str,
candidates: usize,
scored: usize,
fallback: bool,
cand_ns: u64,
score_ns: u64,
selectivity: f32,
strategy: Option<FilterStrategy>,
) {
if let Some(ns_state) = self.get_namespace(ns) {
let stats = RetrievalStats {
candidate_count: candidates,
scored_count: scored,
fell_back_to_exact_scan: fallback,
candidate_ns: cand_ns,
scoring_ns: score_ns,
selectivity_ratio: selectivity,
filter_strategy: strategy,
};
if let Ok(mut s) = ns_state.last_retrieval_stats.write() {
*s = stats;
}
}
}
pub(crate) fn scored_candidate_retrieval_with_stats(
&self,
ns: &str,
params: ScoredCandidateParams,
selectivity: f32,
strategy: Option<FilterStrategy>,
) -> Arc<[(String, f32)]> {
let Some(ns_state) = self.get_namespace(ns) else {
return Arc::from(Vec::new());
};
let ScoredCandidateParams {
query,
top_k,
candidates,
start_ns: _start_ns,
cand_ns,
source: _source,
bypass_cache,
} = params;
let scoring_start = unix_now_ns();
let candidate_count = candidates.len();
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
let mut scores: Vec<(usize, u32)> = candidates
.into_par_iter()
.map(|idx| (idx, query.hamming_distance(&ns_state.concept_vectors[idx])))
.collect();
#[cfg(any(target_arch = "wasm32", not(feature = "parallel")))]
let mut scores: Vec<(usize, u32)> = candidates
.into_iter()
.map(|idx| (idx, query.hamming_distance(&ns_state.concept_vectors[idx])))
.collect();
let scoring_ns = unix_now_ns().saturating_sub(scoring_start);
let scored_count = scores.len();
if scores.len() <= top_k {
scores.sort_unstable_by_key(|&(_, dist)| dist);
} else {
scores.select_nth_unstable_by(top_k - 1, |a, b| a.1.cmp(&b.1));
scores.truncate(top_k);
scores.sort_unstable_by_key(|&(_, dist)| dist);
}
let results: Vec<(String, f32)> = scores
.into_iter()
.map(|(idx, dist)| {
let similarity = 1.0 - (dist as f32 / 5120.0);
(ns_state.concept_indices[idx].clone(), similarity)
})
.collect();
let results_arc = Arc::from(results);
if !bypass_cache {
if let Ok(mut cache) = ns_state.query_cache.write() {
let cache_key = crate::singularity::similarity_cache_key(query, top_k);
if cache.put(cache_key, Arc::clone(&results_arc)) {
ns_state
.cache_metrics
.evictions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
}
self.update_stats(
ns,
candidate_count,
scored_count,
false,
cand_ns,
scoring_ns,
selectivity,
strategy,
);
results_arc
}
}
#[cfg(test)]
mod tests_v2 {
use crate::singularity::{Singularity, SingularityConfig};
#[test]
fn singularity_last_stats_v2() {
let s = Singularity::new(SingularityConfig::default());
assert_eq!(s.last_retrieval_stats("_default").candidate_count, 0);
}
#[test]
fn singularity_get_config_v2() {
let s = Singularity::new(SingularityConfig::default());
assert_eq!(s.retrieval_config().max_candidates, 1000);
}
}