Skip to main content

hermes_core/segment/merger/
mod.rs

1//! Segment merger for combining multiple segments
2
3mod dense_vectors;
4mod sparse_vectors;
5
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::io::Write;
9use std::sync::Arc;
10
11use rustc_hash::FxHashMap;
12
13use super::reader::SegmentReader;
14use super::store::StoreMerger;
15use super::types::{FieldStats, SegmentFiles, SegmentId, SegmentMeta};
16use crate::Result;
17use crate::directories::{Directory, DirectoryWriter, StreamingWriter};
18use crate::dsl::Schema;
19use crate::structures::{
20    BlockPostingList, PositionPostingList, PostingList, SSTableWriter, TERMINATED, TermInfo,
21};
22
23/// Write adapter that tracks bytes written.
24///
25/// Concrete type so it works with generic `serialize<W: Write>` functions
26/// (unlike `dyn StreamingWriter` which isn't `Sized`).
27pub(crate) struct OffsetWriter {
28    inner: Box<dyn StreamingWriter>,
29    offset: u64,
30}
31
32impl OffsetWriter {
33    fn new(inner: Box<dyn StreamingWriter>) -> Self {
34        Self { inner, offset: 0 }
35    }
36
37    /// Current write position (total bytes written so far).
38    fn offset(&self) -> u64 {
39        self.offset
40    }
41
42    /// Finalize the underlying streaming writer.
43    fn finish(self) -> std::io::Result<()> {
44        self.inner.finish()
45    }
46}
47
48impl Write for OffsetWriter {
49    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
50        let n = self.inner.write(buf)?;
51        self.offset += n as u64;
52        Ok(n)
53    }
54
55    fn flush(&mut self) -> std::io::Result<()> {
56        self.inner.flush()
57    }
58}
59
60/// Format byte count as human-readable string
61fn format_bytes(bytes: usize) -> String {
62    if bytes >= 1024 * 1024 * 1024 {
63        format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
64    } else if bytes >= 1024 * 1024 {
65        format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
66    } else if bytes >= 1024 {
67        format!("{:.2} KB", bytes as f64 / 1024.0)
68    } else {
69        format!("{} B", bytes)
70    }
71}
72
73/// Compute per-segment doc ID offsets (each segment's docs start after the previous)
74fn doc_offsets(segments: &[SegmentReader]) -> Vec<u32> {
75    let mut offsets = Vec::with_capacity(segments.len());
76    let mut acc = 0u32;
77    for seg in segments {
78        offsets.push(acc);
79        acc += seg.num_docs();
80    }
81    offsets
82}
83
84/// Statistics for merge operations
85#[derive(Debug, Clone, Default)]
86pub struct MergeStats {
87    /// Number of terms processed
88    pub terms_processed: usize,
89    /// Peak memory usage in bytes (estimated)
90    pub peak_memory_bytes: usize,
91    /// Term dictionary output size
92    pub term_dict_bytes: usize,
93    /// Postings output size
94    pub postings_bytes: usize,
95    /// Store output size
96    pub store_bytes: usize,
97    /// Vector index output size
98    pub vectors_bytes: usize,
99    /// Sparse vector index output size
100    pub sparse_bytes: usize,
101}
102
103/// Entry for k-way merge heap
104struct MergeEntry {
105    key: Vec<u8>,
106    term_info: TermInfo,
107    segment_idx: usize,
108    doc_offset: u32,
109}
110
111impl PartialEq for MergeEntry {
112    fn eq(&self, other: &Self) -> bool {
113        self.key == other.key
114    }
115}
116
117impl Eq for MergeEntry {}
118
119impl PartialOrd for MergeEntry {
120    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125impl Ord for MergeEntry {
126    fn cmp(&self, other: &Self) -> Ordering {
127        // Reverse order for min-heap (BinaryHeap is max-heap by default)
128        other.key.cmp(&self.key)
129    }
130}
131
132/// Trained vector index structures for rebuilding segments with ANN indexes
133pub struct TrainedVectorStructures {
134    /// Trained centroids per field_id
135    pub centroids: rustc_hash::FxHashMap<u32, Arc<crate::structures::CoarseCentroids>>,
136    /// Trained PQ codebooks per field_id (for ScaNN)
137    pub codebooks: rustc_hash::FxHashMap<u32, Arc<crate::structures::PQCodebook>>,
138}
139
140/// Segment merger - merges multiple segments into one
141pub struct SegmentMerger {
142    schema: Arc<Schema>,
143}
144
145impl SegmentMerger {
146    pub fn new(schema: Arc<Schema>) -> Self {
147        Self { schema }
148    }
149
150    /// Merge segments into one, streaming postings/positions/store directly to files.
151    pub async fn merge<D: Directory + DirectoryWriter>(
152        &self,
153        dir: &D,
154        segments: &[SegmentReader],
155        new_segment_id: SegmentId,
156    ) -> Result<(SegmentMeta, MergeStats)> {
157        self.merge_core(dir, segments, new_segment_id, None).await
158    }
159
160    /// Merge segments with trained ANN structures available.
161    ///
162    /// Dense vectors use O(1) cluster merge when possible (homogeneous IVF/ScaNN),
163    /// otherwise extracts raw vectors from all index types and rebuilds with
164    /// the provided trained structures.
165    pub async fn merge_with_ann<D: Directory + DirectoryWriter>(
166        &self,
167        dir: &D,
168        segments: &[SegmentReader],
169        new_segment_id: SegmentId,
170        trained: &TrainedVectorStructures,
171    ) -> Result<(SegmentMeta, MergeStats)> {
172        self.merge_core(dir, segments, new_segment_id, Some(trained))
173            .await
174    }
175
176    /// Core merge: handles all mandatory parts (postings, positions, store, sparse, field stats, meta)
177    /// and delegates dense vector handling based on available trained structures.
178    ///
179    /// Uses streaming writers so postings, positions, and store data flow directly
180    /// to files instead of buffering everything in memory. Only the term dictionary
181    /// (compact key+TermInfo entries) is buffered.
182    async fn merge_core<D: Directory + DirectoryWriter>(
183        &self,
184        dir: &D,
185        segments: &[SegmentReader],
186        new_segment_id: SegmentId,
187        trained: Option<&TrainedVectorStructures>,
188    ) -> Result<(SegmentMeta, MergeStats)> {
189        let mut stats = MergeStats::default();
190        let files = SegmentFiles::new(new_segment_id.0);
191
192        // === Phase 1: merge postings + positions (streaming) ===
193        let mut postings_writer = OffsetWriter::new(dir.streaming_writer(&files.postings).await?);
194        let mut positions_writer = OffsetWriter::new(dir.streaming_writer(&files.positions).await?);
195        let mut term_dict_writer = OffsetWriter::new(dir.streaming_writer(&files.term_dict).await?);
196
197        let terms_processed = self
198            .merge_postings(
199                segments,
200                &mut term_dict_writer,
201                &mut postings_writer,
202                &mut positions_writer,
203                &mut stats,
204            )
205            .await?;
206        stats.terms_processed = terms_processed;
207        stats.postings_bytes = postings_writer.offset() as usize;
208        stats.term_dict_bytes = term_dict_writer.offset() as usize;
209        let positions_bytes = positions_writer.offset();
210
211        postings_writer.finish()?;
212        term_dict_writer.finish()?;
213        if positions_bytes > 0 {
214            positions_writer.finish()?;
215        } else {
216            drop(positions_writer);
217            let _ = dir.delete(&files.positions).await;
218        }
219
220        // === Phase 2: merge store files (streaming) ===
221        {
222            let mut store_writer = OffsetWriter::new(dir.streaming_writer(&files.store).await?);
223            {
224                let mut store_merger = StoreMerger::new(&mut store_writer);
225                for segment in segments {
226                    if segment.store_has_dict() {
227                        store_merger
228                            .append_store_recompressing(segment.store())
229                            .await
230                            .map_err(crate::Error::Io)?;
231                    } else {
232                        let raw_blocks = segment.store_raw_blocks();
233                        let data_slice = segment.store_data_slice();
234                        store_merger.append_store(data_slice, &raw_blocks).await?;
235                    }
236                }
237                store_merger.finish()?;
238            }
239            stats.store_bytes = store_writer.offset() as usize;
240            store_writer.finish()?;
241        }
242
243        // === Dense vectors ===
244        let vectors_bytes = self
245            .merge_dense_vectors(dir, segments, &files, trained)
246            .await?;
247        stats.vectors_bytes = vectors_bytes;
248
249        // === Mandatory: merge sparse vectors ===
250        let sparse_bytes = self.merge_sparse_vectors(dir, segments, &files).await?;
251        stats.sparse_bytes = sparse_bytes;
252
253        // === Mandatory: merge field stats + write meta ===
254        let mut merged_field_stats: FxHashMap<u32, FieldStats> = FxHashMap::default();
255        for segment in segments {
256            for (&field_id, field_stats) in &segment.meta().field_stats {
257                let entry = merged_field_stats.entry(field_id).or_default();
258                entry.total_tokens += field_stats.total_tokens;
259                entry.doc_count += field_stats.doc_count;
260            }
261        }
262
263        let total_docs: u32 = segments.iter().map(|s| s.num_docs()).sum();
264        let meta = SegmentMeta {
265            id: new_segment_id.0,
266            num_docs: total_docs,
267            field_stats: merged_field_stats,
268        };
269
270        dir.write(&files.meta, &meta.serialize()?).await?;
271
272        let label = if trained.is_some() {
273            "ANN merge"
274        } else {
275            "Merge"
276        };
277        log::info!(
278            "{} complete: {} docs, {} terms, term_dict={}, postings={}, store={}, vectors={}, sparse={}",
279            label,
280            total_docs,
281            stats.terms_processed,
282            format_bytes(stats.term_dict_bytes),
283            format_bytes(stats.postings_bytes),
284            format_bytes(stats.store_bytes),
285            format_bytes(stats.vectors_bytes),
286            format_bytes(stats.sparse_bytes),
287        );
288
289        Ok((meta, stats))
290    }
291
292    /// Merge postings from multiple segments using streaming k-way merge
293    ///
294    /// This implementation uses a min-heap to merge terms from all segments
295    /// in sorted order without loading all terms into memory at once.
296    /// Memory usage is O(num_segments) instead of O(total_terms).
297    ///
298    /// Optimization: For terms that exist in only one segment, we copy the
299    /// posting data directly without decode/encode. Only terms that exist
300    /// in multiple segments need full merge.
301    ///
302    /// Returns the number of terms processed.
303    async fn merge_postings(
304        &self,
305        segments: &[SegmentReader],
306        term_dict: &mut OffsetWriter,
307        postings_out: &mut OffsetWriter,
308        positions_out: &mut OffsetWriter,
309        stats: &mut MergeStats,
310    ) -> Result<usize> {
311        let doc_offs = doc_offsets(segments);
312
313        // Bulk-prefetch all term dict blocks (1 I/O per segment instead of ~160)
314        for (i, segment) in segments.iter().enumerate() {
315            log::debug!("Prefetching term dict for segment {} ...", i);
316            segment.prefetch_term_dict().await?;
317        }
318
319        // Create iterators for each segment's term dictionary
320        let mut iterators: Vec<_> = segments.iter().map(|s| s.term_dict_iter()).collect();
321
322        // Initialize min-heap with first entry from each segment
323        let mut heap: BinaryHeap<MergeEntry> = BinaryHeap::new();
324        for (seg_idx, iter) in iterators.iter_mut().enumerate() {
325            if let Some((key, term_info)) = iter.next().await.map_err(crate::Error::from)? {
326                heap.push(MergeEntry {
327                    key,
328                    term_info,
329                    segment_idx: seg_idx,
330                    doc_offset: doc_offs[seg_idx],
331                });
332            }
333        }
334
335        // Buffer term results - needed because SSTableWriter can't be held across await points
336        // Memory is bounded by unique terms (typically much smaller than postings)
337        let mut term_results: Vec<(Vec<u8>, TermInfo)> = Vec::new();
338        let mut terms_processed = 0usize;
339        let mut serialize_buf: Vec<u8> = Vec::new();
340
341        while !heap.is_empty() {
342            // Get the smallest key
343            let first = heap.pop().unwrap();
344            let current_key = first.key.clone();
345
346            // Collect all entries with the same key
347            let mut sources: Vec<(usize, TermInfo, u32)> =
348                vec![(first.segment_idx, first.term_info, first.doc_offset)];
349
350            // Advance the iterator that provided this entry
351            if let Some((key, term_info)) = iterators[first.segment_idx]
352                .next()
353                .await
354                .map_err(crate::Error::from)?
355            {
356                heap.push(MergeEntry {
357                    key,
358                    term_info,
359                    segment_idx: first.segment_idx,
360                    doc_offset: doc_offs[first.segment_idx],
361                });
362            }
363
364            // Check if other segments have the same key
365            while let Some(entry) = heap.peek() {
366                if entry.key != current_key {
367                    break;
368                }
369                let entry = heap.pop().unwrap();
370                sources.push((entry.segment_idx, entry.term_info, entry.doc_offset));
371
372                // Advance this iterator too
373                if let Some((key, term_info)) = iterators[entry.segment_idx]
374                    .next()
375                    .await
376                    .map_err(crate::Error::from)?
377                {
378                    heap.push(MergeEntry {
379                        key,
380                        term_info,
381                        segment_idx: entry.segment_idx,
382                        doc_offset: doc_offs[entry.segment_idx],
383                    });
384                }
385            }
386
387            // Process this term (handles both single-source and multi-source)
388            let term_info = self
389                .merge_term(
390                    segments,
391                    &sources,
392                    postings_out,
393                    positions_out,
394                    &mut serialize_buf,
395                )
396                .await?;
397
398            term_results.push((current_key, term_info));
399            terms_processed += 1;
400
401            // Log progress every 100k terms
402            if terms_processed.is_multiple_of(100_000) {
403                log::debug!("Merge progress: {} terms processed", terms_processed);
404            }
405        }
406
407        // Track memory (only term_results is buffered; postings/positions stream to disk)
408        let results_mem = term_results.capacity() * std::mem::size_of::<(Vec<u8>, TermInfo)>();
409        stats.peak_memory_bytes = stats.peak_memory_bytes.max(results_mem);
410
411        log::info!(
412            "[merge] complete: terms={}, segments={}, term_buffer={:.2} MB, postings={}, positions={}",
413            terms_processed,
414            segments.len(),
415            results_mem as f64 / (1024.0 * 1024.0),
416            format_bytes(postings_out.offset() as usize),
417            format_bytes(positions_out.offset() as usize),
418        );
419
420        // Write to SSTable (sync, no await points)
421        let mut writer = SSTableWriter::<TermInfo>::new(term_dict);
422        for (key, term_info) in term_results {
423            writer.insert(&key, &term_info)?;
424        }
425        writer.finish()?;
426
427        Ok(terms_processed)
428    }
429
430    /// Merge a single term's postings + positions from one or more source segments.
431    ///
432    /// Fast path: when all sources are external and there are multiple,
433    /// uses block-level concatenation (O(blocks) instead of O(postings)).
434    /// Otherwise: full decode → remap doc IDs → re-encode.
435    async fn merge_term(
436        &self,
437        segments: &[SegmentReader],
438        sources: &[(usize, TermInfo, u32)],
439        postings_out: &mut OffsetWriter,
440        positions_out: &mut OffsetWriter,
441        buf: &mut Vec<u8>,
442    ) -> Result<TermInfo> {
443        let mut sorted: Vec<_> = sources.to_vec();
444        sorted.sort_by_key(|(_, _, off)| *off);
445
446        let any_positions = sorted.iter().any(|(_, ti, _)| ti.position_info().is_some());
447        let all_external = sorted.iter().all(|(_, ti, _)| ti.external_info().is_some());
448
449        // === Merge postings ===
450        let (posting_offset, posting_len, doc_count) = if all_external && sorted.len() > 1 {
451            // Fast path: block-level concatenation
452            let mut block_sources = Vec::with_capacity(sorted.len());
453            for (seg_idx, ti, doc_off) in &sorted {
454                let (off, len) = ti.external_info().unwrap();
455                let bytes = segments[*seg_idx].read_postings(off, len).await?;
456                let bpl = BlockPostingList::deserialize(&mut bytes.as_slice())?;
457                block_sources.push((bpl, *doc_off));
458            }
459            let merged = BlockPostingList::concatenate_blocks(&block_sources)?;
460            let offset = postings_out.offset();
461            buf.clear();
462            merged.serialize(buf)?;
463            postings_out.write_all(buf)?;
464            (offset, buf.len() as u32, merged.doc_count())
465        } else {
466            // Decode all sources into a flat PostingList, remap doc IDs
467            let mut merged = PostingList::new();
468            for (seg_idx, ti, doc_off) in &sorted {
469                if let Some((ids, tfs)) = ti.decode_inline() {
470                    for (id, tf) in ids.into_iter().zip(tfs) {
471                        merged.add(id + doc_off, tf);
472                    }
473                } else {
474                    let (off, len) = ti.external_info().unwrap();
475                    let bytes = segments[*seg_idx].read_postings(off, len).await?;
476                    let bpl = BlockPostingList::deserialize(&mut bytes.as_slice())?;
477                    let mut it = bpl.iterator();
478                    while it.doc() != TERMINATED {
479                        merged.add(it.doc() + doc_off, it.term_freq());
480                        it.advance();
481                    }
482                }
483            }
484            // Try to inline (only when no positions)
485            if !any_positions {
486                let ids: Vec<u32> = merged.iter().map(|p| p.doc_id).collect();
487                let tfs: Vec<u32> = merged.iter().map(|p| p.term_freq).collect();
488                if let Some(inline) = TermInfo::try_inline(&ids, &tfs) {
489                    return Ok(inline);
490                }
491            }
492            let offset = postings_out.offset();
493            let block = BlockPostingList::from_posting_list(&merged)?;
494            buf.clear();
495            block.serialize(buf)?;
496            postings_out.write_all(buf)?;
497            (offset, buf.len() as u32, merged.doc_count())
498        };
499
500        // === Merge positions (if any source has them) ===
501        if any_positions {
502            let mut pos_sources = Vec::new();
503            for (seg_idx, ti, doc_off) in &sorted {
504                if let Some((pos_off, pos_len)) = ti.position_info()
505                    && let Some(bytes) = segments[*seg_idx]
506                        .read_position_bytes(pos_off, pos_len)
507                        .await?
508                {
509                    let pl = PositionPostingList::deserialize(&mut bytes.as_slice())
510                        .map_err(crate::Error::Io)?;
511                    pos_sources.push((pl, *doc_off));
512                }
513            }
514            if !pos_sources.is_empty() {
515                let merged = PositionPostingList::concatenate_blocks(&pos_sources)
516                    .map_err(crate::Error::Io)?;
517                let offset = positions_out.offset();
518                buf.clear();
519                merged.serialize(buf).map_err(crate::Error::Io)?;
520                positions_out.write_all(buf)?;
521                return Ok(TermInfo::external_with_positions(
522                    posting_offset,
523                    posting_len,
524                    doc_count,
525                    offset,
526                    buf.len() as u32,
527                ));
528            }
529        }
530
531        Ok(TermInfo::external(posting_offset, posting_len, doc_count))
532    }
533}
534
535/// Delete segment files from directory
536pub async fn delete_segment<D: Directory + DirectoryWriter>(
537    dir: &D,
538    segment_id: SegmentId,
539) -> Result<()> {
540    let files = SegmentFiles::new(segment_id.0);
541    let _ = dir.delete(&files.term_dict).await;
542    let _ = dir.delete(&files.postings).await;
543    let _ = dir.delete(&files.store).await;
544    let _ = dir.delete(&files.meta).await;
545    let _ = dir.delete(&files.vectors).await;
546    let _ = dir.delete(&files.sparse).await;
547    let _ = dir.delete(&files.positions).await;
548    Ok(())
549}