#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use parking_lot::RwLock;
use std::time::Duration;
use crate::index::hnsw::HnswParams;
mod types;
#[cfg(test)]
mod tests;
pub use types::{
AutoReindexConfig, BenchmarkResult, DivergenceCheck, ReindexEvent, ReindexReason, ReindexState,
};
type EventCallback = Arc<dyn Fn(ReindexEvent) + Send + Sync>;
pub struct AutoReindexManager {
config: RwLock<AutoReindexConfig>,
state: AtomicU8,
event_callback: RwLock<Option<EventCallback>>,
last_reindex_timestamp: RwLock<Option<std::time::Instant>>,
}
#[allow(clippy::cast_possible_truncation)]
impl AutoReindexManager {
#[must_use]
pub fn new(config: AutoReindexConfig) -> Self {
Self {
config: RwLock::new(config),
state: AtomicU8::new(ReindexState::Idle as u8),
event_callback: RwLock::new(None),
last_reindex_timestamp: RwLock::new(None),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(AutoReindexConfig::default())
}
#[must_use]
pub fn state(&self) -> ReindexState {
ReindexState::from(self.state.load(Ordering::Acquire))
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.config.read().enabled
}
pub fn set_enabled(&self, enabled: bool) {
self.config.write().enabled = enabled;
}
pub fn set_config(&self, config: AutoReindexConfig) {
*self.config.write() = config;
}
#[must_use]
pub fn config(&self) -> AutoReindexConfig {
self.config.read().clone()
}
pub fn on_event<F>(&self, callback: F)
where
F: Fn(ReindexEvent) + Send + Sync + 'static,
{
*self.event_callback.write() = Some(Arc::new(callback));
}
fn emit_event(&self, event: ReindexEvent) {
if let Some(ref callback) = *self.event_callback.read() {
callback(event);
}
}
#[must_use]
pub fn check_divergence(
&self,
current_params: &HnswParams,
current_size: usize,
dimension: usize,
) -> DivergenceCheck {
let config = self.config.read();
if current_size < config.min_size_for_reindex {
return DivergenceCheck {
should_reindex: false,
current_m: current_params.max_connections,
optimal_m: current_params.max_connections,
ratio: 1.0,
reason: None,
};
}
let optimal_params = HnswParams::for_dataset_size(dimension, current_size);
let current_m = current_params.max_connections;
let optimal_m = optimal_params.max_connections;
let ratio = if current_m > 0 {
optimal_m as f64 / current_m as f64
} else {
f64::INFINITY
};
let should_reindex = config.enabled && ratio >= config.param_divergence_threshold;
let reason = if should_reindex {
Some(ReindexReason::ParamDivergence {
current_m,
optimal_m,
ratio,
})
} else {
None
};
DivergenceCheck {
should_reindex,
current_m,
optimal_m,
ratio,
reason,
}
}
#[must_use]
pub fn should_reindex(
&self,
current_params: &HnswParams,
current_size: usize,
dimension: usize,
) -> bool {
if let Some(last) = *self.last_reindex_timestamp.read() {
let config = self.config.read();
if last.elapsed() < config.cooldown {
return false;
}
}
if self.state() != ReindexState::Idle {
return false;
}
self.check_divergence(current_params, current_size, dimension)
.should_reindex
}
pub fn validate_benchmark(
&self,
old_benchmark: &BenchmarkResult,
new_benchmark: &BenchmarkResult,
) -> Result<(), String> {
let config = self.config.read();
if old_benchmark.latency_p99_us > 0 {
let latency_change = (new_benchmark.latency_p99_us as f64
- old_benchmark.latency_p99_us as f64)
/ old_benchmark.latency_p99_us as f64
* 100.0;
if latency_change > config.max_latency_regression_percent {
return Err(format!(
"Latency regression: {:.1}% (max allowed: {:.1}%)",
latency_change, config.max_latency_regression_percent
));
}
}
if old_benchmark.recall_estimate > 0.0 {
let recall_change =
(old_benchmark.recall_estimate - new_benchmark.recall_estimate) * 100.0;
if recall_change > config.max_recall_regression_percent {
return Err(format!(
"Recall regression: {:.1}% (max allowed: {:.1}%)",
recall_change, config.max_recall_regression_percent
));
}
}
Ok(())
}
fn transition_to(&self, new_state: ReindexState) -> bool {
let current = self.state.load(Ordering::Acquire);
self.state
.compare_exchange(
current,
new_state as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
fn try_start(
&self,
reason: ReindexReason,
old_params: HnswParams,
new_params: HnswParams,
) -> bool {
if self.state() != ReindexState::Idle {
return false;
}
if self.transition_to(ReindexState::Building) {
self.emit_event(ReindexEvent::Started {
reason,
old_params,
new_params,
});
true
} else {
false
}
}
pub fn trigger_manual_reindex(&self) -> bool {
self.try_start(
ReindexReason::Manual,
HnswParams::default(),
HnswParams::default(),
)
}
pub fn start_reindex(
&self,
reason: ReindexReason,
old_params: HnswParams,
new_params: HnswParams,
) -> bool {
self.try_start(reason, old_params, new_params)
}
pub fn report_progress(&self, percent: u8) {
if self.state() == ReindexState::Building {
self.emit_event(ReindexEvent::Progress {
percent: percent.min(100),
});
}
}
pub fn start_validation(&self, old_latency_p99_us: u64, new_latency_p99_us: u64) -> bool {
if self.state() != ReindexState::Building {
return false;
}
if self.transition_to(ReindexState::Validating) {
self.emit_event(ReindexEvent::Validating {
old_latency_p99_us,
new_latency_p99_us,
});
true
} else {
false
}
}
pub fn complete_reindex(&self, duration: Duration) -> bool {
if self.state() != ReindexState::Validating && self.state() != ReindexState::Swapping {
return false;
}
*self.last_reindex_timestamp.write() = Some(std::time::Instant::now());
self.state
.store(ReindexState::Idle as u8, Ordering::Release);
self.emit_event(ReindexEvent::Completed { duration });
true
}
pub fn rollback(&self, reason: String) -> bool {
let current_state = self.state();
if current_state == ReindexState::Idle {
return false;
}
self.state
.store(ReindexState::Idle as u8, Ordering::Release);
self.emit_event(ReindexEvent::RolledBack { reason });
true
}
pub fn reset(&self) {
self.state
.store(ReindexState::Idle as u8, Ordering::Release);
}
}
impl Default for AutoReindexManager {
fn default() -> Self {
Self::with_defaults()
}
}