use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossShardTransaction {
pub txn_id: u64,
pub tenant_id: u32,
pub shard_writes: Vec<(u16, Vec<u8>)>,
pub coordinator_node: u64,
pub coordinator_log_index: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForwardEntry {
pub txn_id: u64,
pub writes: Vec<u8>,
pub source_vshard: u16,
pub coordinator_log_index: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GsiForwardEntry {
pub index_name: String,
pub value: String,
pub tenant_id: u32,
pub collection: String,
pub document_id: String,
pub source_vshard: u16,
pub action: GsiAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GsiAction {
Upsert,
Remove,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeValidationRequest {
pub src_id: String,
pub src_vshard: u16,
pub dst_id: String,
pub dst_vshard: u16,
pub label: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EdgeValidationResult {
Exists,
NotFound,
Unavailable,
}
pub struct TransactionCoordinator {
next_txn_id: u64,
pending: std::collections::HashMap<u64, TxnState>,
node_id: u64,
}
#[derive(Debug)]
struct TxnState {
txn: CrossShardTransaction,
acks_received: std::collections::HashSet<u16>,
acks_needed: std::collections::HashSet<u16>,
committed: bool,
}
impl TransactionCoordinator {
pub fn new(node_id: u64) -> Self {
Self {
next_txn_id: 1,
pending: std::collections::HashMap::new(),
node_id,
}
}
pub fn begin(
&mut self,
tenant_id: u32,
shard_writes: Vec<(u16, Vec<u8>)>,
coordinator_log_index: u64,
) -> CrossShardTransaction {
let txn_id = self.next_txn_id;
self.next_txn_id += 1;
let acks_needed: std::collections::HashSet<u16> =
shard_writes.iter().map(|(s, _)| *s).collect();
let txn = CrossShardTransaction {
txn_id,
tenant_id,
shard_writes,
coordinator_node: self.node_id,
coordinator_log_index,
};
self.pending.insert(
txn_id,
TxnState {
txn: txn.clone(),
acks_received: std::collections::HashSet::new(),
acks_needed,
committed: false,
},
);
debug!(txn_id, "cross-shard transaction created");
txn
}
pub fn ack(&mut self, txn_id: u64, vshard_id: u16) -> bool {
if let Some(state) = self.pending.get_mut(&txn_id) {
state.acks_received.insert(vshard_id);
if state.acks_received == state.acks_needed {
state.committed = true;
info!(txn_id, "cross-shard transaction fully committed");
true
} else {
debug!(
txn_id,
received = state.acks_received.len(),
needed = state.acks_needed.len(),
"cross-shard transaction partial ack"
);
false
}
} else {
warn!(txn_id, "ack for unknown transaction");
false
}
}
pub fn is_committed(&self, txn_id: u64) -> bool {
self.pending.get(&txn_id).is_some_and(|s| s.committed)
}
pub fn cleanup(&mut self, txn_id: u64) {
self.pending.remove(&txn_id);
}
pub fn get_transaction(&self, txn_id: u64) -> Option<&CrossShardTransaction> {
self.pending.get(&txn_id).map(|s| &s.txn)
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn generate_forwards(txn: &CrossShardTransaction) -> Vec<(u16, ForwardEntry)> {
txn.shard_writes
.iter()
.map(|(vshard, writes)| {
(
*vshard,
ForwardEntry {
txn_id: txn.txn_id,
writes: writes.clone(),
source_vshard: txn.shard_writes.first().map(|(s, _)| *s).unwrap_or(0),
coordinator_log_index: txn.coordinator_log_index,
},
)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transaction_lifecycle() {
let mut coord = TransactionCoordinator::new(1);
let txn = coord.begin(
1,
vec![
(10, b"writes_for_shard_10".to_vec()),
(20, b"writes_for_shard_20".to_vec()),
],
100,
);
assert_eq!(txn.txn_id, 1);
assert_eq!(coord.pending_count(), 1);
assert!(!coord.is_committed(1));
assert!(!coord.ack(1, 10));
assert!(!coord.is_committed(1));
assert!(coord.ack(1, 20));
assert!(coord.is_committed(1));
coord.cleanup(1);
assert_eq!(coord.pending_count(), 0);
}
#[test]
fn generate_forwards() {
let txn = CrossShardTransaction {
txn_id: 42,
tenant_id: 1,
shard_writes: vec![(10, b"w1".to_vec()), (20, b"w2".to_vec())],
coordinator_node: 1,
coordinator_log_index: 100,
};
let forwards = TransactionCoordinator::generate_forwards(&txn);
assert_eq!(forwards.len(), 2);
assert_eq!(forwards[0].0, 10);
assert_eq!(forwards[0].1.txn_id, 42);
assert_eq!(forwards[1].0, 20);
}
#[test]
fn edge_validation_types() {
let req = EdgeValidationRequest {
src_id: "alice".into(),
src_vshard: 10,
dst_id: "bob".into(),
dst_vshard: 20,
label: "KNOWS".into(),
};
let bytes = rmp_serde::to_vec_named(&req).unwrap();
let decoded: EdgeValidationRequest = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(decoded.src_id, "alice");
assert_eq!(decoded.dst_vshard, 20);
}
#[test]
fn gsi_forward_roundtrip() {
let entry = GsiForwardEntry {
index_name: "email_idx".into(),
value: "alice@example.com".into(),
tenant_id: 1,
collection: "users".into(),
document_id: "u1".into(),
source_vshard: 10,
action: GsiAction::Upsert,
};
let bytes = rmp_serde::to_vec_named(&entry).unwrap();
let decoded: GsiForwardEntry = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(decoded.index_name, "email_idx");
}
}