use std::collections::HashMap;
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
use std::time::Duration;
use talea_core::store::{Committed, Store, StoreError};
use talea_core::types::Transaction;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub struct WriteConfig {
pub queue_depth: usize,
pub batch_max: usize,
pub idle_reap: Duration,
}
impl Default for WriteConfig {
fn default() -> Self {
Self {
queue_depth: 256,
batch_max: 64,
idle_reap: Duration::from_secs(60),
}
}
}
#[derive(Debug)]
pub enum SubmitError {
Overloaded,
CommitterGone,
Store(StoreError),
}
struct Job {
transaction: Transaction,
reply: oneshot::Sender<Result<Committed, StoreError>>,
}
type BookMap = Arc<Mutex<HashMap<String, mpsc::Sender<Job>>>>;
fn lock_books(
books: &Mutex<HashMap<String, mpsc::Sender<Job>>>,
) -> MutexGuard<'_, HashMap<String, mpsc::Sender<Job>>> {
books.lock().unwrap_or_else(PoisonError::into_inner)
}
pub struct WriteRouter {
store: Arc<dyn Store>,
cfg: WriteConfig,
books: BookMap,
}
impl WriteRouter {
pub fn new(store: Arc<dyn Store>, cfg: WriteConfig) -> Self {
Self {
store,
cfg,
books: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn submit(&self, transaction: Transaction) -> Result<Committed, SubmitError> {
let book_key = transaction.book.0.clone();
let (reply_tx, reply_rx) = oneshot::channel();
let mut job = Job {
transaction,
reply: reply_tx,
};
loop {
let sender = self.sender_for(&book_key);
match sender.try_send(job) {
Ok(()) => break,
Err(mpsc::error::TrySendError::Full(_)) => return Err(SubmitError::Overloaded),
Err(mpsc::error::TrySendError::Closed(returned)) => job = returned,
}
}
match reply_rx.await {
Ok(result) => result.map_err(SubmitError::Store),
Err(_) => Err(SubmitError::CommitterGone),
}
}
pub fn active_books(&self) -> usize {
lock_books(&self.books).len()
}
pub fn queued_jobs(&self) -> usize {
lock_books(&self.books)
.values()
.map(|s| s.max_capacity() - s.capacity())
.sum()
}
fn sender_for(&self, book_key: &str) -> mpsc::Sender<Job> {
let mut books = lock_books(&self.books);
if let Some(sender) = books.get(book_key)
&& !sender.is_closed()
{
return sender.clone();
}
let (tx, rx) = mpsc::channel(self.cfg.queue_depth);
books.insert(book_key.to_string(), tx.clone());
tokio::spawn(run_committer(
Arc::clone(&self.store),
Arc::clone(&self.books),
book_key.to_string(),
rx,
self.cfg.batch_max,
self.cfg.idle_reap,
));
tx
}
}
async fn run_committer(
store: Arc<dyn Store>,
books: BookMap,
book_key: String,
mut rx: mpsc::Receiver<Job>,
batch_max: usize,
idle_reap: Duration,
) {
loop {
let first = match tokio::time::timeout(idle_reap, rx.recv()).await {
Ok(Some(job)) => job,
Ok(None) => return,
Err(_idle) => {
lock_books(&books).remove(&book_key);
rx.close();
while let Some(job) = rx.recv().await {
let batch = drain_batch(job, &mut rx, batch_max);
commit_batch_and_reply(&*store, batch).await;
}
return;
}
};
let batch = drain_batch(first, &mut rx, batch_max);
commit_batch_and_reply(&*store, batch).await;
}
}
fn drain_batch(first: Job, rx: &mut mpsc::Receiver<Job>, batch_max: usize) -> Vec<Job> {
let mut batch = vec![first];
while batch.len() < batch_max {
match rx.try_recv() {
Ok(job) => batch.push(job),
Err(_) => break,
}
}
batch
}
async fn commit_batch_and_reply(store: &dyn Store, batch: Vec<Job>) {
metrics::histogram!("talea_write_batch_size").record(batch.len() as f64);
let txs: Vec<Transaction> = batch.iter().map(|j| j.transaction.clone()).collect();
let results = store.commit_batch(&txs).await;
for (job, result) in batch.into_iter().zip(results) {
let _ = job.reply.send(result);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicI64, Ordering};
use talea_core::store::{
AccountCfg, BalanceSnapshot, EventStream, PostingRecord, Sequenced, StoredTransaction,
TrialBalanceRow,
};
use talea_core::types::{
AccountDef, AccountId, AssetDef, AssetId, Book, Direction, IdempotencyKey, Posting, TxId,
};
use tokio::sync::{Notify, Semaphore};
use chrono::Utc;
use uuid::Uuid;
struct GatedStore {
gate: Arc<Semaphore>,
entered: Arc<Notify>,
batches: Arc<Mutex<Vec<usize>>>,
seq: Arc<AtomicI64>,
}
#[async_trait::async_trait]
impl Store for GatedStore {
async fn register_asset(&self, _asset: &AssetDef) -> Result<(), StoreError> {
unimplemented!()
}
async fn open_account(
&self,
_def: &AccountDef,
_cfg: &AccountCfg,
) -> Result<(), StoreError> {
unimplemented!()
}
async fn commit(&self, _transaction: &Transaction) -> Result<Committed, StoreError> {
unimplemented!("the router batches everything through commit_batch")
}
async fn commit_batch(&self, txs: &[Transaction]) -> Vec<Result<Committed, StoreError>> {
self.entered.notify_one();
self.gate.acquire().await.unwrap().forget();
self.batches.lock().unwrap().push(txs.len());
let mut out = Vec::with_capacity(txs.len());
for tx in txs {
if tx.idempotency_key.0 == "fail" {
out.push(Err(StoreError::UnknownAccount(AccountId {
book: tx.book.clone(),
path: "fail-account".to_string(),
})));
} else {
let seq = self.seq.fetch_add(1, Ordering::Relaxed);
out.push(Ok(Committed {
txid: tx.id.clone(),
seq,
at: Utc::now(),
}));
}
}
out
}
async fn balance(
&self,
_account: &AccountId,
_as_of: Option<chrono::DateTime<Utc>>,
) -> Result<BalanceSnapshot, StoreError> {
unimplemented!()
}
async fn asset(&self, _id: &AssetId) -> Result<Option<AssetDef>, StoreError> {
unimplemented!()
}
async fn account_history(
&self,
_account: &AccountId,
_after_seq: Option<talea_core::types::Seq>,
_limit: usize,
) -> Result<Vec<PostingRecord>, StoreError> {
unimplemented!()
}
async fn transaction(&self, _txid: &TxId) -> Result<Option<StoredTransaction>, StoreError> {
unimplemented!()
}
async fn trial_balance(
&self,
_book: &Book,
_as_of: Option<chrono::DateTime<Utc>>,
) -> Result<Vec<TrialBalanceRow>, StoreError> {
unimplemented!()
}
async fn read_events(
&self,
_book: &Book,
_from: talea_core::types::Seq,
_limit: usize,
) -> Result<Vec<Sequenced<talea_core::events::LedgerEvent>>, StoreError> {
unimplemented!()
}
fn subscribe(&self, _book: &Book, _from: talea_core::types::Seq) -> EventStream {
unimplemented!()
}
}
type StoreBundle = (
Arc<GatedStore>,
Arc<Semaphore>,
Arc<Notify>,
Arc<Mutex<Vec<usize>>>,
);
fn make_store() -> StoreBundle {
let gate = Arc::new(Semaphore::new(0));
let entered = Arc::new(Notify::new());
let batches = Arc::new(Mutex::new(Vec::new()));
let store = Arc::new(GatedStore {
gate: Arc::clone(&gate),
entered: Arc::clone(&entered),
batches: Arc::clone(&batches),
seq: Arc::new(AtomicI64::new(1)),
});
(store, gate, entered, batches)
}
fn tx(book: &str, idem: &str) -> Transaction {
Transaction {
id: TxId(Uuid::now_v7()),
book: Book(book.to_string()),
postings: vec![Posting {
account: AccountId {
book: Book(book.to_string()),
path: "acct".to_string(),
},
amount: talea_core::types::Amount::new(100, AssetId::new("USD")),
direction: Direction::Debit,
}],
idempotency_key: IdempotencyKey(idem.to_string()),
external_refs: vec![],
metadata: serde_json::json!({}),
occurred_at: Utc::now(),
}
}
#[tokio::test(start_paused = true)]
async fn queued_jobs_group_commit() {
let (store, gate, entered, batches) = make_store();
let router = Arc::new(WriteRouter::new(
store,
WriteConfig {
queue_depth: 256,
batch_max: 64,
idle_reap: Duration::from_secs(60),
},
));
let r = Arc::clone(&router);
let ha = tokio::spawn(async move { r.submit(tx("book1", "a")).await });
entered.notified().await;
let r = Arc::clone(&router);
let hb = tokio::spawn(async move { r.submit(tx("book1", "b")).await });
let r = Arc::clone(&router);
let hc = tokio::spawn(async move { r.submit(tx("book1", "c")).await });
let r = Arc::clone(&router);
let hd = tokio::spawn(async move { r.submit(tx("book1", "d")).await });
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(router.queued_jobs(), 3);
gate.add_permits(1);
entered.notified().await;
gate.add_permits(1);
assert!(ha.await.unwrap().is_ok());
assert!(hb.await.unwrap().is_ok());
assert!(hc.await.unwrap().is_ok());
assert!(hd.await.unwrap().is_ok());
let recorded = batches.lock().unwrap().clone();
assert_eq!(recorded, vec![1, 3]);
}
#[tokio::test(start_paused = true)]
async fn full_queue_returns_overloaded() {
let (store, gate, entered, _batches) = make_store();
let router = Arc::new(WriteRouter::new(
store,
WriteConfig {
queue_depth: 1,
batch_max: 64,
idle_reap: Duration::from_secs(60),
},
));
let r = Arc::clone(&router);
let ha = tokio::spawn(async move { r.submit(tx("book1", "a")).await });
entered.notified().await;
let r = Arc::clone(&router);
let hb = tokio::spawn(async move { r.submit(tx("book1", "b")).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let result_c = router.submit(tx("book1", "c")).await;
assert!(
matches!(result_c, Err(SubmitError::Overloaded)),
"expected Overloaded"
);
gate.add_permits(2);
assert!(ha.await.unwrap().is_ok());
assert!(hb.await.unwrap().is_ok());
}
#[tokio::test(start_paused = true)]
async fn results_are_positional() {
let (store, gate, entered, _batches) = make_store();
let router = Arc::new(WriteRouter::new(
store,
WriteConfig {
queue_depth: 256,
batch_max: 64,
idle_reap: Duration::from_secs(60),
},
));
let r = Arc::clone(&router);
let ha = tokio::spawn(async move { r.submit(tx("book1", "a")).await });
entered.notified().await;
let r = Arc::clone(&router);
let hfail = tokio::spawn(async move { r.submit(tx("book1", "fail")).await });
let r = Arc::clone(&router);
let hgood = tokio::spawn(async move { r.submit(tx("book1", "good")).await });
tokio::time::sleep(Duration::from_millis(10)).await;
gate.add_permits(1);
entered.notified().await;
gate.add_permits(1);
assert!(ha.await.unwrap().is_ok());
let fail_result = hfail.await.unwrap();
assert!(
matches!(
fail_result,
Err(SubmitError::Store(StoreError::UnknownAccount(_)))
),
"expected UnknownAccount for 'fail'"
);
assert!(hgood.await.unwrap().is_ok());
}
#[tokio::test(start_paused = true)]
async fn idle_committer_reaps_and_respawns() {
let (store, gate, _entered, _batches) = make_store();
gate.add_permits(64);
let router = Arc::new(WriteRouter::new(
store,
WriteConfig {
queue_depth: 256,
batch_max: 64,
idle_reap: Duration::from_secs(1),
},
));
assert!(router.submit(tx("bookA", "x")).await.is_ok());
assert_eq!(router.active_books(), 1);
tokio::time::sleep(Duration::from_secs(2)).await;
assert_eq!(router.active_books(), 0);
assert!(router.submit(tx("bookA", "y")).await.is_ok());
assert!(router.submit(tx("bookB", "z")).await.is_ok());
assert!(router.active_books() <= 2);
}
}