Skip to main content

hermes_core/segment/merger/
mod.rs

1//! Segment merger for combining multiple segments
2
3mod dense_vectors;
4mod sparse_vectors;
5
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::io::Write;
9use std::sync::Arc;
10
11use rustc_hash::FxHashMap;
12
13use super::reader::SegmentReader;
14use super::store::StoreMerger;
15use super::types::{FieldStats, SegmentFiles, SegmentId, SegmentMeta};
16use crate::Result;
17use crate::directories::{Directory, DirectoryWriter, StreamingWriter};
18use crate::dsl::Schema;
19use crate::structures::{
20    BlockPostingList, PositionPostingList, PostingList, SSTableWriter, TERMINATED, TermInfo,
21};
22
23/// Write adapter that tracks bytes written.
24///
25/// Concrete type so it works with generic `serialize<W: Write>` functions
26/// (unlike `dyn StreamingWriter` which isn't `Sized`).
27pub(crate) struct OffsetWriter {
28    inner: Box<dyn StreamingWriter>,
29    offset: u64,
30}
31
32impl OffsetWriter {
33    fn new(inner: Box<dyn StreamingWriter>) -> Self {
34        Self { inner, offset: 0 }
35    }
36
37    /// Current write position (total bytes written so far).
38    fn offset(&self) -> u64 {
39        self.offset
40    }
41
42    /// Finalize the underlying streaming writer.
43    fn finish(self) -> std::io::Result<()> {
44        self.inner.finish()
45    }
46}
47
48impl Write for OffsetWriter {
49    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
50        let n = self.inner.write(buf)?;
51        self.offset += n as u64;
52        Ok(n)
53    }
54
55    fn flush(&mut self) -> std::io::Result<()> {
56        self.inner.flush()
57    }
58}
59
60/// Format byte count as human-readable string
61fn format_bytes(bytes: usize) -> String {
62    if bytes >= 1024 * 1024 * 1024 {
63        format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
64    } else if bytes >= 1024 * 1024 {
65        format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
66    } else if bytes >= 1024 {
67        format!("{:.2} KB", bytes as f64 / 1024.0)
68    } else {
69        format!("{} B", bytes)
70    }
71}
72
73/// Compute per-segment doc ID offsets (each segment's docs start after the previous)
74fn doc_offsets(segments: &[SegmentReader]) -> Vec<u32> {
75    let mut offsets = Vec::with_capacity(segments.len());
76    let mut acc = 0u32;
77    for seg in segments {
78        offsets.push(acc);
79        acc += seg.num_docs();
80    }
81    offsets
82}
83
84/// Statistics for merge operations
85#[derive(Debug, Clone, Default)]
86pub struct MergeStats {
87    /// Number of terms processed
88    pub terms_processed: usize,
89    /// Term dictionary output size
90    pub term_dict_bytes: usize,
91    /// Postings output size
92    pub postings_bytes: usize,
93    /// Store output size
94    pub store_bytes: usize,
95    /// Vector index output size
96    pub vectors_bytes: usize,
97    /// Sparse vector index output size
98    pub sparse_bytes: usize,
99}
100
101/// Entry for k-way merge heap
102struct MergeEntry {
103    key: Vec<u8>,
104    term_info: TermInfo,
105    segment_idx: usize,
106    doc_offset: u32,
107}
108
109impl PartialEq for MergeEntry {
110    fn eq(&self, other: &Self) -> bool {
111        self.key == other.key
112    }
113}
114
115impl Eq for MergeEntry {}
116
117impl PartialOrd for MergeEntry {
118    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119        Some(self.cmp(other))
120    }
121}
122
123impl Ord for MergeEntry {
124    fn cmp(&self, other: &Self) -> Ordering {
125        // Reverse order for min-heap (BinaryHeap is max-heap by default)
126        other.key.cmp(&self.key)
127    }
128}
129
130// TrainedVectorStructures is defined in super::types (available on all platforms)
131pub use super::types::TrainedVectorStructures;
132
133/// Segment merger - merges multiple segments into one
134pub struct SegmentMerger {
135    schema: Arc<Schema>,
136}
137
138impl SegmentMerger {
139    pub fn new(schema: Arc<Schema>) -> Self {
140        Self { schema }
141    }
142
143    /// Merge segments into one, streaming postings/positions/store directly to files.
144    ///
145    /// If `trained` is provided, dense vectors use O(1) cluster merge when possible
146    /// (homogeneous IVF/ScaNN), otherwise rebuilds ANN from trained structures.
147    /// Without trained structures, only flat vectors are merged.
148    ///
149    /// Uses streaming writers so postings, positions, and store data flow directly
150    /// to files instead of buffering everything in memory. Only the term dictionary
151    /// (compact key+TermInfo entries) is buffered.
152    pub async fn merge<D: Directory + DirectoryWriter>(
153        &self,
154        dir: &D,
155        segments: &[SegmentReader],
156        new_segment_id: SegmentId,
157        trained: Option<&TrainedVectorStructures>,
158    ) -> Result<(SegmentMeta, MergeStats)> {
159        let mut stats = MergeStats::default();
160        let files = SegmentFiles::new(new_segment_id.0);
161
162        // === All 4 phases concurrent: postings || store || dense || sparse ===
163        // Each phase reads from independent parts of source segments (SSTable,
164        // store blocks, flat vectors, sparse index) and writes to independent
165        // output files. No phase consumes another's output.
166        // tokio::try_join! interleaves I/O across the four futures on the same task.
167        let merge_start = std::time::Instant::now();
168
169        let postings_fut = async {
170            let mut postings_writer =
171                OffsetWriter::new(dir.streaming_writer(&files.postings).await?);
172            let mut positions_writer =
173                OffsetWriter::new(dir.streaming_writer(&files.positions).await?);
174            let mut term_dict_writer =
175                OffsetWriter::new(dir.streaming_writer(&files.term_dict).await?);
176
177            let terms_processed = self
178                .merge_postings(
179                    segments,
180                    &mut term_dict_writer,
181                    &mut postings_writer,
182                    &mut positions_writer,
183                )
184                .await?;
185
186            let postings_bytes = postings_writer.offset() as usize;
187            let term_dict_bytes = term_dict_writer.offset() as usize;
188            let positions_bytes = positions_writer.offset();
189
190            postings_writer.finish()?;
191            term_dict_writer.finish()?;
192            if positions_bytes > 0 {
193                positions_writer.finish()?;
194            } else {
195                drop(positions_writer);
196                let _ = dir.delete(&files.positions).await;
197            }
198            log::info!(
199                "[merge] postings done: {} terms, term_dict={}, postings={}, positions={}",
200                terms_processed,
201                format_bytes(term_dict_bytes),
202                format_bytes(postings_bytes),
203                format_bytes(positions_bytes as usize),
204            );
205            Ok::<(usize, usize, usize), crate::Error>((
206                terms_processed,
207                term_dict_bytes,
208                postings_bytes,
209            ))
210        };
211
212        let store_fut = async {
213            let mut store_writer = OffsetWriter::new(dir.streaming_writer(&files.store).await?);
214            {
215                let mut store_merger = StoreMerger::new(&mut store_writer);
216                for segment in segments {
217                    if segment.store_has_dict() {
218                        store_merger
219                            .append_store_recompressing(segment.store())
220                            .await
221                            .map_err(crate::Error::Io)?;
222                    } else {
223                        let raw_blocks = segment.store_raw_blocks();
224                        let data_slice = segment.store_data_slice();
225                        store_merger.append_store(data_slice, &raw_blocks).await?;
226                    }
227                }
228                store_merger.finish()?;
229            }
230            let bytes = store_writer.offset() as usize;
231            store_writer.finish()?;
232            Ok::<usize, crate::Error>(bytes)
233        };
234
235        let dense_fut = async {
236            self.merge_dense_vectors(dir, segments, &files, trained)
237                .await
238        };
239
240        let sparse_fut = async { self.merge_sparse_vectors(dir, segments, &files).await };
241
242        let (postings_result, store_bytes, vectors_bytes, sparse_bytes) =
243            tokio::try_join!(postings_fut, store_fut, dense_fut, sparse_fut)?;
244        stats.terms_processed = postings_result.0;
245        stats.term_dict_bytes = postings_result.1;
246        stats.postings_bytes = postings_result.2;
247        stats.store_bytes = store_bytes;
248        stats.vectors_bytes = vectors_bytes;
249        stats.sparse_bytes = sparse_bytes;
250        log::info!(
251            "[merge] all phases done: store={}, dense={}, sparse={} in {:.1}s",
252            format_bytes(stats.store_bytes),
253            format_bytes(stats.vectors_bytes),
254            format_bytes(stats.sparse_bytes),
255            merge_start.elapsed().as_secs_f64()
256        );
257
258        // === Mandatory: merge field stats + write meta ===
259        let mut merged_field_stats: FxHashMap<u32, FieldStats> = FxHashMap::default();
260        for segment in segments {
261            for (&field_id, field_stats) in &segment.meta().field_stats {
262                let entry = merged_field_stats.entry(field_id).or_default();
263                entry.total_tokens += field_stats.total_tokens;
264                entry.doc_count += field_stats.doc_count;
265            }
266        }
267
268        let total_docs: u32 = segments.iter().map(|s| s.num_docs()).sum();
269        let meta = SegmentMeta {
270            id: new_segment_id.0,
271            num_docs: total_docs,
272            field_stats: merged_field_stats,
273        };
274
275        dir.write(&files.meta, &meta.serialize()?).await?;
276
277        let label = if trained.is_some() {
278            "ANN merge"
279        } else {
280            "Merge"
281        };
282        log::info!(
283            "{} complete: {} docs, {} terms, term_dict={}, postings={}, store={}, vectors={}, sparse={}",
284            label,
285            total_docs,
286            stats.terms_processed,
287            format_bytes(stats.term_dict_bytes),
288            format_bytes(stats.postings_bytes),
289            format_bytes(stats.store_bytes),
290            format_bytes(stats.vectors_bytes),
291            format_bytes(stats.sparse_bytes),
292        );
293
294        Ok((meta, stats))
295    }
296
297    /// Merge postings from multiple segments using streaming k-way merge
298    ///
299    /// This implementation uses a min-heap to merge terms from all segments
300    /// in sorted order without loading all terms into memory at once.
301    ///
302    /// Optimization: For terms that exist in only one segment, we copy the
303    /// posting data directly without decode/encode. Only terms that exist
304    /// in multiple segments need full merge.
305    ///
306    /// SSTable entries are written inline during the merge loop (no buffering).
307    /// This is possible because SSTableWriter<W> is Send when W is Send.
308    ///
309    /// Returns the number of terms processed.
310    async fn merge_postings(
311        &self,
312        segments: &[SegmentReader],
313        term_dict: &mut OffsetWriter,
314        postings_out: &mut OffsetWriter,
315        positions_out: &mut OffsetWriter,
316    ) -> Result<usize> {
317        let doc_offs = doc_offsets(segments);
318
319        // Parallel prefetch all term dict blocks
320        let prefetch_start = std::time::Instant::now();
321        let mut futs = Vec::with_capacity(segments.len());
322        for segment in segments.iter() {
323            futs.push(segment.prefetch_term_dict());
324        }
325        let results = futures::future::join_all(futs).await;
326        for (i, res) in results.into_iter().enumerate() {
327            res.map_err(|e| {
328                log::error!("Prefetch failed for segment {}: {}", i, e);
329                e
330            })?;
331        }
332        log::debug!(
333            "Prefetched {} term dicts in {:.1}s",
334            segments.len(),
335            prefetch_start.elapsed().as_secs_f64()
336        );
337
338        // Create iterators for each segment's term dictionary
339        let mut iterators: Vec<_> = segments.iter().map(|s| s.term_dict_iter()).collect();
340
341        // Initialize min-heap with first entry from each segment
342        let mut heap: BinaryHeap<MergeEntry> = BinaryHeap::new();
343        for (seg_idx, iter) in iterators.iter_mut().enumerate() {
344            if let Some((key, term_info)) = iter.next().await.map_err(crate::Error::from)? {
345                heap.push(MergeEntry {
346                    key,
347                    term_info,
348                    segment_idx: seg_idx,
349                    doc_offset: doc_offs[seg_idx],
350                });
351            }
352        }
353
354        // Write SSTable entries inline — no buffering needed since
355        // SSTableWriter<&mut OffsetWriter> is Send (OffsetWriter is Send).
356        let mut term_dict_writer = SSTableWriter::<&mut OffsetWriter, TermInfo>::new(term_dict);
357        let mut terms_processed = 0usize;
358        let mut serialize_buf: Vec<u8> = Vec::new();
359        // Pre-allocate sources buffer outside loop — reused for every term
360        let mut sources: Vec<(usize, TermInfo, u32)> = Vec::with_capacity(segments.len());
361
362        while !heap.is_empty() {
363            // Get the smallest key (move, not clone)
364            let first = heap.pop().unwrap();
365            let current_key = first.key;
366
367            // Collect all entries with the same key
368            sources.clear();
369            sources.push((first.segment_idx, first.term_info, first.doc_offset));
370
371            // Advance the iterator that provided this entry
372            if let Some((key, term_info)) = iterators[first.segment_idx]
373                .next()
374                .await
375                .map_err(crate::Error::from)?
376            {
377                heap.push(MergeEntry {
378                    key,
379                    term_info,
380                    segment_idx: first.segment_idx,
381                    doc_offset: doc_offs[first.segment_idx],
382                });
383            }
384
385            // Check if other segments have the same key
386            while let Some(entry) = heap.peek() {
387                if entry.key != current_key {
388                    break;
389                }
390                let entry = heap.pop().unwrap();
391                sources.push((entry.segment_idx, entry.term_info, entry.doc_offset));
392
393                // Advance this iterator too
394                if let Some((key, term_info)) = iterators[entry.segment_idx]
395                    .next()
396                    .await
397                    .map_err(crate::Error::from)?
398                {
399                    heap.push(MergeEntry {
400                        key,
401                        term_info,
402                        segment_idx: entry.segment_idx,
403                        doc_offset: doc_offs[entry.segment_idx],
404                    });
405                }
406            }
407
408            // Process this term (handles both single-source and multi-source)
409            let term_info = self
410                .merge_term(
411                    segments,
412                    &sources,
413                    postings_out,
414                    positions_out,
415                    &mut serialize_buf,
416                )
417                .await?;
418
419            // Write directly to SSTable (no buffering)
420            term_dict_writer
421                .insert(&current_key, &term_info)
422                .map_err(crate::Error::Io)?;
423            terms_processed += 1;
424
425            // Log progress every 100k terms
426            if terms_processed.is_multiple_of(100_000) {
427                log::debug!("Merge progress: {} terms processed", terms_processed);
428            }
429        }
430
431        term_dict_writer.finish().map_err(crate::Error::Io)?;
432
433        Ok(terms_processed)
434    }
435
436    /// Merge a single term's postings + positions from one or more source segments.
437    ///
438    /// Fast path: when all sources are external and there are multiple,
439    /// uses block-level concatenation (O(blocks) instead of O(postings)).
440    /// Otherwise: full decode → remap doc IDs → re-encode.
441    async fn merge_term(
442        &self,
443        segments: &[SegmentReader],
444        sources: &[(usize, TermInfo, u32)],
445        postings_out: &mut OffsetWriter,
446        positions_out: &mut OffsetWriter,
447        buf: &mut Vec<u8>,
448    ) -> Result<TermInfo> {
449        let mut sorted: Vec<_> = sources.to_vec();
450        sorted.sort_by_key(|(_, _, off)| *off);
451
452        let any_positions = sorted.iter().any(|(_, ti, _)| ti.position_info().is_some());
453        let all_external = sorted.iter().all(|(_, ti, _)| ti.external_info().is_some());
454
455        // === Merge postings ===
456        let (posting_offset, posting_len, doc_count) = if all_external && sorted.len() > 1 {
457            // Fast path: streaming merge (blocks → output writer, no buffering)
458            let mut raw_sources: Vec<(Vec<u8>, u32)> = Vec::with_capacity(sorted.len());
459            for (seg_idx, ti, doc_off) in &sorted {
460                let (off, len) = ti.external_info().unwrap();
461                let bytes = segments[*seg_idx].read_postings(off, len).await?;
462                raw_sources.push((bytes, *doc_off));
463            }
464            let refs: Vec<(&[u8], u32)> = raw_sources
465                .iter()
466                .map(|(b, off)| (b.as_slice(), *off))
467                .collect();
468            let offset = postings_out.offset();
469            let (doc_count, bytes_written) =
470                BlockPostingList::concatenate_streaming(&refs, postings_out)?;
471            (offset, bytes_written as u32, doc_count)
472        } else {
473            // Decode all sources into a flat PostingList, remap doc IDs
474            let mut merged = PostingList::new();
475            for (seg_idx, ti, doc_off) in &sorted {
476                if let Some((ids, tfs)) = ti.decode_inline() {
477                    for (id, tf) in ids.into_iter().zip(tfs) {
478                        merged.add(id + doc_off, tf);
479                    }
480                } else {
481                    let (off, len) = ti.external_info().unwrap();
482                    let bytes = segments[*seg_idx].read_postings(off, len).await?;
483                    let bpl = BlockPostingList::deserialize(&bytes)?;
484                    let mut it = bpl.iterator();
485                    while it.doc() != TERMINATED {
486                        merged.add(it.doc() + doc_off, it.term_freq());
487                        it.advance();
488                    }
489                }
490            }
491            // Try to inline (only when no positions)
492            if !any_positions {
493                let ids: Vec<u32> = merged.iter().map(|p| p.doc_id).collect();
494                let tfs: Vec<u32> = merged.iter().map(|p| p.term_freq).collect();
495                if let Some(inline) = TermInfo::try_inline(&ids, &tfs) {
496                    return Ok(inline);
497                }
498            }
499            let offset = postings_out.offset();
500            let block = BlockPostingList::from_posting_list(&merged)?;
501            buf.clear();
502            block.serialize(buf)?;
503            postings_out.write_all(buf)?;
504            (offset, buf.len() as u32, merged.doc_count())
505        };
506
507        // === Merge positions (if any source has them) ===
508        if any_positions {
509            let mut raw_pos: Vec<(Vec<u8>, u32)> = Vec::new();
510            for (seg_idx, ti, doc_off) in &sorted {
511                if let Some((pos_off, pos_len)) = ti.position_info()
512                    && let Some(bytes) = segments[*seg_idx]
513                        .read_position_bytes(pos_off, pos_len)
514                        .await?
515                {
516                    raw_pos.push((bytes, *doc_off));
517                }
518            }
519            if !raw_pos.is_empty() {
520                let refs: Vec<(&[u8], u32)> = raw_pos
521                    .iter()
522                    .map(|(b, off)| (b.as_slice(), *off))
523                    .collect();
524                let offset = positions_out.offset();
525                let (_doc_count, bytes_written) =
526                    PositionPostingList::concatenate_streaming(&refs, positions_out)
527                        .map_err(crate::Error::Io)?;
528                return Ok(TermInfo::external_with_positions(
529                    posting_offset,
530                    posting_len,
531                    doc_count,
532                    offset,
533                    bytes_written as u32,
534                ));
535            }
536        }
537
538        Ok(TermInfo::external(posting_offset, posting_len, doc_count))
539    }
540}
541
542/// Delete segment files from directory
543pub async fn delete_segment<D: Directory + DirectoryWriter>(
544    dir: &D,
545    segment_id: SegmentId,
546) -> Result<()> {
547    let files = SegmentFiles::new(segment_id.0);
548    let _ = dir.delete(&files.term_dict).await;
549    let _ = dir.delete(&files.postings).await;
550    let _ = dir.delete(&files.store).await;
551    let _ = dir.delete(&files.meta).await;
552    let _ = dir.delete(&files.vectors).await;
553    let _ = dir.delete(&files.sparse).await;
554    let _ = dir.delete(&files.positions).await;
555    Ok(())
556}