use std::collections::BinaryHeap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use crate::cost_model::CostTracker;
#[derive(Debug, Clone)]
pub struct RerankConfig {
pub max_io_ops: u32,
pub max_io_bytes: u64,
pub max_latency: Duration,
pub coalesce_threshold: u64,
pub min_rerank_candidates: usize,
pub enable_cache: bool,
pub cache_size: usize,
pub io_queue_depth: u32,
pub prefetch_distance: usize,
}
impl Default for RerankConfig {
fn default() -> Self {
Self {
max_io_ops: 100,
max_io_bytes: 16 * 1024 * 1024, max_latency: Duration::from_millis(50),
coalesce_threshold: 4096, min_rerank_candidates: 10,
enable_cache: true,
cache_size: 10000,
io_queue_depth: 64,
prefetch_distance: 4,
}
}
}
impl RerankConfig {
pub fn io_budget(mut self, max_ops: u32) -> Self {
self.max_io_ops = max_ops;
self
}
pub fn coalesce_threshold(mut self, bytes: u64) -> Self {
self.coalesce_threshold = bytes;
self
}
pub fn max_latency(mut self, latency: Duration) -> Self {
self.max_latency = latency;
self
}
}
#[derive(Debug, Clone)]
pub struct IoRange {
pub offset: u64,
pub length: u64,
pub candidate_indices: Vec<usize>,
}
impl IoRange {
pub fn single(offset: u64, length: u64, candidate_idx: usize) -> Self {
Self {
offset,
length,
candidate_indices: vec![candidate_idx],
}
}
pub fn try_merge(&mut self, other: &IoRange, threshold: u64) -> bool {
let self_end = self.offset + self.length;
let other_end = other.offset + other.length;
if other.offset <= self_end + threshold && self.offset <= other_end + threshold {
let new_start = self.offset.min(other.offset);
let new_end = self_end.max(other_end);
self.offset = new_start;
self.length = new_end - new_start;
self.candidate_indices
.extend_from_slice(&other.candidate_indices);
true
} else {
false
}
}
pub fn end(&self) -> u64 {
self.offset + self.length
}
}
#[derive(Debug, Clone)]
pub struct RerankCandidate {
pub id: u32,
pub proxy_score: f32,
pub disk_offset: u64,
pub vector_size: u32,
}
impl RerankCandidate {
pub fn new(id: u32, proxy_score: f32, disk_offset: u64, vector_size: u32) -> Self {
Self {
id,
proxy_score,
disk_offset,
vector_size,
}
}
}
#[derive(Debug, Clone)]
pub struct RerankResult {
pub id: u32,
pub true_score: f32,
pub from_cache: bool,
}
impl PartialEq for RerankResult {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for RerankResult {}
impl PartialOrd for RerankResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
other.true_score.partial_cmp(&self.true_score)
}
}
impl Ord for RerankResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone, Default)]
pub struct RerankStats {
pub candidates_requested: usize,
pub candidates_reranked: usize,
pub io_ops: u32,
pub io_bytes: u64,
pub coalesced_ranges: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub budget_exhausted: bool,
pub stop_reason: String,
pub duration: Duration,
}
impl RerankStats {
pub fn io_amplification(&self) -> f32 {
if self.candidates_reranked == 0 {
0.0
} else {
self.io_bytes as f32 / (self.candidates_reranked as f32 * 4.0 * 768.0) }
}
pub fn cache_hit_ratio(&self) -> f32 {
let total = self.cache_hits + self.cache_misses;
if total == 0 {
0.0
} else {
self.cache_hits as f32 / total as f32
}
}
}
pub struct IoCoalescer {
threshold: u64,
}
impl IoCoalescer {
pub fn new(threshold: u64) -> Self {
Self { threshold }
}
pub fn coalesce(&self, candidates: &[RerankCandidate]) -> Vec<IoRange> {
if candidates.is_empty() {
return Vec::new();
}
let mut indexed: Vec<(usize, &RerankCandidate)> = candidates.iter().enumerate().collect();
indexed.sort_by_key(|(_, c)| c.disk_offset);
let mut ranges: Vec<IoRange> = Vec::with_capacity(candidates.len());
let (first_idx, first) = indexed[0];
let mut current = IoRange::single(first.disk_offset, first.vector_size as u64, first_idx);
for (idx, candidate) in indexed.iter().skip(1) {
let new_range =
IoRange::single(candidate.disk_offset, candidate.vector_size as u64, *idx);
if !current.try_merge(&new_range, self.threshold) {
ranges.push(current);
current = new_range;
}
}
ranges.push(current);
ranges
}
pub fn coalesce_stats(&self, candidates: &[RerankCandidate]) -> CoalesceStats {
let ranges = self.coalesce(candidates);
let total_raw_bytes: u64 = candidates.iter().map(|c| c.vector_size as u64).sum();
let total_coalesced_bytes: u64 = ranges.iter().map(|r| r.length).sum();
CoalesceStats {
n_candidates: candidates.len(),
n_ranges: ranges.len(),
raw_bytes: total_raw_bytes,
coalesced_bytes: total_coalesced_bytes,
reduction_ratio: total_coalesced_bytes as f32 / total_raw_bytes.max(1) as f32,
}
}
}
#[derive(Debug, Clone)]
pub struct CoalesceStats {
pub n_candidates: usize,
pub n_ranges: usize,
pub raw_bytes: u64,
pub coalesced_bytes: u64,
pub reduction_ratio: f32,
}
pub struct VectorCache {
cache: parking_lot::RwLock<std::collections::HashMap<u32, (Vec<f32>, u64)>>,
max_size: usize,
access_counter: AtomicU64,
}
impl VectorCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: parking_lot::RwLock::new(std::collections::HashMap::with_capacity(max_size)),
max_size,
access_counter: AtomicU64::new(0),
}
}
pub fn get(&self, id: u32) -> Option<Vec<f32>> {
let mut cache = self.cache.write();
if let Some((vec, access)) = cache.get_mut(&id) {
*access = self.access_counter.fetch_add(1, Ordering::Relaxed);
Some(vec.clone())
} else {
None
}
}
pub fn insert(&self, id: u32, vector: Vec<f32>) {
let mut cache = self.cache.write();
if cache.len() >= self.max_size {
let lru_id = cache
.iter()
.min_by_key(|(_, (_, access))| *access)
.map(|(id, _)| *id);
if let Some(lru_id) = lru_id {
cache.remove(&lru_id);
}
}
let access = self.access_counter.fetch_add(1, Ordering::Relaxed);
cache.insert(id, (vector, access));
}
pub fn contains(&self, id: u32) -> bool {
self.cache.read().contains_key(&id)
}
pub fn len(&self) -> usize {
self.cache.read().len()
}
pub fn clear(&self) {
self.cache.write().clear();
}
}
pub type DistanceFn = dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync;
pub type StorageReader = dyn Fn(u64, u64) -> std::io::Result<Vec<u8>> + Send + Sync;
pub struct RerankExecutor {
config: RerankConfig,
coalescer: IoCoalescer,
cache: Option<VectorCache>,
distance_fn: Box<DistanceFn>,
reader: Box<StorageReader>,
dim: usize,
}
impl RerankExecutor {
pub fn new<D, R>(config: RerankConfig, distance_fn: D, reader: R, dim: usize) -> Self
where
D: Fn(&[f32], &[f32]) -> f32 + Send + Sync + 'static,
R: Fn(u64, u64) -> std::io::Result<Vec<u8>> + Send + Sync + 'static,
{
let cache = if config.enable_cache {
Some(VectorCache::new(config.cache_size))
} else {
None
};
Self {
coalescer: IoCoalescer::new(config.coalesce_threshold),
cache,
config,
distance_fn: Box::new(distance_fn),
reader: Box::new(reader),
dim,
}
}
pub fn rerank(
&self,
candidates: &[RerankCandidate],
query: &[f32],
k: usize,
) -> (Vec<RerankResult>, RerankStats) {
self.rerank_with_tracker(candidates, query, k, None)
}
pub fn rerank_with_tracker(
&self,
candidates: &[RerankCandidate],
query: &[f32],
k: usize,
cost_tracker: Option<&CostTracker>,
) -> (Vec<RerankResult>, RerankStats) {
let start = Instant::now();
let mut stats = RerankStats {
candidates_requested: candidates.len(),
..Default::default()
};
let (cached_ids, uncached): (Vec<_>, Vec<_>) =
candidates.iter().enumerate().partition(|(_, c)| {
self.cache
.as_ref()
.map(|cache| cache.contains(c.id))
.unwrap_or(false)
});
let mut results: BinaryHeap<RerankResult> = BinaryHeap::new();
for (_idx, candidate) in cached_ids {
if let Some(ref cache) = self.cache {
if let Some(vector) = cache.get(candidate.id) {
let score = (self.distance_fn)(query, &vector);
results.push(RerankResult {
id: candidate.id,
true_score: score,
from_cache: true,
});
stats.cache_hits += 1;
stats.candidates_reranked += 1;
}
}
}
let uncached_candidates: Vec<_> = uncached.iter().map(|(_, c)| (*c).clone()).collect();
let ranges = self.coalescer.coalesce(&uncached_candidates);
stats.coalesced_ranges = ranges.len();
let mut io_ops = 0u32;
let mut io_bytes = 0u64;
for range in &ranges {
if io_ops >= self.config.max_io_ops {
stats.budget_exhausted = true;
stats.stop_reason = "io_ops_exceeded".to_string();
break;
}
if io_bytes + range.length > self.config.max_io_bytes {
stats.budget_exhausted = true;
stats.stop_reason = "io_bytes_exceeded".to_string();
break;
}
if start.elapsed() > self.config.max_latency {
stats.budget_exhausted = true;
stats.stop_reason = "latency_exceeded".to_string();
break;
}
if let Some(tracker) = cost_tracker {
if !tracker.add_ssd_sequential_bytes(range.length) {
stats.budget_exhausted = true;
stats.stop_reason = "cost_budget_exhausted".to_string();
break;
}
}
let data = match (self.reader)(range.offset, range.length) {
Ok(data) => data,
Err(_) => continue, };
io_ops += 1;
io_bytes += range.length;
for &candidate_idx in &range.candidate_indices {
let candidate = &uncached_candidates[candidate_idx];
let offset_in_range = candidate.disk_offset - range.offset;
let start_byte = offset_in_range as usize;
let end_byte = start_byte + candidate.vector_size as usize;
if end_byte > data.len() {
continue; }
let vector_bytes = &data[start_byte..end_byte];
let vector: Vec<f32> = vector_bytes
.chunks(4)
.map(|chunk| {
let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
f32::from_le_bytes(arr)
})
.collect();
let score = (self.distance_fn)(query, &vector);
results.push(RerankResult {
id: candidate.id,
true_score: score,
from_cache: false,
});
if let Some(ref cache) = self.cache {
cache.insert(candidate.id, vector);
}
stats.cache_misses += 1;
stats.candidates_reranked += 1;
if results.len() >= k * 2
&& stats.candidates_reranked >= self.config.min_rerank_candidates
{
}
}
}
stats.io_ops = io_ops;
stats.io_bytes = io_bytes;
stats.duration = start.elapsed();
if stats.stop_reason.is_empty() {
stats.stop_reason = "complete".to_string();
}
let mut top_k: Vec<RerankResult> = Vec::with_capacity(k);
while top_k.len() < k && !results.is_empty() {
if let Some(result) = results.pop() {
top_k.push(result);
}
}
top_k.sort_by(|a, b| b.true_score.partial_cmp(&a.true_score).unwrap());
(top_k, stats)
}
pub fn config(&self) -> &RerankConfig {
&self.config
}
pub fn cache_stats(&self) -> Option<usize> {
self.cache.as_ref().map(|c| c.len())
}
}
pub struct MockStorage {
data: Vec<u8>,
}
impl MockStorage {
pub fn new(n_vectors: usize, dim: usize) -> Self {
let mut data = Vec::with_capacity(n_vectors * dim * 4);
for i in 0..n_vectors {
for j in 0..dim {
let val = (i + j) as f32 / (n_vectors + dim) as f32;
data.extend_from_slice(&val.to_le_bytes());
}
}
Self { data }
}
pub fn reader(&self) -> impl Fn(u64, u64) -> std::io::Result<Vec<u8>> + '_ {
move |offset, length| {
let start = offset as usize;
let end = (start + length as usize).min(self.data.len());
Ok(self.data[start..end].to_vec())
}
}
pub fn offset(&self, id: u32, dim: usize) -> u64 {
(id as usize * dim * 4) as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_io_coalescing() {
let coalescer = IoCoalescer::new(1024);
let candidates = vec![
RerankCandidate::new(0, 0.9, 0, 3072), RerankCandidate::new(1, 0.8, 3072, 3072), RerankCandidate::new(2, 0.7, 10000, 3072), RerankCandidate::new(3, 0.6, 10500, 3072), ];
let ranges = coalescer.coalesce(&candidates);
assert_eq!(ranges.len(), 2);
assert_eq!(ranges[0].offset, 0);
assert!(ranges[0].length >= 6144);
}
#[test]
fn test_vector_cache() {
let cache = VectorCache::new(3);
cache.insert(1, vec![1.0, 2.0, 3.0]);
cache.insert(2, vec![4.0, 5.0, 6.0]);
cache.insert(3, vec![7.0, 8.0, 9.0]);
assert!(cache.contains(1));
assert!(cache.contains(2));
assert!(cache.contains(3));
cache.get(1);
cache.get(2);
cache.insert(4, vec![10.0, 11.0, 12.0]);
assert!(cache.contains(1));
assert!(cache.contains(2));
assert!(!cache.contains(3)); assert!(cache.contains(4));
}
#[test]
fn test_rerank_executor() {
let dim = 4;
let storage = MockStorage::new(100, dim);
let config = RerankConfig::default();
let distance_fn =
|a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() };
let data_clone = storage.data.clone();
let reader = move |offset: u64, length: u64| -> std::io::Result<Vec<u8>> {
let start = offset as usize;
let end = (start + length as usize).min(data_clone.len());
Ok(data_clone[start..end].to_vec())
};
let executor = RerankExecutor::new(config, distance_fn, reader, dim);
let candidates: Vec<RerankCandidate> = (0..10)
.map(|i| {
RerankCandidate::new(
i,
0.9 - i as f32 * 0.01,
storage.offset(i, dim),
(dim * 4) as u32,
)
})
.collect();
let query = vec![1.0, 1.0, 1.0, 1.0];
let (results, stats) = executor.rerank(&candidates, &query, 5);
assert!(results.len() <= 5);
assert!(stats.candidates_reranked > 0);
assert!(stats.io_ops > 0);
}
#[test]
fn test_coalesce_stats() {
let coalescer = IoCoalescer::new(100);
let candidates: Vec<RerankCandidate> = (0..10)
.map(|i| RerankCandidate::new(i, 0.9, i as u64 * 50, 50))
.collect();
let stats = coalescer.coalesce_stats(&candidates);
assert_eq!(stats.n_candidates, 10);
assert!(stats.n_ranges < 10); assert!(stats.reduction_ratio >= 1.0); }
}