endr 0.9.0

endr: append-only replicated objects
Documentation
use futures::{future, stream, FutureExt, StreamExt};
use litl::{impl_debug_as_litl, impl_nested_tagged_data_serde, NestedTaggedData};
use ridl::{
    hashing::{HashOf, StrongHash, StrongHasher},
    signing::{SignatureError, Signed, SignerID, SignerSecret},
};
use serde::Serialize;
use serde_derive::{Deserialize, Serialize};
use thiserror::Error;
use tracing::{debug, error, info};

use crate::{
    telepathic::{
        ApplyDiffErrorFor, ApplyDiffResult, ApplyDiffSuccess, Telepathic, TelepathicDiff,
    },
    ObjectID, StorageBackend,
};

#[derive(Clone, Serialize, Deserialize)]
pub struct LogHeader {
    pub appender: SignerID,
    pub meta: Option<litl::Val>,
}

impl_debug_as_litl!(LogHeader);

#[derive(Debug)]
pub struct LogState {
    pub id: LogID,
    pub header: Option<LogHeader>,
    pub entries: Vec<litl::Val>,
    pub last_hash: Option<Signed<StrongHash>>,
    hasher: StrongHasher,
}

impl LogState {
    pub(crate) fn new_empty(id: LogID) -> Self {
        Self {
            id,
            header: None,
            entries: Vec::new(),
            last_hash: None,
            hasher: StrongHasher::default(),
        }
    }

    pub(crate) fn new<M: Serialize>(meta: Option<M>) -> (Self, LogWriteAccess) {
        let signer_secret = SignerSecret::new_random();

        let header = LogHeader {
            appender: signer_secret.pub_id(),
            meta: meta.map(|meta| litl::to_val(&meta).unwrap()),
        };

        let id = LogID(HashOf::hash(&header));

        (
            Self {
                id,
                header: Some(header),
                entries: Vec::new(),
                last_hash: None,
                hasher: StrongHasher::default(),
            },
            LogWriteAccess(signer_secret),
        )
    }

    pub(crate) fn diff_for_new_entry(
        &self,
        entry: litl::Val,
        write_access: &LogWriteAccess,
    ) -> LogDiff {
        let mut hasher = self.hasher.clone();
        hasher.update(&litl::to_vec_canonical(&entry).unwrap());

        LogDiff {
            id: self.id,
            header: None,
            after: self.entries.len(),
            new_entries: vec![entry],
            new_hash: Some(write_access.0.sign(hasher.finalize())),
        }
    }
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct LogID(pub HashOf<LogHeader>);

impl NestedTaggedData for LogID {
    const TAG: &'static str = "log";

    type Inner = HashOf<LogHeader>;

    fn as_inner(&self) -> &Self::Inner {
        &self.0
    }

    fn from_inner(inner: Self::Inner) -> Self
    where
        Self: Sized,
    {
        LogID(inner)
    }
}

impl_nested_tagged_data_serde!(LogID);
impl_debug_as_litl!(LogID);

pub struct LogWriteAccess(pub(crate) SignerSecret);

impl NestedTaggedData for LogWriteAccess {
    const TAG: &'static str = "logWriteAccess";

    type Inner = SignerSecret;

    fn as_inner(&self) -> &Self::Inner {
        &self.0
    }

    fn from_inner(inner: Self::Inner) -> Self
    where
        Self: Sized,
    {
        LogWriteAccess(inner)
    }
}

impl_nested_tagged_data_serde!(LogWriteAccess);
impl_debug_as_litl!(LogWriteAccess);

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct LogDiff {
    pub id: LogID,
    pub header: Option<LogHeader>,
    pub after: usize,
    pub new_entries: Vec<litl::Val>,
    pub new_hash: Option<Signed<StrongHash>>,
}

impl TelepathicDiff for LogDiff {
    type ID = LogID;

