Skip to main content

lance_index/
mem_wal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::any::Any;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use deepsize::DeepSizeOf;
10use lance_core::Error;
11use lance_table::format::pb;
12use roaring::RoaringBitmap;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use crate::{Index, IndexType};
17
18pub const MEM_WAL_INDEX_NAME: &str = "__lance_mem_wal";
19
20/// Type alias for shard identifier (UUID v4).
21pub type ShardId = Uuid;
22
23/// A flushed MemTable generation and its storage location.
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, DeepSizeOf)]
25pub struct FlushedGeneration {
26    pub generation: u64,
27    pub path: String,
28}
29
30impl From<&FlushedGeneration> for pb::FlushedGeneration {
31    fn from(fg: &FlushedGeneration) -> Self {
32        Self {
33            generation: fg.generation,
34            path: fg.path.clone(),
35        }
36    }
37}
38
39impl From<pb::FlushedGeneration> for FlushedGeneration {
40    fn from(fg: pb::FlushedGeneration) -> Self {
41        Self {
42            generation: fg.generation,
43            path: fg.path,
44        }
45    }
46}
47
48/// A shard's merged generation, used in MemWalIndexDetails.
49#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
50pub struct MergedGeneration {
51    pub shard_id: Uuid,
52    pub generation: u64,
53}
54
55impl DeepSizeOf for MergedGeneration {
56    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
57        0 // UUID is 16 bytes fixed size, no heap allocations
58    }
59}
60
61impl MergedGeneration {
62    pub fn new(shard_id: Uuid, generation: u64) -> Self {
63        Self {
64            shard_id,
65            generation,
66        }
67    }
68}
69
70impl From<&MergedGeneration> for pb::MergedGeneration {
71    fn from(mg: &MergedGeneration) -> Self {
72        Self {
73            shard_id: Some((&mg.shard_id).into()),
74            generation: mg.generation,
75        }
76    }
77}
78
79impl TryFrom<pb::MergedGeneration> for MergedGeneration {
80    type Error = Error;
81
82    fn try_from(mg: pb::MergedGeneration) -> lance_core::Result<Self> {
83        let shard_id = mg
84            .shard_id
85            .as_ref()
86            .map(Uuid::try_from)
87            .ok_or_else(|| Error::invalid_input("Missing shard_id in MergedGeneration"))??;
88        Ok(Self {
89            shard_id,
90            generation: mg.generation,
91        })
92    }
93}
94
95/// Tracks which merged generation a base table index has been rebuilt to cover.
96/// Used to determine whether to read from flushed MemTable indexes or base table.
97#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize, DeepSizeOf)]
98pub struct IndexCatchupProgress {
99    pub index_name: String,
100    pub caught_up_generations: Vec<MergedGeneration>,
101}
102
103impl IndexCatchupProgress {
104    pub fn new(index_name: String, caught_up_generations: Vec<MergedGeneration>) -> Self {
105        Self {
106            index_name,
107            caught_up_generations,
108        }
109    }
110
111    /// Get the caught up generation for a specific shard.
112    /// Returns None if the shard is not present (assumed fully caught up).
113    pub fn caught_up_generation_for_shard(&self, shard_id: &Uuid) -> Option<u64> {
114        self.caught_up_generations
115            .iter()
116            .find(|mg| &mg.shard_id == shard_id)
117            .map(|mg| mg.generation)
118    }
119}
120
121impl From<&IndexCatchupProgress> for pb::IndexCatchupProgress {
122    fn from(icp: &IndexCatchupProgress) -> Self {
123        Self {
124            index_name: icp.index_name.clone(),
125            caught_up_generations: icp
126                .caught_up_generations
127                .iter()
128                .map(|mg| mg.into())
129                .collect(),
130        }
131    }
132}
133
134impl TryFrom<pb::IndexCatchupProgress> for IndexCatchupProgress {
135    type Error = Error;
136
137    fn try_from(icp: pb::IndexCatchupProgress) -> lance_core::Result<Self> {
138        Ok(Self {
139            index_name: icp.index_name,
140            caught_up_generations: icp
141                .caught_up_generations
142                .into_iter()
143                .map(MergedGeneration::try_from)
144                .collect::<lance_core::Result<_>>()?,
145        })
146    }
147}
148
149/// Shard manifest containing epoch-based fencing and WAL state.
150/// Each shard has exactly one active writer at any time.
151#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
152pub struct ShardManifest {
153    pub shard_id: Uuid,
154    pub version: u64,
155    pub shard_spec_id: u32,
156    /// Computed shard field values as raw Arrow scalar bytes, keyed by field id.
157    /// The byte encoding follows Arrow's little-endian convention: int32 is 4 LE
158    /// bytes, utf8 is raw UTF-8 bytes, etc. The result_type in the corresponding
159    /// ShardingField from the ShardingSpec determines how to interpret each value.
160    pub shard_field_values: HashMap<String, Vec<u8>>,
161    pub writer_epoch: u64,
162    /// The most recent WAL entry position flushed to a MemTable.
163    /// Recovery replays from `replay_after_wal_entry_position + 1`. The
164    /// default value 0 means "no flush has ever stamped this shard" — WAL
165    /// positions themselves are 1-based, so 0 is never a valid covered
166    /// position.
167    pub replay_after_wal_entry_position: u64,
168    /// The most recent WAL entry position observed at manifest write time.
169    /// Default 0 means "no entry has been written yet"; WAL positions are
170    /// 1-based.
171    pub wal_entry_position_last_seen: u64,
172    pub current_generation: u64,
173    pub flushed_generations: Vec<FlushedGeneration>,
174}
175
176impl DeepSizeOf for ShardManifest {
177    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
178        self.shard_field_values.deep_size_of_children(context)
179            + self.flushed_generations.deep_size_of_children(context)
180    }
181}
182
183impl From<&ShardManifest> for pb::ShardManifest {
184    fn from(rm: &ShardManifest) -> Self {
185        Self {
186            shard_id: Some((&rm.shard_id).into()),
187            version: rm.version,
188            shard_spec_id: rm.shard_spec_id,
189            shard_field_entries: rm
190                .shard_field_values
191                .iter()
192                .map(|(k, v)| pb::ShardFieldEntry {
193                    field_id: k.clone(),
194                    value: v.clone(),
195                })
196                .collect(),
197            writer_epoch: rm.writer_epoch,
198            replay_after_wal_entry_position: rm.replay_after_wal_entry_position,
199            wal_entry_position_last_seen: rm.wal_entry_position_last_seen,
200            current_generation: rm.current_generation,
201            flushed_generations: rm.flushed_generations.iter().map(|fg| fg.into()).collect(),
202        }
203    }
204}
205
206impl TryFrom<pb::ShardManifest> for ShardManifest {
207    type Error = Error;
208
209    fn try_from(rm: pb::ShardManifest) -> lance_core::Result<Self> {
210        let shard_id = rm
211            .shard_id
212            .as_ref()
213            .map(Uuid::try_from)
214            .ok_or_else(|| Error::invalid_input("Missing shard_id in ShardManifest"))??;
215        let shard_field_values = rm
216            .shard_field_entries
217            .into_iter()
218            .map(|e| (e.field_id, e.value))
219            .collect();
220        Ok(Self {
221            shard_id,
222            version: rm.version,
223            shard_spec_id: rm.shard_spec_id,
224            shard_field_values,
225            writer_epoch: rm.writer_epoch,
226            replay_after_wal_entry_position: rm.replay_after_wal_entry_position,
227            wal_entry_position_last_seen: rm.wal_entry_position_last_seen,
228            current_generation: rm.current_generation,
229            flushed_generations: rm
230                .flushed_generations
231                .into_iter()
232                .map(FlushedGeneration::from)
233                .collect(),
234        })
235    }
236}
237
238/// Sharding field definition.
239#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, DeepSizeOf)]
240pub struct ShardingField {
241    pub field_id: String,
242    pub source_ids: Vec<i32>,
243    pub transform: Option<String>,
244    pub expression: Option<String>,
245    pub result_type: String,
246    pub parameters: HashMap<String, String>,
247}
248
249impl From<&ShardingField> for pb::ShardingField {
250    fn from(rf: &ShardingField) -> Self {
251        Self {
252            field_id: rf.field_id.clone(),
253            source_ids: rf.source_ids.clone(),
254            transform: rf.transform.clone(),
255            expression: rf.expression.clone(),
256            result_type: rf.result_type.clone(),
257            parameters: rf.parameters.clone(),
258        }
259    }
260}
261
262impl From<pb::ShardingField> for ShardingField {
263    fn from(rf: pb::ShardingField) -> Self {
264        Self {
265            field_id: rf.field_id,
266            source_ids: rf.source_ids,
267            transform: rf.transform,
268            expression: rf.expression,
269            result_type: rf.result_type,
270            parameters: rf.parameters,
271        }
272    }
273}
274
275/// Sharding spec definition.
276#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, DeepSizeOf)]
277pub struct ShardingSpec {
278    pub spec_id: u32,
279    pub fields: Vec<ShardingField>,
280}
281
282impl From<&ShardingSpec> for pb::ShardingSpec {
283    fn from(rs: &ShardingSpec) -> Self {
284        Self {
285            spec_id: rs.spec_id,
286            fields: rs.fields.iter().map(|f| f.into()).collect(),
287        }
288    }
289}
290
291impl From<pb::ShardingSpec> for ShardingSpec {
292    fn from(rs: pb::ShardingSpec) -> Self {
293        Self {
294            spec_id: rs.spec_id,
295            fields: rs.fields.into_iter().map(ShardingField::from).collect(),
296        }
297    }
298}
299
300/// Index details for MemWAL Index, stored in IndexMetadata.index_details.
301#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize, DeepSizeOf)]
302pub struct MemWalIndexDetails {
303    pub snapshot_ts_millis: i64,
304    pub num_shards: u32,
305    pub inline_snapshots: Option<Vec<u8>>,
306    pub sharding_specs: Vec<ShardingSpec>,
307    pub maintained_indexes: Vec<String>,
308    pub merged_generations: Vec<MergedGeneration>,
309    pub index_catchup: Vec<IndexCatchupProgress>,
310    /// Default `ShardWriter` configuration values for this MemWAL index.
311    ///
312    /// Persisted so every writer — across processes and restarts — starts
313    /// from the same default writer configuration. These are defaults only;
314    /// an individual writer may still override any value at runtime in its
315    /// own (non-persisted) `ShardWriterConfig`.
316    pub writer_config_defaults: HashMap<String, String>,
317}
318
319impl From<&MemWalIndexDetails> for pb::MemWalIndexDetails {
320    fn from(details: &MemWalIndexDetails) -> Self {
321        Self {
322            snapshot_ts_millis: details.snapshot_ts_millis,
323            num_shards: details.num_shards,
324            inline_snapshots: details.inline_snapshots.clone(),
325            sharding_specs: details.sharding_specs.iter().map(|rs| rs.into()).collect(),
326            maintained_indexes: details.maintained_indexes.clone(),
327            merged_generations: details
328                .merged_generations
329                .iter()
330                .map(|mg| mg.into())
331                .collect(),
332            index_catchup: details.index_catchup.iter().map(|icp| icp.into()).collect(),
333            writer_config_defaults: details.writer_config_defaults.clone(),
334        }
335    }
336}
337
338impl TryFrom<pb::MemWalIndexDetails> for MemWalIndexDetails {
339    type Error = Error;
340
341    fn try_from(details: pb::MemWalIndexDetails) -> lance_core::Result<Self> {
342        Ok(Self {
343            snapshot_ts_millis: details.snapshot_ts_millis,
344            num_shards: details.num_shards,
345            inline_snapshots: details.inline_snapshots,
346            sharding_specs: details
347                .sharding_specs
348                .into_iter()
349                .map(ShardingSpec::from)
350                .collect(),
351            maintained_indexes: details.maintained_indexes,
352            merged_generations: details
353                .merged_generations
354                .into_iter()
355                .map(MergedGeneration::try_from)
356                .collect::<lance_core::Result<_>>()?,
357            index_catchup: details
358                .index_catchup
359                .into_iter()
360                .map(IndexCatchupProgress::try_from)
361                .collect::<lance_core::Result<_>>()?,
362            writer_config_defaults: details.writer_config_defaults,
363        })
364    }
365}
366
367/// MemWAL Index provides access to MemWAL configuration and state.
368#[derive(Debug, Clone, PartialEq, Eq, DeepSizeOf)]
369pub struct MemWalIndex {
370    pub details: MemWalIndexDetails,
371}
372
373impl MemWalIndex {
374    pub fn new(details: MemWalIndexDetails) -> Self {
375        Self { details }
376    }
377
378    pub fn merged_generation_for_shard(&self, shard_id: &Uuid) -> Option<u64> {
379        self.details
380            .merged_generations
381            .iter()
382            .find(|mg| &mg.shard_id == shard_id)
383            .map(|mg| mg.generation)
384    }
385
386    /// Get the caught up generation for a specific index and shard.
387    /// Returns None if the index is not tracked (assumed fully caught up).
388    pub fn index_caught_up_generation(&self, index_name: &str, shard_id: &Uuid) -> Option<u64> {
389        self.details
390            .index_catchup
391            .iter()
392            .find(|icp| icp.index_name == index_name)
393            .and_then(|icp| icp.caught_up_generation_for_shard(shard_id))
394    }
395
396    /// Check if an index is fully caught up for a shard.
397    /// Returns true if the index covers all merged data for the shard.
398    pub fn is_index_caught_up(&self, index_name: &str, shard_id: &Uuid) -> bool {
399        let merged_gen = self.merged_generation_for_shard(shard_id).unwrap_or(0);
400        let caught_up_gen = self.index_caught_up_generation(index_name, shard_id);
401
402        // If not tracked in index_catchup, assumed fully caught up
403        caught_up_gen.is_none_or(|generation| generation >= merged_gen)
404    }
405}
406
407#[derive(Serialize)]
408struct MemWalStatistics {
409    num_shards: u32,
410    num_merged_generations: usize,
411    num_shard_specs: usize,
412    num_maintained_indexes: usize,
413    num_index_catchup_entries: usize,
414}
415
416#[async_trait]
417impl Index for MemWalIndex {
418    fn as_any(&self) -> &dyn Any {
419        self
420    }
421
422    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
423        self
424    }
425
426    fn as_vector_index(self: Arc<Self>) -> lance_core::Result<Arc<dyn crate::vector::VectorIndex>> {
427        Err(Error::not_supported_source(
428            "MemWalIndex is not a vector index".into(),
429        ))
430    }
431
432    fn statistics(&self) -> lance_core::Result<serde_json::Value> {
433        let stats = MemWalStatistics {
434            num_shards: self.details.num_shards,
435            num_merged_generations: self.details.merged_generations.len(),
436            num_shard_specs: self.details.sharding_specs.len(),
437            num_maintained_indexes: self.details.maintained_indexes.len(),
438            num_index_catchup_entries: self.details.index_catchup.len(),
439        };
440        serde_json::to_value(stats).map_err(|e| {
441            Error::internal(format!(
442                "failed to serialize MemWAL index statistics: {}",
443                e
444            ))
445        })
446    }
447
448    async fn prewarm(&self) -> lance_core::Result<()> {
449        Ok(())
450    }
451
452    fn index_type(&self) -> IndexType {
453        IndexType::MemWal
454    }
455
456    async fn calculate_included_frags(&self) -> lance_core::Result<RoaringBitmap> {
457        Ok(RoaringBitmap::new())
458    }
459}