automorph 0.2.0

Derive macros for bidirectional Automerge-Rust struct synchronization
Documentation
//! Implementations for generic arrays `[T; N]`.
//!
//! This uses const generics to provide implementations for arrays of any size
//! where T: Automorph.

use automerge::{ChangeHash, ObjId, ObjType, Prop, ReadDoc, Value, transaction::Transactable};

use crate::{Automorph, ChangeReport, Error, PrimitiveChanged, Result, ScalarCursor};

impl<T: Automorph, const N: usize> Automorph for [T; N] {
    type Changes = PrimitiveChanged;
    type Cursor = ScalarCursor;

    fn save<D: Transactable + ReadDoc>(
        &self,
        doc: &mut D,
        obj: impl AsRef<ObjId>,
        prop: impl Into<Prop>,
    ) -> Result<()> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        // Check if there's already a list
        let list_id = match doc.get(obj, prop.clone())? {
            Some((Value::Object(ObjType::List), id)) => {
                // Adjust list length
                let current_len = doc.length(&id);
                if current_len > N {
                    // Remove excess items from the end
                    for _ in N..current_len {
                        doc.delete(&id, N)?;
                    }
                }
                id
            }
            _ => doc.put_object(obj, prop, ObjType::List)?,
        };

        // Sync each element
        for (i, item) in self.iter().enumerate() {
            item.save(doc, &list_id, i)?;
        }