    fn id(&self) -> Self::ID {
        self.id
    }
}

pub type LogStateInfo = usize;

#[derive(Serialize, Deserialize)]
pub struct StorageHashAndLen {
    pub hash: Signed<StrongHash>,
    pub len: usize,
}

impl Telepathic for LogState {
    type ID = LogID;
    type WriteAccess = LogWriteAccess;
    type StateInfo = LogStateInfo;
    type Diff = LogDiff;
    type Error = LogError;

    fn id(&self) -> Self::ID {
        self.id
    }

    fn try_apply_diff(
        &mut self,
        diff: LogDiff,
    ) -> ApplyDiffResult<Self::StateInfo, Self::ID, Self::Diff, Self::Error> {
        debug_assert_eq!(diff.id, self.id);

        let (header, got_header_first_time) = match (&mut self.header, &diff.header) {
            (None, None) => {
                return Ok(None)
            }
            (own @ None, Some(diff_header)) => {
                if HashOf::hash(diff_header) != self.id.0 {
                    return Err(LogError::InvalidHeaderHash.into());
                }
                *own = Some(diff_header.clone());
                (own.as_ref().unwrap(), true)
            }
            (Some(own_header), _) => (&*own_header, false),
        };

        if diff.after == 0 && diff.new_entries.is_empty() {
            if self.entries.is_empty() {
                return Ok(Some(ApplyDiffSuccess {
                    new_state_info: 0,
                    effective_diff: LogDiff {
                        id: diff.id,
                        header: if got_header_first_time {
                            self.header.clone()
                        } else {
                            None
                        },
                        after: 0,
                        new_entries: vec![],
                        new_hash: None,
                    },
                }));
            } else {
                return Err(ApplyDiffErrorFor::InvalidKnownStateAssumption(
                    ObjectID::Log(self.id),
                    "Got initial diff for non-empty log".to_owned(),
                ));
            }
        }

        let new_hash = diff.new_hash.as_ref().expect("Expected new hash on append");
        new_hash
            .ensure_signed_by(&header.appender)
            .map_err(LogError::InvalidSignature)?;

        if diff.after > self.entries.len() {
            return Err(ApplyDiffErrorFor::InvalidKnownStateAssumption(
                ObjectID::Log(self.id),
                "Got diff later than current log length".to_owned(),
            ));
        }

        if diff.after + diff.new_entries.len() <= self.entries.len() {
            info!(
                after = diff.after,
                new_entries = diff.new_entries.len(),
                self_len = self.entries.len(),
                "Received completely redundant log update"
            );
            return Ok(None)
        }

        let overlap = self.entries.len() - diff.after;
        let data_len_before = self.entries.len();
        let effective_new_entries = &diff.new_entries[overlap..];

        let mut new_hasher = self.hasher.clone();
        for entry in &effective_new_entries[..] {
            new_hasher.update(&litl::to_vec_canonical(&entry).unwrap());
        }

        if new_hasher.finalize() == new_hash.attested {
            self.entries.extend(effective_new_entries.to_vec());
            self.last_hash = Some(new_hash.clone());
            self.hasher = new_hasher;

            let effective_diff = if overlap > 0 {
                info!(
                    overlap = overlap,
                    after = diff.after,
                    self_len = data_len_before,
                    new_entries = diff.new_entries.len(),
                    effective_new_entries = effective_new_entries.len(),
                    "Received redundant log update"
                );

                LogDiff {
                    id: diff.id,
                    header: if got_header_first_time {
                        self.header.clone()
                    } else {
                        None
                    },
                    after: data_len_before,
                    new_entries: effective_new_entries.to_vec(),
                    new_hash: diff.new_hash,
                }
            } else {
                LogDiff {
                    header: if got_header_first_time {
                        self.header.clone()
                    } else {
                        None
                    },
                    ..diff
                }
            };

            Ok(Some(ApplyDiffSuccess {
                new_state_info: self.entries.len(),
                effective_diff,
            }))
        } else {
            Err(ApplyDiffErrorFor::Other(LogError::InvalidHash))
        }
    }

