Skip to main content

rivven_client/
consumer.rs

1//! High-level consumer API for Rivven (§10.1 fix)
2//!
3//! Provides a [`Consumer`] struct that wraps the low-level [`Client`] with:
4//! - Topic subscription with automatic partition assignment
5//! - Offset tracking per (topic, partition) pair
6//! - Auto-commit of consumed offsets to the server
7//! - Long-polling to avoid tight fetch loops
8//! - Configurable batch sizes and poll intervals
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use rivven_client::consumer::{Consumer, ConsumerConfig};
14//!
15//! let config = ConsumerConfig::builder()
16//!     .bootstrap_server("127.0.0.1:9092")
17//!     .group_id("my-group")
18//!     .topics(vec!["events".to_string()])
19//!     .build();
20//!
21//! let mut consumer = Consumer::new(config).await?;
22//!
23//! loop {
24//!     let records = consumer.poll().await?;
25//!     for record in &records {
26//!         println!("topic={} partition={} offset={}: {:?}",
27//!             record.topic, record.partition, record.offset, record.value);
28//!     }
29//!     // Offsets are auto-committed periodically, or call:
30//!     consumer.commit().await?;
31//! }
32//! ```
33//!
34//! # Consumer Group Protocol
35//!
36//! When no explicit partition assignments are configured, the consumer
37//! uses server-side group coordination:
38//!
39//! 1. **JoinGroup** — register with the coordinator, receive generation ID
40//! 2. **SyncGroup** — leader computes assignments, all members receive theirs
41//! 3. **Heartbeat** — periodic keep-alive during `poll()`
42//! 4. **LeaveGroup** — graceful departure on `close()`
43//!
44//! For explicit partition assignment (static model), set
45//! [`ConsumerConfig::partitions`] to bypass the coordination protocol.
46
47use crate::client::Client;
48use crate::error::{Error, Result};
49use rivven_protocol::MessageData;
50use std::collections::HashMap;
51use std::sync::atomic::{AtomicBool, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tracing::{debug, info, warn};
55
56/// A topic-partition pair used in rebalance callbacks.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct TopicPartition {
59    pub topic: Arc<str>,
60    pub partition: u32,
61}
62
63/// Callback interface for consumer group rebalance events (CLIENT-08).
64///
65/// Implement this trait to receive notifications when partitions are
66/// revoked or assigned during a rebalance. This is critical for:
67/// - Committing offsets before partitions are revoked (exactly-once)
68/// - Initializing state when new partitions are assigned
69/// - Cleaning up resources when partitions are lost
70///
71/// # Example
72///
73/// ```rust,ignore
74/// struct MyListener;
75///
76/// #[async_trait::async_trait]
77/// impl RebalanceListener for MyListener {
78///     async fn on_partitions_revoked(&self, partitions: &[TopicPartition]) {
79///         // Commit offsets for revoked partitions
80///     }
81///     async fn on_partitions_assigned(&self, partitions: &[TopicPartition]) {
82///         // Initialize state for new partitions
83///     }
84/// }
85/// ```
86#[async_trait::async_trait]
87pub trait RebalanceListener: Send + Sync {
88    /// Called before partitions are revoked from this consumer.
89    ///
90    /// Use this to commit offsets or flush state for the given partitions.
91    /// This is invoked synchronously — the rebalance blocks until this returns.
92    async fn on_partitions_revoked(&self, partitions: &[TopicPartition]);
93
94    /// Called after new partitions are assigned to this consumer.
95    ///
96    /// Use this to initialize partition-specific state or seek to custom offsets.
97    async fn on_partitions_assigned(&self, partitions: &[TopicPartition]);
98}
99
100/// Configuration for the high-level consumer.
101#[derive(Debug, Clone)]
102pub struct ConsumerConfig {
103    /// Bootstrap server addresses (host:port).
104    ///
105    /// On initial connect and reconnect, the consumer tries each server
106    /// in round-robin order until one succeeds. Accepts a single server
107    /// or multiple for failover.
108    pub bootstrap_servers: Vec<String>,
109    /// Consumer group ID for offset management
110    pub group_id: String,
111    /// Topics to subscribe to
112    pub topics: Vec<String>,
113    /// Explicit partition assignments (topic → partitions).
114    /// If empty, all partitions of each subscribed topic are consumed.
115    pub partitions: HashMap<String, Vec<u32>>,
116    /// Maximum messages per partition per poll
117    pub max_poll_records: u32,
118    /// Long-poll wait time in milliseconds (0 = immediate return)
119    pub max_poll_interval_ms: u64,
120    /// Auto-commit interval (None = manual commit only)
121    pub auto_commit_interval: Option<Duration>,
122    /// Transaction isolation level (0 = read_uncommitted, 1 = read_committed)
123    pub isolation_level: u8,
124    /// Authentication credentials (optional)
125    pub auth: Option<ConsumerAuthConfig>,
126    /// Interval for re-discovering partition assignments (default: 5 min).
127    /// Set to `Duration::MAX` to disable periodic re-discovery.
128    pub metadata_refresh_interval: Duration,
129    /// Initial reconnect backoff delay in milliseconds (default: 100)
130    pub reconnect_backoff_ms: u64,
131    /// Maximum reconnect backoff delay in milliseconds (default: 10 000)
132    pub reconnect_backoff_max_ms: u64,
133    /// Maximum number of reconnect attempts before giving up (default: 10)
134    pub max_reconnect_attempts: u32,
135    /// Session timeout for group coordination in milliseconds (default: 10 000).
136    /// If the coordinator does not receive a heartbeat within this interval,
137    /// it considers the member dead and triggers a rebalance.
138    pub session_timeout_ms: u32,
139    /// Rebalance timeout in milliseconds (default: 30 000).
140    /// Maximum time the coordinator waits for all members to join during a rebalance.
141    pub rebalance_timeout_ms: u32,
142    /// Heartbeat interval in milliseconds (default: 3 000).
143    /// Should be no more than 1/3 of `session_timeout_ms`.
144    pub heartbeat_interval_ms: u64,
145    /// TLS configuration (optional). When set, the consumer connects
146    /// over TLS instead of plaintext.
147    #[cfg(feature = "tls")]
148    pub tls_config: Option<rivven_core::tls::TlsConfig>,
149    /// TLS server name for certificate verification.
150    /// Required when `tls_config` is `Some`.
151    #[cfg(feature = "tls")]
152    pub tls_server_name: Option<String>,
153}
154
155/// Authentication configuration for the consumer.
156#[derive(Clone)]
157pub struct ConsumerAuthConfig {
158    pub username: String,
159    pub password: String,
160}
161
162impl std::fmt::Debug for ConsumerAuthConfig {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.debug_struct("ConsumerAuthConfig")
165            .field("username", &self.username)
166            .field("password", &"[REDACTED]")
167            .finish()
168    }
169}
170
171/// Builder for [`ConsumerConfig`].
172pub struct ConsumerConfigBuilder {
173    bootstrap_servers: Vec<String>,
174    group_id: Option<String>,
175    topics: Vec<String>,
176    partitions: HashMap<String, Vec<u32>>,
177    max_poll_records: u32,
178    max_poll_interval_ms: u64,
179    auto_commit_interval: Option<Duration>,
180    isolation_level: u8,
181    auth: Option<ConsumerAuthConfig>,
182    metadata_refresh_interval: Duration,
183    reconnect_backoff_ms: u64,
184    reconnect_backoff_max_ms: u64,
185    max_reconnect_attempts: u32,
186    session_timeout_ms: u32,
187    rebalance_timeout_ms: u32,
188    heartbeat_interval_ms: u64,
189    #[cfg(feature = "tls")]
190    tls_config: Option<rivven_core::tls::TlsConfig>,
191    #[cfg(feature = "tls")]
192    tls_server_name: Option<String>,
193}
194
195impl ConsumerConfigBuilder {
196    pub fn new() -> Self {
197        Self {
198            bootstrap_servers: vec!["127.0.0.1:9092".to_string()],
199            group_id: None,
200            topics: Vec::new(),
201            partitions: HashMap::new(),
202            max_poll_records: 500,
203            max_poll_interval_ms: 5000,
204            auto_commit_interval: Some(Duration::from_secs(5)),
205            isolation_level: 0,
206            auth: None,
207            metadata_refresh_interval: Duration::from_secs(300),
208            reconnect_backoff_ms: 100,
209            reconnect_backoff_max_ms: 10_000,
210            max_reconnect_attempts: 10,
211            session_timeout_ms: 10_000,
212            rebalance_timeout_ms: 30_000,
213            heartbeat_interval_ms: 3_000,
214            #[cfg(feature = "tls")]
215            tls_config: None,
216            #[cfg(feature = "tls")]
217            tls_server_name: None,
218        }
219    }
220
221    /// Set a single bootstrap server (convenience for `bootstrap_servers`).
222    pub fn bootstrap_server(mut self, server: impl Into<String>) -> Self {
223        self.bootstrap_servers = vec![server.into()];
224        self
225    }
226
227    /// Set multiple bootstrap servers for failover.
228    pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
229        self.bootstrap_servers = servers;
230        self
231    }
232
233    pub fn group_id(mut self, group: impl Into<String>) -> Self {
234        self.group_id = Some(group.into());
235        self
236    }
237
238    pub fn topics(mut self, topics: Vec<String>) -> Self {
239        self.topics = topics;
240        self
241    }
242
243    pub fn topic(mut self, topic: impl Into<String>) -> Self {
244        self.topics.push(topic.into());
245        self
246    }
247
248    /// Assign specific partitions for a topic (static assignment).
249    pub fn assign(mut self, topic: impl Into<String>, partitions: Vec<u32>) -> Self {
250        self.partitions.insert(topic.into(), partitions);
251        self
252    }
253
254    pub fn max_poll_records(mut self, n: u32) -> Self {
255        self.max_poll_records = n;
256        self
257    }
258
259    pub fn max_poll_interval_ms(mut self, ms: u64) -> Self {
260        self.max_poll_interval_ms = ms;
261        self
262    }
263
264    pub fn auto_commit_interval(mut self, interval: Option<Duration>) -> Self {
265        self.auto_commit_interval = interval;
266        self
267    }
268
269    pub fn enable_auto_commit(mut self, enabled: bool) -> Self {
270        if enabled {
271            self.auto_commit_interval = Some(Duration::from_secs(5));
272        } else {
273            self.auto_commit_interval = None;
274        }
275        self
276    }
277
278    pub fn isolation_level(mut self, level: u8) -> Self {
279        self.isolation_level = level;
280        self
281    }
282
283    /// Use read_committed isolation (only see committed transactional messages).
284    pub fn read_committed(mut self) -> Self {
285        self.isolation_level = 1;
286        self
287    }
288
289    pub fn auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
290        self.auth = Some(ConsumerAuthConfig {
291            username: username.into(),
292            password: password.into(),
293        });
294        self
295    }
296
297    /// Set the interval for periodic partition re-discovery.
298    pub fn metadata_refresh_interval(mut self, interval: Duration) -> Self {
299        self.metadata_refresh_interval = interval;
300        self
301    }
302
303    /// Set initial reconnect backoff delay in milliseconds.
304    pub fn reconnect_backoff_ms(mut self, ms: u64) -> Self {
305        self.reconnect_backoff_ms = ms;
306        self
307    }
308
309    /// Set maximum reconnect backoff delay in milliseconds.
310    pub fn reconnect_backoff_max_ms(mut self, ms: u64) -> Self {
311        self.reconnect_backoff_max_ms = ms;
312        self
313    }
314
315    /// Set maximum number of reconnect attempts (default: 10).
316    pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
317        self.max_reconnect_attempts = attempts;
318        self
319    }
320
321    /// Set session timeout for group coordination (default: 10 000 ms).
322    pub fn session_timeout_ms(mut self, ms: u32) -> Self {
323        self.session_timeout_ms = ms;
324        self
325    }
326
327    /// Set rebalance timeout (default: 30 000 ms).
328    pub fn rebalance_timeout_ms(mut self, ms: u32) -> Self {
329        self.rebalance_timeout_ms = ms;
330        self
331    }
332
333    /// Set heartbeat interval (default: 3 000 ms, should be ≤ 1/3 of session timeout).
334    pub fn heartbeat_interval_ms(mut self, ms: u64) -> Self {
335        self.heartbeat_interval_ms = ms;
336        self
337    }
338
339    /// Set TLS configuration for encrypted connections (CLIENT-06).
340    #[cfg(feature = "tls")]
341    pub fn tls(
342        mut self,
343        tls_config: rivven_core::tls::TlsConfig,
344        server_name: impl Into<String>,
345    ) -> Self {
346        self.tls_config = Some(tls_config);
347        self.tls_server_name = Some(server_name.into());
348        self
349    }
350
351    pub fn build(self) -> ConsumerConfig {
352        // Enforce heartbeat ≤ 1/3 of session timeout (Kafka best practice).
353        // If the user configured an unnecessarily long heartbeat, clamp it
354        // automatically instead of silently allowing session expiry.
355        let max_heartbeat = (self.session_timeout_ms as u64) / 3;
356        let heartbeat_interval_ms = if self.heartbeat_interval_ms > max_heartbeat {
357            tracing::warn!(
358                configured = self.heartbeat_interval_ms,
359                clamped_to = max_heartbeat,
360                session_timeout_ms = self.session_timeout_ms,
361                "heartbeat_interval_ms exceeds 1/3 of session_timeout_ms, clamping"
362            );
363            max_heartbeat
364        } else {
365            self.heartbeat_interval_ms
366        };
367
368        ConsumerConfig {
369            bootstrap_servers: self.bootstrap_servers,
370            group_id: self.group_id.unwrap_or_else(|| "default-group".into()),
371            topics: self.topics,
372            partitions: self.partitions,
373            max_poll_records: self.max_poll_records,
374            max_poll_interval_ms: self.max_poll_interval_ms,
375            auto_commit_interval: self.auto_commit_interval,
376            isolation_level: self.isolation_level,
377            auth: self.auth,
378            metadata_refresh_interval: self.metadata_refresh_interval,
379            reconnect_backoff_ms: self.reconnect_backoff_ms,
380            reconnect_backoff_max_ms: self.reconnect_backoff_max_ms,
381            max_reconnect_attempts: self.max_reconnect_attempts,
382            session_timeout_ms: self.session_timeout_ms,
383            rebalance_timeout_ms: self.rebalance_timeout_ms,
384            heartbeat_interval_ms,
385            #[cfg(feature = "tls")]
386            tls_config: self.tls_config,
387            #[cfg(feature = "tls")]
388            tls_server_name: self.tls_server_name,
389        }
390    }
391}
392
393impl Default for ConsumerConfigBuilder {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399impl ConsumerConfig {
400    pub fn builder() -> ConsumerConfigBuilder {
401        ConsumerConfigBuilder::new()
402    }
403}
404
405/// A consumed record with topic/partition metadata.
406#[derive(Debug, Clone)]
407pub struct ConsumerRecord {
408    /// Topic the record was consumed from (cheap `Arc` clone per record)
409    pub topic: Arc<str>,
410    /// Partition number
411    pub partition: u32,
412    /// Record offset within the partition
413    pub offset: u64,
414    /// Record data
415    pub data: MessageData,
416}
417
418/// High-level consumer that manages offset tracking and auto-commit.
419///
420/// Wraps one or more partitions across subscribed topics, polling them
421/// round-robin and tracking the latest consumed offset per partition.
422pub struct Consumer {
423    client: Client,
424    config: ConsumerConfig,
425    /// Current offset per (topic, partition) — next offset to fetch
426    offsets: HashMap<(Arc<str>, u32), u64>,
427    /// Resolved partition assignments: topic → Vec<partition_id>
428    assignments: HashMap<String, Vec<u32>>,
429    /// Flattened assignment list cached to avoid cloning on every poll
430    assignment_list: Vec<(Arc<str>, u32)>,
431    /// Last auto-commit time
432    last_commit: Instant,
433    /// Last partition discovery time
434    last_discovery: Instant,
435    /// Whether initial assignment discovery has been done
436    initialized: bool,
437    /// Member ID assigned by the group coordinator (empty for static assignment)
438    member_id: String,
439    /// Current generation ID from the coordinator
440    generation_id: u32,
441    /// Whether this consumer is the group leader (computes assignments in SyncGroup)
442    is_leader: bool,
443    /// Last heartbeat time (for periodic heartbeats during poll)
444    last_heartbeat: Instant,
445    /// Whether this consumer uses server-side group coordination
446    uses_coordination: bool,
447    /// Set when a fetch response or background heartbeat signals rebalance,
448    /// triggering a rejoin on the next poll (CLIENT-07 / CLIENT-02).
449    /// Shared with the background heartbeat task via `Arc`.
450    needs_rejoin: Arc<AtomicBool>,
451    /// Optional rebalance listener for partition revocation/assignment callbacks (CLIENT-08).
452    rebalance_listener: Option<Arc<dyn RebalanceListener>>,
453    /// Background heartbeat task handle (CLIENT-02).
454    /// Aborted on close, reconnect, or rebalance.
455    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
456}
457
458impl Consumer {
459    /// Create and connect a new consumer.
460    ///
461    /// Connects to the first available bootstrap server, authenticates if
462    /// configured, and discovers partition assignments for subscribed topics.
463    ///
464    /// ## Auto-commit semantics
465    ///
466    /// When `auto_commit_interval` is set, offsets are committed periodically
467    /// at the **next-fetch** position. This provides **at-most-once** semantics:
468    /// if the application crashes between `poll()` returning and the records
469    /// being processed, those records will be skipped on restart.
470    ///
471    /// For **at-least-once** semantics, disable auto-commit and call
472    /// `commit()` explicitly after processing each batch.
473    pub async fn new(config: ConsumerConfig) -> Result<Self> {
474        let servers = &config.bootstrap_servers;
475        if servers.is_empty() {
476            return Err(Error::ConnectionError(
477                "No bootstrap servers configured".to_string(),
478            ));
479        }
480
481        let mut last_error = None;
482        let mut client = None;
483        for server in servers {
484            // Connect with TLS when configured (CLIENT-06), otherwise plaintext.
485            #[cfg(feature = "tls")]
486            let connect_result = if let (Some(ref tls_cfg), Some(ref sni)) =
487                (&config.tls_config, &config.tls_server_name)
488            {
489                Client::connect_tls(server, tls_cfg, sni).await
490            } else {
491                Client::connect(server).await
492            };
493            #[cfg(not(feature = "tls"))]
494            let connect_result = Client::connect(server).await;
495
496            match connect_result {
497                Ok(c) => {
498                    client = Some(c);
499                    break;
500                }
501                Err(e) => {
502                    warn!(server = %server, error = %e, "Failed to connect to bootstrap server");
503                    last_error = Some(e);
504                }
505            }
506        }
507        let mut client = client.ok_or_else(|| {
508            last_error.unwrap_or_else(|| {
509                Error::ConnectionError("No bootstrap servers available".to_string())
510            })
511        })?;
512
513        // Authenticate if credentials are provided
514        if let Some(ref auth) = config.auth {
515            client
516                .authenticate_scram(&auth.username, &auth.password)
517                .await?;
518        }
519
520        let uses_coordination = config.partitions.is_empty();
521
522        let mut consumer = Self {
523            client,
524            config,
525            offsets: HashMap::new(),
526            assignments: HashMap::new(),
527            assignment_list: Vec::new(),
528            last_commit: Instant::now(),
529            last_discovery: Instant::now(),
530            initialized: false,
531            member_id: String::new(),
532            generation_id: 0,
533            is_leader: false,
534            last_heartbeat: Instant::now(),
535            uses_coordination,
536            needs_rejoin: Arc::new(AtomicBool::new(false)),
537            rebalance_listener: None,
538            heartbeat_handle: None,
539        };
540
541        consumer.discover_assignments().await?;
542
543        info!(
544            group_id = %consumer.config.group_id,
545            topics = ?consumer.config.topics,
546            partitions = ?consumer.assignments,
547            "Consumer initialized"
548        );
549
550        Ok(consumer)
551    }
552
553    /// Register a rebalance listener for partition revocation/assignment events.
554    ///
555    /// The listener is invoked during `discover_assignments()`:
556    /// - `on_partitions_revoked` is called with the old assignment before reassignment
557    /// - `on_partitions_assigned` is called with the new assignment after reassignment
558    pub fn set_rebalance_listener(&mut self, listener: Arc<dyn RebalanceListener>) {
559        self.rebalance_listener = Some(listener);
560    }
561
562    /// Spawn (or restart) the background heartbeat task (CLIENT-02).
563    ///
564    /// The task opens its own TCP connection to the first available
565    /// bootstrap server, authenticates if needed, then sends periodic
566    /// heartbeats independently of the poll loop. This matches Kafka's
567    /// dedicated `HeartbeatThread` design.
568    ///
569    /// If the heartbeat detects `REBALANCE_IN_PROGRESS` or a connection
570    /// error, it sets `needs_rejoin` so the next `poll()` triggers a
571    /// group rejoin.
572    async fn spawn_heartbeat_task(&mut self) {
573        // Abort any existing heartbeat task from a previous generation.
574        if let Some(handle) = self.heartbeat_handle.take() {
575            handle.abort();
576        }
577
578        // Nothing to heartbeat without a member ID.
579        if self.member_id.is_empty() || !self.uses_coordination {
580            return;
581        }
582
583        let group_id = self.config.group_id.clone();
584        let member_id = self.member_id.clone();
585        let generation_id = self.generation_id;
586        let interval = Duration::from_millis(self.config.heartbeat_interval_ms);
587        let needs_rejoin = self.needs_rejoin.clone();
588        let servers = self.config.bootstrap_servers.clone();
589        let auth = self.config.auth.clone();
590
591        self.heartbeat_handle = Some(tokio::spawn(async move {
592            // Establish a dedicated connection for heartbeats.
593            let mut hb_client = None;
594            for server in &servers {
595                match Client::connect(server).await {
596                    Ok(mut c) => {
597                        if let Some(ref auth) = auth {
598                            if let Err(e) =
599                                c.authenticate_scram(&auth.username, &auth.password).await
600                            {
601                                warn!(
602                                    server = %server,
603                                    error = %e,
604                                    "Heartbeat connection auth failed, trying next server"
605                                );
606                                continue;
607                            }
608                        }
609                        hb_client = Some(c);
610                        break;
611                    }
612                    Err(e) => {
613                        warn!(
614                            server = %server,
615                            error = %e,
616                            "Heartbeat connection failed, trying next server"
617                        );
618                    }
619                }
620            }
621
622            let Some(mut client) = hb_client else {
623                warn!("Could not establish heartbeat connection to any server, signaling rejoin");
624                needs_rejoin.store(true, Ordering::Release);
625                return;
626            };
627
628            let mut ticker = tokio::time::interval(interval);
629            ticker.tick().await; // first tick is immediate, skip it
630
631            loop {
632                ticker.tick().await;
633                match client.heartbeat(&group_id, generation_id, &member_id).await {
634                    Ok(27) => {
635                        // REBALANCE_IN_PROGRESS
636                        info!(
637                            group_id = %group_id,
638                            "Background heartbeat: rebalance in progress, signaling rejoin"
639                        );
640                        needs_rejoin.store(true, Ordering::Release);
641                    }
642                    Ok(_) => {
643                        // OK — heartbeat accepted
644                    }
645                    Err(e) => {
646                        warn!(
647                            group_id = %group_id,
648                            error = %e,
649                            "Background heartbeat failed, signaling rejoin"
650                        );
651                        needs_rejoin.store(true, Ordering::Release);
652                        // Keep running — may be a transient error.
653                        // If the generation is stale, the next poll() will
654                        // rejoin and spawn a fresh heartbeat task.
655                    }
656                }
657            }
658        }));
659    }
660
661    /// Discover partition assignments for subscribed topics.
662    ///
663    /// When server-side group coordination is active (`uses_coordination`),
664    /// performs the JoinGroup/SyncGroup protocol with the coordinator.
665    /// Otherwise, queries metadata for each topic and assigns all partitions
666    /// (or uses explicit assignments from config).
667    async fn discover_assignments(&mut self) -> Result<()> {
668        // Capture old assignments for rebalance callbacks (CLIENT-08).
669        let old_tps: Vec<TopicPartition> = self
670            .assignments
671            .iter()
672            .flat_map(|(t, ps)| {
673                let arc: Arc<str> = Arc::from(t.as_str());
674                ps.iter().map(move |&p| TopicPartition {
675                    topic: arc.clone(),
676                    partition: p,
677                })
678            })
679            .collect();
680
681        // Invoke on_partitions_revoked before changing assignments
682        if !old_tps.is_empty() {
683            if let Some(ref listener) = self.rebalance_listener {
684                listener.on_partitions_revoked(&old_tps).await;
685            }
686        }
687
688        if self.uses_coordination {
689            self.discover_via_coordination().await?;
690            // Spawn a dedicated background heartbeat task so heartbeats
691            // are decoupled from the poll loop (CLIENT-02).
692            self.spawn_heartbeat_task().await;
693        } else {
694            self.discover_via_metadata().await?;
695        }
696
697        // Clean up stale offsets for partitions we no longer own
698        // (prevents unbounded memory growth across rebalances).
699        let owned_keys: std::collections::HashSet<(Arc<str>, u32)> = self
700            .assignments
701            .iter()
702            .flat_map(|(t, ps)| {
703                let arc: Arc<str> = Arc::from(t.as_str());
704                ps.iter().map(move |&p| (arc.clone(), p))
705            })
706            .collect();
707        self.offsets.retain(|k, _| owned_keys.contains(k));
708
709        // Initialize offsets from server (committed offsets)
710        for (topic, partitions) in &self.assignments {
711            for &partition in partitions {
712                let key: (Arc<str>, u32) = (Arc::from(topic.as_str()), partition);
713                if self.offsets.contains_key(&key) {
714                    continue;
715                }
716                // Try to load committed offset from server
717                match self
718                    .client
719                    .get_offset(&self.config.group_id, topic, partition)
720                    .await
721                {
722                    Ok(Some(offset)) => {
723                        debug!(
724                            topic = %topic,
725                            partition,
726                            offset,
727                            "Resumed from committed offset"
728                        );
729                        self.offsets.insert(key, offset);
730                    }
731                    Ok(None) => {
732                        // No committed offset — start from 0
733                        self.offsets.insert(key, 0);
734                    }
735                    Err(e) => {
736                        debug!(
737                            topic = %topic,
738                            partition,
739                            error = %e,
740                            "Failed to load committed offset, starting from 0"
741                        );
742                        self.offsets.insert(key, 0);
743                    }
744                }
745            }
746        }
747
748        self.initialized = true;
749
750        // Rebuild cached flattened assignment list (one Arc<str> per topic,
751        // cloned cheaply for each partition).
752        self.assignment_list = self
753            .assignments
754            .iter()
755            .flat_map(|(t, ps)| {
756                let arc: Arc<str> = Arc::from(t.as_str());
757                ps.iter().map(move |&p| (arc.clone(), p))
758            })
759            .collect();
760
761        // Invoke on_partitions_assigned with the new assignment (CLIENT-08)
762        if let Some(ref listener) = self.rebalance_listener {
763            let new_tps: Vec<TopicPartition> = self
764                .assignment_list
765                .iter()
766                .map(|(t, p)| TopicPartition {
767                    topic: t.clone(),
768                    partition: *p,
769                })
770                .collect();
771            if !new_tps.is_empty() {
772                listener.on_partitions_assigned(&new_tps).await;
773            }
774        }
775
776        Ok(())
777    }
778
779    /// Discover assignments via server-side consumer group coordination.
780    ///
781    /// 1. **JoinGroup** — register with the coordinator; receive member ID,
782    ///    generation ID, and the list of group members (if leader).
783    /// 2. **SyncGroup** — the leader computes a range-based assignment for
784    ///    all members and submits it; every member receives its own
785    ///    partition list.
786    async fn discover_via_coordination(&mut self) -> Result<()> {
787        // Step 1: JoinGroup
788        let (generation_id, _protocol_type, member_id, leader_id, members) = self
789            .client
790            .join_group(
791                &self.config.group_id,
792                &self.member_id,
793                self.config.session_timeout_ms,
794                self.config.rebalance_timeout_ms,
795                "consumer",
796                self.config.topics.clone(),
797            )
798            .await?;
799
800        self.member_id = member_id.clone();
801        self.generation_id = generation_id;
802        self.is_leader = member_id == leader_id;
803
804        info!(
805            group_id = %self.config.group_id,
806            member_id = %self.member_id,
807            generation_id,
808            is_leader = self.is_leader,
809            member_count = members.len(),
810            "Joined consumer group"
811        );
812
813        // Step 2: SyncGroup
814        // Leader computes range-based assignments for all members.
815        let group_assignments = if self.is_leader {
816            self.compute_range_assignments(&members).await?
817        } else {
818            Vec::new()
819        };
820
821        let my_assignments = self
822            .client
823            .sync_group(
824                &self.config.group_id,
825                generation_id,
826                &self.member_id,
827                group_assignments,
828            )
829            .await?;
830
831        // Apply returned assignments
832        self.assignments.clear();
833        for (topic, partitions) in my_assignments {
834            debug!(
835                topic = %topic,
836                partitions = ?partitions,
837                "Received partition assignment"
838            );
839            self.assignments.insert(topic, partitions);
840        }
841
842        self.last_heartbeat = Instant::now();
843
844        Ok(())
845    }
846
847    /// Compute range-based partition assignments for all group members.
848    ///
849    /// For each topic, fetches the partition count from the server, then
850    /// distributes partitions evenly across the members that subscribe
851    /// to that topic.
852    async fn compute_range_assignments(
853        &mut self,
854        members: &[(String, Vec<String>)],
855    ) -> Result<Vec<(String, Vec<(String, Vec<u32>)>)>> {
856        // Collect all unique topics across all members
857        let mut all_topics: Vec<String> = members
858            .iter()
859            .flat_map(|(_, subs)| subs.iter().cloned())
860            .collect();
861        all_topics.sort();
862        all_topics.dedup();
863
864        // member_id → Vec<(topic, Vec<partition>)>
865        let mut result_map: HashMap<String, Vec<(String, Vec<u32>)>> = members
866            .iter()
867            .map(|(mid, _)| (mid.clone(), Vec::new()))
868            .collect();
869
870        for topic in &all_topics {
871            // Find members subscribed to this topic
872            let mut subscribed: Vec<&str> = members
873                .iter()
874                .filter(|(_, subs)| subs.iter().any(|s| s == topic))
875                .map(|(mid, _)| mid.as_str())
876                .collect();
877            subscribed.sort(); // deterministic ordering
878
879            let partition_count = match self.client.get_metadata(topic.as_str()).await {
880                Ok((_name, count)) => count,
881                Err(e) => {
882                    warn!(topic = %topic, error = %e, "Failed to get metadata for assignment");
883                    continue;
884                }
885            };
886
887            if subscribed.is_empty() || partition_count == 0 {
888                continue;
889            }
890
891            // Range assignment: distribute partitions as evenly as possible
892            let n_members = subscribed.len() as u32;
893            let per_member = partition_count / n_members;
894            let remainder = partition_count % n_members;
895
896            let mut offset = 0u32;
897            for (i, mid) in subscribed.iter().enumerate() {
898                let extra = if (i as u32) < remainder { 1 } else { 0 };
899                let count = per_member + extra;
900                let partitions: Vec<u32> = (offset..offset + count).collect();
901                offset += count;
902
903                if let Some(entry) = result_map.get_mut(*mid) {
904                    entry.push((topic.clone(), partitions));
905                }
906            }
907        }
908
909        Ok(result_map.into_iter().collect())
910    }
911
912    /// Discover assignments via metadata queries (static model).
913    ///
914    /// For each topic, queries the server for the partition count and
915    /// assigns all partitions (or uses explicit assignments from config).
916    async fn discover_via_metadata(&mut self) -> Result<()> {
917        for topic in &self.config.topics {
918            if let Some(explicit) = self.config.partitions.get(topic) {
919                // Use explicit partition assignment
920                self.assignments.insert(topic.clone(), explicit.clone());
921            } else {
922                // Discover partitions from server
923                match self.client.get_metadata(topic.as_str()).await {
924                    Ok((_name, partition_count)) => {
925                        let partitions: Vec<u32> = (0..partition_count).collect();
926                        self.assignments.insert(topic.clone(), partitions);
927                    }
928                    Err(e) => {
929                        warn!(
930                            topic = %topic,
931                            error = %e,
932                            "Failed to discover partitions, will retry on next poll"
933                        );
934                    }
935                }
936            }
937        }
938
939        Ok(())
940    }
941
942    /// Poll for new records across all assigned partitions.
943    ///
944    /// Automatically reconnects with exponential backoff on connection
945    /// errors and periodically re-discovers partition
946    /// assignments.
947    pub async fn poll(&mut self) -> Result<Vec<ConsumerRecord>> {
948        match self.poll_inner().await {
949            Ok(records) => Ok(records),
950            Err(e) if Self::is_connection_error(&e) => {
951                warn!(error = %e, "Connection error during poll, attempting reconnect");
952                self.reconnect().await?;
953                self.poll_inner().await
954            }
955            Err(e) => Err(e),
956        }
957    }
958
959    /// Inner poll implementation without reconnection wrapper.
960    async fn poll_inner(&mut self) -> Result<Vec<ConsumerRecord>> {
961        if !self.initialized {
962            self.discover_assignments().await?;
963        }
964
965        // If the background heartbeat or a fetch response signalled a
966        // rebalance, rejoin the group before proceeding (CLIENT-02 / CLIENT-07).
967        if self.needs_rejoin.load(Ordering::Acquire) && self.uses_coordination {
968            info!(
969                group_id = %self.config.group_id,
970                "Rejoining group due to rebalance signal"
971            );
972            self.discover_assignments().await?;
973            self.needs_rejoin.store(false, Ordering::Release);
974        }
975
976        // Periodically re-discover partition assignments so
977        // that newly added partitions are picked up automatically.
978        if self.last_discovery.elapsed() >= self.config.metadata_refresh_interval {
979            if let Err(e) = self.discover_assignments().await {
980                warn!(error = %e, "Failed to re-discover assignments, continuing with existing");
981            }
982            self.last_discovery = Instant::now();
983        }
984
985        let mut records = Vec::new();
986        let isolation_level = if self.config.isolation_level > 0 {
987            Some(self.config.isolation_level)
988        } else {
989            None
990        };
991
992        // Phase 1: Pipelined (non-blocking) fetch from all partitions.
993        // Sends all consume requests at once, then reads all responses —
994        // eliminates per-partition round-trip latency.
995        if !self.assignment_list.is_empty() {
996            let fetches: Vec<(&str, u32, u64, u32, Option<u8>)> = self
997                .assignment_list
998                .iter()
999                .map(|(topic, partition)| {
1000                    let key = (topic.clone(), *partition);
1001                    let offset = self.offsets.get(&key).copied().unwrap_or(0);
1002                    (
1003                        &**topic,
1004                        *partition,
1005                        offset,
1006                        self.config.max_poll_records,
1007                        isolation_level,
1008                    )
1009                })
1010                .collect();
1011
1012            let results = self.client.consume_pipelined(&fetches).await?;
1013
1014            for (i, result) in results.into_iter().enumerate() {
1015                let (topic, partition) = &self.assignment_list[i];
1016                match result {
1017                    Ok(messages) if !messages.is_empty() => {
1018                        let key = (topic.clone(), *partition);
1019                        let max_offset = messages.iter().map(|m| m.offset).max().unwrap_or(0);
1020                        self.offsets.insert(key, max_offset + 1);
1021
1022                        records.extend(messages.into_iter().map(|data| ConsumerRecord {
1023                            topic: topic.clone(),
1024                            partition: *partition,
1025                            offset: data.offset,
1026                            data,
1027                        }));
1028                    }
1029                    Err(e) => {
1030                        // Detect rebalance-related errors and set the needs_rejoin
1031                        // flag so the next poll triggers a group rejoin (CLIENT-07).
1032                        let err_str = e.to_string();
1033                        if err_str.contains("UNKNOWN_MEMBER_ID")
1034                            || err_str.contains("ILLEGAL_GENERATION")
1035                            || err_str.contains("REBALANCE_IN_PROGRESS")
1036                        {
1037                            warn!(
1038                                topic = %topic,
1039                                partition = partition,
1040                                error = %e,
1041                                "Rebalance signal in fetch response, will rejoin group"
1042                            );
1043                            self.needs_rejoin.store(true, Ordering::Release);
1044                        } else {
1045                            warn!(
1046                                topic = %topic,
1047                                partition = partition,
1048                                error = %e,
1049                                "Pipelined fetch error, skipping partition"
1050                            );
1051                        }
1052                    }
1053                    _ => {} // empty result — no data
1054                }
1055            }
1056        }
1057
1058        // Phase 2: If nothing was returned and long-polling is enabled,
1059        // issue a single long-poll to avoid a busy loop. We rotate the
1060        // assignment list so each call long-polls a different partition,
1061        // preventing starvation where only the first partition is polled.
1062        if records.is_empty() && self.config.max_poll_interval_ms > 0 {
1063            if !self.assignment_list.is_empty() {
1064                self.assignment_list.rotate_left(1);
1065            }
1066            if let Some((topic, partition)) = self.assignment_list.first() {
1067                let key = (topic.clone(), *partition);
1068                let offset = self.offsets.get(&key).copied().unwrap_or(0);
1069
1070                // Cap long-poll timeout so it returns before the next
1071                // heartbeat is due, preventing session expiry during
1072                // long-polls (CLIENT-02).
1073                let max_wait = if self.uses_coordination {
1074                    self.config.max_poll_interval_ms.min(
1075                        self.config
1076                            .heartbeat_interval_ms
1077                            .saturating_sub(500)
1078                            .max(500),
1079                    )
1080                } else {
1081                    self.config.max_poll_interval_ms
1082                };
1083
1084                let messages = self
1085                    .client
1086                    .consume_long_poll(
1087                        topic.to_string(),
1088                        *partition,
1089                        offset,
1090                        self.config.max_poll_records,
1091                        isolation_level,
1092                        max_wait,
1093                    )
1094                    .await?;
1095
1096                if !messages.is_empty() {
1097                    let max_offset = messages.iter().map(|m| m.offset).max().unwrap_or(offset);
1098                    self.offsets.insert(key, max_offset + 1);
1099
1100                    records.extend(messages.into_iter().map(|data| ConsumerRecord {
1101                        topic: topic.clone(),
1102                        partition: *partition,
1103                        offset: data.offset,
1104                        data,
1105                    }));
1106                }
1107            }
1108        }
1109
1110        // Auto-commit if interval has elapsed
1111        if let Some(interval) = self.config.auto_commit_interval {
1112            if self.last_commit.elapsed() >= interval {
1113                if let Err(e) = self.commit_inner().await {
1114                    warn!(error = %e, "Auto-commit failed");
1115                }
1116            }
1117        }
1118
1119        Ok(records)
1120    }
1121
1122    /// Commit current offsets to the server.
1123    ///
1124    /// Automatically reconnects on connection errors.
1125    pub async fn commit(&mut self) -> Result<()> {
1126        match self.commit_inner().await {
1127            Ok(()) => Ok(()),
1128            Err(e) if Self::is_connection_error(&e) => {
1129                warn!(error = %e, "Connection error during commit, attempting reconnect");
1130                self.reconnect().await?;
1131                self.commit_inner().await
1132            }
1133            Err(e) => Err(e),
1134        }
1135    }
1136
1137    /// Inner commit implementation without reconnection wrapper.
1138    ///
1139    /// Uses request pipelining to send all offset commits at once, then
1140    /// reads all responses. Collects all errors instead of only the last.
1141    async fn commit_inner(&mut self) -> Result<()> {
1142        if self.offsets.is_empty() {
1143            return Ok(());
1144        }
1145
1146        // Build commit requests
1147        let commits: Vec<(String, u32, u64)> = self
1148            .offsets
1149            .iter()
1150            .map(|((topic, partition), offset)| (topic.to_string(), *partition, *offset))
1151            .collect();
1152
1153        let mut errors = Vec::new();
1154
1155        // Use pipelining: send all commit requests back-to-back, then read responses
1156        if self.client.is_poisoned() {
1157            // Stream is desynchronized — sequential fallback would also fail.
1158            // Trigger reconnect by returning connection error immediately.
1159            return Err(Error::ConnectionError(
1160                "Client stream is desynchronized — reconnect required".into(),
1161            ));
1162        }
1163
1164        {
1165            let results = self
1166                .client
1167                .commit_offsets_pipelined(&self.config.group_id, &commits)
1168                .await;
1169
1170            match results {
1171                Ok(per_partition) => {
1172                    for (i, result) in per_partition.into_iter().enumerate() {
1173                        if let Err(e) = result {
1174                            let (topic, partition, offset) = &commits[i];
1175                            warn!(
1176                                topic = %topic, partition, offset, error = %e,
1177                                "Failed to commit offset"
1178                            );
1179                            errors.push(e);
1180                        }
1181                    }
1182                }
1183                Err(e) => {
1184                    // Transport-level failure — all commits failed
1185                    errors.push(e);
1186                }
1187            }
1188        }
1189
1190        self.last_commit = Instant::now();
1191
1192        if errors.is_empty() {
1193            debug!(
1194                group_id = %self.config.group_id,
1195                partitions = self.offsets.len(),
1196                "Offsets committed"
1197            );
1198            Ok(())
1199        } else {
1200            // Return the first error (all are logged above)
1201            Err(errors.into_iter().next().expect("errors is non-empty"))
1202        }
1203    }
1204
1205    /// Seek a specific partition to a given offset.
1206    ///
1207    /// The next `poll()` will fetch from this offset for the specified partition.
1208    pub fn seek(&mut self, topic: impl Into<String>, partition: u32, offset: u64) {
1209        let arc: Arc<str> = Arc::from(topic.into());
1210        self.offsets.insert((arc, partition), offset);
1211    }
1212
1213    /// Seek all partitions of a topic to the beginning (offset 0).
1214    pub fn seek_to_beginning(&mut self, topic: &str) {
1215        if let Some(partitions) = self.assignments.get(topic) {
1216            let arc: Arc<str> = Arc::from(topic);
1217            for &p in partitions {
1218                self.offsets.insert((arc.clone(), p), 0);
1219            }
1220        }
1221    }
1222
1223    /// Get the current offset position for a (topic, partition) pair.
1224    pub fn position(&self, topic: &str, partition: u32) -> Option<u64> {
1225        self.offsets
1226            .get(&(Arc::<str>::from(topic), partition))
1227            .copied()
1228    }
1229
1230    /// Get current partition assignments.
1231    pub fn assignments(&self) -> &HashMap<String, Vec<u32>> {
1232        &self.assignments
1233    }
1234
1235    /// Get the consumer group ID.
1236    pub fn group_id(&self) -> &str {
1237        &self.config.group_id
1238    }
1239
1240    // ========================================================================
1241    // Reconnection
1242    // ========================================================================
1243
1244    /// Attempt to reconnect to a bootstrap server with exponential backoff.
1245    ///
1246    /// Tries each configured bootstrap server in round-robin order.
1247    async fn reconnect(&mut self) -> Result<()> {
1248        // Abort background heartbeat — it holds a stale connection and
1249        // generation. A new one will be spawned after rejoining.
1250        if let Some(handle) = self.heartbeat_handle.take() {
1251            handle.abort();
1252        }
1253
1254        let mut backoff = Duration::from_millis(self.config.reconnect_backoff_ms);
1255        let max_backoff = Duration::from_millis(self.config.reconnect_backoff_max_ms);
1256        let servers = &self.config.bootstrap_servers;
1257
1258        if servers.is_empty() {
1259            return Err(Error::ConnectionError(
1260                "No bootstrap servers configured".to_string(),
1261            ));
1262        }
1263
1264        for attempt in 1..=self.config.max_reconnect_attempts {
1265            // Round-robin across bootstrap servers
1266            let server = &servers[(attempt as usize - 1) % servers.len()];
1267            info!(
1268                attempt,
1269                server = %server,
1270                "Attempting to reconnect"
1271            );
1272            match Client::connect(server).await {
1273                Ok(mut new_client) => {
1274                    // Re-authenticate if credentials are configured
1275                    if let Some(ref auth) = self.config.auth {
1276                        if let Err(e) = new_client
1277                            .authenticate_scram(&auth.username, &auth.password)
1278                            .await
1279                        {
1280                            warn!(error = %e, attempt, "Re-authentication failed during reconnect");
1281                            tokio::time::sleep(backoff).await;
1282                            backoff = (backoff * 2).min(max_backoff);
1283                            continue;
1284                        }
1285                    }
1286                    self.client = new_client;
1287                    info!("Consumer reconnected successfully to {}", server);
1288
1289                    // Rejoin the consumer group on the new connection.
1290                    // The server has no state for this member on a new TCP
1291                    // connection, so we must re-discover assignments.
1292                    if self.uses_coordination {
1293                        if let Err(e) = self.discover_assignments().await {
1294                            warn!(error = %e, "Failed to rejoin group after reconnect");
1295                        }
1296                    }
1297
1298                    return Ok(());
1299                }
1300                Err(e) => {
1301                    warn!(error = %e, attempt, server = %server, "Reconnect attempt failed");
1302                    tokio::time::sleep(backoff).await;
1303                    backoff = (backoff * 2).min(max_backoff);
1304                }
1305            }
1306        }
1307        Err(Error::ConnectionError(format!(
1308            "Failed to reconnect to any of {:?} after {} attempts",
1309            servers, self.config.max_reconnect_attempts
1310        )))
1311    }
1312
1313    /// Check whether an error indicates a broken connection.
1314    fn is_connection_error(e: &Error) -> bool {
1315        matches!(
1316            e,
1317            Error::ConnectionError(_)
1318                | Error::IoError(_, _)
1319                | Error::Timeout
1320                | Error::TimeoutWithMessage(_)
1321                | Error::ProtocolError(_)
1322                | Error::ResponseTooLarge(_, _)
1323        )
1324    }
1325
1326    /// Close the consumer, committing final offsets and leaving the group.
1327    pub async fn close(mut self) -> Result<()> {
1328        // Stop background heartbeat first (CLIENT-02).
1329        if let Some(handle) = self.heartbeat_handle.take() {
1330            handle.abort();
1331        }
1332
1333        if self.config.auto_commit_interval.is_some() {
1334            self.commit().await?;
1335        }
1336
1337        // Leave group if using server-side coordination
1338        if self.uses_coordination && !self.member_id.is_empty() {
1339            if let Err(e) = self
1340                .client
1341                .leave_group(&self.config.group_id, &self.member_id)
1342                .await
1343            {
1344                warn!(
1345                    error = %e,
1346                    group_id = %self.config.group_id,
1347                    member_id = %self.member_id,
1348                    "Failed to leave group gracefully"
1349                );
1350            } else {
1351                info!(
1352                    group_id = %self.config.group_id,
1353                    member_id = %self.member_id,
1354                    "Left consumer group"
1355                );
1356            }
1357        }
1358
1359        info!(group_id = %self.config.group_id, "Consumer closed");
1360        Ok(())
1361    }
1362}