use super::store::VersionStore;
use super::types::*;
use crate::error::{Error, Result};
use crate::history::now_ms;
use crate::versioned_messages::{
MessageSerial, validate_replay_continuity, validate_version_chain,
};
use async_trait::async_trait;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone, Default)]
pub struct MemoryVersionStore {
channels: Arc<RwLock<BTreeMap<String, MemoryVersionChannel>>>,
}
#[derive(Clone)]
struct MemoryVersionChannel {
stream_id: String,
next_delivery_serial: u64,
messages: BTreeMap<String, Vec<StoredVersionRecord>>,
replay: BTreeMap<u64, StoredVersionRecord>,
created_at: BTreeMap<u64, i64>,
}
impl Default for MemoryVersionChannel {
fn default() -> Self {
Self {
stream_id: uuid::Uuid::new_v4().to_string(),
next_delivery_serial: 1,
messages: BTreeMap::new(),
replay: BTreeMap::new(),
created_at: BTreeMap::new(),
}
}
}
impl MemoryVersionStore {
pub fn new() -> Self {
Self::default()
}
fn channel_key(app_id: &str, channel: &str) -> String {
format!("{app_id}\0{channel}")
}
}
#[async_trait]
impl VersionStore for MemoryVersionStore {
async fn reserve_delivery_position(
&self,
app_id: &str,
channel: &str,
) -> Result<VersionWriteReservation> {
let key = Self::channel_key(app_id, channel);
let mut channels = self.channels.write().await;
let channel_state = channels.entry(key).or_default();
let reservation = VersionWriteReservation {
stream_id: channel_state.stream_id.clone(),
delivery_serial: channel_state.next_delivery_serial,
};
channel_state.next_delivery_serial = channel_state.next_delivery_serial.saturating_add(1);
Ok(reservation)
}
async fn reserve_delivery_positions(
&self,
app_id: &str,
channel: &str,
block_size: u64,
) -> Result<VersionWriteReservationBlock> {
VersionWriteReservationBlock::validate(block_size)?;
let key = Self::channel_key(app_id, channel);
let mut channels = self.channels.write().await;
let channel_state = channels.entry(key).or_default();
let block = VersionWriteReservationBlock {
stream_id: channel_state.stream_id.clone(),
start_delivery_serial: channel_state.next_delivery_serial,
len: block_size,
};
channel_state.next_delivery_serial = channel_state
.next_delivery_serial
.saturating_add(block_size);
Ok(block)
}
async fn append_version(&self, record: StoredVersionRecord) -> Result<()> {
let key = Self::channel_key(&record.app_id, &record.channel);
let mut channels = self.channels.write().await;
let channel_state = channels.entry(key).or_default();
if let Some(existing) = channel_state.replay.get(&record.delivery_serial()) {
return Err(Error::InvalidMessageFormat(format!(
"duplicate delivery_serial {} in version replay log for {}:{} (existing message_serial {}, incoming {})",
record.delivery_serial(),
record.app_id,
record.channel,
existing.message_serial().as_str(),
record.message_serial().as_str()
)));
}
let tentative_chain = channel_state
.messages
.get(record.message_serial().as_str())
.cloned()
.unwrap_or_default();
let mut validated_chain = tentative_chain;
validated_chain.push(record.clone());
validate_version_chain(
&validated_chain
.iter()
.map(|entry| entry.message.clone())
.collect::<Vec<_>>(),
)?;
channel_state.messages.insert(
record.message_serial().as_str().to_string(),
validated_chain,
);
channel_state
.created_at
.insert(record.delivery_serial(), now_ms());
channel_state
.replay
.insert(record.delivery_serial(), record.clone());
channel_state.next_delivery_serial = channel_state
.next_delivery_serial
.max(record.delivery_serial().saturating_add(1));
Ok(())
}
async fn get_latest(
&self,
app_id: &str,
channel: &str,
message_serial: &MessageSerial,
) -> Result<Option<StoredVersionRecord>> {
let key = Self::channel_key(app_id, channel);
let channels = self.channels.read().await;
let Some(channel_state) = channels.get(&key) else {
return Ok(None);
};
let Some(chain) = channel_state.messages.get(message_serial.as_str()) else {
return Ok(None);
};
let latest = chain
.iter()
.max_by(|left, right| left.version_serial().cmp(right.version_serial()))
.cloned()
.ok_or_else(|| Error::InvalidMessageFormat("version chain must not be empty".into()))?;
Ok(Some(latest))
}
async fn get_versions(&self, request: VersionStoreReadRequest) -> Result<VersionStorePage> {
request.validate()?;
let key = Self::channel_key(&request.app_id, &request.channel);
let channels = self.channels.read().await;
let Some(channel_state) = channels.get(&key) else {
return Ok(VersionStorePage {
items: Vec::new(),
next_cursor: None,
has_more: false,
});
};
let Some(chain) = channel_state.messages.get(request.message_serial.as_str()) else {
return Ok(VersionStorePage {
items: Vec::new(),
next_cursor: None,
has_more: false,
});
};
let mut items = chain.clone();
items.sort_by(|left, right| left.version_serial().cmp(right.version_serial()));
if matches!(request.direction, VersionStoreDirection::NewestFirst) {
items.reverse();
}
let filtered: Vec<StoredVersionRecord> = items
.into_iter()
.filter(|item| {
request
.cursor
.as_ref()
.is_none_or(|cursor| match request.direction {
VersionStoreDirection::NewestFirst => {
item.version_serial() < &cursor.version_serial
}
VersionStoreDirection::OldestFirst => {
item.version_serial() > &cursor.version_serial
}
})
})
.take(request.limit + 1)
.collect();
let has_more = filtered.len() > request.limit;
let items: Vec<StoredVersionRecord> = filtered.into_iter().take(request.limit).collect();
let next_cursor = if has_more {
items.last().map(|item| VersionStoreCursor {
version: 1,
version_serial: item.version_serial().clone(),
direction: request.direction,
})
} else {
None
};
Ok(VersionStorePage {
items,
next_cursor,
has_more,
})
}
async fn replay_after(
&self,
request: VersionReplayRequest,
) -> Result<Vec<StoredVersionRecord>> {
request.validate()?;
let key = Self::channel_key(&request.app_id, &request.channel);
let channels = self.channels.read().await;
let Some(channel_state) = channels.get(&key) else {
return Ok(Vec::new());
};
let items: Vec<StoredVersionRecord> = channel_state
.replay
.range((request.after_delivery_serial.saturating_add(1))..)
.map(|(_, value)| value.clone())
.take(request.limit)
.collect();
validate_replay_continuity(
&items
.iter()
.map(|entry| entry.message.clone())
.collect::<Vec<_>>(),
request.after_delivery_serial,
)?;
Ok(items)
}
async fn latest_by_history(
&self,
app_id: &str,
channel: &str,
) -> Result<Vec<StoredVersionRecord>> {
let key = Self::channel_key(app_id, channel);
let channels = self.channels.read().await;
let Some(channel_state) = channels.get(&key) else {
return Ok(Vec::new());
};
let mut latest = channel_state
.messages
.values()
.filter_map(|chain| {
chain
.iter()
.max_by(|left, right| left.version_serial().cmp(right.version_serial()))
.cloned()
})
.collect::<Vec<_>>();
latest.sort_by_key(StoredVersionRecord::history_serial);
Ok(latest)
}
async fn stream_state(&self, app_id: &str, channel: &str) -> Result<VersionStreamState> {
let key = Self::channel_key(app_id, channel);
let channels = self.channels.read().await;
let Some(channel_state) = channels.get(&key) else {
return Ok(VersionStreamState::default());
};
Ok(VersionStreamState {
stream_id: Some(channel_state.stream_id.clone()),
next_delivery_serial: Some(channel_state.next_delivery_serial),
oldest_available_delivery_serial: channel_state
.replay
.first_key_value()
.map(|(k, _)| *k),
newest_available_delivery_serial: channel_state
.replay
.last_key_value()
.map(|(k, _)| *k),
})
}
async fn purge_before(&self, before_ms: i64, batch_size: usize) -> Result<(u64, bool)> {
if batch_size == 0 {
return Ok((0, false));
}
let mut channels = self.channels.write().await;
let mut deleted: u64 = 0;
let mut has_more = false;
for state in channels.values_mut() {
let remaining = batch_size.saturating_sub(deleted as usize);
if remaining == 0 {
has_more = true;
break;
}
let mut to_remove: Vec<u64> = Vec::new();
for (&delivery_serial, &created_ms) in state.created_at.iter() {
if created_ms >= before_ms {
break;
}
if to_remove.len() >= remaining {
has_more = true;
break;
}
to_remove.push(delivery_serial);
}
for delivery_serial in to_remove {
state.created_at.remove(&delivery_serial);
let Some(record) = state.replay.remove(&delivery_serial) else {
continue;
};
let message_key = record.message_serial().as_str().to_string();
if let Some(chain) = state.messages.get_mut(&message_key) {
chain.retain(|entry| entry.version_serial() != record.version_serial());
if chain.is_empty() {
state.messages.remove(&message_key);
}
}
deleted += 1;
}
if !has_more
&& state
.created_at
.iter()
.next()
.is_some_and(|(_, &ts)| ts < before_ms)
{
has_more = true;
}
}
Ok((deleted, has_more))
}
}