use super::store::VersionStore;
use super::types::*;
use crate::error::{Error, Result};
use crate::versioned_messages::MessageSerial;
use ahash::AHashMap;
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct LeaseKey {
app_id: String,
channel: String,
}
impl LeaseKey {
fn new(app_id: &str, channel: &str) -> Self {
Self {
app_id: app_id.to_string(),
channel: channel.to_string(),
}
}
}
#[derive(Debug, Clone)]
struct LeaseCursor {
stream_id: String,
next_delivery_serial: u64,
end_exclusive: u64,
}
impl LeaseCursor {
fn from_block(block: VersionWriteReservationBlock) -> Self {
Self {
stream_id: block.stream_id,
next_delivery_serial: block.start_delivery_serial,
end_exclusive: block.start_delivery_serial.saturating_add(block.len),
}
}
fn take_next(&mut self) -> Option<VersionWriteReservation> {
if self.next_delivery_serial >= self.end_exclusive {
return None;
}
let reservation = VersionWriteReservation {
stream_id: self.stream_id.clone(),
delivery_serial: self.next_delivery_serial,
};
self.next_delivery_serial = self.next_delivery_serial.saturating_add(1);
Some(reservation)
}
}
#[derive(Default)]
struct LeaseState {
leases: AHashMap<LeaseKey, LeaseCursor>,
in_flight: AHashMap<LeaseKey, Vec<oneshot::Sender<()>>>,
}
pub struct LeasedVersionStore {
inner: Arc<dyn VersionStore + Send + Sync>,
block_size: u64,
state: Mutex<LeaseState>,
}
impl LeasedVersionStore {
#[must_use]
pub fn new(inner: Arc<dyn VersionStore + Send + Sync>, block_size: u64) -> Self {
Self {
inner,
block_size: block_size.max(1),
state: Mutex::new(LeaseState::default()),
}
}
fn take_cached(&self, key: &LeaseKey) -> Option<VersionWriteReservation> {
let mut state = self.state.lock().unwrap_or_else(|err| err.into_inner());
let cursor = state.leases.get_mut(key)?;
let reservation = cursor.take_next();
if cursor.next_delivery_serial >= cursor.end_exclusive {
state.leases.remove(key);
}
reservation
}
fn take_cached_after(
&self,
key: &LeaseKey,
after_delivery_serial: u64,
) -> Option<VersionWriteReservation> {
let mut state = self.state.lock().unwrap_or_else(|err| err.into_inner());
let cursor = state.leases.get_mut(key)?;
if cursor.next_delivery_serial <= after_delivery_serial {
let next_after = after_delivery_serial.saturating_add(1);
if next_after >= cursor.end_exclusive {
state.leases.remove(key);
return None;
}
cursor.next_delivery_serial = next_after;
}
let reservation = cursor.take_next();
if cursor.next_delivery_serial >= cursor.end_exclusive {
state.leases.remove(key);
}
reservation
}
fn start_or_join_reservation(&self, key: LeaseKey) -> Option<oneshot::Receiver<()>> {
let mut state = self.state.lock().unwrap_or_else(|err| err.into_inner());
if let Some(waiters) = state.in_flight.get_mut(&key) {
let (tx, rx) = oneshot::channel();
waiters.push(tx);
Some(rx)
} else {
state.in_flight.insert(key, Vec::new());
None
}
}
fn finish_reservation(&self, key: LeaseKey, block: VersionWriteReservationBlock) {
let mut state = self.state.lock().unwrap_or_else(|err| err.into_inner());
state
.leases
.insert(key.clone(), LeaseCursor::from_block(block));
if let Some(waiters) = state.in_flight.remove(&key) {
for waiter in waiters {
let _ = waiter.send(());
}
}
}
fn fail_reservation(&self, key: &LeaseKey) {
let mut state = self.state.lock().unwrap_or_else(|err| err.into_inner());
if let Some(waiters) = state.in_flight.remove(key) {
for waiter in waiters {
let _ = waiter.send(());
}
}
}
}
#[async_trait]
impl VersionStore for LeasedVersionStore {
async fn reserve_delivery_position(
&self,
app_id: &str,
channel: &str,
) -> Result<VersionWriteReservation> {
if self.block_size == 1 {
return self.inner.reserve_delivery_position(app_id, channel).await;
}
let key = LeaseKey::new(app_id, channel);
loop {
if let Some(reservation) = self.take_cached(&key) {
return Ok(reservation);
}
if let Some(waiter) = self.start_or_join_reservation(key.clone()) {
let _ = waiter.await;
continue;
}
match self
.inner
.reserve_delivery_positions(app_id, channel, self.block_size)
.await
{
Ok(block) => self.finish_reservation(key.clone(), block),
Err(err) => {
self.fail_reservation(&key);
return Err(err);
}
}
}
}
async fn reserve_delivery_positions(
&self,
app_id: &str,
channel: &str,
block_size: u64,
) -> Result<VersionWriteReservationBlock> {
VersionWriteReservationBlock::validate(block_size)?;
let first = self.reserve_delivery_position(app_id, channel).await?;
let mut expected_next = first.delivery_serial.saturating_add(1);
for _ in 1..block_size {
let next = self.reserve_delivery_position(app_id, channel).await?;
if next.stream_id != first.stream_id || next.delivery_serial != expected_next {
return Err(Error::Internal(
"leased version store returned a non-contiguous reservation block".to_string(),
));
}
expected_next = expected_next.saturating_add(1);
}
Ok(VersionWriteReservationBlock {
stream_id: first.stream_id,
start_delivery_serial: first.delivery_serial,
len: block_size,
})
}
async fn reserve_delivery_position_after(
&self,
app_id: &str,
channel: &str,
after_delivery_serial: u64,
) -> Result<VersionWriteReservation> {
let max_attempts = self.block_size.saturating_mul(2).max(64);
for _ in 0..max_attempts {
let key = LeaseKey::new(app_id, channel);
if let Some(reservation) = self.take_cached_after(&key, after_delivery_serial) {
return Ok(reservation);
}
if let Some(waiter) = self.start_or_join_reservation(key.clone()) {
let _ = waiter.await;
continue;
}
match self
.inner
.reserve_delivery_positions(app_id, channel, self.block_size)
.await
{
Ok(block) => self.finish_reservation(key.clone(), block),
Err(err) => {
self.fail_reservation(&key);
return Err(err);
}
}
}
Err(Error::Internal(format!(
"leased version store could not reserve delivery_serial greater than {after_delivery_serial}"
)))
}
async fn append_version(&self, record: StoredVersionRecord) -> Result<()> {
self.inner.append_version(record).await
}
async fn get_latest(
&self,
app_id: &str,
channel: &str,
message_serial: &MessageSerial,
) -> Result<Option<StoredVersionRecord>> {
self.inner.get_latest(app_id, channel, message_serial).await
}
async fn get_versions(&self, request: VersionStoreReadRequest) -> Result<VersionStorePage> {
self.inner.get_versions(request).await
}
async fn replay_after(
&self,
request: VersionReplayRequest,
) -> Result<Vec<StoredVersionRecord>> {
self.inner.replay_after(request).await
}
async fn latest_by_history(
&self,
app_id: &str,
channel: &str,
) -> Result<Vec<StoredVersionRecord>> {
self.inner.latest_by_history(app_id, channel).await
}
async fn stream_state(&self, app_id: &str, channel: &str) -> Result<VersionStreamState> {
self.inner.stream_state(app_id, channel).await
}
async fn purge_before(&self, before_ms: i64, batch_size: usize) -> Result<(u64, bool)> {
self.inner.purge_before(before_ms, batch_size).await
}
}