        Ok(())
    }

    fn load<D: ReadDoc>(doc: &D, obj: impl AsRef<ObjId>, prop: impl Into<Prop>) -> Result<Self> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        match doc.get(obj, prop)? {
            Some((Value::Object(ObjType::List), list_id)) => {
                let len = doc.length(&list_id);
                if len != N {
                    return Err(Error::invalid_value(format!(
                        "expected array of length {}, got {}",
                        N, len
                    )));
                }

                // Collect into Vec first, then convert to array
                let vec: Vec<T> = (0..N)
                    .map(|i| T::load(doc, &list_id, i))
                    .collect::<Result<Vec<T>>>()?;

                // Convert Vec to array - this is safe because we know the length is N
                vec.try_into()
                    .map_err(|_| Error::invalid_value("failed to convert Vec to array"))
            }
            Some((v, _)) => Err(Error::type_mismatch(
                "[T; N] (List)",
                Some(format!("{:?}", v)),
            )),
            None => Err(Error::missing_value()),
        }
    }

    fn load_at<D: ReadDoc>(
        doc: &D,
        obj: impl AsRef<ObjId>,
        prop: impl Into<Prop>,
        heads: &[ChangeHash],
    ) -> Result<Self> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        match doc.get_at(obj, prop, heads)? {
            Some((Value::Object(ObjType::List), list_id)) => {
                let len = doc.length_at(&list_id, heads);
                if len != N {
                    return Err(Error::invalid_value(format!(
                        "expected array of length {}, got {}",
                        N, len
                    )));
                }

                // Collect into Vec first, then convert to array
                let vec: Vec<T> = (0..N)
                    .map(|i| T::load_at(doc, &list_id, i, heads))
                    .collect::<Result<Vec<T>>>()?;

                // Convert Vec to array - this is safe because we know the length is N
                vec.try_into()
                    .map_err(|_| Error::invalid_value("failed to convert Vec to array"))
            }
            Some((v, _)) => Err(Error::type_mismatch(
                "[T; N] (List)",
                Some(format!("{:?}", v)),
            )),
            None => Err(Error::missing_value()),
        }
    }

    fn diff<D: ReadDoc>(
        &self,
        doc: &D,
        obj: impl AsRef<ObjId>,
        prop: impl Into<Prop>,
    ) -> Result<Self::Changes> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        match doc.get(obj, prop)? {
            Some((Value::Object(ObjType::List), list_id)) => {
                let len = doc.length(&list_id);
                if len != N {
                    return Ok(PrimitiveChanged::new(true));
                }

                for (i, item) in self.iter().enumerate() {
                    let changes = item.diff(doc, &list_id, i)?;
                    if changes.any() {
                        return Ok(PrimitiveChanged::new(true));
                    }
                }

                Ok(PrimitiveChanged::new(false))
            }
            Some((v, _)) => Err(Error::type_mismatch(
                "[T; N] (List)",
                Some(format!("{:?}", v)),
            )),
            None => Err(Error::missing_value()),
        }
    }

    fn diff_at<D: ReadDoc>(
        &self,
        doc: &D,
        obj: impl AsRef<ObjId>,
        prop: impl Into<Prop>,
        heads: &[ChangeHash],
    ) -> Result<Self::Changes> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        match doc.get_at(obj, prop, heads)? {
            Some((Value::Object(ObjType::List), list_id)) => {
                let len = doc.length_at(&list_id, heads);
                if len != N {
                    return Ok(PrimitiveChanged::new(true));
                }

                for (i, item) in self.iter().enumerate() {
                    let changes = item.diff_at(doc, &list_id, i, heads)?;
                    if changes.any() {
                        return Ok(PrimitiveChanged::new(true));
                    }
                }

                Ok(PrimitiveChanged::new(false))
            }
            Some((v, _)) => Err(Error::type_mismatch(
                "[T; N] (List)",
                Some(format!("{:?}", v)),
            )),
            None => Err(Error::missing_value()),
        }
    }

    fn update<D: ReadDoc>(
        &mut self,
        doc: &D,
        obj: impl AsRef<ObjId>,
        prop: impl Into<Prop>,
    ) -> Result<Self::Changes> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        match doc.get(obj, prop)? {
            Some((Value::Object(ObjType::List), list_id)) => {
                let len = doc.length(&list_id);
                if len != N {
                    return Err(Error::invalid_value(format!(
                        "expected array of length {}, got {}",
                        N, len
                    )));
                }

                let mut any_changed = false;
                for (i, item) in self.iter_mut().enumerate() {
                    let changes = item.update(doc, &list_id, i)?;
                    if changes.any() {
                        any_changed = true;
                    }
                }

                Ok(PrimitiveChanged::new(any_changed))
            }
            Some((v, _)) => Err(Error::type_mismatch(
                "[T; N] (List)",
                Some(format!("{:?}", v)),
            )),
            None => Err(Error::missing_value()),
        }
    }

    fn update_at<D: ReadDoc>(
        &mut self,
        doc: &D,
        obj: impl AsRef<ObjId>,
        prop: impl Into<Prop>,
        heads: &[ChangeHash],
    ) -> Result<Self::Changes> {
        let prop: Prop = prop.into();
        let obj = obj.as_ref();

        match doc.get_at(obj, prop, heads)? {
            Some((Value::Object(ObjType::List), list_id)) => {
                let len = doc.length_at(&list_id, heads);
                if len != N {
                    return Err(Error::invalid_value(format!(
                        "expected array of length {}, got {}",
                        N, len
                    )));
                }

                let mut any_changed = false;
                for (i, item) in self.iter_mut().enumerate() {
                    let changes = item.update_at(doc, &list_id, i, heads)?;
                    if changes.any() {
                        any_changed = true;
                    }
                }

                Ok(PrimitiveChanged::new(any_changed))
            }
            Some((v, _)) => Err(Error::type_mismatch(
                "[T; N] (List)",
                Some(format!("{:?}", v)),
            )),
            None => Err(Error::missing_value()),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use automerge::{AutoCommit, ROOT};

    #[test]
    fn test_array_roundtrip() {
        let mut doc = AutoCommit::new();

        let arr = [1i64, 2, 3, 4, 5];
        arr.save(&mut doc, &ROOT, "arr").unwrap();

        let restored = <[i64; 5]>::load(&doc, &ROOT, "arr").unwrap();
        assert_eq!(restored, [1, 2, 3, 4, 5]);
    }

    #[test]
    fn test_array_wrong_size() {
        let mut doc = AutoCommit::new();

        let arr = [1i64, 2, 3];
        arr.save(&mut doc, &ROOT, "arr").unwrap();

        // Try to read as different size
        let result = <[i64; 5]>::load(&doc, &ROOT, "arr");
        assert!(result.is_err());
    }

    #[test]
    fn test_empty_array() {
        let mut doc = AutoCommit::new();

        let arr: [i64; 0] = [];
        arr.save(&mut doc, &ROOT, "arr").unwrap();

        let restored = <[i64; 0]>::load(&doc, &ROOT, "arr").unwrap();
        let expected: [i64; 0] = [];
        assert_eq!(restored, expected);
    }

    #[test]
    fn test_array_diff() {
        let mut doc = AutoCommit::new();

        let arr = [1i64, 2, 3];
        arr.save(&mut doc, &ROOT, "arr").unwrap();

        // Same array - no changes
        let same = [1i64, 2, 3];
        let changes = same.diff(&doc, &ROOT, "arr").unwrap();
        assert!(!changes.any());

        // Different array - has changes
        let different = [1i64, 2, 4];
        let changes = different.diff(&doc, &ROOT, "arr").unwrap();
        assert!(changes.any());
    }

    #[test]
    fn test_array_update() {
        let mut doc = AutoCommit::new();

        let arr = [10i64, 20, 30];
        arr.save(&mut doc, &ROOT, "arr").unwrap();

        let mut local = [0i64, 0, 0];
        let changes = local.update(&doc, &ROOT, "arr").unwrap();
        assert!(changes.any());
        assert_eq!(local, [10, 20, 30]);
    }
}