Skip to main content

hermes_core/segment/merger/
mod.rs

1//! Segment merger for combining multiple segments
2
3mod dense_vectors;
4mod sparse_vectors;
5
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::io::Write;
9use std::sync::Arc;
10
11use rustc_hash::FxHashMap;
12
13use super::reader::SegmentReader;
14use super::store::StoreMerger;
15use super::types::{FieldStats, SegmentFiles, SegmentId, SegmentMeta};
16use crate::Result;
17use crate::directories::{Directory, DirectoryWriter, StreamingWriter};
18use crate::dsl::Schema;
19use crate::structures::{
20    BlockPostingList, PositionPostingList, PostingList, SSTableWriter, TERMINATED, TermInfo,
21};
22
23/// Write adapter that tracks bytes written.
24///
25/// Concrete type so it works with generic `serialize<W: Write>` functions
26/// (unlike `dyn StreamingWriter` which isn't `Sized`).
27pub(crate) struct OffsetWriter {
28    inner: Box<dyn StreamingWriter>,
29    offset: u64,
30}
31
32impl OffsetWriter {
33    fn new(inner: Box<dyn StreamingWriter>) -> Self {
34        Self { inner, offset: 0 }
35    }
36
37    /// Current write position (total bytes written so far).
38    fn offset(&self) -> u64 {
39        self.offset
40    }
41
42    /// Finalize the underlying streaming writer.
43    fn finish(self) -> std::io::Result<()> {
44        self.inner.finish()
45    }
46}
47
48impl Write for OffsetWriter {
49    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
50        let n = self.inner.write(buf)?;
51        self.offset += n as u64;
52        Ok(n)
53    }
54
55    fn flush(&mut self) -> std::io::Result<()> {
56        self.inner.flush()
57    }
58}
59
60/// Format byte count as human-readable string
61fn format_bytes(bytes: usize) -> String {
62    if bytes >= 1024 * 1024 * 1024 {
63        format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
64    } else if bytes >= 1024 * 1024 {
65        format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
66    } else if bytes >= 1024 {
67        format!("{:.2} KB", bytes as f64 / 1024.0)
68    } else {
69        format!("{} B", bytes)
70    }
71}
72
73/// Compute per-segment doc ID offsets (each segment's docs start after the previous)
74fn doc_offsets(segments: &[SegmentReader]) -> Vec<u32> {
75    let mut offsets = Vec::with_capacity(segments.len());
76    let mut acc = 0u32;
77    for seg in segments {
78        offsets.push(acc);
79        acc += seg.num_docs();
80    }
81    offsets
82}
83
84/// Statistics for merge operations
85#[derive(Debug, Clone, Default)]
86pub struct MergeStats {
87    /// Number of terms processed
88    pub terms_processed: usize,
89    /// Peak memory usage in bytes (estimated)
90    pub peak_memory_bytes: usize,
91    /// Term dictionary output size
92    pub term_dict_bytes: usize,
93    /// Postings output size
94    pub postings_bytes: usize,
95    /// Store output size
96    pub store_bytes: usize,
97    /// Vector index output size
98    pub vectors_bytes: usize,
99    /// Sparse vector index output size
100    pub sparse_bytes: usize,
101}
102
103/// Entry for k-way merge heap
104struct MergeEntry {
105    key: Vec<u8>,
106    term_info: TermInfo,
107    segment_idx: usize,
108    doc_offset: u32,
109}
110
111impl PartialEq for MergeEntry {
112    fn eq(&self, other: &Self) -> bool {
113        self.key == other.key
114    }
115}
116
117impl Eq for MergeEntry {}
118
119impl PartialOrd for MergeEntry {
120    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125impl Ord for MergeEntry {
126    fn cmp(&self, other: &Self) -> Ordering {
127        // Reverse order for min-heap (BinaryHeap is max-heap by default)
128        other.key.cmp(&self.key)
129    }
130}
131
132/// Trained vector index structures for rebuilding segments with ANN indexes
133#[derive(Clone)]
134pub struct TrainedVectorStructures {
135    /// Trained centroids per field_id
136    pub centroids: rustc_hash::FxHashMap<u32, Arc<crate::structures::CoarseCentroids>>,
137    /// Trained PQ codebooks per field_id (for ScaNN)
138    pub codebooks: rustc_hash::FxHashMap<u32, Arc<crate::structures::PQCodebook>>,
139}
140
141/// Segment merger - merges multiple segments into one
142pub struct SegmentMerger {
143    schema: Arc<Schema>,
144}
145
146impl SegmentMerger {
147    pub fn new(schema: Arc<Schema>) -> Self {
148        Self { schema }
149    }
150
151    /// Merge segments into one, streaming postings/positions/store directly to files.
152    ///
153    /// If `trained` is provided, dense vectors use O(1) cluster merge when possible
154    /// (homogeneous IVF/ScaNN), otherwise rebuilds ANN from trained structures.
155    /// Without trained structures, only flat vectors are merged.
156    ///
157    /// Uses streaming writers so postings, positions, and store data flow directly
158    /// to files instead of buffering everything in memory. Only the term dictionary
159    /// (compact key+TermInfo entries) is buffered.
160    pub async fn merge<D: Directory + DirectoryWriter>(
161        &self,
162        dir: &D,
163        segments: &[SegmentReader],
164        new_segment_id: SegmentId,
165        trained: Option<&TrainedVectorStructures>,
166    ) -> Result<(SegmentMeta, MergeStats)> {
167        let mut stats = MergeStats::default();
168        let files = SegmentFiles::new(new_segment_id.0);
169
170        // === Phase 1: merge postings + positions (streaming) ===
171        let mut postings_writer = OffsetWriter::new(dir.streaming_writer(&files.postings).await?);
172        let mut positions_writer = OffsetWriter::new(dir.streaming_writer(&files.positions).await?);
173        let mut term_dict_writer = OffsetWriter::new(dir.streaming_writer(&files.term_dict).await?);
174
175        let terms_processed = self
176            .merge_postings(
177                segments,
178                &mut term_dict_writer,
179                &mut postings_writer,
180                &mut positions_writer,
181                &mut stats,
182            )
183            .await?;
184        stats.terms_processed = terms_processed;
185        stats.postings_bytes = postings_writer.offset() as usize;
186        stats.term_dict_bytes = term_dict_writer.offset() as usize;
187        let positions_bytes = positions_writer.offset();
188
189        postings_writer.finish()?;
190        term_dict_writer.finish()?;
191        if positions_bytes > 0 {
192            positions_writer.finish()?;
193        } else {
194            drop(positions_writer);
195            let _ = dir.delete(&files.positions).await;
196        }
197
198        // === Phase 2: merge store files (streaming) ===
199        {
200            let mut store_writer = OffsetWriter::new(dir.streaming_writer(&files.store).await?);
201            {
202                let mut store_merger = StoreMerger::new(&mut store_writer);
203                for segment in segments {
204                    if segment.store_has_dict() {
205                        store_merger
206                            .append_store_recompressing(segment.store())
207                            .await
208                            .map_err(crate::Error::Io)?;
209                    } else {
210                        let raw_blocks = segment.store_raw_blocks();
211                        let data_slice = segment.store_data_slice();
212                        store_merger.append_store(data_slice, &raw_blocks).await?;
213                    }
214                }
215                store_merger.finish()?;
216            }
217            stats.store_bytes = store_writer.offset() as usize;
218            store_writer.finish()?;
219        }
220
221        // === Dense vectors ===
222        let vectors_bytes = self
223            .merge_dense_vectors(dir, segments, &files, trained)
224            .await?;
225        stats.vectors_bytes = vectors_bytes;
226
227        // === Mandatory: merge sparse vectors ===
228        let sparse_bytes = self.merge_sparse_vectors(dir, segments, &files).await?;
229        stats.sparse_bytes = sparse_bytes;
230
231        // === Mandatory: merge field stats + write meta ===
232        let mut merged_field_stats: FxHashMap<u32, FieldStats> = FxHashMap::default();
233        for segment in segments {
234            for (&field_id, field_stats) in &segment.meta().field_stats {
235                let entry = merged_field_stats.entry(field_id).or_default();
236                entry.total_tokens += field_stats.total_tokens;
237                entry.doc_count += field_stats.doc_count;
238            }
239        }
240
241        let total_docs: u32 = segments.iter().map(|s| s.num_docs()).sum();
242        let meta = SegmentMeta {
243            id: new_segment_id.0,
244            num_docs: total_docs,
245            field_stats: merged_field_stats,
246        };
247
248        dir.write(&files.meta, &meta.serialize()?).await?;
249
250        let label = if trained.is_some() {
251            "ANN merge"
252        } else {
253            "Merge"
254        };
255        log::info!(
256            "{} complete: {} docs, {} terms, term_dict={}, postings={}, store={}, vectors={}, sparse={}",
257            label,
258            total_docs,
259            stats.terms_processed,
260            format_bytes(stats.term_dict_bytes),
261            format_bytes(stats.postings_bytes),
262            format_bytes(stats.store_bytes),
263            format_bytes(stats.vectors_bytes),
264            format_bytes(stats.sparse_bytes),
265        );
266
267        Ok((meta, stats))
268    }
269
270    /// Merge postings from multiple segments using streaming k-way merge
271    ///
272    /// This implementation uses a min-heap to merge terms from all segments
273    /// in sorted order without loading all terms into memory at once.
274    /// Memory usage is O(num_segments) instead of O(total_terms).
275    ///
276    /// Optimization: For terms that exist in only one segment, we copy the
277    /// posting data directly without decode/encode. Only terms that exist
278    /// in multiple segments need full merge.
279    ///
280    /// Returns the number of terms processed.
281    async fn merge_postings(
282        &self,
283        segments: &[SegmentReader],
284        term_dict: &mut OffsetWriter,
285        postings_out: &mut OffsetWriter,
286        positions_out: &mut OffsetWriter,
287        stats: &mut MergeStats,
288    ) -> Result<usize> {
289        let doc_offs = doc_offsets(segments);
290
291        // Bulk-prefetch all term dict blocks (1 I/O per segment instead of ~160)
292        for (i, segment) in segments.iter().enumerate() {
293            log::debug!("Prefetching term dict for segment {} ...", i);
294            segment.prefetch_term_dict().await?;
295        }
296
297        // Create iterators for each segment's term dictionary
298        let mut iterators: Vec<_> = segments.iter().map(|s| s.term_dict_iter()).collect();
299
300        // Initialize min-heap with first entry from each segment
301        let mut heap: BinaryHeap<MergeEntry> = BinaryHeap::new();
302        for (seg_idx, iter) in iterators.iter_mut().enumerate() {
303            if let Some((key, term_info)) = iter.next().await.map_err(crate::Error::from)? {
304                heap.push(MergeEntry {
305                    key,
306                    term_info,
307                    segment_idx: seg_idx,
308                    doc_offset: doc_offs[seg_idx],
309                });
310            }
311        }
312
313        // Buffer term results - needed because SSTableWriter can't be held across await points
314        // Memory is bounded by unique terms (typically much smaller than postings)
315        let mut term_results: Vec<(Vec<u8>, TermInfo)> = Vec::new();
316        let mut terms_processed = 0usize;
317        let mut serialize_buf: Vec<u8> = Vec::new();
318
319        while !heap.is_empty() {
320            // Get the smallest key
321            let first = heap.pop().unwrap();
322            let current_key = first.key.clone();
323
324            // Collect all entries with the same key
325            let mut sources: Vec<(usize, TermInfo, u32)> =
326                vec![(first.segment_idx, first.term_info, first.doc_offset)];
327
328            // Advance the iterator that provided this entry
329            if let Some((key, term_info)) = iterators[first.segment_idx]
330                .next()
331                .await
332                .map_err(crate::Error::from)?
333            {
334                heap.push(MergeEntry {
335                    key,
336                    term_info,
337                    segment_idx: first.segment_idx,
338                    doc_offset: doc_offs[first.segment_idx],
339                });
340            }
341
342            // Check if other segments have the same key
343            while let Some(entry) = heap.peek() {
344                if entry.key != current_key {
345                    break;
346                }
347                let entry = heap.pop().unwrap();
348                sources.push((entry.segment_idx, entry.term_info, entry.doc_offset));
349
350                // Advance this iterator too
351                if let Some((key, term_info)) = iterators[entry.segment_idx]
352                    .next()
353                    .await
354                    .map_err(crate::Error::from)?
355                {
356                    heap.push(MergeEntry {
357                        key,
358                        term_info,
359                        segment_idx: entry.segment_idx,
360                        doc_offset: doc_offs[entry.segment_idx],
361                    });
362                }
363            }
364
365            // Process this term (handles both single-source and multi-source)
366            let term_info = self
367                .merge_term(
368                    segments,
369                    &sources,
370                    postings_out,
371                    positions_out,
372                    &mut serialize_buf,
373                )
374                .await?;
375
376            term_results.push((current_key, term_info));
377            terms_processed += 1;
378
379            // Log progress every 100k terms
380            if terms_processed.is_multiple_of(100_000) {
381                log::debug!("Merge progress: {} terms processed", terms_processed);
382            }
383        }
384
385        // Track memory (only term_results is buffered; postings/positions stream to disk)
386        let results_mem = term_results.capacity() * std::mem::size_of::<(Vec<u8>, TermInfo)>();
387        stats.peak_memory_bytes = stats.peak_memory_bytes.max(results_mem);
388
389        log::info!(
390            "[merge] complete: terms={}, segments={}, term_buffer={:.2} MB, postings={}, positions={}",
391            terms_processed,
392            segments.len(),
393            results_mem as f64 / (1024.0 * 1024.0),
394            format_bytes(postings_out.offset() as usize),
395            format_bytes(positions_out.offset() as usize),
396        );
397
398        // Write to SSTable (sync, no await points)
399        let mut writer = SSTableWriter::<TermInfo>::new(term_dict);
400        for (key, term_info) in term_results {
401            writer.insert(&key, &term_info)?;
402        }
403        writer.finish()?;
404
405        Ok(terms_processed)
406    }
407
408    /// Merge a single term's postings + positions from one or more source segments.
409    ///
410    /// Fast path: when all sources are external and there are multiple,
411    /// uses block-level concatenation (O(blocks) instead of O(postings)).
412    /// Otherwise: full decode → remap doc IDs → re-encode.
413    async fn merge_term(
414        &self,
415        segments: &[SegmentReader],
416        sources: &[(usize, TermInfo, u32)],
417        postings_out: &mut OffsetWriter,
418        positions_out: &mut OffsetWriter,
419        buf: &mut Vec<u8>,
420    ) -> Result<TermInfo> {
421        let mut sorted: Vec<_> = sources.to_vec();
422        sorted.sort_by_key(|(_, _, off)| *off);
423
424        let any_positions = sorted.iter().any(|(_, ti, _)| ti.position_info().is_some());
425        let all_external = sorted.iter().all(|(_, ti, _)| ti.external_info().is_some());
426
427        // === Merge postings ===
428        let (posting_offset, posting_len, doc_count) = if all_external && sorted.len() > 1 {
429            // Fast path: block-level concatenation
430            let mut block_sources = Vec::with_capacity(sorted.len());
431            for (seg_idx, ti, doc_off) in &sorted {
432                let (off, len) = ti.external_info().unwrap();
433                let bytes = segments[*seg_idx].read_postings(off, len).await?;
434                let bpl = BlockPostingList::deserialize(&mut bytes.as_slice())?;
435                block_sources.push((bpl, *doc_off));
436            }
437            let merged = BlockPostingList::concatenate_blocks(&block_sources)?;
438            let offset = postings_out.offset();
439            buf.clear();
440            merged.serialize(buf)?;
441            postings_out.write_all(buf)?;
442            (offset, buf.len() as u32, merged.doc_count())
443        } else {
444            // Decode all sources into a flat PostingList, remap doc IDs
445            let mut merged = PostingList::new();
446            for (seg_idx, ti, doc_off) in &sorted {
447                if let Some((ids, tfs)) = ti.decode_inline() {
448                    for (id, tf) in ids.into_iter().zip(tfs) {
449                        merged.add(id + doc_off, tf);
450                    }
451                } else {
452                    let (off, len) = ti.external_info().unwrap();
453                    let bytes = segments[*seg_idx].read_postings(off, len).await?;
454                    let bpl = BlockPostingList::deserialize(&mut bytes.as_slice())?;
455                    let mut it = bpl.iterator();
456                    while it.doc() != TERMINATED {
457                        merged.add(it.doc() + doc_off, it.term_freq());
458                        it.advance();
459                    }
460                }
461            }
462            // Try to inline (only when no positions)
463            if !any_positions {
464                let ids: Vec<u32> = merged.iter().map(|p| p.doc_id).collect();
465                let tfs: Vec<u32> = merged.iter().map(|p| p.term_freq).collect();
466                if let Some(inline) = TermInfo::try_inline(&ids, &tfs) {
467                    return Ok(inline);
468                }
469            }
470            let offset = postings_out.offset();
471            let block = BlockPostingList::from_posting_list(&merged)?;
472            buf.clear();
473            block.serialize(buf)?;
474            postings_out.write_all(buf)?;
475            (offset, buf.len() as u32, merged.doc_count())
476        };
477
478        // === Merge positions (if any source has them) ===
479        if any_positions {
480            let mut pos_sources = Vec::new();
481            for (seg_idx, ti, doc_off) in &sorted {
482                if let Some((pos_off, pos_len)) = ti.position_info()
483                    && let Some(bytes) = segments[*seg_idx]
484                        .read_position_bytes(pos_off, pos_len)
485                        .await?
486                {
487                    let pl = PositionPostingList::deserialize(&mut bytes.as_slice())
488                        .map_err(crate::Error::Io)?;
489                    pos_sources.push((pl, *doc_off));
490                }
491            }
492            if !pos_sources.is_empty() {
493                let merged = PositionPostingList::concatenate_blocks(&pos_sources)
494                    .map_err(crate::Error::Io)?;
495                let offset = positions_out.offset();
496                buf.clear();
497                merged.serialize(buf).map_err(crate::Error::Io)?;
498                positions_out.write_all(buf)?;
499                return Ok(TermInfo::external_with_positions(
500                    posting_offset,
501                    posting_len,
502                    doc_count,
503                    offset,
504                    buf.len() as u32,
505                ));
506            }
507        }
508
509        Ok(TermInfo::external(posting_offset, posting_len, doc_count))
510    }
511}
512
513/// Delete segment files from directory
514pub async fn delete_segment<D: Directory + DirectoryWriter>(
515    dir: &D,
516    segment_id: SegmentId,
517) -> Result<()> {
518    let files = SegmentFiles::new(segment_id.0);
519    let _ = dir.delete(&files.term_dict).await;
520    let _ = dir.delete(&files.postings).await;
521    let _ = dir.delete(&files.store).await;
522    let _ = dir.delete(&files.meta).await;
523    let _ = dir.delete(&files.vectors).await;
524    let _ = dir.delete(&files.sparse).await;
525    let _ = dir.delete(&files.positions).await;
526    Ok(())
527}