use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use tokio::sync::Mutex;
use tracing::warn;
use crabka_metadata::NodeId;
use crate::coordinator::unified::share::config::ShareGroupConfig;
use crate::metadata_source::MetadataSource;
use crate::partition_registry::PartitionRegistry;
use crate::share_coordinator::persister_client::SharePersister;
use crate::share_partition::session::ShareSessionCache;
use crate::share_partition::state::AcquisitionState;
type LeaderKey = (String, uuid::Uuid, i32);
pub(crate) struct SharePartitionLeaderManager {
node_id: NodeId,
partitions: Arc<PartitionRegistry>,
controller: Arc<dyn MetadataSource>,
persister: Arc<SharePersister>,
config: Arc<ShareGroupConfig>,
sessions: ShareSessionCache,
leaders: DashMap<LeaderKey, Arc<Mutex<AcquisitionState>>>,
}
impl std::fmt::Debug for SharePartitionLeaderManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharePartitionLeaderManager")
.field("node_id", &self.node_id)
.field("live_partitions", &self.leaders.len())
.finish_non_exhaustive()
}
}
impl SharePartitionLeaderManager {
pub(crate) fn new(
node_id: NodeId,
partitions: Arc<PartitionRegistry>,
controller: Arc<dyn MetadataSource>,
persister: Arc<SharePersister>,
config: Arc<ShareGroupConfig>,
) -> Self {
let session_max = if config.max_groups == 0 {
10_000
} else {
config.max_groups.saturating_mul(config.max_size.max(1))
};
Self {
node_id,
partitions,
controller,
persister,
config,
sessions: ShareSessionCache::new(session_max),
leaders: DashMap::new(),
}
}
pub(crate) fn validate_session(
&self,
group: &str,
member: &str,
epoch: i32,
) -> Result<(), i16> {
self.sessions.validate(group, member, epoch)
}
pub(crate) fn current_leader_of(&self, topic_id: uuid::Uuid, partition: i32) -> (i32, i32) {
let image = self.controller.current_image();
let Some(topic) = image.topics().find(|t| t.topic_id == topic_id) else {
return (-1, -1);
};
image
.partition(&topic.name, partition)
.map_or((-1, -1), |p| {
(i32::try_from(p.leader).unwrap_or(-1), p.leader_epoch)
})
}
pub(crate) fn topic_name_for(&self, topic_id: uuid::Uuid) -> Option<String> {
self.controller
.current_image()
.topics()
.find(|t| t.topic_id == topic_id)
.map(|t| t.name.clone())
}
pub(crate) fn topic_leader_is_self(&self, topic_id: uuid::Uuid, partition: i32) -> bool {
let image = self.controller.current_image();
let Some(topic) = image.topics().find(|t| t.topic_id == topic_id) else {
return false;
};
image
.partition(&topic.name, partition)
.is_some_and(|p| p.leader == self.node_id)
}
fn leader_epoch_for(&self, topic_id: uuid::Uuid, partition: i32) -> i32 {
let image = self.controller.current_image();
let Some(topic) = image.topics().find(|t| t.topic_id == topic_id) else {
return 0;
};
self.partitions.get(&topic.name, partition).map_or(0, |p| {
p.current_leader_epoch
.load(std::sync::atomic::Ordering::Acquire)
})
}
pub(crate) async fn get_or_load(
&self,
group: &str,
topic_id: uuid::Uuid,
partition: i32,
) -> Arc<Mutex<AcquisitionState>> {
let key = (group.to_string(), topic_id, partition);
if let Some(cell) = self.leaders.get(&key) {
return cell.value().clone();
}
let leader_epoch = self.leader_epoch_for(topic_id, partition);
let loaded = match self.persister.read_state(group, topic_id, partition).await {
Ok(Some(persisted)) => {
let mut st = AcquisitionState::new(persisted.start_offset);
st.load_from(
persisted.start_offset,
persisted.state_epoch,
leader_epoch,
persisted.delivery_complete_count,
&persisted.state_batches,
);
st
}
Ok(None) => {
let mut st = AcquisitionState::new(0);
st.leader_epoch = leader_epoch;
st
}
Err(e) => {
warn!(
group,
%topic_id, partition, error = %e,
"share-partition state load failed; starting from empty window"
);
let mut st = AcquisitionState::new(0);
st.leader_epoch = leader_epoch;
st
}
};
let cell = Arc::new(Mutex::new(loaded));
self.leaders.entry(key).or_insert(cell).value().clone()
}
pub(crate) fn invalidate(&self, group: &str, topic_id: uuid::Uuid, partition: i32) {
self.leaders
.remove(&(group.to_string(), topic_id, partition));
}
pub(crate) async fn persist_if_dirty(
&self,
group: &str,
topic_id: uuid::Uuid,
partition: i32,
st: &mut AcquisitionState,
) {
if !st.dirty {
return;
}
let (start, dcc, batches) = st.to_persist_batches();
match self
.persister
.write_state(
group,
topic_id,
partition,
st.state_epoch,
st.leader_epoch,
start,
dcc,
batches,
)
.await
{
Ok(()) => st.dirty = false,
Err(e) => warn!(
group,
%topic_id, partition, error = %e,
"share-partition state persist failed; will retry on next change"
),
}
}
pub(crate) fn spawn_lock_sweeper(self: &Arc<Self>) {
let mgr = Arc::clone(self);
let period = (mgr.config.record_lock_duration / 2).max(Duration::from_millis(100));
tokio::spawn(async move {
let mut tick = tokio::time::interval(period);
loop {
tick.tick().await;
let cells: Vec<(LeaderKey, Arc<Mutex<AcquisitionState>>)> = mgr
.leaders
.iter()
.map(|e| (e.key().clone(), e.value().clone()))
.collect();
let now = std::time::Instant::now();
for ((group, topic_id, partition), cell) in cells {
let mut st = cell.lock().await;
st.expire_locks(now);
if st.dirty {
mgr.persist_if_dirty(&group, topic_id, partition, &mut st)
.await;
}
}
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use std::collections::BTreeSet;
use std::net::SocketAddr;
const LOCK: Duration = Duration::from_secs(30);
use async_trait::async_trait;
use tokio::sync::watch;
use crabka_metadata::{MetadataImage, MetadataRecord};
use crabka_raft::{
AddVoter, Node, QuorumState, RaftError, ReconfigOutcome, RemoveVoter, SnapshotRange,
UpdateVoter,
};
use crabka_security::ListenerProtocol;
use crate::network::client::InterBrokerClient;
use crate::share_coordinator::config::ShareCoordinatorConfig;
use crate::share_coordinator::coordinator::ShareCoordinator;
struct MockSource {
image: Arc<MetadataImage>,
leader_rx: watch::Receiver<Option<NodeId>>,
_leader_tx: watch::Sender<Option<NodeId>>,
}
impl MockSource {
fn new() -> Self {
let (tx, rx) = watch::channel(Some(1));
Self {
image: Arc::new(MetadataImage::new(uuid::Uuid::nil())),
leader_rx: rx,
_leader_tx: tx,
}
}
}
#[async_trait]
impl MetadataSource for MockSource {
fn current_image(&self) -> Arc<MetadataImage> {
self.image.clone()
}
fn watch_image(&self) -> watch::Receiver<Arc<MetadataImage>> {
unimplemented!()
}
fn watch_leader(&self) -> watch::Receiver<Option<NodeId>> {
self.leader_rx.clone()
}
fn quorum_state(&self) -> QuorumState {
unimplemented!()
}
async fn submit_change(&self, _records: Vec<MetadataRecord>) -> Result<(), RaftError> {
Ok(())
}
async fn change_membership(&self, _new_voters: BTreeSet<NodeId>) -> Result<(), RaftError> {
unimplemented!()
}
async fn add_learner(&self, _node_id: NodeId, _node: Node) -> Result<(), RaftError> {
unimplemented!()
}
fn controller_bound_addr(&self) -> SocketAddr {
unimplemented!()
}
fn read_snapshot_range(&self, _position: i64, _max_bytes: i32) -> SnapshotRange {
unimplemented!()
}
async fn trigger_snapshot(&self) -> Result<(), RaftError> {
unimplemented!()
}
async fn add_voter(&self, _req: AddVoter) -> Result<ReconfigOutcome, RaftError> {
unimplemented!()
}
async fn remove_voter(&self, _req: RemoveVoter) -> Result<ReconfigOutcome, RaftError> {
unimplemented!()
}
async fn update_voter(&self, _req: UpdateVoter) -> Result<ReconfigOutcome, RaftError> {
unimplemented!()
}
async fn cancel(&self) {}
}
fn manager() -> Arc<SharePartitionLeaderManager> {
let reg = Arc::new(PartitionRegistry::new());
let controller: Arc<dyn MetadataSource> = Arc::new(MockSource::new());
let coord = Arc::new(ShareCoordinator::new(
1,
reg.clone(),
ShareCoordinatorConfig::default(),
));
let client = Arc::new(InterBrokerClient::new(None, None));
let persister = Arc::new(SharePersister::new(
1,
coord,
controller.clone(),
client,
ListenerProtocol::Plaintext,
"INTERNAL".to_string(),
));
Arc::new(SharePartitionLeaderManager::new(
1,
reg,
controller,
persister,
Arc::new(ShareGroupConfig::default()),
))
}
#[tokio::test]
async fn get_or_load_fresh_returns_empty_window_and_caches() {
let mgr = manager();
let tid = uuid::Uuid::from_bytes([21; 16]);
let cell = mgr.get_or_load("g1", tid, 0).await;
let st = cell.lock().await;
assert!(st.start_offset == 0);
assert!(!st.dirty);
drop(st);
let cell2 = mgr.get_or_load("g1", tid, 0).await;
assert!(Arc::ptr_eq(&cell, &cell2));
}
#[tokio::test]
async fn persist_if_dirty_is_noop_when_clean() {
let mgr = manager();
let tid = uuid::Uuid::from_bytes([22; 16]);
let cell = mgr.get_or_load("g1", tid, 0).await;
let mut st = cell.lock().await;
assert!(!st.dirty);
mgr.persist_if_dirty("g1", tid, 0, &mut st).await;
assert!(!st.dirty);
}
#[tokio::test]
async fn persist_if_dirty_keeps_dirty_on_write_failure() {
let mgr = manager();
let tid = uuid::Uuid::from_bytes([25; 16]);
let cell = mgr.get_or_load("g1", tid, 0).await;
let mut st = cell.lock().await;
st.materialize(4, 100);
let _ = st.acquire("m1", 10, i32::MAX, std::time::Instant::now(), LOCK, 5);
assert!(st.dirty);
mgr.persist_if_dirty("g1", tid, 0, &mut st).await;
assert!(st.dirty);
}
#[tokio::test]
async fn topic_leader_is_self_false_for_unknown_topic() {
let mgr = manager();
let tid = uuid::Uuid::from_bytes([23; 16]);
assert!(!mgr.topic_leader_is_self(tid, 0));
}
#[tokio::test]
async fn invalidate_removes_cached_cell() {
let mgr = manager();
let tid = uuid::Uuid::from_bytes([24; 16]);
let cell = mgr.get_or_load("g1", tid, 0).await;
mgr.invalidate("g1", tid, 0);
let cell2 = mgr.get_or_load("g1", tid, 0).await;
assert!(!Arc::ptr_eq(&cell, &cell2));
}
}