#![expect(missing_docs)]
use std::any::Any;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::try_join_all;
use itertools::Itertools as _;
use pollster::FutureExt as _;
use thiserror::Error;
use crate::dag_walk;
use crate::op_store::OpStore;
use crate::op_store::OpStoreError;
use crate::op_store::OperationId;
use crate::operation::Operation;
#[derive(Debug, Error)]
pub enum OpHeadsStoreError {
#[error("Failed to read operation heads")]
Read(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("Failed to record operation head {new_op_id}")]
Write {
new_op_id: OperationId,
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("Failed to lock operation heads store")]
Lock(#[source] Box<dyn std::error::Error + Send + Sync>),
}
#[derive(Debug, Error)]
pub enum OpHeadResolutionError {
#[error("Operation log has no heads")]
NoHeads,
}
pub trait OpHeadsStoreLock {}
#[async_trait]
pub trait OpHeadsStore: Any + Send + Sync + Debug {
fn name(&self) -> &str;
async fn update_op_heads(
&self,
old_ids: &[OperationId],
new_id: &OperationId,
) -> Result<(), OpHeadsStoreError>;
async fn get_op_heads(&self) -> Result<Vec<OperationId>, OpHeadsStoreError>;
async fn lock(&self) -> Result<Box<dyn OpHeadsStoreLock + '_>, OpHeadsStoreError>;
}
impl dyn OpHeadsStore {
pub fn downcast_ref<T: OpHeadsStore>(&self) -> Option<&T> {
(self as &dyn Any).downcast_ref()
}
}
pub async fn resolve_op_heads<E>(
op_heads_store: &dyn OpHeadsStore,
op_store: &Arc<dyn OpStore>,
resolver: impl AsyncFnOnce(Vec<Operation>) -> Result<Operation, E>,
) -> Result<Operation, E>
where
E: From<OpHeadResolutionError> + From<OpHeadsStoreError> + From<OpStoreError>,
{
let mut op_heads = op_heads_store.get_op_heads().await?;
if op_heads.len() == 1 {
let operation_id = op_heads.pop().unwrap();
let operation = op_store.read_operation(&operation_id).await?;
return Ok(Operation::new(op_store.clone(), operation_id, operation));
}
let _lock = op_heads_store.lock().await?;
let op_head_ids = op_heads_store.get_op_heads().await?;
if op_head_ids.is_empty() {
return Err(OpHeadResolutionError::NoHeads.into());
}
if op_head_ids.len() == 1 {
let op_head_id = op_head_ids[0].clone();
let op_head = op_store.read_operation(&op_head_id).await?;
return Ok(Operation::new(op_store.clone(), op_head_id, op_head));
}
let op_heads: Vec<_> = try_join_all(op_head_ids.iter().map(
async |op_id: &OperationId| -> Result<Operation, OpStoreError> {
let data = op_store.read_operation(op_id).await?;
Ok(Operation::new(op_store.clone(), op_id.clone(), data))
},
))
.await?;
let op_head_ids_before: HashSet<_> = op_heads.iter().map(|op| op.id().clone()).collect();
let filtered_op_heads = dag_walk::heads_ok(
op_heads.into_iter().map(Ok),
|op: &Operation| op.id().clone(),
|op: &Operation| match op.parents().block_on() {
Ok(parents) => parents.into_iter().map(Ok).collect_vec(),
Err(err) => vec![Err(err)],
},
)?;
let op_head_ids_after: HashSet<_> =
filtered_op_heads.iter().map(|op| op.id().clone()).collect();
let ancestor_op_heads = op_head_ids_before
.difference(&op_head_ids_after)
.cloned()
.collect_vec();
let mut op_heads = filtered_op_heads.into_iter().collect_vec();
if let [op_head] = &*op_heads {
op_heads_store
.update_op_heads(&ancestor_op_heads, op_head.id())
.await?;
return Ok(op_head.clone());
}
op_heads.sort_by_key(|op| op.metadata().time.end.timestamp);
let new_op = resolver(op_heads).await?;
let mut old_op_heads = ancestor_op_heads;
old_op_heads.extend_from_slice(new_op.parent_ids());
op_heads_store
.update_op_heads(&old_op_heads, new_op.id())
.await?;
Ok(new_op)
}