1use std::collections::{HashMap, HashSet};
71use std::net::SocketAddr;
72use std::path::{Path, PathBuf};
73use std::sync::Arc;
74use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
75use std::time::{Duration, Instant};
76
77use tokio::net::{TcpListener, TcpStream};
78use tokio::sync::{RwLock, broadcast};
79
80use crate::client::{ClientConfig, LanceClient};
81use crate::consumer::{PollResult, SeekPosition};
82use crate::error::{ClientError, Result};
83use crate::offset::{LockFileOffsetStore, MemoryOffsetStore, OffsetStore};
84use crate::standalone::{StandaloneConfig, StandaloneConsumer};
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
88pub enum AssignmentStrategy {
89 #[default]
91 RoundRobin,
92 Range,
94 Sticky,
96}
97
98#[derive(Debug, Clone)]
100pub struct GroupConfig {
101 pub group_id: String,
103 pub topics: Vec<u32>,
105 pub assignment_strategy: AssignmentStrategy,
107 pub heartbeat_interval: Duration,
109 pub session_timeout: Duration,
111 pub coordinator_addr: SocketAddr,
113 pub server_addr: String,
115 pub offset_dir: Option<PathBuf>,
117}
118
119impl GroupConfig {
120 pub fn new(group_id: impl Into<String>) -> Self {
122 Self {
123 group_id: group_id.into(),
124 topics: Vec::new(),
125 assignment_strategy: AssignmentStrategy::RoundRobin,
126 heartbeat_interval: Duration::from_secs(3),
127 session_timeout: Duration::from_secs(30),
128 coordinator_addr: "127.0.0.1:19920"
129 .parse()
130 .unwrap_or_else(|_| SocketAddr::from(([127, 0, 0, 1], 19920))),
131 server_addr: "127.0.0.1:1992".to_string(),
132 offset_dir: None,
133 }
134 }
135
136 pub fn with_topics(mut self, topics: Vec<u32>) -> Self {
138 self.topics = topics;
139 self
140 }
141
142 pub fn with_assignment_strategy(mut self, strategy: AssignmentStrategy) -> Self {
144 self.assignment_strategy = strategy;
145 self
146 }
147
148 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
150 self.heartbeat_interval = interval;
151 self
152 }
153
154 pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
156 self.session_timeout = timeout;
157 self
158 }
159
160 pub fn with_coordinator_addr(mut self, addr: SocketAddr) -> Self {
162 self.coordinator_addr = addr;
163 self
164 }
165
166 pub fn with_server_addr(mut self, addr: impl Into<String>) -> Self {
168 self.server_addr = addr.into();
169 self
170 }
171
172 pub fn with_offset_dir(mut self, dir: &Path) -> Self {
174 self.offset_dir = Some(dir.to_path_buf());
175 self
176 }
177}
178
179#[derive(Debug, Clone)]
181#[allow(dead_code)]
182enum CoordinatorMessage {
183 Join { worker_id: String },
185 Leave { worker_id: String },
187 Heartbeat { worker_id: String },
189 GetAssignments { worker_id: String },
191}
192
193#[derive(Debug, Clone)]
195#[allow(dead_code)]
196enum CoordinatorResponse {
197 JoinAccepted {
199 worker_id: String,
200 generation: u64,
201 assignments: Vec<u32>,
202 },
203 LeaveAcknowledged,
205 HeartbeatAck { generation: u64 },
207 Assignments {
209 generation: u64,
210 assignments: Vec<u32>,
211 },
212 Rebalance {
214 generation: u64,
215 assignments: Vec<u32>,
216 },
217 Error { message: String },
219}
220
221#[derive(Debug, Clone)]
223#[allow(dead_code)]
224struct WorkerState {
225 worker_id: String,
226 last_heartbeat: Instant,
227 assignments: Vec<u32>,
228 generation: u64,
229}
230
231pub struct GroupCoordinator {
239 #[allow(dead_code)]
240 config: GroupConfig,
241 workers: Arc<RwLock<HashMap<String, WorkerState>>>,
242 generation: Arc<AtomicU64>,
243 running: Arc<AtomicBool>,
244 shutdown_tx: broadcast::Sender<()>,
245 join_addr: SocketAddr,
246}
247
248impl GroupCoordinator {
249 pub async fn start(server_addr: &str, mut config: GroupConfig) -> Result<Self> {
251 config.server_addr = server_addr.to_string();
252
253 if config.topics.is_empty() {
255 let client_config = ClientConfig::new(&config.server_addr);
256 let mut client = LanceClient::connect(client_config).await?;
257 let topics = client.list_topics().await?;
258 config.topics = topics.iter().map(|t| t.id).collect();
259 }
260
261 let workers = Arc::new(RwLock::new(HashMap::new()));
262 let generation = Arc::new(AtomicU64::new(0));
263 let running = Arc::new(AtomicBool::new(true));
264 let (shutdown_tx, _) = broadcast::channel(1);
265
266 let listener = TcpListener::bind(config.coordinator_addr)
267 .await
268 .map_err(ClientError::ConnectionFailed)?;
269
270 let join_addr = listener.local_addr().map_err(ClientError::IoError)?;
271
272 let coordinator = Self {
273 config: config.clone(),
274 workers: workers.clone(),
275 generation: generation.clone(),
276 running: running.clone(),
277 shutdown_tx: shutdown_tx.clone(),
278 join_addr,
279 };
280
281 let workers_clone = workers.clone();
283 let generation_clone = generation.clone();
284 let _running_clone = running.clone();
285 let config_clone = config.clone();
286 let mut shutdown_rx = shutdown_tx.subscribe();
287
288 tokio::spawn(async move {
290 loop {
291 tokio::select! {
292 accept_result = listener.accept() => {
293 match accept_result {
294 Ok((stream, addr)) => {
295 let workers = workers_clone.clone();
296 let generation = generation_clone.clone();
297 let config = config_clone.clone();
298 tokio::spawn(async move {
299 if let Err(e) = Self::handle_worker_connection(
300 stream, addr, workers, generation, config
301 ).await {
302 tracing::warn!("Worker connection error: {}", e);
303 }
304 });
305 }
306 Err(e) => {
307 tracing::error!("Accept error: {}", e);
308 }
309 }
310 }
311 _ = shutdown_rx.recv() => {
312 tracing::info!("Coordinator shutting down");
313 break;
314 }
315 }
316 }
317 });
318
319 let workers_checker = workers.clone();
321 let generation_checker = generation.clone();
322 let running_checker = running.clone();
323 let session_timeout = config.session_timeout;
324 let mut shutdown_rx2 = shutdown_tx.subscribe();
325
326 tokio::spawn(async move {
327 let mut interval = tokio::time::interval(Duration::from_secs(1));
328 loop {
329 tokio::select! {
330 _ = interval.tick() => {
331 if !running_checker.load(Ordering::Relaxed) {
332 break;
333 }
334 Self::check_worker_health(
335 &workers_checker,
336 &generation_checker,
337 session_timeout,
338 ).await;
339 }
340 _ = shutdown_rx2.recv() => {
341 break;
342 }
343 }
344 }
345 });
346
347 Ok(coordinator)
348 }
349
350 pub fn join_address(&self) -> SocketAddr {
352 self.join_addr
353 }
354
355 pub fn generation(&self) -> u64 {
357 self.generation.load(Ordering::Relaxed)
358 }
359
360 pub async fn worker_count(&self) -> usize {
362 self.workers.read().await.len()
363 }
364
365 pub async fn get_assignments(&self) -> HashMap<String, Vec<u32>> {
367 self.workers
368 .read()
369 .await
370 .iter()
371 .map(|(id, state)| (id.clone(), state.assignments.clone()))
372 .collect()
373 }
374
375 pub fn stop(&self) {
377 self.running.store(false, Ordering::Relaxed);
378 let _ = self.shutdown_tx.send(());
379 }
380
381 async fn handle_worker_connection(
383 stream: TcpStream,
384 _addr: SocketAddr,
385 workers: Arc<RwLock<HashMap<String, WorkerState>>>,
386 generation: Arc<AtomicU64>,
387 config: GroupConfig,
388 ) -> Result<()> {
389 use tokio::io::{AsyncReadExt, AsyncWriteExt};
392
393 let (mut reader, mut writer) = stream.into_split();
394 let mut buf = vec![0u8; 4096];
395
396 loop {
397 let n = reader.read(&mut buf).await?;
399 if n == 0 {
400 break; }
402
403 let msg_str = std::str::from_utf8(&buf[..n])
405 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
406
407 let response = Self::process_message(msg_str, &workers, &generation, &config).await;
408
409 let response_bytes = format!("{:?}", response);
411 writer.write_all(response_bytes.as_bytes()).await?;
412 }
413
414 Ok(())
415 }
416
417 async fn process_message(
419 msg: &str,
420 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
421 generation: &Arc<AtomicU64>,
422 config: &GroupConfig,
423 ) -> CoordinatorResponse {
424 if msg.starts_with("JOIN:") {
426 let worker_id = msg
427 .strip_prefix("JOIN:")
428 .unwrap_or("unknown")
429 .trim()
430 .to_string();
431 Self::handle_join(worker_id, workers, generation, config).await
432 } else if msg.starts_with("LEAVE:") {
433 let worker_id = msg
434 .strip_prefix("LEAVE:")
435 .unwrap_or("unknown")
436 .trim()
437 .to_string();
438 Self::handle_leave(worker_id, workers, generation, config).await
439 } else if msg.starts_with("HEARTBEAT:") {
440 let worker_id = msg
441 .strip_prefix("HEARTBEAT:")
442 .unwrap_or("unknown")
443 .trim()
444 .to_string();
445 Self::handle_heartbeat(worker_id, workers, generation).await
446 } else if msg.starts_with("GET_ASSIGNMENTS:") {
447 let worker_id = msg
448 .strip_prefix("GET_ASSIGNMENTS:")
449 .unwrap_or("unknown")
450 .trim()
451 .to_string();
452 Self::handle_get_assignments(worker_id, workers, generation).await
453 } else {
454 CoordinatorResponse::Error {
455 message: format!("Unknown message: {}", msg),
456 }
457 }
458 }
459
460 async fn handle_join(
462 worker_id: String,
463 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
464 generation: &Arc<AtomicU64>,
465 config: &GroupConfig,
466 ) -> CoordinatorResponse {
467 let new_gen = generation.fetch_add(1, Ordering::SeqCst) + 1;
468
469 {
470 let mut workers_lock = workers.write().await;
471 workers_lock.insert(
472 worker_id.clone(),
473 WorkerState {
474 worker_id: worker_id.clone(),
475 last_heartbeat: Instant::now(),
476 assignments: Vec::new(),
477 generation: new_gen,
478 },
479 );
480 }
481
482 let _assignments = Self::rebalance(workers, config).await;
484
485 let worker_assignments = {
487 let workers_lock = workers.read().await;
488 workers_lock
489 .get(&worker_id)
490 .map(|w| w.assignments.clone())
491 .unwrap_or_default()
492 };
493
494 CoordinatorResponse::JoinAccepted {
495 worker_id,
496 generation: new_gen,
497 assignments: worker_assignments,
498 }
499 }
500
501 async fn handle_leave(
503 worker_id: String,
504 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
505 generation: &Arc<AtomicU64>,
506 config: &GroupConfig,
507 ) -> CoordinatorResponse {
508 {
509 let mut workers_lock = workers.write().await;
510 workers_lock.remove(&worker_id);
511 }
512
513 generation.fetch_add(1, Ordering::SeqCst);
514
515 Self::rebalance(workers, config).await;
517
518 CoordinatorResponse::LeaveAcknowledged
519 }
520
521 async fn handle_heartbeat(
523 worker_id: String,
524 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
525 generation: &Arc<AtomicU64>,
526 ) -> CoordinatorResponse {
527 let current_gen = generation.load(Ordering::Relaxed);
528
529 let mut workers_lock = workers.write().await;
530 if let Some(worker) = workers_lock.get_mut(&worker_id) {
531 worker.last_heartbeat = Instant::now();
532 CoordinatorResponse::HeartbeatAck {
533 generation: current_gen,
534 }
535 } else {
536 CoordinatorResponse::Error {
537 message: "Worker not found".to_string(),
538 }
539 }
540 }
541
542 async fn handle_get_assignments(
544 worker_id: String,
545 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
546 generation: &Arc<AtomicU64>,
547 ) -> CoordinatorResponse {
548 let current_gen = generation.load(Ordering::Relaxed);
549 let workers_lock = workers.read().await;
550
551 if let Some(worker) = workers_lock.get(&worker_id) {
552 CoordinatorResponse::Assignments {
553 generation: current_gen,
554 assignments: worker.assignments.clone(),
555 }
556 } else {
557 CoordinatorResponse::Error {
558 message: "Worker not found".to_string(),
559 }
560 }
561 }
562
563 async fn check_worker_health(
565 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
566 generation: &Arc<AtomicU64>,
567 session_timeout: Duration,
568 ) {
569 let now = Instant::now();
570 let mut dead_workers = Vec::new();
571
572 {
573 let workers_lock = workers.read().await;
574 for (id, state) in workers_lock.iter() {
575 if now.duration_since(state.last_heartbeat) > session_timeout {
576 dead_workers.push(id.clone());
577 }
578 }
579 }
580
581 if !dead_workers.is_empty() {
582 let mut workers_lock = workers.write().await;
583 for id in dead_workers {
584 tracing::warn!("Worker {} timed out, removing from group", id);
585 workers_lock.remove(&id);
586 }
587 generation.fetch_add(1, Ordering::SeqCst);
588 }
590 }
591
592 async fn rebalance(
594 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
595 config: &GroupConfig,
596 ) -> HashMap<String, Vec<u32>> {
597 let mut workers_lock = workers.write().await;
598 let worker_ids: Vec<String> = workers_lock.keys().cloned().collect();
599
600 if worker_ids.is_empty() {
601 return HashMap::new();
602 }
603
604 let assignments = match config.assignment_strategy {
605 AssignmentStrategy::RoundRobin => Self::assign_round_robin(&config.topics, &worker_ids),
606 AssignmentStrategy::Range => Self::assign_range(&config.topics, &worker_ids),
607 AssignmentStrategy::Sticky => {
608 let existing: HashMap<String, Vec<u32>> = workers_lock
610 .iter()
611 .map(|(id, state)| (id.clone(), state.assignments.clone()))
612 .collect();
613 Self::assign_sticky(&config.topics, &worker_ids, &existing)
614 },
615 };
616
617 for (worker_id, topics) in &assignments {
619 if let Some(worker) = workers_lock.get_mut(worker_id) {
620 worker.assignments = topics.clone();
621 }
622 }
623
624 assignments
625 }
626
627 fn assign_round_robin(topics: &[u32], workers: &[String]) -> HashMap<String, Vec<u32>> {
629 let mut assignments: HashMap<String, Vec<u32>> =
630 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
631
632 for (i, topic) in topics.iter().enumerate() {
633 let worker = &workers[i % workers.len()];
634 if let Some(topics) = assignments.get_mut(worker) {
635 topics.push(*topic);
636 }
637 }
638
639 assignments
640 }
641
642 fn assign_range(topics: &[u32], workers: &[String]) -> HashMap<String, Vec<u32>> {
644 let mut assignments: HashMap<String, Vec<u32>> =
645 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
646
647 let topics_per_worker = topics.len() / workers.len();
648 let remainder = topics.len() % workers.len();
649
650 let mut topic_idx = 0;
651 for (worker_idx, worker) in workers.iter().enumerate() {
652 let extra = if worker_idx < remainder { 1 } else { 0 };
653 let count = topics_per_worker + extra;
654
655 if let Some(worker_topics) = assignments.get_mut(worker) {
656 for _ in 0..count {
657 if topic_idx < topics.len() {
658 worker_topics.push(topics[topic_idx]);
659 topic_idx += 1;
660 }
661 }
662 }
663 }
664
665 assignments
666 }
667
668 fn assign_sticky(
670 topics: &[u32],
671 workers: &[String],
672 existing: &HashMap<String, Vec<u32>>,
673 ) -> HashMap<String, Vec<u32>> {
674 let mut assignments: HashMap<String, Vec<u32>> =
675 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
676
677 let topic_set: HashSet<u32> = topics.iter().copied().collect();
678 let mut assigned: HashSet<u32> = HashSet::new();
679
680 for (worker, old_topics) in existing {
682 if assignments.contains_key(worker) {
683 for topic in old_topics {
684 if topic_set.contains(topic) && !assigned.contains(topic) {
685 if let Some(worker_topics) = assignments.get_mut(worker) {
686 worker_topics.push(*topic);
687 assigned.insert(*topic);
688 }
689 }
690 }
691 }
692 }
693
694 let unassigned: Vec<u32> = topics
696 .iter()
697 .filter(|t| !assigned.contains(t))
698 .copied()
699 .collect();
700
701 for topic in unassigned {
703 let min_worker = assignments
704 .iter()
705 .min_by_key(|(_, topics)| topics.len())
706 .map(|(w, _)| w.clone());
707
708 if let Some(worker) = min_worker {
709 if let Some(worker_topics) = assignments.get_mut(&worker) {
710 worker_topics.push(topic);
711 }
712 }
713 }
714
715 assignments
716 }
717}
718
719impl Drop for GroupCoordinator {
720 fn drop(&mut self) {
721 self.stop();
722 }
723}
724
725#[derive(Debug, Clone)]
727pub struct WorkerConfig {
728 pub worker_id: String,
730 pub max_fetch_bytes: u32,
732 pub heartbeat_interval: Duration,
734 pub offset_dir: Option<PathBuf>,
736 pub start_position: SeekPosition,
738}
739
740impl WorkerConfig {
741 pub fn new(worker_id: impl Into<String>) -> Self {
743 Self {
744 worker_id: worker_id.into(),
745 max_fetch_bytes: 1_048_576,
746 heartbeat_interval: Duration::from_secs(3),
747 offset_dir: None,
748 start_position: SeekPosition::Beginning,
749 }
750 }
751
752 pub fn with_max_fetch_bytes(mut self, bytes: u32) -> Self {
754 self.max_fetch_bytes = bytes;
755 self
756 }
757
758 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
760 self.heartbeat_interval = interval;
761 self
762 }
763
764 pub fn with_offset_dir(mut self, dir: &Path) -> Self {
766 self.offset_dir = Some(dir.to_path_buf());
767 self
768 }
769
770 pub fn with_start_position(mut self, position: SeekPosition) -> Self {
772 self.start_position = position;
773 self
774 }
775}
776
777pub struct GroupedConsumer {
779 config: WorkerConfig,
780 server_addr: String,
781 coordinator_addr: SocketAddr,
782 assignments: Vec<u32>,
783 generation: u64,
784 consumers: HashMap<u32, StandaloneConsumer>,
785 #[allow(dead_code)]
786 offset_store: Arc<dyn OffsetStore>,
787 running: bool,
788}
789
790impl GroupedConsumer {
791 pub async fn join(
793 server_addr: &str,
794 coordinator_addr: SocketAddr,
795 config: WorkerConfig,
796 ) -> Result<Self> {
797 let mut stream = TcpStream::connect(coordinator_addr)
799 .await
800 .map_err(ClientError::ConnectionFailed)?;
801
802 use tokio::io::{AsyncReadExt, AsyncWriteExt};
803
804 let join_msg = format!("JOIN:{}", config.worker_id);
806 stream.write_all(join_msg.as_bytes()).await?;
807
808 let mut buf = vec![0u8; 4096];
810 let n = stream.read(&mut buf).await?;
811 let response = std::str::from_utf8(&buf[..n])
812 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
813
814 let (generation, assignments) = Self::parse_join_response(response)?;
816
817 let offset_store: Arc<dyn OffsetStore> = if let Some(ref dir) = config.offset_dir {
819 Arc::new(LockFileOffsetStore::open(dir, &config.worker_id)?)
820 } else {
821 Arc::new(MemoryOffsetStore::new())
822 };
823
824 let mut consumers = HashMap::new();
826 for topic_id in &assignments {
827 let standalone_config = StandaloneConfig::new(&config.worker_id, *topic_id)
828 .with_max_fetch_bytes(config.max_fetch_bytes)
829 .with_start_position(config.start_position);
830
831 if let Some(ref dir) = config.offset_dir {
832 let consumer = StandaloneConsumer::connect(
833 server_addr,
834 standalone_config.with_offset_dir(dir),
835 )
836 .await?;
837 consumers.insert(*topic_id, consumer);
838 } else {
839 let consumer = StandaloneConsumer::connect(server_addr, standalone_config).await?;
840 consumers.insert(*topic_id, consumer);
841 }
842 }
843
844 Ok(Self {
845 config,
846 server_addr: server_addr.to_string(),
847 coordinator_addr,
848 assignments,
849 generation,
850 consumers,
851 offset_store,
852 running: true,
853 })
854 }
855
856 pub fn assignments(&self) -> &[u32] {
858 &self.assignments
859 }
860
861 pub fn generation(&self) -> u64 {
863 self.generation
864 }
865
866 pub async fn poll(&mut self, topic_id: u32) -> Result<Option<PollResult>> {
868 if let Some(consumer) = self.consumers.get_mut(&topic_id) {
869 consumer.poll().await
870 } else {
871 Err(ClientError::InvalidResponse(format!(
872 "Topic {} not assigned to this worker",
873 topic_id
874 )))
875 }
876 }
877
878 pub async fn commit(&mut self, topic_id: u32) -> Result<()> {
880 if let Some(consumer) = self.consumers.get_mut(&topic_id) {
881 consumer.commit().await
882 } else {
883 Err(ClientError::InvalidResponse(format!(
884 "Topic {} not assigned to this worker",
885 topic_id
886 )))
887 }
888 }
889
890 pub async fn commit_all(&mut self) -> Result<()> {
892 for consumer in self.consumers.values_mut() {
893 consumer.commit().await?;
894 }
895 Ok(())
896 }
897
898 pub async fn heartbeat(&mut self) -> Result<u64> {
900 let mut stream = TcpStream::connect(self.coordinator_addr)
901 .await
902 .map_err(ClientError::ConnectionFailed)?;
903
904 use tokio::io::{AsyncReadExt, AsyncWriteExt};
905
906 let msg = format!("HEARTBEAT:{}", self.config.worker_id);
907 stream.write_all(msg.as_bytes()).await?;
908
909 let mut buf = vec![0u8; 1024];
910 let n = stream.read(&mut buf).await?;
911 let response = std::str::from_utf8(&buf[..n])
912 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
913
914 let new_gen = Self::parse_heartbeat_response(response)?;
916
917 if new_gen > self.generation {
918 self.refresh_assignments().await?;
920 }
921
922 Ok(self.generation)
923 }
924
925 async fn refresh_assignments(&mut self) -> Result<()> {
927 let mut stream = TcpStream::connect(self.coordinator_addr)
928 .await
929 .map_err(ClientError::ConnectionFailed)?;
930
931 use tokio::io::{AsyncReadExt, AsyncWriteExt};
932
933 let msg = format!("GET_ASSIGNMENTS:{}", self.config.worker_id);
934 stream.write_all(msg.as_bytes()).await?;
935
936 let mut buf = vec![0u8; 4096];
937 let n = stream.read(&mut buf).await?;
938 let response = std::str::from_utf8(&buf[..n])
939 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
940
941 let (generation, new_assignments) = Self::parse_assignments_response(response)?;
942
943 let old_set: HashSet<u32> = self.assignments.iter().copied().collect();
945 let new_set: HashSet<u32> = new_assignments.iter().copied().collect();
946
947 for topic_id in old_set.difference(&new_set) {
949 if let Some(consumer) = self.consumers.remove(topic_id) {
950 let _ = consumer.close().await;
951 }
952 }
953
954 for topic_id in new_set.difference(&old_set) {
956 let standalone_config = StandaloneConfig::new(&self.config.worker_id, *topic_id)
957 .with_max_fetch_bytes(self.config.max_fetch_bytes)
958 .with_start_position(self.config.start_position);
959
960 let consumer = if let Some(ref dir) = self.config.offset_dir {
961 StandaloneConsumer::connect(
962 &self.server_addr,
963 standalone_config.with_offset_dir(dir),
964 )
965 .await?
966 } else {
967 StandaloneConsumer::connect(&self.server_addr, standalone_config).await?
968 };
969
970 self.consumers.insert(*topic_id, consumer);
971 }
972
973 self.assignments = new_assignments;
974 self.generation = generation;
975
976 Ok(())
977 }
978
979 pub async fn leave(mut self) -> Result<()> {
981 self.commit_all().await?;
983
984 let mut stream = TcpStream::connect(self.coordinator_addr)
986 .await
987 .map_err(ClientError::ConnectionFailed)?;
988
989 use tokio::io::AsyncWriteExt;
990
991 let msg = format!("LEAVE:{}", self.config.worker_id);
992 stream.write_all(msg.as_bytes()).await?;
993
994 for (_, consumer) in self.consumers.drain() {
996 let _ = consumer.close().await;
997 }
998
999 self.running = false;
1000 Ok(())
1001 }
1002
1003 fn parse_join_response(response: &str) -> Result<(u64, Vec<u32>)> {
1005 let generation = response
1009 .find("generation: ")
1010 .and_then(|i| {
1011 let start = i + 12;
1012 let end = response[start..].find(',')?;
1013 response[start..start + end].parse().ok()
1014 })
1015 .unwrap_or(0);
1016
1017 let assignments = response
1018 .find("assignments: [")
1019 .map(|i| {
1020 let start = i + 14;
1021 let end = response[start..].find(']').unwrap_or(0);
1022 response[start..start + end]
1023 .split(',')
1024 .filter_map(|s| s.trim().parse().ok())
1025 .collect()
1026 })
1027 .unwrap_or_default();
1028
1029 Ok((generation, assignments))
1030 }
1031
1032 fn parse_heartbeat_response(response: &str) -> Result<u64> {
1034 response
1035 .find("generation: ")
1036 .and_then(|i| {
1037 let start = i + 12;
1038 let end = response[start..]
1039 .find([',', ' ', '}'])
1040 .unwrap_or(response.len() - start);
1041 response[start..start + end].parse().ok()
1042 })
1043 .ok_or_else(|| ClientError::ProtocolError("Invalid heartbeat response".to_string()))
1044 }
1045
1046 fn parse_assignments_response(response: &str) -> Result<(u64, Vec<u32>)> {
1048 Self::parse_join_response(response)
1049 }
1050}
1051
1052impl std::fmt::Debug for GroupedConsumer {
1053 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1054 f.debug_struct("GroupedConsumer")
1055 .field("worker_id", &self.config.worker_id)
1056 .field("generation", &self.generation)
1057 .field("assignments", &self.assignments)
1058 .field("running", &self.running)
1059 .finish()
1060 }
1061}
1062
1063#[cfg(test)]
1064#[allow(clippy::unwrap_used)]
1065mod tests {
1066 use super::*;
1067
1068 #[test]
1069 fn test_group_config_defaults() {
1070 let config = GroupConfig::new("test-group");
1071
1072 assert_eq!(config.group_id, "test-group");
1073 assert!(config.topics.is_empty());
1074 assert_eq!(config.assignment_strategy, AssignmentStrategy::RoundRobin);
1075 }
1076
1077 #[test]
1078 fn test_worker_config_defaults() {
1079 let config = WorkerConfig::new("worker-1");
1080
1081 assert_eq!(config.worker_id, "worker-1");
1082 assert_eq!(config.max_fetch_bytes, 1_048_576);
1083 }
1084
1085 #[test]
1086 fn test_round_robin_assignment() {
1087 let topics = vec![1, 2, 3, 4, 5, 6];
1088 let workers = vec!["w1".to_string(), "w2".to_string(), "w3".to_string()];
1089
1090 let assignments = GroupCoordinator::assign_round_robin(&topics, &workers);
1091
1092 assert_eq!(assignments.get("w1"), Some(&vec![1, 4]));
1093 assert_eq!(assignments.get("w2"), Some(&vec![2, 5]));
1094 assert_eq!(assignments.get("w3"), Some(&vec![3, 6]));
1095 }
1096
1097 #[test]
1098 fn test_range_assignment() {
1099 let topics = vec![1, 2, 3, 4, 5, 6, 7];
1100 let workers = vec!["w1".to_string(), "w2".to_string(), "w3".to_string()];
1101
1102 let assignments = GroupCoordinator::assign_range(&topics, &workers);
1103
1104 assert_eq!(assignments.get("w1").map(|v| v.len()), Some(3));
1107 assert_eq!(assignments.get("w2").map(|v| v.len()), Some(2));
1108 assert_eq!(assignments.get("w3").map(|v| v.len()), Some(2));
1109 }
1110
1111 #[test]
1112 fn test_sticky_assignment_preserves_existing() {
1113 let topics = vec![1, 2, 3, 4];
1114 let workers = vec!["w1".to_string(), "w2".to_string()];
1115
1116 let mut existing = HashMap::new();
1117 existing.insert("w1".to_string(), vec![1, 2]);
1118 existing.insert("w2".to_string(), vec![3, 4]);
1119
1120 let assignments = GroupCoordinator::assign_sticky(&topics, &workers, &existing);
1121
1122 assert_eq!(assignments.get("w1"), Some(&vec![1, 2]));
1124 assert_eq!(assignments.get("w2"), Some(&vec![3, 4]));
1125 }
1126
1127 #[test]
1128 fn test_assignment_strategy_default() {
1129 assert_eq!(
1130 AssignmentStrategy::default(),
1131 AssignmentStrategy::RoundRobin
1132 );
1133 }
1134}