caro 0.7.1

caro: creation-addressed replicated objects
Documentation
use futures::{future, stream, FutureExt, StreamExt};
use litl::{impl_debug_as_litl, impl_nested_tagged_data_serde, NestedTaggedData};
use ridl::{
    hashing::{HashOf, StrongHash, StrongHasher},
    signing::{SignatureError, Signed, SignerID, SignerSecret},
};
use serde::Serialize;
use serde_derive::{Deserialize, Serialize};
use thiserror::Error;
use tracing::{error, info};

use crate::{
    telepathic::{
        ApplyDiffErrorFor, ApplyDiffResult, ApplyDiffSuccess, Telepathic, TelepathicDiff,
    },
    StorageBackend,
};

#[derive(Clone, Serialize, Deserialize)]
pub struct LogHeader {
    pub append_signer: SignerID,
    pub meta: Option<litl::Val>,
}

impl_debug_as_litl!(LogHeader);

#[derive(Debug)]
pub struct LogState {
    pub id: LogID,
    pub header: Option<LogHeader>,
    pub data: Vec<u8>,
    pub last_hash: Option<Signed<StrongHash>>,
    hasher: StrongHasher,
}

impl LogState {
    pub(crate) fn new_empty(id: LogID) -> Self {
        Self {
            id,
            header: None,
            data: Vec::new(),
            last_hash: None,
            hasher: StrongHasher::default(),
        }
    }

    pub(crate) fn new<M: Serialize>(meta: Option<M>) -> (Self, LogWriteAccess) {
        let signer_secret = SignerSecret::new_random();

        let header = LogHeader {
            append_signer: signer_secret.pub_id(),
            meta: meta.map(|meta| litl::to_val(&meta).unwrap()),
        };

        let id = LogID(HashOf::hash(&header));

        (
            Self {
                id,
                header: Some(header),
                data: Vec::new(),
                last_hash: None,
                hasher: StrongHasher::default(),
            },
            LogWriteAccess(signer_secret),
        )
    }

    pub(crate) fn diff_for_new_append(
        &self,
        append: &[u8],
        write_access: &LogWriteAccess,
    ) -> LogDiff {
        let mut hasher = self.hasher.clone();
        hasher.update(append);

        LogDiff {
            id: self.id,
            header: None,
            after: self.data.len(),
            append: append.to_vec(),
            new_hash: Some(write_access.0.sign(hasher.finalize())),
        }
    }
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct LogID(pub HashOf<LogHeader>);

impl NestedTaggedData for LogID {
    const TAG: &'static str = "log";

    type Inner = HashOf<LogHeader>;

    fn as_inner(&self) -> &Self::Inner {
        &self.0
    }

    fn from_inner(inner: Self::Inner) -> Self
    where
        Self: Sized,
    {
        LogID(inner)
    }
}

impl_nested_tagged_data_serde!(LogID);
impl_debug_as_litl!(LogID);

pub struct LogWriteAccess(pub(crate) SignerSecret);

impl NestedTaggedData for LogWriteAccess {
    const TAG: &'static str = "logWriteAccess";

    type Inner = SignerSecret;

    fn as_inner(&self) -> &Self::Inner {
        &self.0
    }

    fn from_inner(inner: Self::Inner) -> Self
    where
        Self: Sized,
    {
        LogWriteAccess(inner)
    }
}

impl_nested_tagged_data_serde!(LogWriteAccess);
impl_debug_as_litl!(LogWriteAccess);

#[derive(Clone, Serialize, Deserialize)]
pub struct LogDiff {
    pub id: LogID,
    pub header: Option<LogHeader>,
    pub after: usize,
    #[serde(with = "litl::raw_data_serde")]
    pub append: Vec<u8>,
    pub new_hash: Option<Signed<StrongHash>>,
}

impl TelepathicDiff for LogDiff {
    type ID = LogID;

    fn id(&self) -> Self::ID {
        self.id
    }
}

pub type LogStateInfo = usize;

#[derive(Serialize, Deserialize)]
pub struct StorageHashAndLen {
    pub hash: Signed<StrongHash>,
    pub len: usize,
}

impl Telepathic for LogState {
    type ID = LogID;
    type WriteAccess = LogWriteAccess;
    type StateInfo = LogStateInfo;
    type Diff = LogDiff;
    type Error = LogError;

    fn id(&self) -> Self::ID {
        self.id
    }

