1#![expect(missing_docs)]
16
17use std::any::Any;
18use std::collections::HashSet;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use futures::future::try_join_all;
24use itertools::Itertools as _;
25use thiserror::Error;
26
27use crate::dag_walk_async;
28use crate::op_store::OpStore;
29use crate::op_store::OpStoreError;
30use crate::op_store::OperationId;
31use crate::operation::Operation;
32
33#[derive(Debug, Error)]
34pub enum OpHeadsStoreError {
35 #[error("Failed to read operation heads")]
36 Read(#[source] Box<dyn std::error::Error + Send + Sync>),
37 #[error("Failed to record operation head {new_op_id}")]
38 Write {
39 new_op_id: OperationId,
40 source: Box<dyn std::error::Error + Send + Sync>,
41 },
42 #[error("Failed to lock operation heads store")]
43 Lock(#[source] Box<dyn std::error::Error + Send + Sync>),
44}
45
46#[derive(Debug, Error)]
47pub enum OpHeadResolutionError {
48 #[error("Operation log has no heads")]
49 NoHeads,
50}
51
52pub trait OpHeadsStoreLock {}
53
54#[async_trait]
56pub trait OpHeadsStore: Any + Send + Sync + Debug {
57 fn name(&self) -> &str;
58
59 async fn update_op_heads(
63 &self,
64 old_ids: &[OperationId],
65 new_id: &OperationId,
66 ) -> Result<(), OpHeadsStoreError>;
67
68 async fn get_op_heads(&self) -> Result<Vec<OperationId>, OpHeadsStoreError>;
69
70 async fn lock(&self) -> Result<Box<dyn OpHeadsStoreLock + '_>, OpHeadsStoreError>;
75}
76
77impl dyn OpHeadsStore {
78 pub fn downcast_ref<T: OpHeadsStore>(&self) -> Option<&T> {
80 (self as &dyn Any).downcast_ref()
81 }
82}
83
84pub async fn resolve_op_heads<E>(
89 op_heads_store: &dyn OpHeadsStore,
90 op_store: &Arc<dyn OpStore>,
91 resolver: impl AsyncFnOnce(Vec<Operation>) -> Result<Operation, E>,
92) -> Result<Operation, E>
93where
94 E: From<OpHeadResolutionError> + From<OpHeadsStoreError> + From<OpStoreError>,
95{
96 let mut op_heads = op_heads_store.get_op_heads().await?;
100
101 if op_heads.len() == 1 {
102 let operation_id = op_heads.pop().unwrap();
103 let operation = op_store.read_operation(&operation_id).await?;
104 return Ok(Operation::new(op_store.clone(), operation_id, operation));
105 }
106
107 let _lock = op_heads_store.lock().await?;
116 let op_head_ids = op_heads_store.get_op_heads().await?;
117
118 if op_head_ids.is_empty() {
119 return Err(OpHeadResolutionError::NoHeads.into());
120 }
121
122 if op_head_ids.len() == 1 {
123 let op_head_id = op_head_ids[0].clone();
124 let op_head = op_store.read_operation(&op_head_id).await?;
125 return Ok(Operation::new(op_store.clone(), op_head_id, op_head));
126 }
127
128 let op_heads: Vec<_> = try_join_all(op_head_ids.iter().map(
129 async |op_id: &OperationId| -> Result<Operation, OpStoreError> {
130 let data = op_store.read_operation(op_id).await?;
131 Ok(Operation::new(op_store.clone(), op_id.clone(), data))
132 },
133 ))
134 .await?;
135 let op_head_ids_before: HashSet<_> = op_heads.iter().map(|op| op.id().clone()).collect();
138 let filtered_op_heads = dag_walk_async::heads(
139 op_heads,
140 |op: &Operation| op.id().clone(),
141 async |op: &Operation| op.parents().await,
142 )
143 .await?;
144 let op_head_ids_after: HashSet<_> =
145 filtered_op_heads.iter().map(|op| op.id().clone()).collect();
146 let ancestor_op_heads = op_head_ids_before
147 .difference(&op_head_ids_after)
148 .cloned()
149 .collect_vec();
150 let mut op_heads = filtered_op_heads.into_iter().collect_vec();
151
152 if let [op_head] = &*op_heads {
154 op_heads_store
155 .update_op_heads(&ancestor_op_heads, op_head.id())
156 .await?;
157 return Ok(op_head.clone());
158 }
159
160 op_heads.sort_by_key(|op| op.metadata().time.end.timestamp);
161 let new_op = resolver(op_heads).await?;
162 let mut old_op_heads = ancestor_op_heads;
163 old_op_heads.extend_from_slice(new_op.parent_ids());
164 op_heads_store
165 .update_op_heads(&old_op_heads, new_op.id())
166 .await?;
167 Ok(new_op)
168}