icechunk/
change_set.rs

1use std::{
2    collections::{BTreeMap, HashMap, HashSet},
3    iter,
4    mem::take,
5};
6
7use bytes::Bytes;
8use itertools::{Either, Itertools as _};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    format::{
13        ChunkIndices, NodeId, Path,
14        manifest::{ChunkInfo, ChunkPayload},
15        snapshot::{ArrayShape, DimensionName, NodeData, NodeSnapshot},
16    },
17    session::SessionResult,
18};
19
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct ArrayData {
22    pub shape: ArrayShape,
23    pub dimension_names: Option<Vec<DimensionName>>,
24    pub user_data: Bytes,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
28pub struct ChangeSet {
29    new_groups: HashMap<Path, (NodeId, Bytes)>,
30    new_arrays: HashMap<Path, (NodeId, ArrayData)>,
31    updated_arrays: HashMap<NodeId, ArrayData>,
32    updated_groups: HashMap<NodeId, Bytes>,
33    // It's important we keep these sorted, we use this fact in TransactionLog creation
34    set_chunks: BTreeMap<NodeId, BTreeMap<ChunkIndices, Option<ChunkPayload>>>,
35    deleted_groups: HashSet<(Path, NodeId)>,
36    deleted_arrays: HashSet<(Path, NodeId)>,
37}
38
39impl ChangeSet {
40    pub fn deleted_arrays(&self) -> impl Iterator<Item = &(Path, NodeId)> {
41        self.deleted_arrays.iter()
42    }
43
44    pub fn deleted_groups(&self) -> impl Iterator<Item = &(Path, NodeId)> {
45        self.deleted_groups.iter()
46    }
47
48    pub fn updated_arrays(&self) -> impl Iterator<Item = &NodeId> {
49        self.updated_arrays.keys()
50    }
51
52    pub fn updated_groups(&self) -> impl Iterator<Item = &NodeId> {
53        self.updated_groups.keys()
54    }
55
56    pub fn array_is_deleted(&self, path_and_id: &(Path, NodeId)) -> bool {
57        self.deleted_arrays.contains(path_and_id)
58    }
59
60    pub fn chunk_changes(
61        &self,
62    ) -> impl Iterator<Item = (&NodeId, &BTreeMap<ChunkIndices, Option<ChunkPayload>>)>
63    {
64        self.set_chunks.iter()
65    }
66
67    pub fn has_chunk_changes(&self, node: &NodeId) -> bool {
68        self.set_chunks.get(node).map(|m| !m.is_empty()).unwrap_or(false)
69    }
70
71    pub fn arrays_with_chunk_changes(&self) -> impl Iterator<Item = &NodeId> {
72        self.chunk_changes().map(|(node, _)| node)
73    }
74
75    pub fn is_empty(&self) -> bool {
76        self == &ChangeSet::default()
77    }
78
79    pub fn add_group(&mut self, path: Path, node_id: NodeId, definition: Bytes) {
80        debug_assert!(!self.updated_groups.contains_key(&node_id));
81        self.new_groups.insert(path, (node_id, definition));
82    }
83
84    pub fn get_group(&self, path: &Path) -> Option<&(NodeId, Bytes)> {
85        self.new_groups.get(path)
86    }
87
88    pub fn get_array(&self, path: &Path) -> Option<&(NodeId, ArrayData)> {
89        self.new_arrays.get(path)
90    }
91
92    /// IMPORTANT: This method does not delete children. The caller
93    /// is responsible for doing that
94    pub fn delete_group(&mut self, path: Path, node_id: &NodeId) {
95        self.updated_groups.remove(node_id);
96        if self.new_groups.remove(&path).is_none() {
97            // it's an old group, we need to flag it as deleted
98            self.deleted_groups.insert((path, node_id.clone()));
99        }
100    }
101
102    pub fn add_array(&mut self, path: Path, node_id: NodeId, array_data: ArrayData) {
103        self.new_arrays.insert(path, (node_id, array_data));
104    }
105
106    pub fn update_array(&mut self, node_id: &NodeId, path: &Path, array_data: ArrayData) {
107        match self.new_arrays.get(path) {
108            Some((id, _)) => {
109                debug_assert!(!self.updated_arrays.contains_key(id));
110                self.new_arrays.insert(path.clone(), (node_id.clone(), array_data));
111            }
112            None => {
113                self.updated_arrays.insert(node_id.clone(), array_data);
114            }
115        }
116    }
117
118    pub fn update_group(&mut self, node_id: &NodeId, path: &Path, definition: Bytes) {
119        match self.new_groups.get(path) {
120            Some((id, _)) => {
121                debug_assert!(!self.updated_groups.contains_key(id));
122                self.new_groups.insert(path.clone(), (node_id.clone(), definition));
123            }
124            None => {
125                self.updated_groups.insert(node_id.clone(), definition);
126            }
127        }
128    }
129
130    pub fn delete_array(&mut self, path: Path, node_id: &NodeId) {
131        // if deleting a new array created in this session, just remove the entry
132        // from new_arrays
133        let node_and_meta = self.new_arrays.remove(&path);
134        let is_new_array = node_and_meta.is_some();
135        debug_assert!(
136            !is_new_array || node_and_meta.map(|n| n.0).as_ref() == Some(node_id)
137        );
138
139        self.updated_arrays.remove(node_id);
140        self.set_chunks.remove(node_id);
141        if !is_new_array {
142            self.deleted_arrays.insert((path, node_id.clone()));
143        }
144    }
145
146    pub fn is_deleted(&self, path: &Path, node_id: &NodeId) -> bool {
147        let key = (path.clone(), node_id.clone());
148        self.deleted_groups.contains(&key) || self.deleted_arrays.contains(&key)
149    }
150
151    //pub fn has_updated_definition(&self, node_id: &NodeId) -> bool {
152    //    self.updated_definitions.contains_key(node_id)
153    //}
154
155    pub fn get_updated_array(&self, node_id: &NodeId) -> Option<&ArrayData> {
156        self.updated_arrays.get(node_id)
157    }
158
159    pub fn get_updated_group(&self, node_id: &NodeId) -> Option<&Bytes> {
160        self.updated_groups.get(node_id)
161    }
162
163    pub fn set_chunk_ref(
164        &mut self,
165        node_id: NodeId,
166        coord: ChunkIndices,
167        data: Option<ChunkPayload>,
168    ) {
169        // this implementation makes delete idempotent
170        // it allows deleting a deleted chunk by repeatedly setting None.
171        self.set_chunks
172            .entry(node_id)
173            .and_modify(|h| {
174                h.insert(coord.clone(), data.clone());
175            })
176            .or_insert(BTreeMap::from([(coord, data)]));
177    }
178
179    pub fn get_chunk_ref(
180        &self,
181        node_id: &NodeId,
182        coords: &ChunkIndices,
183    ) -> Option<&Option<ChunkPayload>> {
184        self.set_chunks.get(node_id).and_then(|h| h.get(coords))
185    }
186
187    /// Drop the updated chunk references for the node.
188    /// This will only drop the references for which `predicate` returns true
189    pub fn drop_chunk_changes(
190        &mut self,
191        node_id: &NodeId,
192        predicate: impl Fn(&ChunkIndices) -> bool,
193    ) {
194        if let Some(changes) = self.set_chunks.get_mut(node_id) {
195            changes.retain(|coord, _| !predicate(coord));
196        }
197    }
198
199    pub fn array_chunks_iterator(
200        &self,
201        node_id: &NodeId,
202        node_path: &Path,
203    ) -> impl Iterator<Item = (&ChunkIndices, &Option<ChunkPayload>)> + use<'_> {
204        if self.is_deleted(node_path, node_id) {
205            return Either::Left(iter::empty());
206        }
207        match self.set_chunks.get(node_id) {
208            None => Either::Left(iter::empty()),
209            Some(h) => Either::Right(h.iter()),
210        }
211    }
212
213    pub fn new_arrays_chunk_iterator(
214        &self,
215    ) -> impl Iterator<Item = (Path, ChunkInfo)> + use<'_> {
216        self.new_arrays.iter().flat_map(|(path, (node_id, _))| {
217            self.new_array_chunk_iterator(node_id, path).map(|ci| (path.clone(), ci))
218        })
219    }
220
221    pub fn new_array_chunk_iterator<'a>(
222        &'a self,
223        node_id: &'a NodeId,
224        node_path: &Path,
225    ) -> impl Iterator<Item = ChunkInfo> + use<'a> {
226        self.array_chunks_iterator(node_id, node_path).filter_map(
227            move |(coords, payload)| {
228                payload.as_ref().map(|p| ChunkInfo {
229                    node: node_id.clone(),
230                    coord: coords.clone(),
231                    payload: p.clone(),
232                })
233            },
234        )
235    }
236
237    pub fn new_nodes(&self) -> impl Iterator<Item = (&Path, &NodeId)> {
238        self.new_groups().chain(self.new_arrays())
239    }
240
241    pub fn new_groups(&self) -> impl Iterator<Item = (&Path, &NodeId)> {
242        self.new_groups.iter().map(|(path, (node_id, _))| (path, node_id))
243    }
244
245    pub fn new_arrays(&self) -> impl Iterator<Item = (&Path, &NodeId)> {
246        self.new_arrays.iter().map(|(path, (node_id, _))| (path, node_id))
247    }
248
249    pub fn take_chunks(
250        &mut self,
251    ) -> BTreeMap<NodeId, BTreeMap<ChunkIndices, Option<ChunkPayload>>> {
252        take(&mut self.set_chunks)
253    }
254
255    pub fn set_chunks(
256        &mut self,
257        chunks: BTreeMap<NodeId, BTreeMap<ChunkIndices, Option<ChunkPayload>>>,
258    ) {
259        self.set_chunks = chunks
260    }
261
262    /// Merge this ChangeSet with `other`.
263    ///
264    /// Results of the merge are applied to `self`. Changes present in `other` take precedence over
265    /// `self` changes.
266    pub fn merge(&mut self, other: ChangeSet) {
267        // FIXME: this should detect conflict, for example, if different writers added on the same
268        // path, different objects, or if the same path is added and deleted, etc.
269        // TODO: optimize
270        self.new_groups.extend(other.new_groups);
271        self.new_arrays.extend(other.new_arrays);
272        self.updated_groups.extend(other.updated_groups);
273        self.updated_arrays.extend(other.updated_arrays);
274        self.deleted_groups.extend(other.deleted_groups);
275        self.deleted_arrays.extend(other.deleted_arrays);
276
277        for (node, other_chunks) in other.set_chunks.into_iter() {
278            match self.set_chunks.remove(&node) {
279                Some(mut old_value) => {
280                    old_value.extend(other_chunks);
281                    self.set_chunks.insert(node, old_value);
282                }
283                None => {
284                    self.set_chunks.insert(node, other_chunks);
285                }
286            }
287        }
288    }
289
290    pub fn merge_many<T: IntoIterator<Item = ChangeSet>>(&mut self, others: T) {
291        others.into_iter().fold(self, |res, change_set| {
292            res.merge(change_set);
293            res
294        });
295    }
296
297    /// Serialize this ChangeSet
298    ///
299    /// This is intended to help with marshalling distributed writers back to the coordinator
300    pub fn export_to_bytes(&self) -> SessionResult<Vec<u8>> {
301        Ok(rmp_serde::to_vec(self)?)
302    }
303
304    /// Deserialize a ChangeSet
305    ///
306    /// This is intended to help with marshalling distributed writers back to the coordinator
307    pub fn import_from_bytes(bytes: &[u8]) -> SessionResult<Self> {
308        Ok(rmp_serde::from_slice(bytes)?)
309    }
310
311    pub fn update_existing_chunks<'a, E>(
312        &'a self,
313        node: NodeId,
314        chunks: impl Iterator<Item = Result<ChunkInfo, E>> + 'a,
315    ) -> impl Iterator<Item = Result<ChunkInfo, E>> + 'a {
316        chunks.filter_map_ok(move |chunk| match self.get_chunk_ref(&node, &chunk.coord) {
317            None => Some(chunk),
318            Some(new_payload) => {
319                new_payload.clone().map(|pl| ChunkInfo { payload: pl, ..chunk })
320            }
321        })
322    }
323
324    pub fn get_new_node(&self, path: &Path) -> Option<NodeSnapshot> {
325        self.get_new_array(path).or(self.get_new_group(path))
326    }
327
328    pub fn get_new_array(&self, path: &Path) -> Option<NodeSnapshot> {
329        self.get_array(path).map(|(id, array_data)| {
330            debug_assert!(!self.updated_arrays.contains_key(id));
331            NodeSnapshot {
332                id: id.clone(),
333                path: path.clone(),
334                user_data: array_data.user_data.clone(),
335                // We put no manifests in new arrays, see get_chunk_ref to understand how chunks get
336                // fetched for those arrays
337                node_data: NodeData::Array {
338                    shape: array_data.shape.clone(),
339                    dimension_names: array_data.dimension_names.clone(),
340                    manifests: vec![],
341                },
342            }
343        })
344    }
345
346    pub fn get_new_group(&self, path: &Path) -> Option<NodeSnapshot> {
347        self.get_group(path).map(|(id, definition)| {
348            debug_assert!(!self.updated_groups.contains_key(id));
349            NodeSnapshot {
350                id: id.clone(),
351                path: path.clone(),
352                user_data: definition.clone(),
353                node_data: NodeData::Group,
354            }
355        })
356    }
357
358    pub fn new_nodes_iterator(&self) -> impl Iterator<Item = NodeSnapshot> {
359        self.new_nodes().filter_map(move |(path, node_id)| {
360            if self.is_deleted(path, node_id) {
361                return None;
362            }
363            // we should be able to create the full node because we
364            // know it's a new node
365            #[allow(clippy::expect_used)]
366            let node = self.get_new_node(path).expect("Bug in new_nodes implementation");
367            Some(node)
368        })
369    }
370
371    // Applies the changeset to an existing node, yielding a new node if it hasn't been deleted
372    pub fn update_existing_node(&self, node: NodeSnapshot) -> Option<NodeSnapshot> {
373        if self.is_deleted(&node.path, &node.id) {
374            return None;
375        }
376
377        match node.node_data {
378            NodeData::Group => {
379                let new_definition =
380                    self.updated_groups.get(&node.id).cloned().unwrap_or(node.user_data);
381                Some(NodeSnapshot { user_data: new_definition, ..node })
382            }
383            NodeData::Array { shape, dimension_names, manifests } => {
384                let new_data =
385                    self.updated_arrays.get(&node.id).cloned().unwrap_or_else(|| {
386                        ArrayData { shape, dimension_names, user_data: node.user_data }
387                    });
388                Some(NodeSnapshot {
389                    user_data: new_data.user_data,
390                    node_data: NodeData::Array {
391                        shape: new_data.shape,
392                        dimension_names: new_data.dimension_names,
393                        manifests,
394                    },
395                    ..node
396                })
397            }
398        }
399    }
400
401    pub fn undo_update(&mut self, node_id: &NodeId) {
402        self.updated_arrays.remove(node_id);
403        self.updated_groups.remove(node_id);
404    }
405}
406
407#[cfg(test)]
408#[allow(clippy::unwrap_used)]
409mod tests {
410    use bytes::Bytes;
411    use itertools::Itertools;
412
413    use super::ChangeSet;
414
415    use crate::{
416        change_set::ArrayData,
417        format::{
418            ChunkIndices, NodeId,
419            manifest::{ChunkInfo, ChunkPayload},
420            snapshot::ArrayShape,
421        },
422    };
423
424    #[icechunk_macros::test]
425    fn test_new_arrays_chunk_iterator() {
426        let mut change_set = ChangeSet::default();
427        assert_eq!(None, change_set.new_arrays_chunk_iterator().next());
428
429        let shape = ArrayShape::new(vec![(2, 1), (2, 1), (2, 1)]).unwrap();
430        let dimension_names = Some(vec!["x".into(), "y".into(), "t".into()]);
431
432        let node_id1 = NodeId::random();
433        let node_id2 = NodeId::random();
434        let array_data = ArrayData {
435            shape: shape.clone(),
436            dimension_names: dimension_names.clone(),
437            user_data: Bytes::from_static(b"foobar"),
438        };
439        change_set.add_array(
440            "/foo/bar".try_into().unwrap(),
441            node_id1.clone(),
442            array_data.clone(),
443        );
444        change_set.add_array(
445            "/foo/baz".try_into().unwrap(),
446            node_id2.clone(),
447            array_data.clone(),
448        );
449        assert_eq!(None, change_set.new_arrays_chunk_iterator().next());
450
451        change_set.set_chunk_ref(node_id1.clone(), ChunkIndices(vec![0, 1]), None);
452        assert_eq!(None, change_set.new_arrays_chunk_iterator().next());
453
454        change_set.set_chunk_ref(
455            node_id1.clone(),
456            ChunkIndices(vec![1, 0]),
457            Some(ChunkPayload::Inline("bar1".into())),
458        );
459        change_set.set_chunk_ref(
460            node_id1.clone(),
461            ChunkIndices(vec![1, 1]),
462            Some(ChunkPayload::Inline("bar2".into())),
463        );
464        change_set.set_chunk_ref(
465            node_id2.clone(),
466            ChunkIndices(vec![0]),
467            Some(ChunkPayload::Inline("baz1".into())),
468        );
469        change_set.set_chunk_ref(
470            node_id2.clone(),
471            ChunkIndices(vec![1]),
472            Some(ChunkPayload::Inline("baz2".into())),
473        );
474
475        {
476            let all_chunks: Vec<_> = change_set
477                .new_arrays_chunk_iterator()
478                .sorted_by_key(|c| c.1.coord.clone())
479                .collect();
480            let expected_chunks: Vec<_> = [
481                (
482                    "/foo/baz".try_into().unwrap(),
483                    ChunkInfo {
484                        node: node_id2.clone(),
485                        coord: ChunkIndices(vec![0]),
486                        payload: ChunkPayload::Inline("baz1".into()),
487                    },
488                ),
489                (
490                    "/foo/baz".try_into().unwrap(),
491                    ChunkInfo {
492                        node: node_id2.clone(),
493                        coord: ChunkIndices(vec![1]),
494                        payload: ChunkPayload::Inline("baz2".into()),
495                    },
496                ),
497                (
498                    "/foo/bar".try_into().unwrap(),
499                    ChunkInfo {
500                        node: node_id1.clone(),
501                        coord: ChunkIndices(vec![1, 0]),
502                        payload: ChunkPayload::Inline("bar1".into()),
503                    },
504                ),
505                (
506                    "/foo/bar".try_into().unwrap(),
507                    ChunkInfo {
508                        node: node_id1.clone(),
509                        coord: ChunkIndices(vec![1, 1]),
510                        payload: ChunkPayload::Inline("bar2".into()),
511                    },
512                ),
513            ]
514            .into();
515            assert_eq!(all_chunks, expected_chunks);
516        }
517    }
518}