use anyhow::Result;
use rocksdb::{DB, Options};
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info};
use post_cortex_embeddings::{VectorDB, VectorDbConfig};
use super::RealRocksDBStorage;
use super::types::EMBEDDING_DIMENSION;
impl RealRocksDBStorage {
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let data_dir = path.as_ref().to_path_buf();
if !data_dir.exists() {
std::fs::create_dir_all(&data_dir)?;
}
let db_path = data_dir.join("rocksdb");
let mut opts = Options::default();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
opts.set_max_open_files(1000); opts.set_write_buffer_size(64 * 1024 * 1024); opts.set_max_write_buffer_number(6); opts.set_min_write_buffer_number_to_merge(2); opts.set_target_file_size_base(64 * 1024 * 1024);
opts.set_level_zero_slowdown_writes_trigger(20); opts.set_level_zero_stop_writes_trigger(36);
opts.set_use_fsync(false); opts.set_bytes_per_sync(1024 * 1024);
opts.set_compaction_style(rocksdb::DBCompactionStyle::Universal);
opts.set_prefix_extractor(rocksdb::SliceTransform::create_fixed_prefix(16));
let db = DB::open(&opts, &db_path)?;
let db = Arc::new(db);
info!(
"RealRocksDBStorage: Initialized RocksDB at {}",
db_path.display()
);
let vector_config = VectorDbConfig {
dimension: EMBEDDING_DIMENSION,
enable_hnsw_index: true,
max_connections: 16,
num_layers: 4,
..Default::default()
};
let vector_index = Arc::new(VectorDB::new(vector_config)?);
let storage = Self {
db: db.clone(),
data_dir,
vector_index,
};
storage.rebuild_hnsw_index().await?;
Ok(storage)
}
async fn rebuild_hnsw_index(&self) -> Result<()> {
let embeddings = self.load_all_embeddings().await?;
let count = embeddings.len();
if count == 0 {
info!("RealRocksDBStorage: No embeddings to load into HNSW index");
return Ok(());
}
info!(
"RealRocksDBStorage: Loading {} embeddings into HNSW index...",
count
);
let start = std::time::Instant::now();
const BATCH_SIZE: usize = 100;
let mut loaded = 0;
for chunk in embeddings.chunks(BATCH_SIZE) {
let batch: Vec<(Vec<f32>, _)> = chunk
.iter()
.map(|e| (e.vector.clone(), e.to_metadata()))
.collect();
match self.vector_index.add_vectors_batch(batch) {
Ok(ids) => loaded += ids.len(),
Err(e) => {
tracing::warn!("Failed to add embedding batch to HNSW index: {}", e);
}
}
if loaded % 500 == 0 && loaded > 0 {
debug!(
"RealRocksDBStorage: HNSW rebuild progress: {}/{} vectors",
loaded, count
);
}
}
let elapsed = start.elapsed();
info!(
"RealRocksDBStorage: HNSW index ready with {} vectors (rebuilt in {:.1}ms)",
self.vector_index.len(),
elapsed.as_secs_f64() * 1000.0
);
Ok(())
}
pub async fn get_stats(&self) -> Result<String> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || -> Result<String> {
let stats = db
.property_value(rocksdb::properties::STATS)?
.unwrap_or_else(|| "No stats available".to_string());
Ok(stats)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))?
}
pub async fn compact(&self) -> Result<()> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || -> Result<()> {
db.compact_range(None::<&[u8]>, None::<&[u8]>);
info!("RealRocksDBStorage: Database compacted");
Ok(())
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))?
}
pub async fn get_key_count(&self) -> Result<usize> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || -> Result<usize> {
if let Some(count_str) = db.property_value(rocksdb::properties::ESTIMATE_NUM_KEYS)? {
if let Ok(count) = count_str.parse::<usize>() {
return Ok(count);
}
}
let mut count = 0;
let iter = db.iterator(rocksdb::IteratorMode::Start);
for item in iter {
let _ = item?;
count += 1;
}
Ok(count)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))?
}
}