use std::sync::Arc;
use crate::error::Result;
use crate::io::Io;
use crate::segment::Frame;
use crate::shared_wal::SharedWal;
use crate::transaction::{Transaction, TxGuardOwned};
pub struct Injector<IO: Io> {
wal: Arc<SharedWal<IO>>,
buffer: Vec<Box<Frame>>,
capacity: usize,
tx: Option<TxGuardOwned<IO::File>>,
max_tx_frame_no: u64,
previous_durable_frame_no: u64,
}
impl<IO: Io> Injector<IO> {
pub fn new(wal: Arc<SharedWal<IO>>, buffer_capacity: usize) -> Result<Self> {
Ok(Self {
wal,
buffer: Vec::with_capacity(buffer_capacity),
capacity: buffer_capacity,
tx: None,
max_tx_frame_no: 0,
previous_durable_frame_no: 0,
})
}
pub fn set_durable(&mut self, durable_frame_no: u64) {
let mut old = self.wal.durable_frame_no.lock();
if *old <= durable_frame_no {
self.previous_durable_frame_no = *old;
*old = durable_frame_no;
} else {
todo!("primary reported older frame_no than current");
}
}
pub fn current_durable(&self) -> u64 {
*self.wal.durable_frame_no.lock()
}
pub fn maybe_begin_txn(&mut self) -> Result<()> {
if self.tx.is_none() {
let mut tx = Transaction::Read(self.wal.begin_read(u64::MAX));
self.wal.upgrade(&mut tx)?;
let tx = tx
.into_write()
.unwrap_or_else(|_| unreachable!())
.into_lock_owned();
assert!(self.tx.replace(tx).is_none());
}
Ok(())
}
pub async fn insert_frame(&mut self, frame: Box<Frame>) -> Result<Option<u64>> {
self.maybe_begin_txn()?;
let size_after = frame.size_after();
self.max_tx_frame_no = self.max_tx_frame_no.max(frame.header().frame_no());
self.buffer.push(frame);
if size_after.is_some() || self.capacity == self.buffer.len() {
self.flush(size_after).await?;
}
Ok(size_after.map(|_| self.max_tx_frame_no))
}
pub async fn flush(&mut self, size_after: Option<u32>) -> Result<()> {
if !self.buffer.is_empty() && self.tx.is_some() {
let last_committed_frame_no = self.max_tx_frame_no;
{
let tx = self.tx.as_mut().expect("we just checked that tx was there");
let buffer = std::mem::take(&mut self.buffer);
let current = self.wal.current.load();
let commit_data = size_after.map(|size| (size, self.max_tx_frame_no));
if commit_data.is_some() {
self.max_tx_frame_no = 0;
}
let buffer = current.inject_frames(buffer, commit_data, tx).await?;
self.buffer = buffer;
self.buffer.clear();
}
if size_after.is_some() {
let mut tx = self.tx.take().unwrap();
self.wal
.new_frame_notifier
.send_replace(last_committed_frame_no);
if self.current_durable() != self.previous_durable_frame_no
&& self.current_durable() >= self.max_tx_frame_no
{
let wal = self.wal.clone();
tokio::task::spawn_blocking(move || {
tx.commit();
wal.swap_current(&tx)
})
.await
.unwrap()?
}
}
}
Ok(())
}
pub fn rollback(&mut self) {
self.buffer.clear();
if let Some(tx) = self.tx.as_mut() {
tx.reset(0);
}
}
}
#[cfg(test)]
mod test {
use tokio_stream::StreamExt;
use crate::replication::replicator::Replicator;
use crate::test::TestEnv;
use super::*;
#[tokio::test]
async fn inject_basic() {
let primary_env = TestEnv::new();
let primary_conn = primary_env.open_conn("test");
let primary_shared = primary_env.shared("test");
let replicator = Replicator::new(primary_shared.clone(), 1, true);
let stream = replicator.into_frame_stream();
tokio::pin!(stream);
let replica_env = TestEnv::new();
let replica_conn = replica_env.open_conn("test");
let replica_shared = replica_env.shared("test");
let mut injector = Injector::new(replica_shared.clone(), 10).unwrap();
primary_conn.execute("create table test (x)", ()).unwrap();
primary_shared.last_committed_frame_no();
for _ in 0..2 {
let frame = stream.next().await.unwrap().unwrap();
injector.insert_frame(frame).await.unwrap();
}
replica_conn
.query_row("select count(*) from test", (), |r| {
assert_eq!(r.get_unwrap::<_, usize>(0), 0);
Ok(())
})
.unwrap();
primary_conn
.execute("insert into test values (123)", ())
.unwrap();
primary_conn
.execute("insert into test values (123)", ())
.unwrap();
primary_conn
.execute("insert into test values (123)", ())
.unwrap();
let frame = stream.next().await.unwrap().unwrap();
injector.insert_frame(frame).await.unwrap();
replica_conn
.query_row("select count(*) from test", (), |r| {
assert_eq!(r.get_unwrap::<_, usize>(0), 3);
Ok(())
})
.unwrap();
}
}