Skip to main content

hermes_core/segment/merger/
mod.rs

1//! Segment merger for combining multiple segments
2
3mod dense_vectors;
4mod sparse_vectors;
5
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::io::Write;
9use std::sync::Arc;
10
11use rustc_hash::FxHashMap;
12
13use super::reader::SegmentReader;
14use super::store::StoreMerger;
15use super::types::{FieldStats, SegmentFiles, SegmentId, SegmentMeta};
16use crate::Result;
17use crate::directories::{Directory, DirectoryWriter, StreamingWriter};
18use crate::dsl::Schema;
19use crate::structures::{
20    BlockPostingList, PositionPostingList, PostingList, SSTableWriter, TERMINATED, TermInfo,
21};
22
23/// Write adapter that tracks bytes written.
24///
25/// Concrete type so it works with generic `serialize<W: Write>` functions
26/// (unlike `dyn StreamingWriter` which isn't `Sized`).
27pub(crate) struct OffsetWriter {
28    inner: Box<dyn StreamingWriter>,
29    offset: u64,
30}
31
32impl OffsetWriter {
33    fn new(inner: Box<dyn StreamingWriter>) -> Self {
34        Self { inner, offset: 0 }
35    }
36
37    /// Current write position (total bytes written so far).
38    fn offset(&self) -> u64 {
39        self.offset
40    }
41
42    /// Finalize the underlying streaming writer.
43    fn finish(self) -> std::io::Result<()> {
44        self.inner.finish()
45    }
46}
47
48impl Write for OffsetWriter {
49    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
50        let n = self.inner.write(buf)?;
51        self.offset += n as u64;
52        Ok(n)
53    }
54
55    fn flush(&mut self) -> std::io::Result<()> {
56        self.inner.flush()
57    }
58}
59
60/// Format byte count as human-readable string
61fn format_bytes(bytes: usize) -> String {
62    if bytes >= 1024 * 1024 * 1024 {
63        format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
64    } else if bytes >= 1024 * 1024 {
65        format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
66    } else if bytes >= 1024 {
67        format!("{:.2} KB", bytes as f64 / 1024.0)
68    } else {
69        format!("{} B", bytes)
70    }
71}
72
73/// Compute per-segment doc ID offsets (each segment's docs start after the previous)
74fn doc_offsets(segments: &[SegmentReader]) -> Vec<u32> {
75    let mut offsets = Vec::with_capacity(segments.len());
76    let mut acc = 0u32;
77    for seg in segments {
78        offsets.push(acc);
79        acc += seg.num_docs();
80    }
81    offsets
82}
83
84/// Statistics for merge operations
85#[derive(Debug, Clone, Default)]
86pub struct MergeStats {
87    /// Number of terms processed
88    pub terms_processed: usize,
89    /// Term dictionary output size
90    pub term_dict_bytes: usize,
91    /// Postings output size
92    pub postings_bytes: usize,
93    /// Store output size
94    pub store_bytes: usize,
95    /// Vector index output size
96    pub vectors_bytes: usize,
97    /// Sparse vector index output size
98    pub sparse_bytes: usize,
99}
100
101/// Entry for k-way merge heap
102struct MergeEntry {
103    key: Vec<u8>,
104    term_info: TermInfo,
105    segment_idx: usize,
106    doc_offset: u32,
107}
108
109impl PartialEq for MergeEntry {
110    fn eq(&self, other: &Self) -> bool {
111        self.key == other.key
112    }
113}
114
115impl Eq for MergeEntry {}
116
117impl PartialOrd for MergeEntry {
118    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119        Some(self.cmp(other))
120    }
121}
122
123impl Ord for MergeEntry {
124    fn cmp(&self, other: &Self) -> Ordering {
125        // Reverse order for min-heap (BinaryHeap is max-heap by default)
126        other.key.cmp(&self.key)
127    }
128}
129
130// TrainedVectorStructures is defined in super::types (available on all platforms)
131pub use super::types::TrainedVectorStructures;
132
133/// Segment merger - merges multiple segments into one
134pub struct SegmentMerger {
135    schema: Arc<Schema>,
136}
137
138impl SegmentMerger {
139    pub fn new(schema: Arc<Schema>) -> Self {
140        Self { schema }
141    }
142
143    /// Merge segments into one, streaming postings/positions/store directly to files.
144    ///
145    /// If `trained` is provided, dense vectors use O(1) cluster merge when possible
146    /// (homogeneous IVF/ScaNN), otherwise rebuilds ANN from trained structures.
147    /// Without trained structures, only flat vectors are merged.
148    ///
149    /// Uses streaming writers so postings, positions, and store data flow directly
150    /// to files instead of buffering everything in memory. Only the term dictionary
151    /// (compact key+TermInfo entries) is buffered.
152    pub async fn merge<D: Directory + DirectoryWriter>(
153        &self,
154        dir: &D,
155        segments: &[SegmentReader],
156        new_segment_id: SegmentId,
157        trained: Option<&TrainedVectorStructures>,
158    ) -> Result<(SegmentMeta, MergeStats)> {
159        let mut stats = MergeStats::default();
160        let files = SegmentFiles::new(new_segment_id.0);
161
162        // === Phase 1: merge postings + positions (streaming) ===
163        let mut postings_writer = OffsetWriter::new(dir.streaming_writer(&files.postings).await?);
164        let mut positions_writer = OffsetWriter::new(dir.streaming_writer(&files.positions).await?);
165        let mut term_dict_writer = OffsetWriter::new(dir.streaming_writer(&files.term_dict).await?);
166
167        let terms_processed = self
168            .merge_postings(
169                segments,
170                &mut term_dict_writer,
171                &mut postings_writer,
172                &mut positions_writer,
173                &mut stats,
174            )
175            .await?;
176        stats.terms_processed = terms_processed;
177        stats.postings_bytes = postings_writer.offset() as usize;
178        stats.term_dict_bytes = term_dict_writer.offset() as usize;
179        let positions_bytes = positions_writer.offset();
180
181        postings_writer.finish()?;
182        term_dict_writer.finish()?;
183        if positions_bytes > 0 {
184            positions_writer.finish()?;
185        } else {
186            drop(positions_writer);
187            let _ = dir.delete(&files.positions).await;
188        }
189
190        // === Phase 2: merge store files (streaming) ===
191        let phase2_start = std::time::Instant::now();
192        {
193            let mut store_writer = OffsetWriter::new(dir.streaming_writer(&files.store).await?);
194            {
195                let mut store_merger = StoreMerger::new(&mut store_writer);
196                for segment in segments {
197                    if segment.store_has_dict() {
198                        store_merger
199                            .append_store_recompressing(segment.store())
200                            .await
201                            .map_err(crate::Error::Io)?;
202                    } else {
203                        let raw_blocks = segment.store_raw_blocks();
204                        let data_slice = segment.store_data_slice();
205                        store_merger.append_store(data_slice, &raw_blocks).await?;
206                    }
207                }
208                store_merger.finish()?;
209            }
210            stats.store_bytes = store_writer.offset() as usize;
211            store_writer.finish()?;
212        }
213        log::info!(
214            "[merge] store done: {} in {:.1}s",
215            format_bytes(stats.store_bytes),
216            phase2_start.elapsed().as_secs_f64()
217        );
218
219        // === Phase 3: Dense vectors ===
220        let phase3_start = std::time::Instant::now();
221        let vectors_bytes = self
222            .merge_dense_vectors(dir, segments, &files, trained)
223            .await?;
224        stats.vectors_bytes = vectors_bytes;
225        log::info!(
226            "[merge] dense vectors done: {} in {:.1}s",
227            format_bytes(stats.vectors_bytes),
228            phase3_start.elapsed().as_secs_f64()
229        );
230
231        // === Phase 4: merge sparse vectors ===
232        let phase4_start = std::time::Instant::now();
233        let sparse_bytes = self.merge_sparse_vectors(dir, segments, &files).await?;
234        stats.sparse_bytes = sparse_bytes;
235        log::info!(
236            "[merge] sparse vectors done: {} in {:.1}s",
237            format_bytes(stats.sparse_bytes),
238            phase4_start.elapsed().as_secs_f64()
239        );
240
241        // === Mandatory: merge field stats + write meta ===
242        let mut merged_field_stats: FxHashMap<u32, FieldStats> = FxHashMap::default();
243        for segment in segments {
244            for (&field_id, field_stats) in &segment.meta().field_stats {
245                let entry = merged_field_stats.entry(field_id).or_default();
246                entry.total_tokens += field_stats.total_tokens;
247                entry.doc_count += field_stats.doc_count;
248            }
249        }
250
251        let total_docs: u32 = segments.iter().map(|s| s.num_docs()).sum();
252        let meta = SegmentMeta {
253            id: new_segment_id.0,
254            num_docs: total_docs,
255            field_stats: merged_field_stats,
256        };
257
258        dir.write(&files.meta, &meta.serialize()?).await?;
259
260        let label = if trained.is_some() {
261            "ANN merge"
262        } else {
263            "Merge"
264        };
265        log::info!(
266            "{} complete: {} docs, {} terms, term_dict={}, postings={}, store={}, vectors={}, sparse={}",
267            label,
268            total_docs,
269            stats.terms_processed,
270            format_bytes(stats.term_dict_bytes),
271            format_bytes(stats.postings_bytes),
272            format_bytes(stats.store_bytes),
273            format_bytes(stats.vectors_bytes),
274            format_bytes(stats.sparse_bytes),
275        );
276
277        Ok((meta, stats))
278    }
279
280    /// Merge postings from multiple segments using streaming k-way merge
281    ///
282    /// This implementation uses a min-heap to merge terms from all segments
283    /// in sorted order without loading all terms into memory at once.
284    ///
285    /// Optimization: For terms that exist in only one segment, we copy the
286    /// posting data directly without decode/encode. Only terms that exist
287    /// in multiple segments need full merge.
288    ///
289    /// SSTable entries are written inline during the merge loop (no buffering).
290    /// This is possible because SSTableWriter<W> is Send when W is Send.
291    ///
292    /// Returns the number of terms processed.
293    async fn merge_postings(
294        &self,
295        segments: &[SegmentReader],
296        term_dict: &mut OffsetWriter,
297        postings_out: &mut OffsetWriter,
298        positions_out: &mut OffsetWriter,
299        _stats: &mut MergeStats,
300    ) -> Result<usize> {
301        let doc_offs = doc_offsets(segments);
302
303        // Parallel prefetch all term dict blocks
304        let prefetch_start = std::time::Instant::now();
305        let mut futs = Vec::with_capacity(segments.len());
306        for segment in segments.iter() {
307            futs.push(segment.prefetch_term_dict());
308        }
309        let results = futures::future::join_all(futs).await;
310        for (i, res) in results.into_iter().enumerate() {
311            res.map_err(|e| {
312                log::error!("Prefetch failed for segment {}: {}", i, e);
313                e
314            })?;
315        }
316        log::debug!(
317            "Prefetched {} term dicts in {:.1}s",
318            segments.len(),
319            prefetch_start.elapsed().as_secs_f64()
320        );
321
322        // Create iterators for each segment's term dictionary
323        let mut iterators: Vec<_> = segments.iter().map(|s| s.term_dict_iter()).collect();
324
325        // Initialize min-heap with first entry from each segment
326        let mut heap: BinaryHeap<MergeEntry> = BinaryHeap::new();
327        for (seg_idx, iter) in iterators.iter_mut().enumerate() {
328            if let Some((key, term_info)) = iter.next().await.map_err(crate::Error::from)? {
329                heap.push(MergeEntry {
330                    key,
331                    term_info,
332                    segment_idx: seg_idx,
333                    doc_offset: doc_offs[seg_idx],
334                });
335            }
336        }
337
338        // Write SSTable entries inline — no buffering needed since
339        // SSTableWriter<&mut OffsetWriter> is Send (OffsetWriter is Send).
340        let mut term_dict_writer = SSTableWriter::<&mut OffsetWriter, TermInfo>::new(term_dict);
341        let mut terms_processed = 0usize;
342        let mut serialize_buf: Vec<u8> = Vec::new();
343
344        while !heap.is_empty() {
345            // Get the smallest key
346            let first = heap.pop().unwrap();
347            let current_key = first.key.clone();
348
349            // Collect all entries with the same key
350            let mut sources: Vec<(usize, TermInfo, u32)> =
351                vec![(first.segment_idx, first.term_info, first.doc_offset)];
352
353            // Advance the iterator that provided this entry
354            if let Some((key, term_info)) = iterators[first.segment_idx]
355                .next()
356                .await
357                .map_err(crate::Error::from)?
358            {
359                heap.push(MergeEntry {
360                    key,
361                    term_info,
362                    segment_idx: first.segment_idx,
363                    doc_offset: doc_offs[first.segment_idx],
364                });
365            }
366
367            // Check if other segments have the same key
368            while let Some(entry) = heap.peek() {
369                if entry.key != current_key {
370                    break;
371                }
372                let entry = heap.pop().unwrap();
373                sources.push((entry.segment_idx, entry.term_info, entry.doc_offset));
374
375                // Advance this iterator too
376                if let Some((key, term_info)) = iterators[entry.segment_idx]
377                    .next()
378                    .await
379                    .map_err(crate::Error::from)?
380                {
381                    heap.push(MergeEntry {
382                        key,
383                        term_info,
384                        segment_idx: entry.segment_idx,
385                        doc_offset: doc_offs[entry.segment_idx],
386                    });
387                }
388            }
389
390            // Process this term (handles both single-source and multi-source)
391            let term_info = self
392                .merge_term(
393                    segments,
394                    &sources,
395                    postings_out,
396                    positions_out,
397                    &mut serialize_buf,
398                )
399                .await?;
400
401            // Write directly to SSTable (no buffering)
402            term_dict_writer
403                .insert(&current_key, &term_info)
404                .map_err(crate::Error::Io)?;
405            terms_processed += 1;
406
407            // Log progress every 100k terms
408            if terms_processed.is_multiple_of(100_000) {
409                log::debug!("Merge progress: {} terms processed", terms_processed);
410            }
411        }
412
413        log::info!(
414            "[merge] postings done: terms={}, segments={}, postings={}, positions={}",
415            terms_processed,
416            segments.len(),
417            format_bytes(postings_out.offset() as usize),
418            format_bytes(positions_out.offset() as usize),
419        );
420
421        term_dict_writer.finish().map_err(crate::Error::Io)?;
422
423        Ok(terms_processed)
424    }
425
426    /// Merge a single term's postings + positions from one or more source segments.
427    ///
428    /// Fast path: when all sources are external and there are multiple,
429    /// uses block-level concatenation (O(blocks) instead of O(postings)).
430    /// Otherwise: full decode → remap doc IDs → re-encode.
431    async fn merge_term(
432        &self,
433        segments: &[SegmentReader],
434        sources: &[(usize, TermInfo, u32)],
435        postings_out: &mut OffsetWriter,
436        positions_out: &mut OffsetWriter,
437        buf: &mut Vec<u8>,
438    ) -> Result<TermInfo> {
439        let mut sorted: Vec<_> = sources.to_vec();
440        sorted.sort_by_key(|(_, _, off)| *off);
441
442        let any_positions = sorted.iter().any(|(_, ti, _)| ti.position_info().is_some());
443        let all_external = sorted.iter().all(|(_, ti, _)| ti.external_info().is_some());
444
445        // === Merge postings ===
446        let (posting_offset, posting_len, doc_count) = if all_external && sorted.len() > 1 {
447            // Fast path: block-level concatenation
448            let mut block_sources = Vec::with_capacity(sorted.len());
449            for (seg_idx, ti, doc_off) in &sorted {
450                let (off, len) = ti.external_info().unwrap();
451                let bytes = segments[*seg_idx].read_postings(off, len).await?;
452                let bpl = BlockPostingList::deserialize(&mut bytes.as_slice())?;
453                block_sources.push((bpl, *doc_off));
454            }
455            let merged = BlockPostingList::concatenate_blocks(&block_sources)?;
456            let offset = postings_out.offset();
457            buf.clear();
458            merged.serialize(buf)?;
459            postings_out.write_all(buf)?;
460            (offset, buf.len() as u32, merged.doc_count())
461        } else {
462            // Decode all sources into a flat PostingList, remap doc IDs
463            let mut merged = PostingList::new();
464            for (seg_idx, ti, doc_off) in &sorted {
465                if let Some((ids, tfs)) = ti.decode_inline() {
466                    for (id, tf) in ids.into_iter().zip(tfs) {
467                        merged.add(id + doc_off, tf);
468                    }
469                } else {
470                    let (off, len) = ti.external_info().unwrap();
471                    let bytes = segments[*seg_idx].read_postings(off, len).await?;
472                    let bpl = BlockPostingList::deserialize(&mut bytes.as_slice())?;
473                    let mut it = bpl.iterator();
474                    while it.doc() != TERMINATED {
475                        merged.add(it.doc() + doc_off, it.term_freq());
476                        it.advance();
477                    }
478                }
479            }
480            // Try to inline (only when no positions)
481            if !any_positions {
482                let ids: Vec<u32> = merged.iter().map(|p| p.doc_id).collect();
483                let tfs: Vec<u32> = merged.iter().map(|p| p.term_freq).collect();
484                if let Some(inline) = TermInfo::try_inline(&ids, &tfs) {
485                    return Ok(inline);
486                }
487            }
488            let offset = postings_out.offset();
489            let block = BlockPostingList::from_posting_list(&merged)?;
490            buf.clear();
491            block.serialize(buf)?;
492            postings_out.write_all(buf)?;
493            (offset, buf.len() as u32, merged.doc_count())
494        };
495
496        // === Merge positions (if any source has them) ===
497        if any_positions {
498            let mut pos_sources = Vec::new();
499            for (seg_idx, ti, doc_off) in &sorted {
500                if let Some((pos_off, pos_len)) = ti.position_info()
501                    && let Some(bytes) = segments[*seg_idx]
502                        .read_position_bytes(pos_off, pos_len)
503                        .await?
504                {
505                    let pl = PositionPostingList::deserialize(&mut bytes.as_slice())
506                        .map_err(crate::Error::Io)?;
507                    pos_sources.push((pl, *doc_off));
508                }
509            }
510            if !pos_sources.is_empty() {
511                let merged = PositionPostingList::concatenate_blocks(&pos_sources)
512                    .map_err(crate::Error::Io)?;
513                let offset = positions_out.offset();
514                buf.clear();
515                merged.serialize(buf).map_err(crate::Error::Io)?;
516                positions_out.write_all(buf)?;
517                return Ok(TermInfo::external_with_positions(
518                    posting_offset,
519                    posting_len,
520                    doc_count,
521                    offset,
522                    buf.len() as u32,
523                ));
524            }
525        }
526
527        Ok(TermInfo::external(posting_offset, posting_len, doc_count))
528    }
529}
530
531/// Delete segment files from directory
532pub async fn delete_segment<D: Directory + DirectoryWriter>(
533    dir: &D,
534    segment_id: SegmentId,
535) -> Result<()> {
536    let files = SegmentFiles::new(segment_id.0);
537    let _ = dir.delete(&files.term_dict).await;
538    let _ = dir.delete(&files.postings).await;
539    let _ = dir.delete(&files.store).await;
540    let _ = dir.delete(&files.meta).await;
541    let _ = dir.delete(&files.vectors).await;
542    let _ = dir.delete(&files.sparse).await;
543    let _ = dir.delete(&files.positions).await;
544    Ok(())
545}