use crate::config::{CompiledFilters, Granularity, PerformanceConfig};
use crate::error::Result;
use crate::kafka::client::{ConsumerGroupInfo, KafkaClient, TopicPartition};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::{debug, instrument, warn};
struct CompactedTopicsCache {
ttl: Duration,
entries: Mutex<HashMap<String, (bool, Instant)>>,
}
impl CompactedTopicsCache {
fn new(ttl: Duration) -> Self {
Self {
ttl,
entries: Mutex::new(HashMap::new()),
}
}
fn partition<'a>(&self, monitored_topics: &'a [String]) -> (HashSet<String>, Vec<&'a str>) {
let now = Instant::now();
let entries = self.entries.lock().unwrap_or_else(|p| p.into_inner());
let mut cached_compacted = HashSet::new();
let mut to_fetch: Vec<&str> = Vec::new();
for topic in monitored_topics {
match entries.get(topic) {
Some((is_compacted, fetched_at)) if now.duration_since(*fetched_at) < self.ttl => {
if *is_compacted {
cached_compacted.insert(topic.clone());
}
}
_ => to_fetch.push(topic.as_str()),
}
}
(cached_compacted, to_fetch)
}
fn update(&self, fetched_topics: &[&str], compacted_result: &HashSet<String>) {
let now = Instant::now();
let mut entries = self.entries.lock().unwrap_or_else(|p| p.into_inner());
for topic in fetched_topics {
let is_compacted = compacted_result.contains(*topic);
entries.insert((*topic).to_string(), (is_compacted, now));
}
}
fn prune_to(&self, monitored_topics: &[String]) {
let keep: HashSet<&str> = monitored_topics.iter().map(|s| s.as_str()).collect();
let mut entries = self.entries.lock().unwrap_or_else(|p| p.into_inner());
entries.retain(|k, _| keep.contains(k.as_str()));
}
}
type MonitoredSet = (Arc<Vec<TopicPartition>>, Arc<Vec<String>>);
struct MetadataCache {
ttl: Duration,
entry: Mutex<Option<(MonitoredSet, Instant)>>,
}
impl MetadataCache {
fn new(ttl: Duration) -> Self {
Self {
ttl,
entry: Mutex::new(None),
}
}
fn get(&self) -> Option<MonitoredSet> {
if self.ttl.is_zero() {
return None;
}
let guard = self.entry.lock().unwrap_or_else(|p| p.into_inner());
let (set, at) = guard.as_ref()?;
if at.elapsed() < self.ttl {
Some((Arc::clone(&set.0), Arc::clone(&set.1)))
} else {
None
}
}
fn set(&self, partitions: Vec<TopicPartition>, topics: Vec<String>) -> MonitoredSet {
let set: MonitoredSet = (Arc::new(partitions), Arc::new(topics));
if !self.ttl.is_zero() {
let mut guard = self.entry.lock().unwrap_or_else(|p| p.into_inner());
*guard = Some(((Arc::clone(&set.0), Arc::clone(&set.1)), Instant::now()));
}
set
}
}
struct ConsumerGroupsCache {
ttl: Duration,
entry: Mutex<Option<(Arc<Vec<ConsumerGroupInfo>>, Instant)>>,
}
impl ConsumerGroupsCache {
fn new(ttl: Duration) -> Self {
Self {
ttl,
entry: Mutex::new(None),
}
}
fn get(&self) -> Option<Arc<Vec<ConsumerGroupInfo>>> {
if self.ttl.is_zero() {
return None;
}
let guard = self.entry.lock().unwrap_or_else(|p| p.into_inner());
let (groups, at) = guard.as_ref()?;
if at.elapsed() < self.ttl {
Some(Arc::clone(groups))
} else {
None
}
}
fn set(&self, groups: Vec<ConsumerGroupInfo>) -> Arc<Vec<ConsumerGroupInfo>> {
let arc = Arc::new(groups);
if !self.ttl.is_zero() {
let mut guard = self.entry.lock().unwrap_or_else(|p| p.into_inner());
*guard = Some((Arc::clone(&arc), Instant::now()));
}
arc
}
}
pub struct OffsetCollector {
client: Arc<KafkaClient>,
filters: CompiledFilters,
performance: PerformanceConfig,
granularity: Granularity,
compacted_cache: CompactedTopicsCache,
metadata_cache: MetadataCache,
consumer_groups_cache: ConsumerGroupsCache,
}
#[derive(Default)]
struct PhaseTimings {
list_groups_ms: u64,
describe_groups_ms: u64,
metadata_ms: u64,
watermarks_ms: u64,
group_offsets_ms: u64,
compacted_ms: u64,
}
#[derive(Debug, Clone)]
pub struct OffsetsSnapshot {
pub cluster_name: String,
pub groups: Vec<GroupSnapshot>,
pub watermarks: HashMap<TopicPartition, (i64, i64)>,
pub compacted_topics: HashSet<String>,
#[allow(dead_code)]
pub timestamp_ms: i64,
}
#[derive(Debug, Clone)]
pub struct GroupSnapshot {
pub group_id: String,
pub state: String,
pub members: Vec<MemberSnapshot>,
pub offsets: HashMap<TopicPartition, i64>,
}
#[derive(Debug, Clone)]
pub struct MemberSnapshot {
pub member_id: String,
pub client_id: String,
pub client_host: String,
pub assignments: Vec<TopicPartition>,
}
impl OffsetCollector {
pub fn with_performance(
client: Arc<KafkaClient>,
filters: CompiledFilters,
performance: PerformanceConfig,
granularity: Granularity,
) -> Self {
let compacted_cache = CompactedTopicsCache::new(performance.compacted_topics_cache_ttl);
let metadata_cache = MetadataCache::new(performance.metadata_cache_ttl);
let consumer_groups_cache = ConsumerGroupsCache::new(performance.consumer_groups_cache_ttl);
Self {
client,
filters,
performance,
granularity,
compacted_cache,
metadata_cache,
consumer_groups_cache,
}
}
#[instrument(skip(self), fields(cluster = %self.client.cluster_name()))]
pub async fn collect_parallel(&self) -> Result<OffsetsSnapshot> {
let start = std::time::Instant::now();
let mut timings = PhaseTimings::default();
let phase_start = std::time::Instant::now();
let all_groups = if let Some(cached) = self.consumer_groups_cache.get() {
debug!(total_groups = cached.len(), "Consumer groups cache hit");
cached
} else {
let fresh = self.client.list_consumer_groups()?;
debug!(
total_groups = fresh.len(),
"Listed all consumer groups (fresh)"
);
self.consumer_groups_cache.set(fresh)
};
timings.list_groups_ms = phase_start.elapsed().as_millis() as u64;
let filtered_groups: Vec<_> = all_groups
.iter()
.filter(|g| self.filters.matches_group(&g.group_id))
.collect();
debug!(
filtered_groups = filtered_groups.len(),
"Filtered consumer groups"
);
let group_ids: Vec<&str> = filtered_groups
.iter()
.map(|g| g.group_id.as_str())
.collect();
let parse_assignments = matches!(self.granularity, Granularity::Partition);
let phase_start = std::time::Instant::now();
let descriptions = self
.client
.describe_consumer_groups(
&group_ids,
parse_assignments,
self.performance.max_concurrent_groups,
)
.await?;
timings.describe_groups_ms = phase_start.elapsed().as_millis() as u64;
let phase_start = std::time::Instant::now();
let (monitored_partitions, monitored_topics) = self.list_monitored_partitions()?;
timings.metadata_ms = phase_start.elapsed().as_millis() as u64;
debug!(
partitions = monitored_partitions.len(),
topics = monitored_topics.len(),
"Computed monitored topic + partition set"
);
let phase_start = std::time::Instant::now();
let watermarks = {
let client = Arc::clone(&self.client);
tokio::task::spawn_blocking(move || {
client.fetch_watermarks_for_partitions(&monitored_partitions)
})
.await
.map_err(|e| {
crate::error::KlagError::Admin(format!("watermark task panicked: {e}"))
})??
};
timings.watermarks_ms = phase_start.elapsed().as_millis() as u64;
debug!(
partitions = watermarks.len(),
"Fetched watermarks (batched)"
);
let phase_start = std::time::Instant::now();
let group_offsets = self.fetch_all_group_offsets_batched(&group_ids).await;
timings.group_offsets_ms = phase_start.elapsed().as_millis() as u64;
let phase_start = std::time::Instant::now();
let (mut compacted_topics, to_fetch) = self.compacted_cache.partition(&monitored_topics);
if !to_fetch.is_empty() {
debug!(
to_fetch = to_fetch.len(),
cached = compacted_topics.len(),
"Compacted-topic cache partial miss — refreshing"
);
let to_fetch_owned: Vec<String> = to_fetch.iter().map(|s| s.to_string()).collect();
match self
.client
.fetch_compacted_topics_for(&to_fetch_owned)
.await
{
Ok(freshly_compacted) => {
self.compacted_cache.update(&to_fetch, &freshly_compacted);
compacted_topics.extend(freshly_compacted);
}
Err(e) => warn!(error = %e, "Failed to refresh compacted topics"),
}
} else {
debug!(
cached = compacted_topics.len(),
"Compacted-topic cache fully hit — no DescribeConfigs RPC"
);
}
self.compacted_cache.prune_to(&monitored_topics);
timings.compacted_ms = phase_start.elapsed().as_millis() as u64;
let mut groups = Vec::with_capacity(descriptions.len());
for desc in descriptions {
let offsets = group_offsets
.get(&desc.group_id)
.cloned()
.unwrap_or_default();
let filtered_offsets: HashMap<TopicPartition, i64> = offsets
.into_iter()
.filter(|(tp, _)| self.filters.matches_topic(&tp.topic))
.collect();
let members = desc
.members
.into_iter()
.map(|m| MemberSnapshot {
member_id: m.member_id,
client_id: m.client_id,
client_host: m.client_host,
assignments: m.assignments,
})
.collect();
groups.push(GroupSnapshot {
group_id: desc.group_id,
state: desc.state,
members,
offsets: filtered_offsets,
});
}
let elapsed = start.elapsed();
debug!(
elapsed_ms = elapsed.as_millis(),
list_groups_ms = timings.list_groups_ms,
describe_groups_ms = timings.describe_groups_ms,
metadata_ms = timings.metadata_ms,
watermarks_ms = timings.watermarks_ms,
group_offsets_ms = timings.group_offsets_ms,
compacted_ms = timings.compacted_ms,
monitored_topics = monitored_topics.len(),
compacted_topics = compacted_topics.len(),
"Batched collection completed"
);
Ok(OffsetsSnapshot {
cluster_name: self.client.cluster_name().to_string(),
groups,
watermarks,
compacted_topics,
timestamp_ms: chrono_timestamp_ms(),
})
}
fn list_monitored_partitions(&self) -> Result<MonitoredSet> {
if let Some(cached) = self.metadata_cache.get() {
debug!("Metadata cache hit");
return Ok(cached);
}
let metadata = self.client.fetch_metadata()?;
let mut partitions = Vec::new();
let mut topics = Vec::new();
for topic in metadata.topics() {
let name = topic.name();
if !self.filters.matches_topic(name) {
continue;
}
topics.push(name.to_string());
let topic_arc: Arc<str> = Arc::from(name);
for p in topic.partitions() {
partitions.push(TopicPartition::new(Arc::clone(&topic_arc), p.id()));
}
}
Ok(self.metadata_cache.set(partitions, topics))
}
async fn fetch_all_group_offsets_batched(
&self,
group_ids: &[&str],
) -> HashMap<String, HashMap<TopicPartition, i64>> {
use crate::kafka::admin::list_consumer_group_offsets_batched;
if group_ids.is_empty() {
return HashMap::new();
}
const PER_CALL_CHUNK: usize = 1;
let offset_timeout = self.performance.offset_fetch_timeout;
let max_concurrent = self.performance.max_concurrent_groups;
debug!(
groups = group_ids.len(),
per_call_chunk = PER_CALL_CHUNK,
max_concurrent = max_concurrent,
"Fetching group offsets (one call per group, fanned out)"
);
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let client = Arc::clone(&self.client);
let mut handles = Vec::with_capacity(group_ids.len());
for gid in group_ids {
let gid = gid.to_string();
let permit = semaphore.clone();
let client_clone = Arc::clone(&client);
handles.push(tokio::spawn(async move {
let _permit: OwnedSemaphorePermit =
permit.acquire_owned().await.expect("semaphore closed");
let result = tokio::task::spawn_blocking({
let gid = gid.clone();
move || {
list_consumer_group_offsets_batched(
&client_clone.admin_handle(),
&[gid.as_str()],
offset_timeout,
PER_CALL_CHUNK,
)
}
})
.await;
(gid, result)
}));
}
let results = futures::future::join_all(handles).await;
let mut merged: HashMap<String, HashMap<TopicPartition, i64>> = HashMap::new();
for r in results {
match r {
Ok((_gid, Ok(Ok(map)))) => merged.extend(map),
Ok((gid, Ok(Err(e)))) => {
warn!(group = %gid, error = %e, "Group-offset call failed")
}
Ok((gid, Err(e))) => {
warn!(group = %gid, error = %e, "Group-offset call task panicked")
}
Err(e) => warn!(error = %e, "Group-offset join error"),
}
}
merged
}
}
impl OffsetsSnapshot {
#[allow(dead_code)]
pub fn filtered_groups(&self) -> Vec<&str> {
self.groups.iter().map(|g| g.group_id.as_str()).collect()
}
pub fn get_watermark(&self, tp: &TopicPartition) -> Option<(i64, i64)> {
self.watermarks.get(tp).copied()
}
#[allow(dead_code)]
pub fn get_high_watermark(&self, tp: &TopicPartition) -> Option<i64> {
self.watermarks.get(tp).map(|(_, high)| *high)
}
}
fn chrono_timestamp_ms() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_offsets_snapshot_filtered_groups() {
let snapshot = OffsetsSnapshot {
cluster_name: "test".to_string(),
groups: vec![
GroupSnapshot {
group_id: "group1".to_string(),
state: "Stable".to_string(),
members: vec![],
offsets: HashMap::new(),
},
GroupSnapshot {
group_id: "group2".to_string(),
state: "Stable".to_string(),
members: vec![],
offsets: HashMap::new(),
},
],
watermarks: HashMap::new(),
compacted_topics: HashSet::new(),
timestamp_ms: 0,
};
let groups = snapshot.filtered_groups();
assert_eq!(groups.len(), 2);
assert!(groups.contains(&"group1"));
assert!(groups.contains(&"group2"));
}
#[test]
fn compacted_cache_empty_cache_requests_all() {
let cache = CompactedTopicsCache::new(Duration::from_secs(60));
let topics = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let (cached, to_fetch) = cache.partition(&topics);
assert!(cached.is_empty());
assert_eq!(to_fetch, vec!["a", "b", "c"]);
}
#[test]
fn compacted_cache_hit_skips_fetch() {
let cache = CompactedTopicsCache::new(Duration::from_secs(60));
let topics = vec!["a".to_string(), "b".to_string()];
let mut compacted = HashSet::new();
compacted.insert("a".to_string());
let (_cached, to_fetch) = cache.partition(&topics);
assert_eq!(to_fetch.len(), 2);
cache.update(&to_fetch, &compacted);
let (cached, to_fetch) = cache.partition(&topics);
assert_eq!(to_fetch.len(), 0);
assert_eq!(cached.len(), 1);
assert!(cached.contains("a"));
}
#[test]
fn compacted_cache_expired_entries_re_fetched() {
let cache = CompactedTopicsCache::new(Duration::from_millis(50));
let topics = vec!["a".to_string()];
let mut compacted = HashSet::new();
compacted.insert("a".to_string());
let (_cached, to_fetch) = cache.partition(&topics);
cache.update(&to_fetch, &compacted);
let (cached, to_fetch) = cache.partition(&topics);
assert!(cached.contains("a"));
assert!(to_fetch.is_empty());
std::thread::sleep(Duration::from_millis(70));
let (cached, to_fetch) = cache.partition(&topics);
assert!(cached.is_empty());
assert_eq!(to_fetch, vec!["a"]);
}
#[test]
fn metadata_cache_hit_returns_cached() {
let cache = MetadataCache::new(Duration::from_secs(60));
assert!(cache.get().is_none(), "empty cache should miss");
let (parts, topics) = cache.set(vec![TopicPartition::new("t", 0)], vec!["t".to_string()]);
assert_eq!(parts.len(), 1);
assert_eq!(topics.len(), 1);
let cached = cache.get().expect("cache should be populated");
assert_eq!(cached.0.len(), 1);
assert_eq!(cached.1.len(), 1);
}
#[test]
fn metadata_cache_disabled_by_zero_ttl() {
let cache = MetadataCache::new(Duration::ZERO);
let (_parts, _topics) = cache.set(vec![TopicPartition::new("t", 0)], vec!["t".into()]);
assert!(
cache.get().is_none(),
"zero-TTL cache must never return a hit"
);
}
#[test]
fn metadata_cache_expires() {
let cache = MetadataCache::new(Duration::from_millis(50));
cache.set(vec![], vec![]);
assert!(cache.get().is_some());
std::thread::sleep(Duration::from_millis(70));
assert!(cache.get().is_none(), "expired entry must miss");
}
#[test]
fn consumer_groups_cache_hit_returns_cached() {
let cache = ConsumerGroupsCache::new(Duration::from_secs(60));
assert!(cache.get().is_none(), "empty cache should miss");
let arc = cache.set(vec![ConsumerGroupInfo {
group_id: "g".into(),
protocol_type: String::new(),
state: String::new(),
}]);
assert_eq!(arc.len(), 1);
let cached = cache.get().expect("should hit after set");
assert_eq!(cached.len(), 1);
}
#[test]
fn consumer_groups_cache_zero_ttl_disabled() {
let cache = ConsumerGroupsCache::new(Duration::ZERO);
let _arc = cache.set(vec![ConsumerGroupInfo {
group_id: "g".into(),
protocol_type: String::new(),
state: String::new(),
}]);
assert!(cache.get().is_none());
}
#[test]
fn consumer_groups_cache_expires() {
let cache = ConsumerGroupsCache::new(Duration::from_millis(50));
cache.set(vec![]);
assert!(cache.get().is_some());
std::thread::sleep(Duration::from_millis(70));
assert!(cache.get().is_none());
}
#[test]
fn compacted_cache_prune_removes_unseen_topics() {
let cache = CompactedTopicsCache::new(Duration::from_secs(60));
let initial = vec!["a".to_string(), "b".to_string()];
let mut compacted = HashSet::new();
compacted.insert("a".to_string());
let (_cached, to_fetch) = cache.partition(&initial);
cache.update(&to_fetch, &compacted);
let remaining = vec!["a".to_string()];
cache.prune_to(&remaining);
let entries = cache.entries.lock().unwrap();
assert!(entries.contains_key("a"));
assert!(!entries.contains_key("b"));
}
}