use std::collections::HashMap;
use std::sync::RwLock;
use std::time::SystemTime;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::checkpoint::NodeId;
use crate::checkpoint::{CheckpointId, CheckpointStoreError, TraceId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MutationLogEntry {
pub trace_id: TraceId,
pub step: usize,
pub node_id: NodeId,
pub checkpoint_id: Option<CheckpointId>,
pub mutation_index: usize,
pub mutation: serde_json::Value,
pub timestamp: SystemTime,
}
impl MutationLogEntry {
pub fn new(
trace_id: TraceId,
step: usize,
node_id: NodeId,
checkpoint_id: Option<CheckpointId>,
mutation_index: usize,
mutation: serde_json::Value,
) -> Self {
Self {
trace_id,
step,
node_id,
checkpoint_id,
mutation_index,
mutation,
timestamp: SystemTime::now(),
}
}
}
#[async_trait]
pub trait MutationLogStore: Send + Sync {
async fn append(&self, entry: MutationLogEntry) -> Result<(), CheckpointStoreError>;
async fn append_batch(
&self,
entries: Vec<MutationLogEntry>,
) -> Result<(), CheckpointStoreError> {
for entry in entries {
self.append(entry).await?;
}
Ok(())
}
async fn replay(
&self,
trace_id: &TraceId,
from_step: usize,
) -> Result<Vec<MutationLogEntry>, CheckpointStoreError>;
async fn truncate(
&self,
trace_id: &TraceId,
keep_from_step: usize,
) -> Result<usize, CheckpointStoreError>;
}
#[derive(Default)]
pub struct InMemoryMutationLog {
entries: RwLock<Vec<MutationLogEntry>>,
index: RwLock<HashMap<TraceId, Vec<usize>>>,
}
impl InMemoryMutationLog {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.entries.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[async_trait]
impl MutationLogStore for InMemoryMutationLog {
async fn append(&self, entry: MutationLogEntry) -> Result<(), CheckpointStoreError> {
let trace_id = entry.trace_id;
let idx = {
let mut entries = self
.entries
.write()
.map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
let idx = entries.len();
entries.push(entry);
idx
};
{
let mut index_map = self
.index
.write()
.map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
index_map.entry(trace_id).or_default().push(idx);
}
Ok(())
}
async fn replay(
&self,
trace_id: &TraceId,
from_step: usize,
) -> Result<Vec<MutationLogEntry>, CheckpointStoreError> {
let entry_indices = {
let index_map = self
.index
.read()
.map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
index_map.get(trace_id).cloned().unwrap_or_default()
};
let entries = self
.entries
.read()
.map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
let mut result = Vec::new();
for &idx in &entry_indices {
if idx < entries.len() {
let entry = &entries[idx];
if entry.step >= from_step {
result.push(entry.clone());
}
}
}
Ok(result)
}
async fn truncate(
&self,
trace_id: &TraceId,
keep_from_step: usize,
) -> Result<usize, CheckpointStoreError> {
let entry_indices: Vec<usize> = {
let index_map = self
.index
.read()
.map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
index_map.get(trace_id).cloned().unwrap_or_default()
};
let entries = self
.entries
.read()
.map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
let mut removed = 0;
for &idx in &entry_indices {
if idx < entries.len() && entries[idx].step < keep_from_step {
removed += 1;
}
}
Ok(removed)
}
}
pub fn mutations_to_log_entries<E: Serialize>(
trace_id: TraceId,
step: usize,
node_id: NodeId,
checkpoint_id: Option<CheckpointId>,
mutations: impl IntoIterator<Item = E>,
) -> Vec<MutationLogEntry> {
let mut result = Vec::new();
for (idx, mutation) in mutations.into_iter().enumerate() {
if let Ok(value) = serde_json::to_value(&mutation) {
result.push(MutationLogEntry::new(
trace_id,
step,
node_id.clone(),
checkpoint_id.clone(),
idx,
value,
));
}
}
result
}