1use std::collections::{HashMap, HashSet};
71use std::net::SocketAddr;
72use std::path::{Path, PathBuf};
73use std::sync::Arc;
74use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
75use std::time::{Duration, Instant};
76
77use tokio::net::{TcpListener, TcpStream};
78use tokio::sync::{RwLock, broadcast};
79
80use crate::client::{ClientConfig, LanceClient};
81use crate::consumer::{PollResult, SeekPosition};
82use crate::error::{ClientError, Result};
83use crate::offset::{LockFileOffsetStore, MemoryOffsetStore, OffsetStore};
84use crate::standalone::{StandaloneConfig, StandaloneConsumer};
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
88pub enum AssignmentStrategy {
89 #[default]
91 RoundRobin,
92 Range,
94 Sticky,
96}
97
98#[derive(Debug, Clone)]
100pub struct GroupConfig {
101 pub group_id: String,
103 pub topics: Vec<u32>,
105 pub assignment_strategy: AssignmentStrategy,
107 pub heartbeat_interval: Duration,
109 pub session_timeout: Duration,
111 pub coordinator_addr: SocketAddr,
113 pub server_addr: String,
115 pub offset_dir: Option<PathBuf>,
117}
118
119impl GroupConfig {
120 pub fn new(group_id: impl Into<String>) -> Self {
122 Self {
123 group_id: group_id.into(),
124 topics: Vec::new(),
125 assignment_strategy: AssignmentStrategy::RoundRobin,
126 heartbeat_interval: Duration::from_secs(3),
127 session_timeout: Duration::from_secs(30),
128 coordinator_addr: "127.0.0.1:19920"
129 .parse()
130 .unwrap_or_else(|_| SocketAddr::from(([127, 0, 0, 1], 19920))),
131 server_addr: "127.0.0.1:1992".to_string(),
132 offset_dir: None,
133 }
134 }
135
136 pub fn with_topics(mut self, topics: Vec<u32>) -> Self {
138 self.topics = topics;
139 self
140 }
141
142 pub fn with_assignment_strategy(mut self, strategy: AssignmentStrategy) -> Self {
144 self.assignment_strategy = strategy;
145 self
146 }
147
148 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
150 self.heartbeat_interval = interval;
151 self
152 }
153
154 pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
156 self.session_timeout = timeout;
157 self
158 }
159
160 pub fn with_coordinator_addr(mut self, addr: SocketAddr) -> Self {
162 self.coordinator_addr = addr;
163 self
164 }
165
166 pub fn with_server_addr(mut self, addr: impl Into<String>) -> Self {
168 self.server_addr = addr.into();
169 self
170 }
171
172 pub fn with_offset_dir(mut self, dir: &Path) -> Self {
174 self.offset_dir = Some(dir.to_path_buf());
175 self
176 }
177}
178
179#[derive(Debug, Clone)]
181#[allow(dead_code)]
182enum CoordinatorMessage {
183 Join { worker_id: String },
185 Leave { worker_id: String },
187 Heartbeat { worker_id: String },
189 GetAssignments { worker_id: String },
191}
192
193#[derive(Debug, Clone)]
195#[allow(dead_code)]
196enum CoordinatorResponse {
197 JoinAccepted {
199 worker_id: String,
200 generation: u64,
201 assignments: Vec<u32>,
202 },
203 LeaveAcknowledged,
205 HeartbeatAck { generation: u64 },
207 Assignments {
209 generation: u64,
210 assignments: Vec<u32>,
211 },
212 Rebalance {
214 generation: u64,
215 assignments: Vec<u32>,
216 },
217 Error { message: String },
219}
220
221#[derive(Debug, Clone)]
223#[allow(dead_code)]
224struct WorkerState {
225 worker_id: String,
226 last_heartbeat: Instant,
227 assignments: Vec<u32>,
228 generation: u64,
229}
230
231pub struct GroupCoordinator {
239 #[allow(dead_code)]
240 config: GroupConfig,
241 workers: Arc<RwLock<HashMap<String, WorkerState>>>,
242 generation: Arc<AtomicU64>,
243 running: Arc<AtomicBool>,
244 shutdown_tx: broadcast::Sender<()>,
245 join_addr: SocketAddr,
246}
247
248impl GroupCoordinator {
249 pub async fn start(server_addr: &str, mut config: GroupConfig) -> Result<Self> {
251 config.server_addr = server_addr.to_string();
252
253 if config.topics.is_empty() {
255 let client_config = ClientConfig::new(&config.server_addr);
256 let mut client = LanceClient::connect(client_config).await?;
257 let topics = client.list_topics().await?;
258 config.topics = topics.iter().map(|t| t.id).collect();
259 }
260
261 let workers = Arc::new(RwLock::new(HashMap::new()));
262 let generation = Arc::new(AtomicU64::new(0));
263 let running = Arc::new(AtomicBool::new(true));
264 let (shutdown_tx, _) = broadcast::channel(1);
265
266 let listener = TcpListener::bind(config.coordinator_addr)
267 .await
268 .map_err(ClientError::ConnectionFailed)?;
269
270 let join_addr = listener.local_addr().map_err(ClientError::IoError)?;
271
272 let coordinator = Self {
273 config: config.clone(),
274 workers: workers.clone(),
275 generation: generation.clone(),
276 running: running.clone(),
277 shutdown_tx: shutdown_tx.clone(),
278 join_addr,
279 };
280
281 let workers_clone = workers.clone();
283 let generation_clone = generation.clone();
284 let _running_clone = running.clone();
285 let config_clone = config.clone();
286 let mut shutdown_rx = shutdown_tx.subscribe();
287
288 tokio::spawn(async move {
290 loop {
291 tokio::select! {
292 accept_result = listener.accept() => {
293 match accept_result {
294 Ok((stream, addr)) => {
295 let workers = workers_clone.clone();
296 let generation = generation_clone.clone();
297 let config = config_clone.clone();
298 tokio::spawn(async move {
299 if let Err(e) = Self::handle_worker_connection(
300 stream, addr, workers, generation, config
301 ).await {
302 tracing::warn!("Worker connection error: {}", e);
303 }
304 });
305 }
306 Err(e) => {
307 tracing::error!("Accept error: {}", e);
308 }
309 }
310 }
311 _ = shutdown_rx.recv() => {
312 tracing::info!("Coordinator shutting down");
313 break;
314 }
315 }
316 }
317 });
318
319 let workers_checker = workers.clone();
321 let generation_checker = generation.clone();
322 let running_checker = running.clone();
323 let session_timeout = config.session_timeout;
324 let mut shutdown_rx2 = shutdown_tx.subscribe();
325
326 tokio::spawn(async move {
327 let mut interval = tokio::time::interval(Duration::from_secs(1));
328 loop {
329 tokio::select! {
330 _ = interval.tick() => {
331 if !running_checker.load(Ordering::Relaxed) {
332 break;
333 }
334 Self::check_worker_health(
335 &workers_checker,
336 &generation_checker,
337 session_timeout,
338 ).await;
339 }
340 _ = shutdown_rx2.recv() => {
341 break;
342 }
343 }
344 }
345 });
346
347 Ok(coordinator)
348 }
349
350 pub fn join_address(&self) -> SocketAddr {
352 self.join_addr
353 }
354
355 pub fn generation(&self) -> u64 {
357 self.generation.load(Ordering::Relaxed)
358 }
359
360 pub async fn worker_count(&self) -> usize {
362 self.workers.read().await.len()
363 }
364
365 pub async fn get_assignments(&self) -> HashMap<String, Vec<u32>> {
367 self.workers
368 .read()
369 .await
370 .iter()
371 .map(|(id, state)| (id.clone(), state.assignments.clone()))
372 .collect()
373 }
374
375 pub fn stop(&self) {
377 self.running.store(false, Ordering::Relaxed);
378 let _ = self.shutdown_tx.send(());
379 }
380
381 async fn handle_worker_connection(
383 stream: TcpStream,
384 _addr: SocketAddr,
385 workers: Arc<RwLock<HashMap<String, WorkerState>>>,
386 generation: Arc<AtomicU64>,
387 config: GroupConfig,
388 ) -> Result<()> {
389 use tokio::io::{AsyncReadExt, AsyncWriteExt};
392
393 let (mut reader, mut writer) = stream.into_split();
394 let mut buf = vec![0u8; 4096];
395
396 loop {
397 let n = reader.read(&mut buf).await?;
399 if n == 0 {
400 break; }
402
403 let msg_str = std::str::from_utf8(&buf[..n])
405 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
406
407 let response = Self::process_message(msg_str, &workers, &generation, &config).await;
408
409 let response_bytes = format!("{:?}", response);
411 writer.write_all(response_bytes.as_bytes()).await?;
412 }
413
414 Ok(())
415 }
416
417 async fn process_message(
419 msg: &str,
420 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
421 generation: &Arc<AtomicU64>,
422 config: &GroupConfig,
423 ) -> CoordinatorResponse {
424 if msg.starts_with("JOIN:") {
426 let worker_id = msg
427 .strip_prefix("JOIN:")
428 .unwrap_or("unknown")
429 .trim()
430 .to_string();
431 Self::handle_join(worker_id, workers, generation, config).await
432 } else if msg.starts_with("LEAVE:") {
433 let worker_id = msg
434 .strip_prefix("LEAVE:")
435 .unwrap_or("unknown")
436 .trim()
437 .to_string();
438 Self::handle_leave(worker_id, workers, generation, config).await
439 } else if msg.starts_with("HEARTBEAT:") {
440 let worker_id = msg
441 .strip_prefix("HEARTBEAT:")
442 .unwrap_or("unknown")
443 .trim()
444 .to_string();
445 Self::handle_heartbeat(worker_id, workers, generation).await
446 } else if msg.starts_with("GET_ASSIGNMENTS:") {
447 let worker_id = msg
448 .strip_prefix("GET_ASSIGNMENTS:")
449 .unwrap_or("unknown")
450 .trim()
451 .to_string();
452 Self::handle_get_assignments(worker_id, workers, generation).await
453 } else {
454 CoordinatorResponse::Error {
455 message: format!("Unknown message: {}", msg),
456 }
457 }
458 }
459
460 async fn handle_join(
462 worker_id: String,
463 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
464 generation: &Arc<AtomicU64>,
465 config: &GroupConfig,
466 ) -> CoordinatorResponse {
467 let new_gen = generation.fetch_add(1, Ordering::SeqCst) + 1;
468
469 {
470 let mut workers_lock = workers.write().await;
471 workers_lock.insert(
472 worker_id.clone(),
473 WorkerState {
474 worker_id: worker_id.clone(),
475 last_heartbeat: Instant::now(),
476 assignments: Vec::new(),
477 generation: new_gen,
478 },
479 );
480 }
481
482 let _assignments = Self::rebalance(workers, config).await;
484
485 let worker_assignments = {
487 let workers_lock = workers.read().await;
488 workers_lock
489 .get(&worker_id)
490 .map(|w| w.assignments.clone())
491 .unwrap_or_default()
492 };
493
494 CoordinatorResponse::JoinAccepted {
495 worker_id,
496 generation: new_gen,
497 assignments: worker_assignments,
498 }
499 }
500
501 async fn handle_leave(
503 worker_id: String,
504 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
505 generation: &Arc<AtomicU64>,
506 config: &GroupConfig,
507 ) -> CoordinatorResponse {
508 {
509 let mut workers_lock = workers.write().await;
510 workers_lock.remove(&worker_id);
511 }
512
513 generation.fetch_add(1, Ordering::SeqCst);
514
515 Self::rebalance(workers, config).await;
517
518 CoordinatorResponse::LeaveAcknowledged
519 }
520
521 async fn handle_heartbeat(
523 worker_id: String,
524 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
525 generation: &Arc<AtomicU64>,
526 ) -> CoordinatorResponse {
527 let current_gen = generation.load(Ordering::Relaxed);
528
529 let mut workers_lock = workers.write().await;
530 if let Some(worker) = workers_lock.get_mut(&worker_id) {
531 worker.last_heartbeat = Instant::now();
532 CoordinatorResponse::HeartbeatAck {
533 generation: current_gen,
534 }
535 } else {
536 CoordinatorResponse::Error {
537 message: "Worker not found".to_string(),
538 }
539 }
540 }
541
542 async fn handle_get_assignments(
544 worker_id: String,
545 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
546 generation: &Arc<AtomicU64>,
547 ) -> CoordinatorResponse {
548 let current_gen = generation.load(Ordering::Relaxed);
549 let workers_lock = workers.read().await;
550
551 if let Some(worker) = workers_lock.get(&worker_id) {
552 CoordinatorResponse::Assignments {
553 generation: current_gen,
554 assignments: worker.assignments.clone(),
555 }
556 } else {
557 CoordinatorResponse::Error {
558 message: "Worker not found".to_string(),
559 }
560 }
561 }
562
563 async fn check_worker_health(
565 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
566 generation: &Arc<AtomicU64>,
567 session_timeout: Duration,
568 ) {
569 let now = Instant::now();
570 let mut dead_workers = Vec::new();
571
572 {
573 let workers_lock = workers.read().await;
574 for (id, state) in workers_lock.iter() {
575 if now.duration_since(state.last_heartbeat) > session_timeout {
576 dead_workers.push(id.clone());
577 }
578 }
579 }
580
581 if !dead_workers.is_empty() {
582 let mut workers_lock = workers.write().await;
583 for id in dead_workers {
584 tracing::warn!("Worker {} timed out, removing from group", id);
585 workers_lock.remove(&id);
586 }
587 generation.fetch_add(1, Ordering::SeqCst);
588 }
590 }
591
592 async fn rebalance(
594 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
595 config: &GroupConfig,
596 ) -> HashMap<String, Vec<u32>> {
597 let mut workers_lock = workers.write().await;
598 let worker_ids: Vec<String> = workers_lock.keys().cloned().collect();
599
600 if worker_ids.is_empty() {
601 return HashMap::new();
602 }
603
604 let assignments = match config.assignment_strategy {
605 AssignmentStrategy::RoundRobin => Self::assign_round_robin(&config.topics, &worker_ids),
606 AssignmentStrategy::Range => Self::assign_range(&config.topics, &worker_ids),
607 AssignmentStrategy::Sticky => {
608 let existing: HashMap<String, Vec<u32>> = workers_lock
610 .iter()
611 .map(|(id, state)| (id.clone(), state.assignments.clone()))
612 .collect();
613 Self::assign_sticky(&config.topics, &worker_ids, &existing)
614 },
615 };
616
617 for (worker_id, topics) in &assignments {
619 if let Some(worker) = workers_lock.get_mut(worker_id) {
620 worker.assignments = topics.clone();
621 }
622 }
623
624 assignments
625 }
626
627 fn assign_round_robin(topics: &[u32], workers: &[String]) -> HashMap<String, Vec<u32>> {
629 let mut assignments: HashMap<String, Vec<u32>> =
630 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
631
632 for (i, topic) in topics.iter().enumerate() {
633 let worker = &workers[i % workers.len()];
634 if let Some(topics) = assignments.get_mut(worker) {
635 topics.push(*topic);
636 }
637 }
638
639 assignments
640 }
641
642 fn assign_range(topics: &[u32], workers: &[String]) -> HashMap<String, Vec<u32>> {
644 let mut assignments: HashMap<String, Vec<u32>> =
645 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
646
647 let topics_per_worker = topics.len() / workers.len();
648 let remainder = topics.len() % workers.len();
649
650 let mut topic_idx = 0;
651 for (worker_idx, worker) in workers.iter().enumerate() {
652 let extra = if worker_idx < remainder { 1 } else { 0 };
653 let count = topics_per_worker + extra;
654
655 if let Some(worker_topics) = assignments.get_mut(worker) {
656 for _ in 0..count {
657 if topic_idx < topics.len() {
658 worker_topics.push(topics[topic_idx]);
659 topic_idx += 1;
660 }
661 }
662 }
663 }
664
665 assignments
666 }
667
668 fn assign_sticky(
670 topics: &[u32],
671 workers: &[String],
672 existing: &HashMap<String, Vec<u32>>,
673 ) -> HashMap<String, Vec<u32>> {
674 let mut assignments: HashMap<String, Vec<u32>> =
675 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
676
677 let topic_set: HashSet<u32> = topics.iter().copied().collect();
678 let mut assigned: HashSet<u32> = HashSet::new();
679
680 for (worker, old_topics) in existing {
682 if assignments.contains_key(worker) {
683 for topic in old_topics {
684 if topic_set.contains(topic) && !assigned.contains(topic) {
685 if let Some(worker_topics) = assignments.get_mut(worker) {
686 worker_topics.push(*topic);
687 assigned.insert(*topic);
688 }
689 }
690 }
691 }
692 }
693
694 let unassigned: Vec<u32> = topics
696 .iter()
697 .filter(|t| !assigned.contains(t))
698 .copied()
699 .collect();
700
701 for topic in unassigned {
703 let min_worker = assignments
704 .iter()
705 .min_by_key(|(_, topics)| topics.len())
706 .map(|(w, _)| w.clone());
707
708 if let Some(worker) = min_worker {
709 if let Some(worker_topics) = assignments.get_mut(&worker) {
710 worker_topics.push(topic);
711 }
712 }
713 }
714
715 assignments
716 }
717}
718
719impl Drop for GroupCoordinator {
720 fn drop(&mut self) {
721 self.stop();
722 }
723}
724
725#[derive(Debug, Clone)]
727pub struct WorkerConfig {
728 pub worker_id: String,
730 pub max_fetch_bytes: u32,
732 pub heartbeat_interval: Duration,
734 pub offset_dir: Option<PathBuf>,
736 pub start_position: SeekPosition,
738}
739
740impl WorkerConfig {
741 pub fn new(worker_id: impl Into<String>) -> Self {
743 Self {
744 worker_id: worker_id.into(),
745 max_fetch_bytes: 1_048_576,
746 heartbeat_interval: Duration::from_secs(3),
747 offset_dir: None,
748 start_position: SeekPosition::Beginning,
749 }
750 }
751
752 pub fn with_max_fetch_bytes(mut self, bytes: u32) -> Self {
754 self.max_fetch_bytes = bytes;
755 self
756 }
757
758 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
760 self.heartbeat_interval = interval;
761 self
762 }
763
764 pub fn with_offset_dir(mut self, dir: &Path) -> Self {
766 self.offset_dir = Some(dir.to_path_buf());
767 self
768 }
769
770 pub fn with_start_position(mut self, position: SeekPosition) -> Self {
772 self.start_position = position;
773 self
774 }
775}
776
777pub struct GroupedConsumer {
779 config: WorkerConfig,
780 server_addr: String,
781 coordinator_addr: SocketAddr,
782 assignments: Vec<u32>,
783 generation: u64,
784 consumers: HashMap<u32, StandaloneConsumer>,
785 #[allow(dead_code)]
786 offset_store: Arc<dyn OffsetStore>,
787 running: bool,
788}
789
790impl GroupedConsumer {
791 pub async fn join(
793 server_addr: &str,
794 coordinator_addr: SocketAddr,
795 config: WorkerConfig,
796 ) -> Result<Self> {
797 let mut stream = TcpStream::connect(coordinator_addr)
799 .await
800 .map_err(ClientError::ConnectionFailed)?;
801
802 use tokio::io::{AsyncReadExt, AsyncWriteExt};
803
804 let join_msg = format!("JOIN:{}", config.worker_id);
806 stream.write_all(join_msg.as_bytes()).await?;
807
808 let mut buf = vec![0u8; 4096];
810 let n = stream.read(&mut buf).await?;
811 let response = std::str::from_utf8(&buf[..n])
812 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
813
814 let (generation, assignments) = Self::parse_join_response(response)?;
816
817 let offset_store: Arc<dyn OffsetStore> = if let Some(ref dir) = config.offset_dir {
819 Arc::new(LockFileOffsetStore::open(dir, &config.worker_id)?)
820 } else {
821 Arc::new(MemoryOffsetStore::new())
822 };
823
824 let mut consumers = HashMap::new();
826 for topic_id in &assignments {
827 let standalone_config = StandaloneConfig::new(&config.worker_id, *topic_id)
828 .with_max_fetch_bytes(config.max_fetch_bytes)
829 .with_start_position(config.start_position);
830
831 if let Some(ref dir) = config.offset_dir {
832 let consumer = StandaloneConsumer::connect(
833 server_addr,
834 standalone_config.with_offset_dir(dir),
835 )
836 .await?;
837 consumers.insert(*topic_id, consumer);
838 } else {
839 let consumer = StandaloneConsumer::connect(server_addr, standalone_config).await?;
840 consumers.insert(*topic_id, consumer);
841 }
842 }
843
844 Ok(Self {
845 config,
846 server_addr: server_addr.to_string(),
847 coordinator_addr,
848 assignments,
849 generation,
850 consumers,
851 offset_store,
852 running: true,
853 })
854 }
855
856 pub fn assignments(&self) -> &[u32] {
858 &self.assignments
859 }
860
861 pub fn generation(&self) -> u64 {
863 self.generation
864 }
865
866 pub async fn next_batch(&mut self, topic_id: u32) -> Result<Option<PollResult>> {
868 if let Some(consumer) = self.consumers.get_mut(&topic_id) {
869 consumer.next_batch().await
870 } else {
871 Err(ClientError::InvalidResponse(format!(
872 "Topic {} not assigned to this worker",
873 topic_id
874 )))
875 }
876 }
877
878 #[inline]
880 pub async fn consume(&mut self, topic_id: u32) -> Result<Option<PollResult>> {
881 self.next_batch(topic_id).await
882 }
883
884 #[inline]
886 pub async fn poll(&mut self, topic_id: u32) -> Result<Option<PollResult>> {
887 self.next_batch(topic_id).await
888 }
889
890 pub async fn commit(&mut self, topic_id: u32) -> Result<()> {
892 if let Some(consumer) = self.consumers.get_mut(&topic_id) {
893 consumer.commit().await
894 } else {
895 Err(ClientError::InvalidResponse(format!(
896 "Topic {} not assigned to this worker",
897 topic_id
898 )))
899 }
900 }
901
902 pub async fn commit_all(&mut self) -> Result<()> {
904 for consumer in self.consumers.values_mut() {
905 consumer.commit().await?;
906 }
907 Ok(())
908 }
909
910 pub async fn heartbeat(&mut self) -> Result<u64> {
912 let mut stream = TcpStream::connect(self.coordinator_addr)
913 .await
914 .map_err(ClientError::ConnectionFailed)?;
915
916 use tokio::io::{AsyncReadExt, AsyncWriteExt};
917
918 let msg = format!("HEARTBEAT:{}", self.config.worker_id);
919 stream.write_all(msg.as_bytes()).await?;
920
921 let mut buf = vec![0u8; 1024];
922 let n = stream.read(&mut buf).await?;
923 let response = std::str::from_utf8(&buf[..n])
924 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
925
926 let new_gen = Self::parse_heartbeat_response(response)?;
928
929 if new_gen > self.generation {
930 self.refresh_assignments().await?;
932 }
933
934 Ok(self.generation)
935 }
936
937 async fn refresh_assignments(&mut self) -> Result<()> {
939 let mut stream = TcpStream::connect(self.coordinator_addr)
940 .await
941 .map_err(ClientError::ConnectionFailed)?;
942
943 use tokio::io::{AsyncReadExt, AsyncWriteExt};
944
945 let msg = format!("GET_ASSIGNMENTS:{}", self.config.worker_id);
946 stream.write_all(msg.as_bytes()).await?;
947
948 let mut buf = vec![0u8; 4096];
949 let n = stream.read(&mut buf).await?;
950 let response = std::str::from_utf8(&buf[..n])
951 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
952
953 let (generation, new_assignments) = Self::parse_assignments_response(response)?;
954
955 let old_set: HashSet<u32> = self.assignments.iter().copied().collect();
957 let new_set: HashSet<u32> = new_assignments.iter().copied().collect();
958
959 for topic_id in old_set.difference(&new_set) {
961 if let Some(consumer) = self.consumers.remove(topic_id) {
962 let _ = consumer.close().await;
963 }
964 }
965
966 for topic_id in new_set.difference(&old_set) {
968 let standalone_config = StandaloneConfig::new(&self.config.worker_id, *topic_id)
969 .with_max_fetch_bytes(self.config.max_fetch_bytes)
970 .with_start_position(self.config.start_position);
971
972 let consumer = if let Some(ref dir) = self.config.offset_dir {
973 StandaloneConsumer::connect(
974 &self.server_addr,
975 standalone_config.with_offset_dir(dir),
976 )
977 .await?
978 } else {
979 StandaloneConsumer::connect(&self.server_addr, standalone_config).await?
980 };
981
982 self.consumers.insert(*topic_id, consumer);
983 }
984
985 self.assignments = new_assignments;
986 self.generation = generation;
987
988 Ok(())
989 }
990
991 pub async fn leave(mut self) -> Result<()> {
993 self.commit_all().await?;
995
996 let mut stream = TcpStream::connect(self.coordinator_addr)
998 .await
999 .map_err(ClientError::ConnectionFailed)?;
1000
1001 use tokio::io::AsyncWriteExt;
1002
1003 let msg = format!("LEAVE:{}", self.config.worker_id);
1004 stream.write_all(msg.as_bytes()).await?;
1005
1006 for (_, consumer) in self.consumers.drain() {
1008 let _ = consumer.close().await;
1009 }
1010
1011 self.running = false;
1012 Ok(())
1013 }
1014
1015 fn parse_join_response(response: &str) -> Result<(u64, Vec<u32>)> {
1017 let generation = response
1021 .find("generation: ")
1022 .and_then(|i| {
1023 let start = i + 12;
1024 let end = response[start..].find(',')?;
1025 response[start..start + end].parse().ok()
1026 })
1027 .unwrap_or(0);
1028
1029 let assignments = response
1030 .find("assignments: [")
1031 .map(|i| {
1032 let start = i + 14;
1033 let end = response[start..].find(']').unwrap_or(0);
1034 response[start..start + end]
1035 .split(',')
1036 .filter_map(|s| s.trim().parse().ok())
1037 .collect()
1038 })
1039 .unwrap_or_default();
1040
1041 Ok((generation, assignments))
1042 }
1043
1044 fn parse_heartbeat_response(response: &str) -> Result<u64> {
1046 response
1047 .find("generation: ")
1048 .and_then(|i| {
1049 let start = i + 12;
1050 let end = response[start..]
1051 .find([',', ' ', '}'])
1052 .unwrap_or(response.len() - start);
1053 response[start..start + end].parse().ok()
1054 })
1055 .ok_or_else(|| ClientError::ProtocolError("Invalid heartbeat response".to_string()))
1056 }
1057
1058 fn parse_assignments_response(response: &str) -> Result<(u64, Vec<u32>)> {
1060 Self::parse_join_response(response)
1061 }
1062}
1063
1064impl std::fmt::Debug for GroupedConsumer {
1065 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1066 f.debug_struct("GroupedConsumer")
1067 .field("worker_id", &self.config.worker_id)
1068 .field("generation", &self.generation)
1069 .field("assignments", &self.assignments)
1070 .field("running", &self.running)
1071 .finish()
1072 }
1073}
1074
1075#[cfg(test)]
1076#[allow(clippy::unwrap_used)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_group_config_defaults() {
1082 let config = GroupConfig::new("test-group");
1083
1084 assert_eq!(config.group_id, "test-group");
1085 assert!(config.topics.is_empty());
1086 assert_eq!(config.assignment_strategy, AssignmentStrategy::RoundRobin);
1087 }
1088
1089 #[test]
1090 fn test_worker_config_defaults() {
1091 let config = WorkerConfig::new("worker-1");
1092
1093 assert_eq!(config.worker_id, "worker-1");
1094 assert_eq!(config.max_fetch_bytes, 1_048_576);
1095 }
1096
1097 #[test]
1098 fn test_round_robin_assignment() {
1099 let topics = vec![1, 2, 3, 4, 5, 6];
1100 let workers = vec!["w1".to_string(), "w2".to_string(), "w3".to_string()];
1101
1102 let assignments = GroupCoordinator::assign_round_robin(&topics, &workers);
1103
1104 assert_eq!(assignments.get("w1"), Some(&vec![1, 4]));
1105 assert_eq!(assignments.get("w2"), Some(&vec![2, 5]));
1106 assert_eq!(assignments.get("w3"), Some(&vec![3, 6]));
1107 }
1108
1109 #[test]
1110 fn test_range_assignment() {
1111 let topics = vec![1, 2, 3, 4, 5, 6, 7];
1112 let workers = vec!["w1".to_string(), "w2".to_string(), "w3".to_string()];
1113
1114 let assignments = GroupCoordinator::assign_range(&topics, &workers);
1115
1116 assert_eq!(assignments.get("w1").map(|v| v.len()), Some(3));
1119 assert_eq!(assignments.get("w2").map(|v| v.len()), Some(2));
1120 assert_eq!(assignments.get("w3").map(|v| v.len()), Some(2));
1121 }
1122
1123 #[test]
1124 fn test_sticky_assignment_preserves_existing() {
1125 let topics = vec![1, 2, 3, 4];
1126 let workers = vec!["w1".to_string(), "w2".to_string()];
1127
1128 let mut existing = HashMap::new();
1129 existing.insert("w1".to_string(), vec![1, 2]);
1130 existing.insert("w2".to_string(), vec![3, 4]);
1131
1132 let assignments = GroupCoordinator::assign_sticky(&topics, &workers, &existing);
1133
1134 assert_eq!(assignments.get("w1"), Some(&vec![1, 2]));
1136 assert_eq!(assignments.get("w2"), Some(&vec![3, 4]));
1137 }
1138
1139 #[test]
1140 fn test_assignment_strategy_default() {
1141 assert_eq!(
1142 AssignmentStrategy::default(),
1143 AssignmentStrategy::RoundRobin
1144 );
1145 }
1146}