use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use rustc_hash::FxHashMap;
use crate::directories::DirectoryWriter;
use crate::dsl::{Document, Field, Schema};
use crate::error::{Error, Result};
use crate::segment::{SegmentBuilder, SegmentBuilderConfig, SegmentId};
use crate::tokenizer::BoxedTokenizer;
use super::IndexConfig;
const PIPELINE_MAX_SIZE_IN_DOCS: usize = 10_000;
pub struct IndexWriter<D: DirectoryWriter + 'static> {
pub(super) directory: Arc<D>,
pub(super) schema: Arc<Schema>,
pub(super) config: IndexConfig,
doc_sender: async_channel::Sender<Document>,
workers: Vec<std::thread::JoinHandle<()>>,
worker_state: Arc<WorkerState<D>>,
pub(super) segment_manager: Arc<crate::merge::SegmentManager<D>>,
flushed_segments: Vec<(String, u32)>,
primary_key_index: Option<super::primary_key::PrimaryKeyIndex>,
}
struct WorkerState<D: DirectoryWriter + 'static> {
directory: Arc<D>,
schema: Arc<Schema>,
builder_config: SegmentBuilderConfig,
tokenizers: parking_lot::RwLock<FxHashMap<Field, BoxedTokenizer>>,
memory_budget_per_worker: usize,
segment_manager: Arc<crate::merge::SegmentManager<D>>,
built_segments: parking_lot::Mutex<Vec<(String, u32)>>,
flush_count: AtomicUsize,
flush_mutex: parking_lot::Mutex<()>,
flush_cvar: parking_lot::Condvar,
resume_receiver: parking_lot::Mutex<Option<async_channel::Receiver<Document>>>,
resume_epoch: AtomicUsize,
resume_cvar: parking_lot::Condvar,
shutdown: AtomicBool,
num_workers: usize,
}
impl<D: DirectoryWriter + 'static> IndexWriter<D> {
pub async fn create(directory: D, schema: Schema, config: IndexConfig) -> Result<Self> {
Self::create_with_config(directory, schema, config, SegmentBuilderConfig::default()).await
}
pub async fn create_with_config(
directory: D,
schema: Schema,
config: IndexConfig,
builder_config: SegmentBuilderConfig,
) -> Result<Self> {
let directory = Arc::new(directory);
let schema = Arc::new(schema);
let metadata = super::IndexMetadata::new((*schema).clone());
let segment_manager = Arc::new(crate::merge::SegmentManager::new(
Arc::clone(&directory),
Arc::clone(&schema),
metadata,
config.merge_policy.clone_box(),
config.term_cache_blocks,
config.max_concurrent_merges,
));
segment_manager.update_metadata(|_| {}).await?;
Ok(Self::new_with_parts(
directory,
schema,
config,
builder_config,
segment_manager,
))
}
pub async fn open(directory: D, config: IndexConfig) -> Result<Self> {
Self::open_with_config(directory, config, SegmentBuilderConfig::default()).await
}
pub async fn open_with_config(
directory: D,
config: IndexConfig,
builder_config: SegmentBuilderConfig,
) -> Result<Self> {
let directory = Arc::new(directory);
let metadata = super::IndexMetadata::load(directory.as_ref()).await?;
let schema = Arc::new(metadata.schema.clone());
let segment_manager = Arc::new(crate::merge::SegmentManager::new(
Arc::clone(&directory),
Arc::clone(&schema),
metadata,
config.merge_policy.clone_box(),
config.term_cache_blocks,
config.max_concurrent_merges,
));
segment_manager.load_and_publish_trained().await;
Ok(Self::new_with_parts(
directory,
schema,
config,
builder_config,
segment_manager,
))
}
pub fn from_index(index: &super::Index<D>) -> Self {
Self::new_with_parts(
Arc::clone(&index.directory),
Arc::clone(&index.schema),
index.config.clone(),
SegmentBuilderConfig::default(),
Arc::clone(&index.segment_manager),
)
}
fn new_with_parts(
directory: Arc<D>,
schema: Arc<Schema>,
config: IndexConfig,
builder_config: SegmentBuilderConfig,
segment_manager: Arc<crate::merge::SegmentManager<D>>,
) -> Self {
let registry = crate::tokenizer::TokenizerRegistry::new();
let mut tokenizers = FxHashMap::default();
for (field, entry) in schema.fields() {
if matches!(entry.field_type, crate::dsl::FieldType::Text)
&& let Some(ref tok_name) = entry.tokenizer
&& let Some(tok) = registry.get(tok_name)
{
tokenizers.insert(field, tok);
}
}
let num_workers = config.num_indexing_threads.max(1);
let worker_state = Arc::new(WorkerState {
directory: Arc::clone(&directory),
schema: Arc::clone(&schema),
builder_config,
tokenizers: parking_lot::RwLock::new(tokenizers),
memory_budget_per_worker: config.max_indexing_memory_bytes / num_workers,
segment_manager: Arc::clone(&segment_manager),
built_segments: parking_lot::Mutex::new(Vec::new()),
flush_count: AtomicUsize::new(0),
flush_mutex: parking_lot::Mutex::new(()),
flush_cvar: parking_lot::Condvar::new(),
resume_receiver: parking_lot::Mutex::new(None),
resume_epoch: AtomicUsize::new(0),
resume_cvar: parking_lot::Condvar::new(),
shutdown: AtomicBool::new(false),
num_workers,
});
let (doc_sender, workers) = Self::spawn_workers(&worker_state, num_workers);
Self {
directory,
schema,
config,
doc_sender,
workers,
worker_state,
segment_manager,
flushed_segments: Vec::new(),
primary_key_index: None,
}
}
fn spawn_workers(
worker_state: &Arc<WorkerState<D>>,
num_workers: usize,
) -> (
async_channel::Sender<Document>,
Vec<std::thread::JoinHandle<()>>,
) {
let (sender, receiver) = async_channel::bounded(PIPELINE_MAX_SIZE_IN_DOCS);
let handle = tokio::runtime::Handle::current();
let mut workers = Vec::with_capacity(num_workers);
for i in 0..num_workers {
let state = Arc::clone(worker_state);
let rx = receiver.clone();
let rt = handle.clone();
workers.push(
std::thread::Builder::new()
.name(format!("index-worker-{}", i))
.spawn(move || Self::worker_loop(state, rx, rt))
.expect("failed to spawn index worker thread"),
);
}
(sender, workers)
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn set_tokenizer<T: crate::tokenizer::Tokenizer>(&mut self, field: Field, tokenizer: T) {
self.worker_state
.tokenizers
.write()
.insert(field, Box::new(tokenizer));
}
pub async fn init_primary_key_dedup(&mut self) -> Result<()> {
use super::primary_key::{PK_BLOOM_FILE, deserialize_pk_bloom};
let field = match self.schema.primary_field() {
Some(f) => f,
None => return Ok(()),
};
let snapshot = self.segment_manager.acquire_snapshot().await;
let current_seg_ids: Vec<String> = snapshot.segment_ids().to_vec();
let cached = match self
.directory
.open_read(std::path::Path::new(PK_BLOOM_FILE))
.await
{
Ok(handle) => {
let data = handle.read_bytes_range(0..handle.len()).await;
match data {
Ok(bytes) => deserialize_pk_bloom(bytes.as_slice()),
Err(_) => None,
}
}
Err(_) => None,
};
let load_futures: Vec<_> = current_seg_ids
.iter()
.map(|seg_id_str| {
let seg_id_str = seg_id_str.clone();
let dir = self.directory.as_ref();
let schema = Arc::clone(&self.schema);
async move { load_pk_segment_data(dir, &seg_id_str, &schema).await }
})
.collect();
let all_data = futures::future::try_join_all(load_futures).await?;
if let Some((persisted_seg_ids, bloom)) = cached {
let mut pk_data = Vec::with_capacity(all_data.len());
let mut new_data = Vec::new();
for d in all_data {
if persisted_seg_ids.contains(&d.segment_id) {
pk_data.push(d);
} else {
new_data.push(d);
}
}
let needs_persist = !new_data.is_empty();
let new_start = pk_data.len();
pk_data.extend(new_data);
let pk_index = if new_start == pk_data.len() {
super::primary_key::PrimaryKeyIndex::from_persisted(
field,
bloom,
pk_data,
&[],
snapshot,
)
} else {
tokio::task::spawn_blocking(move || {
let mut bloom = bloom;
let mut added = 0usize;
let num_new = pk_data.len() - new_start;
for data in &pk_data[new_start..] {
if let Some(ff) = data.fast_fields.get(&field.0)
&& let Some(dict) = ff.text_dict()
{
for key in dict.iter() {
bloom.insert(key.as_bytes());
added += 1;
}
}
}
if added > 0 {
log::info!(
"[primary_key] bloom: added {} keys from {} new segment(s)",
added,
num_new,
);
}
super::primary_key::PrimaryKeyIndex::from_persisted(
field,
bloom,
pk_data,
&[],
snapshot,
)
})
.await
.map_err(|e| Error::Internal(format!("spawn_blocking failed: {}", e)))?
};
if needs_persist {
self.persist_pk_bloom(&pk_index, ¤t_seg_ids).await;
}
self.primary_key_index = Some(pk_index);
} else {
let pk_index = tokio::task::spawn_blocking(move || {
super::primary_key::PrimaryKeyIndex::new(field, all_data, snapshot)
})
.await
.map_err(|e| Error::Internal(format!("spawn_blocking failed: {}", e)))?;
self.persist_pk_bloom(&pk_index, ¤t_seg_ids).await;
self.primary_key_index = Some(pk_index);
}
Ok(())
}
async fn persist_pk_bloom(
&self,
pk_index: &super::primary_key::PrimaryKeyIndex,
segment_ids: &[String],
) {
use super::primary_key::{PK_BLOOM_FILE, serialize_pk_bloom};
let bloom_bytes = pk_index.bloom_to_bytes();
let data = serialize_pk_bloom(segment_ids, &bloom_bytes);
if let Err(e) = self
.directory
.write(std::path::Path::new(PK_BLOOM_FILE), &data)
.await
{
log::warn!("[primary_key] failed to persist bloom cache: {}", e);
}
}
pub fn add_document(&self, doc: Document) -> Result<()> {
if let Some(ref pk_index) = self.primary_key_index {
pk_index.check_and_insert(&doc)?;
}
match self.doc_sender.try_send(doc) {
Ok(()) => Ok(()),
Err(async_channel::TrySendError::Full(doc)) => {
if let Some(ref pk_index) = self.primary_key_index {
pk_index.rollback_uncommitted_key(&doc);
}
Err(Error::QueueFull)
}
Err(async_channel::TrySendError::Closed(doc)) => {
if let Some(ref pk_index) = self.primary_key_index {
pk_index.rollback_uncommitted_key(&doc);
}
Err(Error::Internal("Document channel closed".into()))
}
}
}
pub fn add_documents(&self, documents: Vec<Document>) -> Result<usize> {
let total = documents.len();
for (i, doc) in documents.into_iter().enumerate() {
match self.add_document(doc) {
Ok(()) => {}
Err(Error::QueueFull) => return Ok(i),
Err(e) => return Err(e),
}
}
Ok(total)
}
fn worker_loop(
state: Arc<WorkerState<D>>,
initial_receiver: async_channel::Receiver<Document>,
handle: tokio::runtime::Handle,
) {
let mut receiver = initial_receiver;
let mut my_epoch = 0usize;
loop {
let build_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut builder: Option<SegmentBuilder> = None;
while let Ok(doc) = receiver.recv_blocking() {
if builder.is_none() {
match SegmentBuilder::new(
Arc::clone(&state.schema),
state.builder_config.clone(),
) {
Ok(mut b) => {
for (field, tokenizer) in state.tokenizers.read().iter() {
b.set_tokenizer(*field, tokenizer.clone_box());
}
builder = Some(b);
}
Err(e) => {
log::error!("Failed to create segment builder: {:?}", e);
continue;
}
}
}
let b = builder.as_mut().unwrap();
if let Err(e) = b.add_document(doc) {
log::error!("Failed to index document: {:?}", e);
continue;
}
let builder_memory = b.estimated_memory_bytes();
if b.num_docs() & 0x3FFF == 0 {
log::debug!(
"[indexing] docs={}, memory={:.2} MB, budget={:.2} MB",
b.num_docs(),
builder_memory as f64 / (1024.0 * 1024.0),
state.memory_budget_per_worker as f64 / (1024.0 * 1024.0)
);
}
const MIN_DOCS_BEFORE_FLUSH: u32 = 100;
let effective_budget = state.memory_budget_per_worker * 4 / 5;
if builder_memory >= effective_budget && b.num_docs() >= MIN_DOCS_BEFORE_FLUSH {
log::info!(
"[indexing] memory budget reached, building segment: \
docs={}, memory={:.2} MB, budget={:.2} MB",
b.num_docs(),
builder_memory as f64 / (1024.0 * 1024.0),
state.memory_budget_per_worker as f64 / (1024.0 * 1024.0),
);
let full_builder = builder.take().unwrap();
Self::build_segment_inline(&state, full_builder, &handle);
}
}
if let Some(b) = builder.take()
&& b.num_docs() > 0
{
Self::build_segment_inline(&state, b, &handle);
}
}));
if build_result.is_err() {
log::error!(
"[worker] panic during indexing cycle — documents in this cycle may be lost"
);
}
let prev = state.flush_count.fetch_add(1, Ordering::Release);
if prev + 1 == state.num_workers {
let _lock = state.flush_mutex.lock();
state.flush_cvar.notify_one();
}
{
let mut lock = state.resume_receiver.lock();
loop {
if state.shutdown.load(Ordering::Acquire) {
return;
}
let current_epoch = state.resume_epoch.load(Ordering::Acquire);
if current_epoch > my_epoch
&& let Some(rx) = lock.as_ref()
{
receiver = rx.clone();
my_epoch = current_epoch;
break;
}
state.resume_cvar.wait(&mut lock);
}
}
}
}
fn build_segment_inline(
state: &WorkerState<D>,
builder: SegmentBuilder,
handle: &tokio::runtime::Handle,
) {
let segment_id = SegmentId::new();
let segment_hex = segment_id.to_hex();
let trained = state.segment_manager.trained();
let doc_count = builder.num_docs();
let build_start = std::time::Instant::now();
log::info!(
"[segment_build] segment_id={} doc_count={} ann={}",
segment_hex,
doc_count,
trained.is_some()
);
match handle.block_on(builder.build(
state.directory.as_ref(),
segment_id,
trained.as_deref(),
)) {
Ok(meta) if meta.num_docs > 0 => {
let duration_ms = build_start.elapsed().as_millis() as u64;
log::info!(
"[segment_build_done] segment_id={} doc_count={} duration_ms={}",
segment_hex,
meta.num_docs,
duration_ms,
);
state
.built_segments
.lock()
.push((segment_hex, meta.num_docs));
}
Ok(_) => {}
Err(e) => {
log::error!(
"[segment_build_failed] segment_id={} error={:?}",
segment_hex,
e
);
}
}
}
pub async fn maybe_merge(&self) {
self.segment_manager.maybe_merge().await;
}
pub async fn abort_merges(&self) {
self.segment_manager.abort_merges().await;
}
pub async fn wait_for_merging_thread(&self) {
self.segment_manager.wait_for_merging_thread().await;
}
pub async fn wait_for_all_merges(&self) {
self.segment_manager.wait_for_all_merges().await;
}
pub fn tracker(&self) -> std::sync::Arc<crate::segment::SegmentTracker> {
self.segment_manager.tracker()
}
pub async fn acquire_snapshot(&self) -> crate::segment::SegmentSnapshot {
self.segment_manager.acquire_snapshot().await
}
pub async fn cleanup_orphan_segments(&self) -> Result<usize> {
self.segment_manager.cleanup_orphan_segments().await
}
pub async fn prepare_commit(&mut self) -> Result<PreparedCommit<'_, D>> {
self.doc_sender.close();
self.worker_state.resume_cvar.notify_all();
let state = Arc::clone(&self.worker_state);
let all_flushed = tokio::task::spawn_blocking(move || {
let mut lock = state.flush_mutex.lock();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300);
while state.flush_count.load(Ordering::Acquire) < state.num_workers {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
log::error!(
"[prepare_commit] timed out waiting for workers: {}/{} flushed",
state.flush_count.load(Ordering::Acquire),
state.num_workers
);
return false;
}
state.flush_cvar.wait_for(&mut lock, remaining);
}
true
})
.await
.map_err(|e| Error::Internal(format!("Failed to wait for workers: {}", e)))?;
if !all_flushed {
self.resume_workers();
return Err(Error::Internal(format!(
"prepare_commit timed out: {}/{} workers flushed",
self.worker_state.flush_count.load(Ordering::Acquire),
self.worker_state.num_workers
)));
}
let built = std::mem::take(&mut *self.worker_state.built_segments.lock());
self.flushed_segments.extend(built);
Ok(PreparedCommit {
writer: self,
is_resolved: false,
})
}
pub async fn commit(&mut self) -> Result<bool> {
self.prepare_commit().await?.commit().await
}
pub async fn force_merge(&mut self) -> Result<()> {
self.prepare_commit().await?.commit().await?;
self.segment_manager.force_merge().await
}
pub async fn reorder(&mut self) -> Result<()> {
self.prepare_commit().await?.commit().await?;
self.segment_manager.reorder_segments().await
}
pub fn segment_manager(&self) -> &Arc<crate::merge::SegmentManager<D>> {
&self.segment_manager
}
fn resume_workers(&mut self) {
if tokio::runtime::Handle::try_current().is_err() {
self.worker_state.shutdown.store(true, Ordering::Release);
self.worker_state.resume_cvar.notify_all();
return;
}
self.worker_state.flush_count.store(0, Ordering::Release);
let (sender, receiver) = async_channel::bounded(PIPELINE_MAX_SIZE_IN_DOCS);
self.doc_sender = sender;
{
let mut lock = self.worker_state.resume_receiver.lock();
*lock = Some(receiver);
}
self.worker_state
.resume_epoch
.fetch_add(1, Ordering::Release);
self.worker_state.resume_cvar.notify_all();
}
}
impl<D: DirectoryWriter + 'static> Drop for IndexWriter<D> {
fn drop(&mut self) {
self.worker_state.shutdown.store(true, Ordering::Release);
self.doc_sender.close();
self.worker_state.resume_cvar.notify_all();
for w in std::mem::take(&mut self.workers) {
let _ = w.join();
}
}
}
pub struct PreparedCommit<'a, D: DirectoryWriter + 'static> {
writer: &'a mut IndexWriter<D>,
is_resolved: bool,
}
impl<'a, D: DirectoryWriter + 'static> PreparedCommit<'a, D> {
pub async fn commit(mut self) -> Result<bool> {
self.is_resolved = true;
let segments = std::mem::take(&mut self.writer.flushed_segments);
if segments.is_empty() {
log::debug!("[commit] no segments to commit, skipping");
self.writer.resume_workers();
return Ok(false);
}
self.writer.segment_manager.commit(segments).await?;
if let Some(ref mut pk_index) = self.writer.primary_key_index {
let snapshot = self.writer.segment_manager.acquire_snapshot().await;
let existing_ids: std::collections::HashSet<&str> =
pk_index.committed_segment_ids().collect();
let load_futures: Vec<_> = snapshot
.segment_ids()
.iter()
.filter(|id| !existing_ids.contains(id.as_str()))
.map(|seg_id_str| {
let seg_id_str = seg_id_str.clone();
let dir = self.writer.directory.as_ref();
let schema = Arc::clone(&self.writer.schema);
async move { load_pk_segment_data(dir, &seg_id_str, &schema).await }
})
.collect();
let new_data = futures::future::try_join_all(load_futures).await?;
let seg_ids: Vec<String> = snapshot.segment_ids().to_vec();
pk_index.refresh_incremental(new_data, snapshot);
let bloom_bytes = pk_index.bloom_to_bytes();
let data = super::primary_key::serialize_pk_bloom(&seg_ids, &bloom_bytes);
if let Err(e) = self
.writer
.directory
.write(
std::path::Path::new(super::primary_key::PK_BLOOM_FILE),
&data,
)
.await
{
log::warn!("[primary_key] failed to persist bloom cache: {}", e);
}
}
self.writer.segment_manager.maybe_merge().await;
self.writer.resume_workers();
Ok(true)
}
pub fn abort(mut self) {
self.is_resolved = true;
self.writer.flushed_segments.clear();
if let Some(ref mut pk_index) = self.writer.primary_key_index {
pk_index.clear_uncommitted();
}
self.writer.resume_workers();
}
}
impl<D: DirectoryWriter + 'static> Drop for PreparedCommit<'_, D> {
fn drop(&mut self) {
if !self.is_resolved {
log::warn!("PreparedCommit dropped without commit/abort — auto-aborting");
self.writer.flushed_segments.clear();
if let Some(ref mut pk_index) = self.writer.primary_key_index {
pk_index.clear_uncommitted();
}
self.writer.resume_workers();
}
}
}
async fn load_pk_segment_data<D: crate::directories::Directory>(
dir: &D,
seg_id_str: &str,
schema: &Arc<crate::dsl::Schema>,
) -> Result<super::primary_key::PkSegmentData> {
let seg_id = crate::segment::SegmentId::from_hex(seg_id_str)
.ok_or_else(|| Error::Internal(format!("Invalid segment id: {}", seg_id_str)))?;
let files = crate::segment::SegmentFiles::new(seg_id.0);
let fast_fields =
crate::segment::reader::loader::load_fast_fields_file(dir, &files, schema).await?;
Ok(super::primary_key::PkSegmentData {
segment_id: seg_id_str.to_string(),
fast_fields,
})
}