1use crate::client::Client;
48use crate::error::{Error, Result};
49use rivven_protocol::MessageData;
50use std::collections::HashMap;
51use std::sync::atomic::{AtomicBool, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tracing::{debug, info, warn};
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct TopicPartition {
59 pub topic: Arc<str>,
60 pub partition: u32,
61}
62
63#[async_trait::async_trait]
87pub trait RebalanceListener: Send + Sync {
88 async fn on_partitions_revoked(&self, partitions: &[TopicPartition]);
93
94 async fn on_partitions_assigned(&self, partitions: &[TopicPartition]);
98}
99
100#[derive(Debug, Clone)]
102pub struct ConsumerConfig {
103 pub bootstrap_servers: Vec<String>,
109 pub group_id: String,
111 pub topics: Vec<String>,
113 pub partitions: HashMap<String, Vec<u32>>,
116 pub max_poll_records: u32,
118 pub max_poll_interval_ms: u64,
120 pub auto_commit_interval: Option<Duration>,
122 pub isolation_level: u8,
124 pub auth: Option<ConsumerAuthConfig>,
126 pub metadata_refresh_interval: Duration,
129 pub reconnect_backoff_ms: u64,
131 pub reconnect_backoff_max_ms: u64,
133 pub max_reconnect_attempts: u32,
135 pub session_timeout_ms: u32,
139 pub rebalance_timeout_ms: u32,
142 pub heartbeat_interval_ms: u64,
145 #[cfg(feature = "tls")]
148 pub tls_config: Option<rivven_core::tls::TlsConfig>,
149 #[cfg(feature = "tls")]
152 pub tls_server_name: Option<String>,
153}
154
155#[derive(Clone)]
157pub struct ConsumerAuthConfig {
158 pub username: String,
159 pub password: String,
160}
161
162impl std::fmt::Debug for ConsumerAuthConfig {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("ConsumerAuthConfig")
165 .field("username", &self.username)
166 .field("password", &"[REDACTED]")
167 .finish()
168 }
169}
170
171pub struct ConsumerConfigBuilder {
173 bootstrap_servers: Vec<String>,
174 group_id: Option<String>,
175 topics: Vec<String>,
176 partitions: HashMap<String, Vec<u32>>,
177 max_poll_records: u32,
178 max_poll_interval_ms: u64,
179 auto_commit_interval: Option<Duration>,
180 isolation_level: u8,
181 auth: Option<ConsumerAuthConfig>,
182 metadata_refresh_interval: Duration,
183 reconnect_backoff_ms: u64,
184 reconnect_backoff_max_ms: u64,
185 max_reconnect_attempts: u32,
186 session_timeout_ms: u32,
187 rebalance_timeout_ms: u32,
188 heartbeat_interval_ms: u64,
189 #[cfg(feature = "tls")]
190 tls_config: Option<rivven_core::tls::TlsConfig>,
191 #[cfg(feature = "tls")]
192 tls_server_name: Option<String>,
193}
194
195impl ConsumerConfigBuilder {
196 pub fn new() -> Self {
197 Self {
198 bootstrap_servers: vec!["127.0.0.1:9092".to_string()],
199 group_id: None,
200 topics: Vec::new(),
201 partitions: HashMap::new(),
202 max_poll_records: 500,
203 max_poll_interval_ms: 5000,
204 auto_commit_interval: Some(Duration::from_secs(5)),
205 isolation_level: 0,
206 auth: None,
207 metadata_refresh_interval: Duration::from_secs(300),
208 reconnect_backoff_ms: 100,
209 reconnect_backoff_max_ms: 10_000,
210 max_reconnect_attempts: 10,
211 session_timeout_ms: 10_000,
212 rebalance_timeout_ms: 30_000,
213 heartbeat_interval_ms: 3_000,
214 #[cfg(feature = "tls")]
215 tls_config: None,
216 #[cfg(feature = "tls")]
217 tls_server_name: None,
218 }
219 }
220
221 pub fn bootstrap_server(mut self, server: impl Into<String>) -> Self {
223 self.bootstrap_servers = vec![server.into()];
224 self
225 }
226
227 pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
229 self.bootstrap_servers = servers;
230 self
231 }
232
233 pub fn group_id(mut self, group: impl Into<String>) -> Self {
234 self.group_id = Some(group.into());
235 self
236 }
237
238 pub fn topics(mut self, topics: Vec<String>) -> Self {
239 self.topics = topics;
240 self
241 }
242
243 pub fn topic(mut self, topic: impl Into<String>) -> Self {
244 self.topics.push(topic.into());
245 self
246 }
247
248 pub fn assign(mut self, topic: impl Into<String>, partitions: Vec<u32>) -> Self {
250 self.partitions.insert(topic.into(), partitions);
251 self
252 }
253
254 pub fn max_poll_records(mut self, n: u32) -> Self {
255 self.max_poll_records = n;
256 self
257 }
258
259 pub fn max_poll_interval_ms(mut self, ms: u64) -> Self {
260 self.max_poll_interval_ms = ms;
261 self
262 }
263
264 pub fn auto_commit_interval(mut self, interval: Option<Duration>) -> Self {
265 self.auto_commit_interval = interval;
266 self
267 }
268
269 pub fn enable_auto_commit(mut self, enabled: bool) -> Self {
270 if enabled {
271 self.auto_commit_interval = Some(Duration::from_secs(5));
272 } else {
273 self.auto_commit_interval = None;
274 }
275 self
276 }
277
278 pub fn isolation_level(mut self, level: u8) -> Self {
279 self.isolation_level = level;
280 self
281 }
282
283 pub fn read_committed(mut self) -> Self {
285 self.isolation_level = 1;
286 self
287 }
288
289 pub fn auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
290 self.auth = Some(ConsumerAuthConfig {
291 username: username.into(),
292 password: password.into(),
293 });
294 self
295 }
296
297 pub fn metadata_refresh_interval(mut self, interval: Duration) -> Self {
299 self.metadata_refresh_interval = interval;
300 self
301 }
302
303 pub fn reconnect_backoff_ms(mut self, ms: u64) -> Self {
305 self.reconnect_backoff_ms = ms;
306 self
307 }
308
309 pub fn reconnect_backoff_max_ms(mut self, ms: u64) -> Self {
311 self.reconnect_backoff_max_ms = ms;
312 self
313 }
314
315 pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
317 self.max_reconnect_attempts = attempts;
318 self
319 }
320
321 pub fn session_timeout_ms(mut self, ms: u32) -> Self {
323 self.session_timeout_ms = ms;
324 self
325 }
326
327 pub fn rebalance_timeout_ms(mut self, ms: u32) -> Self {
329 self.rebalance_timeout_ms = ms;
330 self
331 }
332
333 pub fn heartbeat_interval_ms(mut self, ms: u64) -> Self {
335 self.heartbeat_interval_ms = ms;
336 self
337 }
338
339 #[cfg(feature = "tls")]
341 pub fn tls(
342 mut self,
343 tls_config: rivven_core::tls::TlsConfig,
344 server_name: impl Into<String>,
345 ) -> Self {
346 self.tls_config = Some(tls_config);
347 self.tls_server_name = Some(server_name.into());
348 self
349 }
350
351 pub fn build(self) -> ConsumerConfig {
352 let max_heartbeat = (self.session_timeout_ms as u64) / 3;
356 let heartbeat_interval_ms = if self.heartbeat_interval_ms > max_heartbeat {
357 tracing::warn!(
358 configured = self.heartbeat_interval_ms,
359 clamped_to = max_heartbeat,
360 session_timeout_ms = self.session_timeout_ms,
361 "heartbeat_interval_ms exceeds 1/3 of session_timeout_ms, clamping"
362 );
363 max_heartbeat
364 } else {
365 self.heartbeat_interval_ms
366 };
367
368 ConsumerConfig {
369 bootstrap_servers: self.bootstrap_servers,
370 group_id: self.group_id.unwrap_or_else(|| "default-group".into()),
371 topics: self.topics,
372 partitions: self.partitions,
373 max_poll_records: self.max_poll_records,
374 max_poll_interval_ms: self.max_poll_interval_ms,
375 auto_commit_interval: self.auto_commit_interval,
376 isolation_level: self.isolation_level,
377 auth: self.auth,
378 metadata_refresh_interval: self.metadata_refresh_interval,
379 reconnect_backoff_ms: self.reconnect_backoff_ms,
380 reconnect_backoff_max_ms: self.reconnect_backoff_max_ms,
381 max_reconnect_attempts: self.max_reconnect_attempts,
382 session_timeout_ms: self.session_timeout_ms,
383 rebalance_timeout_ms: self.rebalance_timeout_ms,
384 heartbeat_interval_ms,
385 #[cfg(feature = "tls")]
386 tls_config: self.tls_config,
387 #[cfg(feature = "tls")]
388 tls_server_name: self.tls_server_name,
389 }
390 }
391}
392
393impl Default for ConsumerConfigBuilder {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399impl ConsumerConfig {
400 pub fn builder() -> ConsumerConfigBuilder {
401 ConsumerConfigBuilder::new()
402 }
403}
404
405#[derive(Debug, Clone)]
407pub struct ConsumerRecord {
408 pub topic: Arc<str>,
410 pub partition: u32,
412 pub offset: u64,
414 pub data: MessageData,
416}
417
418pub struct Consumer {
423 client: Client,
424 config: ConsumerConfig,
425 offsets: HashMap<(Arc<str>, u32), u64>,
427 assignments: HashMap<String, Vec<u32>>,
429 assignment_list: Vec<(Arc<str>, u32)>,
431 last_commit: Instant,
433 last_discovery: Instant,
435 initialized: bool,
437 member_id: String,
439 generation_id: u32,
441 is_leader: bool,
443 last_heartbeat: Instant,
445 uses_coordination: bool,
447 needs_rejoin: Arc<AtomicBool>,
451 rebalance_listener: Option<Arc<dyn RebalanceListener>>,
453 heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
456}
457
458impl Consumer {
459 pub async fn new(config: ConsumerConfig) -> Result<Self> {
474 let servers = &config.bootstrap_servers;
475 if servers.is_empty() {
476 return Err(Error::ConnectionError(
477 "No bootstrap servers configured".to_string(),
478 ));
479 }
480
481 let mut last_error = None;
482 let mut client = None;
483 for server in servers {
484 #[cfg(feature = "tls")]
486 let connect_result = if let (Some(ref tls_cfg), Some(ref sni)) =
487 (&config.tls_config, &config.tls_server_name)
488 {
489 Client::connect_tls(server, tls_cfg, sni).await
490 } else {
491 Client::connect(server).await
492 };
493 #[cfg(not(feature = "tls"))]
494 let connect_result = Client::connect(server).await;
495
496 match connect_result {
497 Ok(c) => {
498 client = Some(c);
499 break;
500 }
501 Err(e) => {
502 warn!(server = %server, error = %e, "Failed to connect to bootstrap server");
503 last_error = Some(e);
504 }
505 }
506 }
507 let mut client = client.ok_or_else(|| {
508 last_error.unwrap_or_else(|| {
509 Error::ConnectionError("No bootstrap servers available".to_string())
510 })
511 })?;
512
513 if let Some(ref auth) = config.auth {
515 client
516 .authenticate_scram(&auth.username, &auth.password)
517 .await?;
518 }
519
520 let uses_coordination = config.partitions.is_empty();
521
522 let mut consumer = Self {
523 client,
524 config,
525 offsets: HashMap::new(),
526 assignments: HashMap::new(),
527 assignment_list: Vec::new(),
528 last_commit: Instant::now(),
529 last_discovery: Instant::now(),
530 initialized: false,
531 member_id: String::new(),
532 generation_id: 0,
533 is_leader: false,
534 last_heartbeat: Instant::now(),
535 uses_coordination,
536 needs_rejoin: Arc::new(AtomicBool::new(false)),
537 rebalance_listener: None,
538 heartbeat_handle: None,
539 };
540
541 consumer.discover_assignments().await?;
542
543 info!(
544 group_id = %consumer.config.group_id,
545 topics = ?consumer.config.topics,
546 partitions = ?consumer.assignments,
547 "Consumer initialized"
548 );
549
550 Ok(consumer)
551 }
552
553 pub fn set_rebalance_listener(&mut self, listener: Arc<dyn RebalanceListener>) {
559 self.rebalance_listener = Some(listener);
560 }
561
562 async fn spawn_heartbeat_task(&mut self) {
573 if let Some(handle) = self.heartbeat_handle.take() {
575 handle.abort();
576 }
577
578 if self.member_id.is_empty() || !self.uses_coordination {
580 return;
581 }
582
583 let group_id = self.config.group_id.clone();
584 let member_id = self.member_id.clone();
585 let generation_id = self.generation_id;
586 let interval = Duration::from_millis(self.config.heartbeat_interval_ms);
587 let needs_rejoin = self.needs_rejoin.clone();
588 let servers = self.config.bootstrap_servers.clone();
589 let auth = self.config.auth.clone();
590
591 self.heartbeat_handle = Some(tokio::spawn(async move {
592 let mut hb_client = None;
594 for server in &servers {
595 match Client::connect(server).await {
596 Ok(mut c) => {
597 if let Some(ref auth) = auth {
598 if let Err(e) =
599 c.authenticate_scram(&auth.username, &auth.password).await
600 {
601 warn!(
602 server = %server,
603 error = %e,
604 "Heartbeat connection auth failed, trying next server"
605 );
606 continue;
607 }
608 }
609 hb_client = Some(c);
610 break;
611 }
612 Err(e) => {
613 warn!(
614 server = %server,
615 error = %e,
616 "Heartbeat connection failed, trying next server"
617 );
618 }
619 }
620 }
621
622 let Some(mut client) = hb_client else {
623 warn!("Could not establish heartbeat connection to any server, signaling rejoin");
624 needs_rejoin.store(true, Ordering::Release);
625 return;
626 };
627
628 let mut ticker = tokio::time::interval(interval);
629 ticker.tick().await; loop {
632 ticker.tick().await;
633 match client.heartbeat(&group_id, generation_id, &member_id).await {
634 Ok(27) => {
635 info!(
637 group_id = %group_id,
638 "Background heartbeat: rebalance in progress, signaling rejoin"
639 );
640 needs_rejoin.store(true, Ordering::Release);
641 }
642 Ok(_) => {
643 }
645 Err(e) => {
646 warn!(
647 group_id = %group_id,
648 error = %e,
649 "Background heartbeat failed, signaling rejoin"
650 );
651 needs_rejoin.store(true, Ordering::Release);
652 }
656 }
657 }
658 }));
659 }
660
661 async fn discover_assignments(&mut self) -> Result<()> {
668 let old_tps: Vec<TopicPartition> = self
670 .assignments
671 .iter()
672 .flat_map(|(t, ps)| {
673 let arc: Arc<str> = Arc::from(t.as_str());
674 ps.iter().map(move |&p| TopicPartition {
675 topic: arc.clone(),
676 partition: p,
677 })
678 })
679 .collect();
680
681 if !old_tps.is_empty() {
683 if let Some(ref listener) = self.rebalance_listener {
684 listener.on_partitions_revoked(&old_tps).await;
685 }
686 }
687
688 if self.uses_coordination {
689 self.discover_via_coordination().await?;
690 self.spawn_heartbeat_task().await;
693 } else {
694 self.discover_via_metadata().await?;
695 }
696
697 let owned_keys: std::collections::HashSet<(Arc<str>, u32)> = self
700 .assignments
701 .iter()
702 .flat_map(|(t, ps)| {
703 let arc: Arc<str> = Arc::from(t.as_str());
704 ps.iter().map(move |&p| (arc.clone(), p))
705 })
706 .collect();
707 self.offsets.retain(|k, _| owned_keys.contains(k));
708
709 for (topic, partitions) in &self.assignments {
711 for &partition in partitions {
712 let key: (Arc<str>, u32) = (Arc::from(topic.as_str()), partition);
713 if self.offsets.contains_key(&key) {
714 continue;
715 }
716 match self
718 .client
719 .get_offset(&self.config.group_id, topic, partition)
720 .await
721 {
722 Ok(Some(offset)) => {
723 debug!(
724 topic = %topic,
725 partition,
726 offset,
727 "Resumed from committed offset"
728 );
729 self.offsets.insert(key, offset);
730 }
731 Ok(None) => {
732 self.offsets.insert(key, 0);
734 }
735 Err(e) => {
736 debug!(
737 topic = %topic,
738 partition,
739 error = %e,
740 "Failed to load committed offset, starting from 0"
741 );
742 self.offsets.insert(key, 0);
743 }
744 }
745 }
746 }
747
748 self.initialized = true;
749
750 self.assignment_list = self
753 .assignments
754 .iter()
755 .flat_map(|(t, ps)| {
756 let arc: Arc<str> = Arc::from(t.as_str());
757 ps.iter().map(move |&p| (arc.clone(), p))
758 })
759 .collect();
760
761 if let Some(ref listener) = self.rebalance_listener {
763 let new_tps: Vec<TopicPartition> = self
764 .assignment_list
765 .iter()
766 .map(|(t, p)| TopicPartition {
767 topic: t.clone(),
768 partition: *p,
769 })
770 .collect();
771 if !new_tps.is_empty() {
772 listener.on_partitions_assigned(&new_tps).await;
773 }
774 }
775
776 Ok(())
777 }
778
779 async fn discover_via_coordination(&mut self) -> Result<()> {
787 let (generation_id, _protocol_type, member_id, leader_id, members) = self
789 .client
790 .join_group(
791 &self.config.group_id,
792 &self.member_id,
793 self.config.session_timeout_ms,
794 self.config.rebalance_timeout_ms,
795 "consumer",
796 self.config.topics.clone(),
797 )
798 .await?;
799
800 self.member_id = member_id.clone();
801 self.generation_id = generation_id;
802 self.is_leader = member_id == leader_id;
803
804 info!(
805 group_id = %self.config.group_id,
806 member_id = %self.member_id,
807 generation_id,
808 is_leader = self.is_leader,
809 member_count = members.len(),
810 "Joined consumer group"
811 );
812
813 let group_assignments = if self.is_leader {
816 self.compute_range_assignments(&members).await?
817 } else {
818 Vec::new()
819 };
820
821 let my_assignments = self
822 .client
823 .sync_group(
824 &self.config.group_id,
825 generation_id,
826 &self.member_id,
827 group_assignments,
828 )
829 .await?;
830
831 self.assignments.clear();
833 for (topic, partitions) in my_assignments {
834 debug!(
835 topic = %topic,
836 partitions = ?partitions,
837 "Received partition assignment"
838 );
839 self.assignments.insert(topic, partitions);
840 }
841
842 self.last_heartbeat = Instant::now();
843
844 Ok(())
845 }
846
847 async fn compute_range_assignments(
853 &mut self,
854 members: &[(String, Vec<String>)],
855 ) -> Result<Vec<(String, Vec<(String, Vec<u32>)>)>> {
856 let mut all_topics: Vec<String> = members
858 .iter()
859 .flat_map(|(_, subs)| subs.iter().cloned())
860 .collect();
861 all_topics.sort();
862 all_topics.dedup();
863
864 let mut result_map: HashMap<String, Vec<(String, Vec<u32>)>> = members
866 .iter()
867 .map(|(mid, _)| (mid.clone(), Vec::new()))
868 .collect();
869
870 for topic in &all_topics {
871 let mut subscribed: Vec<&str> = members
873 .iter()
874 .filter(|(_, subs)| subs.iter().any(|s| s == topic))
875 .map(|(mid, _)| mid.as_str())
876 .collect();
877 subscribed.sort(); let partition_count = match self.client.get_metadata(topic.as_str()).await {
880 Ok((_name, count)) => count,
881 Err(e) => {
882 warn!(topic = %topic, error = %e, "Failed to get metadata for assignment");
883 continue;
884 }
885 };
886
887 if subscribed.is_empty() || partition_count == 0 {
888 continue;
889 }
890
891 let n_members = subscribed.len() as u32;
893 let per_member = partition_count / n_members;
894 let remainder = partition_count % n_members;
895
896 let mut offset = 0u32;
897 for (i, mid) in subscribed.iter().enumerate() {
898 let extra = if (i as u32) < remainder { 1 } else { 0 };
899 let count = per_member + extra;
900 let partitions: Vec<u32> = (offset..offset + count).collect();
901 offset += count;
902
903 if let Some(entry) = result_map.get_mut(*mid) {
904 entry.push((topic.clone(), partitions));
905 }
906 }
907 }
908
909 Ok(result_map.into_iter().collect())
910 }
911
912 async fn discover_via_metadata(&mut self) -> Result<()> {
917 for topic in &self.config.topics {
918 if let Some(explicit) = self.config.partitions.get(topic) {
919 self.assignments.insert(topic.clone(), explicit.clone());
921 } else {
922 match self.client.get_metadata(topic.as_str()).await {
924 Ok((_name, partition_count)) => {
925 let partitions: Vec<u32> = (0..partition_count).collect();
926 self.assignments.insert(topic.clone(), partitions);
927 }
928 Err(e) => {
929 warn!(
930 topic = %topic,
931 error = %e,
932 "Failed to discover partitions, will retry on next poll"
933 );
934 }
935 }
936 }
937 }
938
939 Ok(())
940 }
941
942 pub async fn poll(&mut self) -> Result<Vec<ConsumerRecord>> {
948 match self.poll_inner().await {
949 Ok(records) => Ok(records),
950 Err(e) if Self::is_connection_error(&e) => {
951 warn!(error = %e, "Connection error during poll, attempting reconnect");
952 self.reconnect().await?;
953 self.poll_inner().await
954 }
955 Err(e) => Err(e),
956 }
957 }
958
959 async fn poll_inner(&mut self) -> Result<Vec<ConsumerRecord>> {
961 if !self.initialized {
962 self.discover_assignments().await?;
963 }
964
965 if self.needs_rejoin.load(Ordering::Acquire) && self.uses_coordination {
968 info!(
969 group_id = %self.config.group_id,
970 "Rejoining group due to rebalance signal"
971 );
972 self.discover_assignments().await?;
973 self.needs_rejoin.store(false, Ordering::Release);
974 }
975
976 if self.last_discovery.elapsed() >= self.config.metadata_refresh_interval {
979 if let Err(e) = self.discover_assignments().await {
980 warn!(error = %e, "Failed to re-discover assignments, continuing with existing");
981 }
982 self.last_discovery = Instant::now();
983 }
984
985 let mut records = Vec::new();
986 let isolation_level = if self.config.isolation_level > 0 {
987 Some(self.config.isolation_level)
988 } else {
989 None
990 };
991
992 if !self.assignment_list.is_empty() {
996 let fetches: Vec<(&str, u32, u64, u32, Option<u8>)> = self
997 .assignment_list
998 .iter()
999 .map(|(topic, partition)| {
1000 let key = (topic.clone(), *partition);
1001 let offset = self.offsets.get(&key).copied().unwrap_or(0);
1002 (
1003 &**topic,
1004 *partition,
1005 offset,
1006 self.config.max_poll_records,
1007 isolation_level,
1008 )
1009 })
1010 .collect();
1011
1012 let results = self.client.consume_pipelined(&fetches).await?;
1013
1014 for (i, result) in results.into_iter().enumerate() {
1015 let (topic, partition) = &self.assignment_list[i];
1016 match result {
1017 Ok(messages) if !messages.is_empty() => {
1018 let key = (topic.clone(), *partition);
1019 let max_offset = messages.iter().map(|m| m.offset).max().unwrap_or(0);
1020 self.offsets.insert(key, max_offset + 1);
1021
1022 records.extend(messages.into_iter().map(|data| ConsumerRecord {
1023 topic: topic.clone(),
1024 partition: *partition,
1025 offset: data.offset,
1026 data,
1027 }));
1028 }
1029 Err(e) => {
1030 let err_str = e.to_string();
1033 if err_str.contains("UNKNOWN_MEMBER_ID")
1034 || err_str.contains("ILLEGAL_GENERATION")
1035 || err_str.contains("REBALANCE_IN_PROGRESS")
1036 {
1037 warn!(
1038 topic = %topic,
1039 partition = partition,
1040 error = %e,
1041 "Rebalance signal in fetch response, will rejoin group"
1042 );
1043 self.needs_rejoin.store(true, Ordering::Release);
1044 } else {
1045 warn!(
1046 topic = %topic,
1047 partition = partition,
1048 error = %e,
1049 "Pipelined fetch error, skipping partition"
1050 );
1051 }
1052 }
1053 _ => {} }
1055 }
1056 }
1057
1058 if records.is_empty() && self.config.max_poll_interval_ms > 0 {
1063 if !self.assignment_list.is_empty() {
1064 self.assignment_list.rotate_left(1);
1065 }
1066 if let Some((topic, partition)) = self.assignment_list.first() {
1067 let key = (topic.clone(), *partition);
1068 let offset = self.offsets.get(&key).copied().unwrap_or(0);
1069
1070 let max_wait = if self.uses_coordination {
1074 self.config.max_poll_interval_ms.min(
1075 self.config
1076 .heartbeat_interval_ms
1077 .saturating_sub(500)
1078 .max(500),
1079 )
1080 } else {
1081 self.config.max_poll_interval_ms
1082 };
1083
1084 let messages = self
1085 .client
1086 .consume_long_poll(
1087 topic.to_string(),
1088 *partition,
1089 offset,
1090 self.config.max_poll_records,
1091 isolation_level,
1092 max_wait,
1093 )
1094 .await?;
1095
1096 if !messages.is_empty() {
1097 let max_offset = messages.iter().map(|m| m.offset).max().unwrap_or(offset);
1098 self.offsets.insert(key, max_offset + 1);
1099
1100 records.extend(messages.into_iter().map(|data| ConsumerRecord {
1101 topic: topic.clone(),
1102 partition: *partition,
1103 offset: data.offset,
1104 data,
1105 }));
1106 }
1107 }
1108 }
1109
1110 if let Some(interval) = self.config.auto_commit_interval {
1112 if self.last_commit.elapsed() >= interval {
1113 if let Err(e) = self.commit_inner().await {
1114 warn!(error = %e, "Auto-commit failed");
1115 }
1116 }
1117 }
1118
1119 Ok(records)
1120 }
1121
1122 pub async fn commit(&mut self) -> Result<()> {
1126 match self.commit_inner().await {
1127 Ok(()) => Ok(()),
1128 Err(e) if Self::is_connection_error(&e) => {
1129 warn!(error = %e, "Connection error during commit, attempting reconnect");
1130 self.reconnect().await?;
1131 self.commit_inner().await
1132 }
1133 Err(e) => Err(e),
1134 }
1135 }
1136
1137 async fn commit_inner(&mut self) -> Result<()> {
1142 if self.offsets.is_empty() {
1143 return Ok(());
1144 }
1145
1146 let commits: Vec<(String, u32, u64)> = self
1148 .offsets
1149 .iter()
1150 .map(|((topic, partition), offset)| (topic.to_string(), *partition, *offset))
1151 .collect();
1152
1153 let mut errors = Vec::new();
1154
1155 if self.client.is_poisoned() {
1157 return Err(Error::ConnectionError(
1160 "Client stream is desynchronized — reconnect required".into(),
1161 ));
1162 }
1163
1164 {
1165 let results = self
1166 .client
1167 .commit_offsets_pipelined(&self.config.group_id, &commits)
1168 .await;
1169
1170 match results {
1171 Ok(per_partition) => {
1172 for (i, result) in per_partition.into_iter().enumerate() {
1173 if let Err(e) = result {
1174 let (topic, partition, offset) = &commits[i];
1175 warn!(
1176 topic = %topic, partition, offset, error = %e,
1177 "Failed to commit offset"
1178 );
1179 errors.push(e);
1180 }
1181 }
1182 }
1183 Err(e) => {
1184 errors.push(e);
1186 }
1187 }
1188 }
1189
1190 self.last_commit = Instant::now();
1191
1192 if errors.is_empty() {
1193 debug!(
1194 group_id = %self.config.group_id,
1195 partitions = self.offsets.len(),
1196 "Offsets committed"
1197 );
1198 Ok(())
1199 } else {
1200 Err(errors.into_iter().next().expect("errors is non-empty"))
1202 }
1203 }
1204
1205 pub fn seek(&mut self, topic: impl Into<String>, partition: u32, offset: u64) {
1209 let arc: Arc<str> = Arc::from(topic.into());
1210 self.offsets.insert((arc, partition), offset);
1211 }
1212
1213 pub fn seek_to_beginning(&mut self, topic: &str) {
1215 if let Some(partitions) = self.assignments.get(topic) {
1216 let arc: Arc<str> = Arc::from(topic);
1217 for &p in partitions {
1218 self.offsets.insert((arc.clone(), p), 0);
1219 }
1220 }
1221 }
1222
1223 pub fn position(&self, topic: &str, partition: u32) -> Option<u64> {
1225 self.offsets
1226 .get(&(Arc::<str>::from(topic), partition))
1227 .copied()
1228 }
1229
1230 pub fn assignments(&self) -> &HashMap<String, Vec<u32>> {
1232 &self.assignments
1233 }
1234
1235 pub fn group_id(&self) -> &str {
1237 &self.config.group_id
1238 }
1239
1240 async fn reconnect(&mut self) -> Result<()> {
1248 if let Some(handle) = self.heartbeat_handle.take() {
1251 handle.abort();
1252 }
1253
1254 let mut backoff = Duration::from_millis(self.config.reconnect_backoff_ms);
1255 let max_backoff = Duration::from_millis(self.config.reconnect_backoff_max_ms);
1256 let servers = &self.config.bootstrap_servers;
1257
1258 if servers.is_empty() {
1259 return Err(Error::ConnectionError(
1260 "No bootstrap servers configured".to_string(),
1261 ));
1262 }
1263
1264 for attempt in 1..=self.config.max_reconnect_attempts {
1265 let server = &servers[(attempt as usize - 1) % servers.len()];
1267 info!(
1268 attempt,
1269 server = %server,
1270 "Attempting to reconnect"
1271 );
1272 match Client::connect(server).await {
1273 Ok(mut new_client) => {
1274 if let Some(ref auth) = self.config.auth {
1276 if let Err(e) = new_client
1277 .authenticate_scram(&auth.username, &auth.password)
1278 .await
1279 {
1280 warn!(error = %e, attempt, "Re-authentication failed during reconnect");
1281 tokio::time::sleep(backoff).await;
1282 backoff = (backoff * 2).min(max_backoff);
1283 continue;
1284 }
1285 }
1286 self.client = new_client;
1287 info!("Consumer reconnected successfully to {}", server);
1288
1289 if self.uses_coordination {
1293 if let Err(e) = self.discover_assignments().await {
1294 warn!(error = %e, "Failed to rejoin group after reconnect");
1295 }
1296 }
1297
1298 return Ok(());
1299 }
1300 Err(e) => {
1301 warn!(error = %e, attempt, server = %server, "Reconnect attempt failed");
1302 tokio::time::sleep(backoff).await;
1303 backoff = (backoff * 2).min(max_backoff);
1304 }
1305 }
1306 }
1307 Err(Error::ConnectionError(format!(
1308 "Failed to reconnect to any of {:?} after {} attempts",
1309 servers, self.config.max_reconnect_attempts
1310 )))
1311 }
1312
1313 fn is_connection_error(e: &Error) -> bool {
1315 matches!(
1316 e,
1317 Error::ConnectionError(_)
1318 | Error::IoError(_, _)
1319 | Error::Timeout
1320 | Error::TimeoutWithMessage(_)
1321 | Error::ProtocolError(_)
1322 | Error::ResponseTooLarge(_, _)
1323 )
1324 }
1325
1326 pub async fn close(mut self) -> Result<()> {
1328 if let Some(handle) = self.heartbeat_handle.take() {
1330 handle.abort();
1331 }
1332
1333 if self.config.auto_commit_interval.is_some() {
1334 self.commit().await?;
1335 }
1336
1337 if self.uses_coordination && !self.member_id.is_empty() {
1339 if let Err(e) = self
1340 .client
1341 .leave_group(&self.config.group_id, &self.member_id)
1342 .await
1343 {
1344 warn!(
1345 error = %e,
1346 group_id = %self.config.group_id,
1347 member_id = %self.member_id,
1348 "Failed to leave group gracefully"
1349 );
1350 } else {
1351 info!(
1352 group_id = %self.config.group_id,
1353 member_id = %self.member_id,
1354 "Left consumer group"
1355 );
1356 }
1357 }
1358
1359 info!(group_id = %self.config.group_id, "Consumer closed");
1360 Ok(())
1361 }
1362}