    fn try_apply_diff(
        &mut self,
        diff: LogDiff,
    ) -> ApplyDiffResult<Self::StateInfo, Self::ID, Self::Diff, Self::Error> {
        debug_assert_eq!(diff.id, self.id);

        let (header, got_header_first_time) = match (&mut self.header, &diff.header) {
            (None, None) => {
                return Err(ApplyDiffErrorFor::InvalidKnownStateAssumption);
            }
            (own @ None, Some(diff_header)) => {
                if HashOf::hash(diff_header) != self.id.0 {
                    return Err(LogError::InvalidHeaderHash.into());
                }
                *own = Some(diff_header.clone());
                (own.as_ref().unwrap(), true)
            }
            (Some(own_header), _) => (&*own_header, false),
        };

        if diff.after == 0 && diff.append.is_empty() {
            if self.data.is_empty() {
                return Ok(ApplyDiffSuccess {
                    new_state_info: 0,
                    effective_diff: LogDiff {
                        id: diff.id,
                        header: if got_header_first_time {
                            self.header.clone()
                        } else {
                            None
                        },
                        after: 0,
                        append: vec![],
                        new_hash: None,
                    },
                });
            } else {
                return Err(ApplyDiffErrorFor::InvalidKnownStateAssumption);
            }
        }

        let new_hash = diff.new_hash.as_ref().expect("Expected new hash on append");
        new_hash
            .ensure_signed_by(&header.append_signer)
            .map_err(LogError::InvalidSignature)?;

        if diff.after > self.data.len() {
            return Err(ApplyDiffErrorFor::InvalidKnownStateAssumption);
        }

        let overlap = self.data.len() - diff.after;
        let effective_append = &diff.append[overlap..];

        let mut new_hasher = self.hasher.clone();
        new_hasher.update(effective_append);

        if new_hasher.finalize() == new_hash.attested {
            self.data.extend(effective_append);
            self.last_hash = Some(new_hash.clone());
            self.hasher = new_hasher;

            let effective_diff = if overlap > 0 {
                info!(overlap = overlap, "Received redundant log update");

                LogDiff {
                    id: diff.id,
                    header: if got_header_first_time {
                        self.header.clone()
                    } else {
                        None
                    },
                    after: diff.after + overlap,
                    append: effective_append.to_vec(),
                    new_hash: diff.new_hash,
                }
            } else {
                LogDiff {
                    header: if got_header_first_time {
                        self.header.clone()
                    } else {
                        None
                    },
                    ..diff
                }
            };

            Ok(ApplyDiffSuccess {
                new_state_info: self.data.len(),
                effective_diff,
            })
        } else {
            Err(ApplyDiffErrorFor::Other(LogError::InvalidHash))
        }
    }

    fn diff_since(&self, state_info: Option<&Self::StateInfo>) -> Option<Self::Diff> {
        let (has_header, state_len) = match state_info {
            None => (false, 0),
            Some(len) => (true, *len as usize),
        };

        if state_len >= self.data.len() && has_header {
            None
        } else {
            Some(LogDiff {
                id: self.id,
                header: if has_header {
                    None
                } else {
                    self.header.clone()
                },
                after: state_len,
                append: self.data[state_len..].to_vec(),
                new_hash: self.last_hash.clone(),
            })
        }
    }

    fn state_info(&self) -> Option<Self::StateInfo> {
        if self.header.is_none() {
            None
        } else {
            Some(self.data.len())
        }
    }

    fn load(
        id: LogID,
        storage: Box<dyn StorageBackend>,
    ) -> std::pin::Pin<Box<dyn futures::Stream<Item = Self::Diff>>> {
        let key = id.to_string();

        let header_stream = {
            let key = key.clone();
            stream::once(storage.get_key(&format!("header_{}", key))).filter_map(
                move |maybe_bytes| {
                    future::ready(maybe_bytes.and_then(|bytes| {
                        litl::from_slice(&bytes)
                            .map_err(|err| {
                                error!(err = ?err, id = ?id, "Failed to read loaded log header");
                                err
                            })
                            .ok()
                            .map(|header| LogDiff {
                                id,
                                header: Some(header),
                                after: 0,
                                append: vec![],
                                new_hash: None,
                            })
                    }))
                },
            )
        };

        let all_data_stream = {
            stream::once(storage.get_key(&format!("hash_and_len_{}", key)))
                .filter_map(move |maybe_hash_and_len_bytes| {
                    future::ready(maybe_hash_and_len_bytes.and_then(|hash_and_len_bytes| {
                        litl::from_slice::<StorageHashAndLen>(&hash_and_len_bytes).map_err(|err| {
                            error!(err = ?err, id = ?id, "Failed to read loaded log hash and len");
                            err
                        }).ok()
                    }))
                })
                .filter_map(move |hash_and_len| {
                    let storage = storage.clone_ref();
                    let key = key.clone();
                    async move {
                        let all_data = storage
                            .get_stream(&key)
                            .collect::<Vec<_>>()
                            .await
                            .into_iter()
                            .flatten()
                            .collect::<Vec<u8>>();

                        if all_data.len() < hash_and_len.len {
                            error!(id = ?id, "Storage backend returned less data than last hash");
                            None
                        } else {
                            Some(LogDiff {
                                id,
                                header: None,
                                after: 0,
                                append: all_data[0..hash_and_len.len].to_vec(),
                                new_hash: Some(hash_and_len.hash),
                            })
                        }
                    }
                })
        };

        header_stream.chain(all_data_stream).boxed_local()
    }

    fn store(
        effective_diff: LogDiff,
        storage: Box<dyn StorageBackend>,
    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()>>> {
        let key = effective_diff.id.to_string();
        async move {
            if let Some(header) = &effective_diff.header {
                storage
                    .set_key(&format!("header_{}", key), litl::to_vec(header).unwrap())
                    .await;
            }

            if let Some(new_hash) = &effective_diff.new_hash {
                let new_len = effective_diff.after + effective_diff.append.len();
                storage
                    .append_to_stream(
                        &key,
                        effective_diff.append.clone(),
                        Some(effective_diff.after),
                    )
                    .await;
                storage
                    .set_key(
                        &format!("hash_and_len_{}", key),
                        litl::to_vec(&StorageHashAndLen {
                            hash: new_hash.clone(),
                            len: new_len,
                        })
                        .unwrap(),
                    )
                    .await;
            }
        }
        .boxed_local()
    }
}

#[derive(Error, Debug)]
pub enum LogError {
    #[error("Invalid header hash")]
    InvalidHeaderHash,
    #[error("Invalid hash after append")]
    InvalidHash,
    #[error(transparent)]
    InvalidSignature(#[from] SignatureError),
    #[error("Append out of order")]
    AppendOutOfOrder,
}