use sqlx::PgPool;
use tracing::instrument;
use std::collections::{HashMap, HashSet};
use cala_types::{
balance::BalanceSnapshot,
outbox::OutboxEventPayload,
primitives::{
AccountId, AccountSetId, BalanceId, Currency, DebitOrCredit, EntryId, JournalId, Status,
},
};
use super::{account_balance::AccountBalance, error::BalanceError};
use crate::outbox::OutboxPublisher;
const EC_SET_LOCK_CLASS: i32 = 1;
#[derive(Debug, Clone)]
pub(super) struct BalanceRepo {
pool: PgPool,
publisher: OutboxPublisher,
}
impl BalanceRepo {
pub fn new(pool: &PgPool, publisher: &OutboxPublisher) -> Self {
Self {
pool: pool.clone(),
publisher: publisher.clone(),
}
}
pub async fn find(
&self,
journal_id: JournalId,
account_id: AccountId,
currency: Currency,
) -> Result<AccountBalance, BalanceError> {
self.find_in_op(&self.pool, journal_id, account_id, currency)
.await
}
#[instrument(name = "balance.find_in_op", skip_all)]
pub async fn find_in_op(
&self,
op: impl es_entity::IntoOneTimeExecutor<'_>,
journal_id: JournalId,
account_id: AccountId,
currency: Currency,
) -> Result<AccountBalance, BalanceError> {
let row = op
.into_executor()
.fetch_optional(sqlx::query!(
r#"
SELECT h.values, a.normal_balance_type AS "normal_balance_type!: DebitOrCredit"
FROM cala_balance_history h
JOIN cala_current_balances c
ON h.journal_id = c.journal_id
AND h.account_id = c.account_id
AND h.currency = c.currency
AND h.version = c.latest_version
JOIN cala_accounts a
ON c.account_id = a.id
WHERE c.journal_id = $1
AND c.account_id = $2
AND c.currency = $3
"#,
journal_id as JournalId,
account_id as AccountId,
currency.code(),
))
.await?;
if let Some(row) = row {
let details: BalanceSnapshot =
serde_json::from_value(row.values).expect("Failed to deserialize balance snapshot");
Ok(AccountBalance::new(row.normal_balance_type, details))
} else {
Err(BalanceError::NotFound(journal_id, account_id, currency))
}
}
#[instrument(name = "balance.find_all", skip_all, err(level = "warn"))]
pub(super) async fn find_all(
&self,
ids: &[BalanceId],
) -> Result<HashMap<BalanceId, AccountBalance>, BalanceError> {
self.find_all_in_op(&self.pool, ids).await
}
#[instrument(name = "balance.find_all_in_op", skip_all, err(level = "warn"))]
pub(super) async fn find_all_in_op(
&self,
op: impl es_entity::IntoOneTimeExecutor<'_>,
ids: &[BalanceId],
) -> Result<HashMap<BalanceId, AccountBalance>, BalanceError> {
let mut journal_ids = Vec::with_capacity(ids.len());
let mut account_ids = Vec::with_capacity(ids.len());
let mut currencies = Vec::with_capacity(ids.len());
for (journal_id, account_id, currency) in ids {
journal_ids.push(uuid::Uuid::from(journal_id));
account_ids.push(uuid::Uuid::from(account_id));
currencies.push(currency.code().to_string());
}
let rows = op
.into_executor()
.fetch_all(sqlx::query!(
r#"
WITH balance_ids AS (
SELECT * FROM UNNEST($1::uuid[], $2::uuid[], $3::text[])
AS v(journal_id, account_id, currency)
)
SELECT
h.values,
a.normal_balance_type as "normal_balance_type!: DebitOrCredit"
FROM cala_balance_history h
JOIN cala_current_balances c
ON h.journal_id = c.journal_id
AND h.account_id = c.account_id
AND h.currency = c.currency
AND h.version = c.latest_version
JOIN cala_accounts a
ON c.account_id = a.id
JOIN balance_ids b
ON c.journal_id = b.journal_id
AND c.account_id = b.account_id
AND c.currency = b.currency"#,
&journal_ids[..],
&account_ids[..],
¤cies[..]
))
.await?;
let mut ret = HashMap::new();
for row in rows {
let details: BalanceSnapshot =
serde_json::from_value(row.values).expect("Failed to deserialize balance snapshot");
ret.insert(
(details.journal_id, details.account_id, details.currency),
AccountBalance::new(row.normal_balance_type, details),
);
}
Ok(ret)
}
#[instrument(name = "cala_ledger.balances.find_for_update", skip(self, op))]
pub(super) async fn find_for_update(
&self,
op: &mut impl es_entity::AtomicOperation,
journal_id: JournalId,
(account_ids, currencies): &(Vec<AccountId>, Vec<&str>),
) -> Result<HashMap<(AccountId, Currency), Option<BalanceSnapshot>>, BalanceError> {
sqlx::query!(
r#"
SELECT
pg_advisory_xact_lock_shared(
$1::int4, hashtext(v.account_id::text)
),
CASE WHEN NOT a.eventually_consistent THEN
pg_advisory_xact_lock(
hashtext(concat($2::text, v.account_id::text, v.currency))
)
END
FROM UNNEST($3::uuid[], $4::text[]) AS v(account_id, currency)
JOIN cala_accounts a ON a.id = v.account_id
ORDER BY v.account_id, v.currency
"#,
EC_SET_LOCK_CLASS,
journal_id as JournalId,
account_ids as &[AccountId],
currencies as &[&str],
)
.execute(op.as_executor())
.await?;
let rows = sqlx::query!(
r#"
SELECT
v.account_id AS "account_id!: AccountId",
v.currency AS "currency!",
b.latest_values,
a.status AS "status!: Status"
FROM UNNEST($2::uuid[], $3::text[]) AS v(account_id, currency)
JOIN cala_accounts a ON a.id = v.account_id AND a.eventually_consistent = FALSE
LEFT JOIN cala_current_balances b
ON b.journal_id = $1
AND b.account_id = v.account_id
AND b.currency = v.currency
"#,
journal_id as JournalId,
account_ids as &[AccountId],
currencies as &[&str]
)
.fetch_all(op.as_executor())
.await?;
let mut ret = HashMap::new();
for row in rows {
if row.status == Status::Locked {
return Err(BalanceError::AccountLocked(row.account_id));
}
let snapshot = row.latest_values.map(|v| {
serde_json::from_value::<BalanceSnapshot>(v)
.expect("Failed to deserialize balance snapshot")
});
ret.insert(
(
row.account_id,
row.currency.parse().expect("Could not parse currency"),
),
snapshot,
);
}
Ok(ret)
}
#[instrument(
name = "cala_ledger.balances.lock_accounts_exclusive_in_op",
skip_all,
err(level = "warn")
)]
pub(super) async fn lock_accounts_exclusive_in_op(
&self,
op: &mut impl es_entity::AtomicOperation,
account_ids: &HashSet<AccountId>,
) -> Result<(), BalanceError> {
if account_ids.is_empty() {
return Ok(());
}
let mut account_ids: Vec<AccountId> = account_ids.iter().copied().collect();
account_ids.sort();
sqlx::query!(
r#"
SELECT pg_advisory_xact_lock($1::int4, hashtext(account_id::text))
FROM UNNEST($2::uuid[]) AS v(account_id)
"#,
EC_SET_LOCK_CLASS,
&account_ids as &[AccountId],
)
.execute(op.as_executor())
.await?;
Ok(())
}
#[instrument(
name = "cala_ledger.balances.member_has_balance_history_in_op",
skip_all,
err(level = "warn")
)]
pub(super) async fn member_has_balance_history_in_op(
&self,
op: &mut impl es_entity::AtomicOperation,
journal_id: JournalId,
parent_account_id: AccountId,
member_id: AccountId,
) -> Result<bool, BalanceError> {
sqlx::query!(
r#"
SELECT
CASE WHEN v.account_id = $2 THEN
pg_advisory_xact_lock($1::int4, hashtext(v.account_id::text))
ELSE
pg_advisory_xact_lock_shared($1::int4, hashtext(v.account_id::text))
END
FROM UNNEST($3::uuid[]) AS v(account_id)
ORDER BY v.account_id
"#,
EC_SET_LOCK_CLASS,
member_id as AccountId,
&[parent_account_id, member_id] as &[AccountId],
)
.execute(op.as_executor())
.await?;
let row = sqlx::query!(
r#"
SELECT EXISTS (
SELECT 1
FROM cala_balance_history
WHERE journal_id = $1 AND account_id = $2
) AS "exists!"
"#,
journal_id as JournalId,
member_id as AccountId,
)
.fetch_one(op.as_executor())
.await?;
Ok(row.exists)
}
#[instrument(
name = "cala_ledger.balances.insert_new_snapshots",
skip(self, op, new_balances)
fields(n_new_balances)
)]
pub(crate) async fn insert_new_snapshots(
&self,
op: &mut impl es_entity::AtomicOperation,
journal_id: JournalId,
new_balances: Vec<BalanceSnapshot>,
) -> Result<(), BalanceError> {
tracing::Span::current().record(
"n_new_balances",
tracing::field::display(new_balances.len()),
);
let mut journal_ids = Vec::with_capacity(new_balances.len());
let mut account_ids = Vec::with_capacity(new_balances.len());
let mut entry_ids = Vec::with_capacity(new_balances.len());
let mut currencies = Vec::with_capacity(new_balances.len());
let mut versions = Vec::with_capacity(new_balances.len());
let mut values = Vec::with_capacity(new_balances.len());
for balance in new_balances.iter() {
journal_ids.push(balance.journal_id);
account_ids.push(balance.account_id);
entry_ids.push(balance.entry_id);
currencies.push(balance.currency.code());
versions.push(balance.version as i32);
values
.push(serde_json::to_value(balance).expect("Failed to serialize balance snapshot"));
}
sqlx::query!(
r#"
WITH new_snapshots AS (
INSERT INTO cala_balance_history (
journal_id, account_id, currency, version, latest_entry_id, values
)
SELECT * FROM UNNEST (
$1::uuid[],
$2::uuid[],
$3::text[],
$4::int4[],
$5::uuid[],
$6::jsonb[]
)
RETURNING *
)
INSERT INTO cala_current_balances AS c (
journal_id, account_id, currency, latest_version, latest_values, latest_seq
)
SELECT
journal_id,
account_id,
currency,
MAX(version) as latest_version,
(array_agg(values ORDER BY version DESC))[1] as latest_values,
MAX(seq) as latest_seq
FROM new_snapshots
GROUP BY journal_id, account_id, currency
ON CONFLICT (account_id, journal_id, currency)
DO UPDATE SET
latest_version = GREATEST(c.latest_version, EXCLUDED.latest_version),
latest_values = CASE
WHEN c.latest_version < EXCLUDED.latest_version
THEN EXCLUDED.latest_values
ELSE c.latest_values
END,
latest_seq = GREATEST(c.latest_seq, EXCLUDED.latest_seq)
"#,
&journal_ids as &[JournalId],
&account_ids as &[AccountId],
¤cies as &[&str],
&versions as &[i32],
&entry_ids as &[EntryId],
&values
)
.execute(op.as_executor())
.await?;
self.publisher
.publish_all(
op,
new_balances.into_iter().map(|balance| {
if balance.version == 1 {
OutboxEventPayload::BalanceCreated { balance }
} else {
OutboxEventPayload::BalanceUpdated { balance }
}
}),
)
.await?;
Ok(())
}
#[instrument(
name = "balance.load_account_set_balances_batch",
skip_all,
err(level = "warn")
)]
pub(crate) async fn load_account_set_balances_batch(
&self,
op: &mut impl es_entity::AtomicOperation,
journal_id: JournalId,
account_ids: &[AccountId],
) -> Result<HashMap<AccountId, AccountSetBalanceState>, BalanceError> {
let rows = sqlx::query!(
r#"
SELECT account_id AS "account_id!: AccountId", latest_values, latest_seq
FROM cala_current_balances
WHERE account_id = ANY($1) AND journal_id = $2
ORDER BY account_id
FOR UPDATE
"#,
account_ids as &[AccountId],
journal_id as JournalId,
)
.fetch_all(op.as_executor())
.await?;
let mut result: HashMap<AccountId, (HashMap<Currency, BalanceSnapshot>, Option<i64>)> =
HashMap::new();
for row in rows {
let snap: BalanceSnapshot = serde_json::from_value(row.latest_values)
.expect("Failed to deserialize balance snapshot");
let currency = snap.currency;
let seq = row.latest_seq;
let entry = result
.entry(row.account_id)
.or_insert_with(|| (HashMap::new(), None));
entry.0.insert(currency, snap);
entry.1 = Some(entry.1.map_or(seq, |cur: i64| cur.max(seq)));
}
for (_, watermark) in result.values_mut() {
*watermark = watermark.filter(|&s| s > 0);
}
for id in account_ids {
result.entry(*id).or_insert_with(|| (HashMap::new(), None));
}
Ok(result)
}
#[instrument(
name = "balance.fetch_batch_member_history",
skip_all,
err(level = "warn")
)]
pub(crate) async fn fetch_batch_member_history(
&self,
op: &mut impl es_entity::AtomicOperation,
journal_id: JournalId,
account_set_ids: &[AccountSetId],
min_watermark: Option<i64>,
) -> Result<Vec<MemberBalanceHistoryRow>, BalanceError> {
let rows = sqlx::query!(
r#"
WITH member_accounts AS (
SELECT DISTINCT m.member_account_id
FROM cala_account_set_member_accounts m
LEFT JOIN cala_account_sets s ON s.id = m.member_account_id
WHERE m.account_set_id = ANY($1)
AND s.id IS NULL
),
all_history AS (
SELECT h.values, h.account_id, h.currency, h.version, h.seq
FROM cala_balance_history h
JOIN member_accounts ma ON ma.member_account_id = h.account_id
WHERE h.journal_id = $2
),
with_prev AS (
SELECT values,
LAG(values) OVER (
PARTITION BY account_id, currency ORDER BY version
) as prev_values,
seq,
account_id
FROM all_history
)
SELECT values, prev_values, seq
FROM with_prev
WHERE ($3::bigint IS NULL OR seq > $3)
ORDER BY seq, account_id
"#,
account_set_ids as &[AccountSetId],
journal_id as JournalId,
min_watermark,
)
.fetch_all(op.as_executor())
.await?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let snapshot: BalanceSnapshot =
serde_json::from_value(row.values).expect("Failed to deserialize balance snapshot");
let prev_snapshot: Option<BalanceSnapshot> = row.prev_values.map(|v| {
serde_json::from_value(v).expect("Failed to deserialize previous balance snapshot")
});
result.push(MemberBalanceHistoryRow {
snapshot,
prev_snapshot,
seq: row.seq,
});
}
Ok(result)
}
#[instrument(
name = "balance.fetch_member_account_mappings",
skip_all,
err(level = "warn")
)]
pub(crate) async fn fetch_member_account_mappings(
&self,
op: &mut impl es_entity::AtomicOperation,
account_set_ids: &[AccountSetId],
) -> Result<HashMap<AccountId, Vec<AccountSetId>>, BalanceError> {
let rows = sqlx::query!(
r#"
SELECT
account_set_id AS "account_set_id!: AccountSetId",
member_account_id AS "member_account_id!: AccountId"
FROM cala_account_set_member_accounts m
LEFT JOIN cala_account_sets s ON s.id = m.member_account_id
WHERE m.account_set_id = ANY($1)
AND s.id IS NULL
"#,
account_set_ids as &[AccountSetId],
)
.fetch_all(op.as_executor())
.await?;
let mut result: HashMap<AccountId, Vec<AccountSetId>> = HashMap::new();
for row in rows {
result
.entry(row.member_account_id)
.or_default()
.push(row.account_set_id);
}
Ok(result)
}
}
pub(crate) struct MemberBalanceHistoryRow {
pub(crate) snapshot: BalanceSnapshot,
pub(crate) prev_snapshot: Option<BalanceSnapshot>,
pub(crate) seq: i64,
}
pub(crate) type AccountSetBalanceState = (HashMap<Currency, BalanceSnapshot>, Option<i64>);
pub(crate) type SetRecalcState = (AccountId, HashMap<Currency, BalanceSnapshot>, Option<i64>);