    fn diff_since(&self, state_info: Option<&Self::StateInfo>) -> Option<Self::Diff> {
        let (has_header, state_len) = match state_info {
            None => (false, 0),
            Some(len) => (true, *len as usize),
        };

        if state_len >= self.entries.len() && has_header {
            None
        } else {
            Some(LogDiff {
                id: self.id,
                header: if has_header {
                    None
                } else {
                    self.header.clone()
                },
                after: state_len,
                new_entries: self.entries[state_len..].to_vec(),
                new_hash: self.last_hash.clone(),
            })
        }
    }

    fn state_info(&self) -> Option<Self::StateInfo> {
        if self.header.is_none() {
            None
        } else {
            Some(self.entries.len())
        }
    }

    fn load(
        id: LogID,
        storage: Box<dyn StorageBackend>,
    ) -> std::pin::Pin<Box<dyn futures::Stream<Item = Self::Diff>>> {
        let key = id.to_string();

        let header_stream = {
            let key = key.clone();
            stream::once(storage.get_key(&format!("header_{}", key))).filter_map(
                move |maybe_bytes| {
                    future::ready(maybe_bytes.and_then(|bytes| {
                        litl::from_slice(&bytes)
                            .map_err(|err| {
                                error!(err = ?err, id = ?id, "Failed to read loaded log header");
                                err
                            })
                            .ok()
                            .map(|header| LogDiff {
                                id,
                                header: Some(header),
                                after: 0,
                                new_entries: vec![],
                                new_hash: None,
                            })
                    }))
                },
            )
        };

        let all_chunks_stream = {
            stream::once(storage.get_key(&format!("hash_and_len_{}", key)))
                .filter_map(move |maybe_hash_and_len_bytes| {
                    future::ready(maybe_hash_and_len_bytes.and_then(|hash_and_len_bytes| {
                        litl::from_slice::<StorageHashAndLen>(&hash_and_len_bytes).map_err(|err| {
                            error!(err = ?err, id = ?id, "Failed to read loaded log hash and len");
                            err
                        }).ok()
                    }))
                })
                .filter_map(move |hash_and_len| {
                    let storage = storage.clone_ref();
                    let key = key.clone();
                    async move {
                        let all_items = storage
                            .get_stream(&key)
                            .map(|data|litl::from_slice::<litl::Val>(&data).unwrap())
                            .collect::<Vec<_>>()
                            .await
                           ;

                        if all_items.len() < hash_and_len.len {
                            error!(id = ?id, "Storage backend returned less data than last hash");
                            None
                        } else {
                            Some(LogDiff {
                                id,
                                header: None,
                                after: 0,
                                new_entries: all_items[0..hash_and_len.len].to_vec(),
                                new_hash: Some(hash_and_len.hash),
                            })
                        }
                    }
                })
        };

        header_stream.chain(all_chunks_stream).boxed_local()
    }

    fn store(
        effective_diff: LogDiff,
        storage: Box<dyn StorageBackend>,
    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()>>> {
        let key = effective_diff.id.to_string();
        async move {
            if let Some(header) = &effective_diff.header {
                storage
                    .set_key(&format!("header_{}", key), litl::to_vec(header).unwrap())
                    .await;
            }

            if let Some(new_hash) = &effective_diff.new_hash {
                if effective_diff.new_entries.is_empty() {
                    return;
                }
                let new_len = effective_diff.after + effective_diff.new_entries.len();
                for entry in effective_diff.new_entries {
                    storage
                        .append_to_stream(&key, litl::to_vec(&entry).unwrap(), None)
                        .await;
                }

                storage
                    .set_key(
                        &format!("hash_and_len_{}", key),
                        litl::to_vec(&StorageHashAndLen {
                            hash: new_hash.clone(),
                            len: new_len,
                        })
                        .unwrap(),
                    )
                    .await;
            }
        }
        .boxed_local()
    }
}

#[derive(Error, Debug)]
pub enum LogError {
    #[error("Invalid header hash")]
    InvalidHeaderHash,
    #[error("Invalid hash after append")]
    InvalidHash,
    #[error(transparent)]
    InvalidSignature(#[from] SignatureError),
    #[error("Append out of order")]
    AppendOutOfOrder,
}