1use std::collections::{HashMap, HashSet};
71use std::net::SocketAddr;
72use std::path::{Path, PathBuf};
73use std::sync::Arc;
74use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
75use std::time::{Duration, Instant};
76
77use tokio::net::{TcpListener, TcpStream};
78use tokio::sync::{RwLock, broadcast};
79
80use crate::client::{ClientConfig, LanceClient};
81use crate::consumer::{PollResult, SeekPosition};
82use crate::error::{ClientError, Result};
83use crate::offset::{LockFileOffsetStore, MemoryOffsetStore, OffsetStore};
84use crate::standalone::{StandaloneConfig, StandaloneConsumer};
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
88pub enum AssignmentStrategy {
89 #[default]
91 RoundRobin,
92 Range,
94 Sticky,
96}
97
98#[derive(Debug, Clone)]
100pub struct GroupConfig {
101 pub group_id: String,
103 pub topics: Vec<u32>,
105 pub assignment_strategy: AssignmentStrategy,
107 pub heartbeat_interval: Duration,
109 pub session_timeout: Duration,
111 pub coordinator_addr: SocketAddr,
113 pub server_addr: String,
115 pub offset_dir: Option<PathBuf>,
117}
118
119impl GroupConfig {
120 pub fn new(group_id: impl Into<String>) -> Self {
122 Self {
123 group_id: group_id.into(),
124 topics: Vec::new(),
125 assignment_strategy: AssignmentStrategy::RoundRobin,
126 heartbeat_interval: Duration::from_secs(3),
127 session_timeout: Duration::from_secs(30),
128 coordinator_addr: "127.0.0.1:19920"
129 .parse()
130 .unwrap_or_else(|_| SocketAddr::from(([127, 0, 0, 1], 19920))),
131 server_addr: "127.0.0.1:1992".to_string(),
132 offset_dir: None,
133 }
134 }
135
136 pub fn with_topics(mut self, topics: Vec<u32>) -> Self {
138 self.topics = topics;
139 self
140 }
141
142 pub fn with_assignment_strategy(mut self, strategy: AssignmentStrategy) -> Self {
144 self.assignment_strategy = strategy;
145 self
146 }
147
148 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
150 self.heartbeat_interval = interval;
151 self
152 }
153
154 pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
156 self.session_timeout = timeout;
157 self
158 }
159
160 pub fn with_coordinator_addr(mut self, addr: SocketAddr) -> Self {
162 self.coordinator_addr = addr;
163 self
164 }
165
166 pub fn with_server_addr(mut self, addr: impl Into<String>) -> Self {
168 self.server_addr = addr.into();
169 self
170 }
171
172 pub fn with_offset_dir(mut self, dir: &Path) -> Self {
174 self.offset_dir = Some(dir.to_path_buf());
175 self
176 }
177}
178
179#[derive(Debug, Clone)]
181#[allow(dead_code)]
182enum CoordinatorMessage {
183 Join { worker_id: String },
185 Leave { worker_id: String },
187 Heartbeat { worker_id: String },
189 GetAssignments { worker_id: String },
191}
192
193#[derive(Debug, Clone)]
195#[allow(dead_code)]
196enum CoordinatorResponse {
197 JoinAccepted {
199 worker_id: String,
200 generation: u64,
201 assignments: Vec<u32>,
202 },
203 LeaveAcknowledged,
205 HeartbeatAck { generation: u64 },
207 Assignments {
209 generation: u64,
210 assignments: Vec<u32>,
211 },
212 Rebalance {
214 generation: u64,
215 assignments: Vec<u32>,
216 },
217 Error { message: String },
219}
220
221#[derive(Debug, Clone)]
223#[allow(dead_code)]
224struct WorkerState {
225 worker_id: String,
226 last_heartbeat: Instant,
227 assignments: Vec<u32>,
228 generation: u64,
229}
230
231pub struct GroupCoordinator {
239 #[allow(dead_code)]
240 config: GroupConfig,
241 workers: Arc<RwLock<HashMap<String, WorkerState>>>,
242 generation: Arc<AtomicU64>,
243 running: Arc<AtomicBool>,
244 shutdown_tx: broadcast::Sender<()>,
245 join_addr: SocketAddr,
246}
247
248impl GroupCoordinator {
249 pub async fn start(server_addr: &str, mut config: GroupConfig) -> Result<Self> {
251 config.server_addr = server_addr.to_string();
252
253 if config.topics.is_empty() {
255 let socket_addr: SocketAddr = config.server_addr.parse().map_err(|e| {
256 ClientError::ProtocolError(format!("Invalid server address: {}", e))
257 })?;
258 let client_config = ClientConfig::new(socket_addr);
259 let mut client = LanceClient::connect(client_config).await?;
260 let topics = client.list_topics().await?;
261 config.topics = topics.iter().map(|t| t.id).collect();
262 }
263
264 let workers = Arc::new(RwLock::new(HashMap::new()));
265 let generation = Arc::new(AtomicU64::new(0));
266 let running = Arc::new(AtomicBool::new(true));
267 let (shutdown_tx, _) = broadcast::channel(1);
268
269 let listener = TcpListener::bind(config.coordinator_addr)
270 .await
271 .map_err(ClientError::ConnectionFailed)?;
272
273 let join_addr = listener.local_addr().map_err(ClientError::IoError)?;
274
275 let coordinator = Self {
276 config: config.clone(),
277 workers: workers.clone(),
278 generation: generation.clone(),
279 running: running.clone(),
280 shutdown_tx: shutdown_tx.clone(),
281 join_addr,
282 };
283
284 let workers_clone = workers.clone();
286 let generation_clone = generation.clone();
287 let _running_clone = running.clone();
288 let config_clone = config.clone();
289 let mut shutdown_rx = shutdown_tx.subscribe();
290
291 tokio::spawn(async move {
293 loop {
294 tokio::select! {
295 accept_result = listener.accept() => {
296 match accept_result {
297 Ok((stream, addr)) => {
298 let workers = workers_clone.clone();
299 let generation = generation_clone.clone();
300 let config = config_clone.clone();
301 tokio::spawn(async move {
302 if let Err(e) = Self::handle_worker_connection(
303 stream, addr, workers, generation, config
304 ).await {
305 tracing::warn!("Worker connection error: {}", e);
306 }
307 });
308 }
309 Err(e) => {
310 tracing::error!("Accept error: {}", e);
311 }
312 }
313 }
314 _ = shutdown_rx.recv() => {
315 tracing::info!("Coordinator shutting down");
316 break;
317 }
318 }
319 }
320 });
321
322 let workers_checker = workers.clone();
324 let generation_checker = generation.clone();
325 let running_checker = running.clone();
326 let session_timeout = config.session_timeout;
327 let mut shutdown_rx2 = shutdown_tx.subscribe();
328
329 tokio::spawn(async move {
330 let mut interval = tokio::time::interval(Duration::from_secs(1));
331 loop {
332 tokio::select! {
333 _ = interval.tick() => {
334 if !running_checker.load(Ordering::Relaxed) {
335 break;
336 }
337 Self::check_worker_health(
338 &workers_checker,
339 &generation_checker,
340 session_timeout,
341 ).await;
342 }
343 _ = shutdown_rx2.recv() => {
344 break;
345 }
346 }
347 }
348 });
349
350 Ok(coordinator)
351 }
352
353 pub fn join_address(&self) -> SocketAddr {
355 self.join_addr
356 }
357
358 pub fn generation(&self) -> u64 {
360 self.generation.load(Ordering::Relaxed)
361 }
362
363 pub async fn worker_count(&self) -> usize {
365 self.workers.read().await.len()
366 }
367
368 pub async fn get_assignments(&self) -> HashMap<String, Vec<u32>> {
370 self.workers
371 .read()
372 .await
373 .iter()
374 .map(|(id, state)| (id.clone(), state.assignments.clone()))
375 .collect()
376 }
377
378 pub fn stop(&self) {
380 self.running.store(false, Ordering::Relaxed);
381 let _ = self.shutdown_tx.send(());
382 }
383
384 async fn handle_worker_connection(
386 stream: TcpStream,
387 _addr: SocketAddr,
388 workers: Arc<RwLock<HashMap<String, WorkerState>>>,
389 generation: Arc<AtomicU64>,
390 config: GroupConfig,
391 ) -> Result<()> {
392 use tokio::io::{AsyncReadExt, AsyncWriteExt};
395
396 let (mut reader, mut writer) = stream.into_split();
397 let mut buf = vec![0u8; 4096];
398
399 loop {
400 let n = reader.read(&mut buf).await?;
402 if n == 0 {
403 break; }
405
406 let msg_str = std::str::from_utf8(&buf[..n])
408 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
409
410 let response = Self::process_message(msg_str, &workers, &generation, &config).await;
411
412 let response_bytes = format!("{:?}", response);
414 writer.write_all(response_bytes.as_bytes()).await?;
415 }
416
417 Ok(())
418 }
419
420 async fn process_message(
422 msg: &str,
423 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
424 generation: &Arc<AtomicU64>,
425 config: &GroupConfig,
426 ) -> CoordinatorResponse {
427 if msg.starts_with("JOIN:") {
429 let worker_id = msg
430 .strip_prefix("JOIN:")
431 .unwrap_or("unknown")
432 .trim()
433 .to_string();
434 Self::handle_join(worker_id, workers, generation, config).await
435 } else if msg.starts_with("LEAVE:") {
436 let worker_id = msg
437 .strip_prefix("LEAVE:")
438 .unwrap_or("unknown")
439 .trim()
440 .to_string();
441 Self::handle_leave(worker_id, workers, generation, config).await
442 } else if msg.starts_with("HEARTBEAT:") {
443 let worker_id = msg
444 .strip_prefix("HEARTBEAT:")
445 .unwrap_or("unknown")
446 .trim()
447 .to_string();
448 Self::handle_heartbeat(worker_id, workers, generation).await
449 } else if msg.starts_with("GET_ASSIGNMENTS:") {
450 let worker_id = msg
451 .strip_prefix("GET_ASSIGNMENTS:")
452 .unwrap_or("unknown")
453 .trim()
454 .to_string();
455 Self::handle_get_assignments(worker_id, workers, generation).await
456 } else {
457 CoordinatorResponse::Error {
458 message: format!("Unknown message: {}", msg),
459 }
460 }
461 }
462
463 async fn handle_join(
465 worker_id: String,
466 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
467 generation: &Arc<AtomicU64>,
468 config: &GroupConfig,
469 ) -> CoordinatorResponse {
470 let new_gen = generation.fetch_add(1, Ordering::SeqCst) + 1;
471
472 {
473 let mut workers_lock = workers.write().await;
474 workers_lock.insert(
475 worker_id.clone(),
476 WorkerState {
477 worker_id: worker_id.clone(),
478 last_heartbeat: Instant::now(),
479 assignments: Vec::new(),
480 generation: new_gen,
481 },
482 );
483 }
484
485 let _assignments = Self::rebalance(workers, config).await;
487
488 let worker_assignments = {
490 let workers_lock = workers.read().await;
491 workers_lock
492 .get(&worker_id)
493 .map(|w| w.assignments.clone())
494 .unwrap_or_default()
495 };
496
497 CoordinatorResponse::JoinAccepted {
498 worker_id,
499 generation: new_gen,
500 assignments: worker_assignments,
501 }
502 }
503
504 async fn handle_leave(
506 worker_id: String,
507 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
508 generation: &Arc<AtomicU64>,
509 config: &GroupConfig,
510 ) -> CoordinatorResponse {
511 {
512 let mut workers_lock = workers.write().await;
513 workers_lock.remove(&worker_id);
514 }
515
516 generation.fetch_add(1, Ordering::SeqCst);
517
518 Self::rebalance(workers, config).await;
520
521 CoordinatorResponse::LeaveAcknowledged
522 }
523
524 async fn handle_heartbeat(
526 worker_id: String,
527 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
528 generation: &Arc<AtomicU64>,
529 ) -> CoordinatorResponse {
530 let current_gen = generation.load(Ordering::Relaxed);
531
532 let mut workers_lock = workers.write().await;
533 if let Some(worker) = workers_lock.get_mut(&worker_id) {
534 worker.last_heartbeat = Instant::now();
535 CoordinatorResponse::HeartbeatAck {
536 generation: current_gen,
537 }
538 } else {
539 CoordinatorResponse::Error {
540 message: "Worker not found".to_string(),
541 }
542 }
543 }
544
545 async fn handle_get_assignments(
547 worker_id: String,
548 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
549 generation: &Arc<AtomicU64>,
550 ) -> CoordinatorResponse {
551 let current_gen = generation.load(Ordering::Relaxed);
552 let workers_lock = workers.read().await;
553
554 if let Some(worker) = workers_lock.get(&worker_id) {
555 CoordinatorResponse::Assignments {
556 generation: current_gen,
557 assignments: worker.assignments.clone(),
558 }
559 } else {
560 CoordinatorResponse::Error {
561 message: "Worker not found".to_string(),
562 }
563 }
564 }
565
566 async fn check_worker_health(
568 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
569 generation: &Arc<AtomicU64>,
570 session_timeout: Duration,
571 ) {
572 let now = Instant::now();
573 let mut dead_workers = Vec::new();
574
575 {
576 let workers_lock = workers.read().await;
577 for (id, state) in workers_lock.iter() {
578 if now.duration_since(state.last_heartbeat) > session_timeout {
579 dead_workers.push(id.clone());
580 }
581 }
582 }
583
584 if !dead_workers.is_empty() {
585 let mut workers_lock = workers.write().await;
586 for id in dead_workers {
587 tracing::warn!("Worker {} timed out, removing from group", id);
588 workers_lock.remove(&id);
589 }
590 generation.fetch_add(1, Ordering::SeqCst);
591 }
593 }
594
595 async fn rebalance(
597 workers: &Arc<RwLock<HashMap<String, WorkerState>>>,
598 config: &GroupConfig,
599 ) -> HashMap<String, Vec<u32>> {
600 let mut workers_lock = workers.write().await;
601 let worker_ids: Vec<String> = workers_lock.keys().cloned().collect();
602
603 if worker_ids.is_empty() {
604 return HashMap::new();
605 }
606
607 let assignments = match config.assignment_strategy {
608 AssignmentStrategy::RoundRobin => Self::assign_round_robin(&config.topics, &worker_ids),
609 AssignmentStrategy::Range => Self::assign_range(&config.topics, &worker_ids),
610 AssignmentStrategy::Sticky => {
611 let existing: HashMap<String, Vec<u32>> = workers_lock
613 .iter()
614 .map(|(id, state)| (id.clone(), state.assignments.clone()))
615 .collect();
616 Self::assign_sticky(&config.topics, &worker_ids, &existing)
617 },
618 };
619
620 for (worker_id, topics) in &assignments {
622 if let Some(worker) = workers_lock.get_mut(worker_id) {
623 worker.assignments = topics.clone();
624 }
625 }
626
627 assignments
628 }
629
630 fn assign_round_robin(topics: &[u32], workers: &[String]) -> HashMap<String, Vec<u32>> {
632 let mut assignments: HashMap<String, Vec<u32>> =
633 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
634
635 for (i, topic) in topics.iter().enumerate() {
636 let worker = &workers[i % workers.len()];
637 if let Some(topics) = assignments.get_mut(worker) {
638 topics.push(*topic);
639 }
640 }
641
642 assignments
643 }
644
645 fn assign_range(topics: &[u32], workers: &[String]) -> HashMap<String, Vec<u32>> {
647 let mut assignments: HashMap<String, Vec<u32>> =
648 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
649
650 let topics_per_worker = topics.len() / workers.len();
651 let remainder = topics.len() % workers.len();
652
653 let mut topic_idx = 0;
654 for (worker_idx, worker) in workers.iter().enumerate() {
655 let extra = if worker_idx < remainder { 1 } else { 0 };
656 let count = topics_per_worker + extra;
657
658 if let Some(worker_topics) = assignments.get_mut(worker) {
659 for _ in 0..count {
660 if topic_idx < topics.len() {
661 worker_topics.push(topics[topic_idx]);
662 topic_idx += 1;
663 }
664 }
665 }
666 }
667
668 assignments
669 }
670
671 fn assign_sticky(
673 topics: &[u32],
674 workers: &[String],
675 existing: &HashMap<String, Vec<u32>>,
676 ) -> HashMap<String, Vec<u32>> {
677 let mut assignments: HashMap<String, Vec<u32>> =
678 workers.iter().map(|w| (w.clone(), Vec::new())).collect();
679
680 let topic_set: HashSet<u32> = topics.iter().copied().collect();
681 let mut assigned: HashSet<u32> = HashSet::new();
682
683 for (worker, old_topics) in existing {
685 if assignments.contains_key(worker) {
686 for topic in old_topics {
687 if topic_set.contains(topic) && !assigned.contains(topic) {
688 if let Some(worker_topics) = assignments.get_mut(worker) {
689 worker_topics.push(*topic);
690 assigned.insert(*topic);
691 }
692 }
693 }
694 }
695 }
696
697 let unassigned: Vec<u32> = topics
699 .iter()
700 .filter(|t| !assigned.contains(t))
701 .copied()
702 .collect();
703
704 for topic in unassigned {
706 let min_worker = assignments
707 .iter()
708 .min_by_key(|(_, topics)| topics.len())
709 .map(|(w, _)| w.clone());
710
711 if let Some(worker) = min_worker {
712 if let Some(worker_topics) = assignments.get_mut(&worker) {
713 worker_topics.push(topic);
714 }
715 }
716 }
717
718 assignments
719 }
720}
721
722impl Drop for GroupCoordinator {
723 fn drop(&mut self) {
724 self.stop();
725 }
726}
727
728#[derive(Debug, Clone)]
730pub struct WorkerConfig {
731 pub worker_id: String,
733 pub max_fetch_bytes: u32,
735 pub heartbeat_interval: Duration,
737 pub offset_dir: Option<PathBuf>,
739 pub start_position: SeekPosition,
741}
742
743impl WorkerConfig {
744 pub fn new(worker_id: impl Into<String>) -> Self {
746 Self {
747 worker_id: worker_id.into(),
748 max_fetch_bytes: 1_048_576,
749 heartbeat_interval: Duration::from_secs(3),
750 offset_dir: None,
751 start_position: SeekPosition::Beginning,
752 }
753 }
754
755 pub fn with_max_fetch_bytes(mut self, bytes: u32) -> Self {
757 self.max_fetch_bytes = bytes;
758 self
759 }
760
761 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
763 self.heartbeat_interval = interval;
764 self
765 }
766
767 pub fn with_offset_dir(mut self, dir: &Path) -> Self {
769 self.offset_dir = Some(dir.to_path_buf());
770 self
771 }
772
773 pub fn with_start_position(mut self, position: SeekPosition) -> Self {
775 self.start_position = position;
776 self
777 }
778}
779
780pub struct GroupedConsumer {
782 config: WorkerConfig,
783 server_addr: String,
784 coordinator_addr: SocketAddr,
785 assignments: Vec<u32>,
786 generation: u64,
787 consumers: HashMap<u32, StandaloneConsumer>,
788 #[allow(dead_code)]
789 offset_store: Arc<dyn OffsetStore>,
790 running: bool,
791}
792
793impl GroupedConsumer {
794 pub async fn join(
796 server_addr: &str,
797 coordinator_addr: SocketAddr,
798 config: WorkerConfig,
799 ) -> Result<Self> {
800 let mut stream = TcpStream::connect(coordinator_addr)
802 .await
803 .map_err(ClientError::ConnectionFailed)?;
804
805 use tokio::io::{AsyncReadExt, AsyncWriteExt};
806
807 let join_msg = format!("JOIN:{}", config.worker_id);
809 stream.write_all(join_msg.as_bytes()).await?;
810
811 let mut buf = vec![0u8; 4096];
813 let n = stream.read(&mut buf).await?;
814 let response = std::str::from_utf8(&buf[..n])
815 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
816
817 let (generation, assignments) = Self::parse_join_response(response)?;
819
820 let offset_store: Arc<dyn OffsetStore> = if let Some(ref dir) = config.offset_dir {
822 Arc::new(LockFileOffsetStore::open(dir, &config.worker_id)?)
823 } else {
824 Arc::new(MemoryOffsetStore::new())
825 };
826
827 let mut consumers = HashMap::new();
829 for topic_id in &assignments {
830 let standalone_config = StandaloneConfig::new(&config.worker_id, *topic_id)
831 .with_max_fetch_bytes(config.max_fetch_bytes)
832 .with_start_position(config.start_position);
833
834 if let Some(ref dir) = config.offset_dir {
835 let consumer = StandaloneConsumer::connect(
836 server_addr,
837 standalone_config.with_offset_dir(dir),
838 )
839 .await?;
840 consumers.insert(*topic_id, consumer);
841 } else {
842 let consumer = StandaloneConsumer::connect(server_addr, standalone_config).await?;
843 consumers.insert(*topic_id, consumer);
844 }
845 }
846
847 Ok(Self {
848 config,
849 server_addr: server_addr.to_string(),
850 coordinator_addr,
851 assignments,
852 generation,
853 consumers,
854 offset_store,
855 running: true,
856 })
857 }
858
859 pub fn assignments(&self) -> &[u32] {
861 &self.assignments
862 }
863
864 pub fn generation(&self) -> u64 {
866 self.generation
867 }
868
869 pub async fn poll(&mut self, topic_id: u32) -> Result<Option<PollResult>> {
871 if let Some(consumer) = self.consumers.get_mut(&topic_id) {
872 consumer.poll().await
873 } else {
874 Err(ClientError::InvalidResponse(format!(
875 "Topic {} not assigned to this worker",
876 topic_id
877 )))
878 }
879 }
880
881 pub async fn commit(&mut self, topic_id: u32) -> Result<()> {
883 if let Some(consumer) = self.consumers.get_mut(&topic_id) {
884 consumer.commit().await
885 } else {
886 Err(ClientError::InvalidResponse(format!(
887 "Topic {} not assigned to this worker",
888 topic_id
889 )))
890 }
891 }
892
893 pub async fn commit_all(&mut self) -> Result<()> {
895 for consumer in self.consumers.values_mut() {
896 consumer.commit().await?;
897 }
898 Ok(())
899 }
900
901 pub async fn heartbeat(&mut self) -> Result<u64> {
903 let mut stream = TcpStream::connect(self.coordinator_addr)
904 .await
905 .map_err(ClientError::ConnectionFailed)?;
906
907 use tokio::io::{AsyncReadExt, AsyncWriteExt};
908
909 let msg = format!("HEARTBEAT:{}", self.config.worker_id);
910 stream.write_all(msg.as_bytes()).await?;
911
912 let mut buf = vec![0u8; 1024];
913 let n = stream.read(&mut buf).await?;
914 let response = std::str::from_utf8(&buf[..n])
915 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
916
917 let new_gen = Self::parse_heartbeat_response(response)?;
919
920 if new_gen > self.generation {
921 self.refresh_assignments().await?;
923 }
924
925 Ok(self.generation)
926 }
927
928 async fn refresh_assignments(&mut self) -> Result<()> {
930 let mut stream = TcpStream::connect(self.coordinator_addr)
931 .await
932 .map_err(ClientError::ConnectionFailed)?;
933
934 use tokio::io::{AsyncReadExt, AsyncWriteExt};
935
936 let msg = format!("GET_ASSIGNMENTS:{}", self.config.worker_id);
937 stream.write_all(msg.as_bytes()).await?;
938
939 let mut buf = vec![0u8; 4096];
940 let n = stream.read(&mut buf).await?;
941 let response = std::str::from_utf8(&buf[..n])
942 .map_err(|e| ClientError::ProtocolError(e.to_string()))?;
943
944 let (generation, new_assignments) = Self::parse_assignments_response(response)?;
945
946 let old_set: HashSet<u32> = self.assignments.iter().copied().collect();
948 let new_set: HashSet<u32> = new_assignments.iter().copied().collect();
949
950 for topic_id in old_set.difference(&new_set) {
952 if let Some(consumer) = self.consumers.remove(topic_id) {
953 let _ = consumer.close().await;
954 }
955 }
956
957 for topic_id in new_set.difference(&old_set) {
959 let standalone_config = StandaloneConfig::new(&self.config.worker_id, *topic_id)
960 .with_max_fetch_bytes(self.config.max_fetch_bytes)
961 .with_start_position(self.config.start_position);
962
963 let consumer = if let Some(ref dir) = self.config.offset_dir {
964 StandaloneConsumer::connect(
965 &self.server_addr,
966 standalone_config.with_offset_dir(dir),
967 )
968 .await?
969 } else {
970 StandaloneConsumer::connect(&self.server_addr, standalone_config).await?
971 };
972
973 self.consumers.insert(*topic_id, consumer);
974 }
975
976 self.assignments = new_assignments;
977 self.generation = generation;
978
979 Ok(())
980 }
981
982 pub async fn leave(mut self) -> Result<()> {
984 self.commit_all().await?;
986
987 let mut stream = TcpStream::connect(self.coordinator_addr)
989 .await
990 .map_err(ClientError::ConnectionFailed)?;
991
992 use tokio::io::AsyncWriteExt;
993
994 let msg = format!("LEAVE:{}", self.config.worker_id);
995 stream.write_all(msg.as_bytes()).await?;
996
997 for (_, consumer) in self.consumers.drain() {
999 let _ = consumer.close().await;
1000 }
1001
1002 self.running = false;
1003 Ok(())
1004 }
1005
1006 fn parse_join_response(response: &str) -> Result<(u64, Vec<u32>)> {
1008 let generation = response
1012 .find("generation: ")
1013 .and_then(|i| {
1014 let start = i + 12;
1015 let end = response[start..].find(',')?;
1016 response[start..start + end].parse().ok()
1017 })
1018 .unwrap_or(0);
1019
1020 let assignments = response
1021 .find("assignments: [")
1022 .map(|i| {
1023 let start = i + 14;
1024 let end = response[start..].find(']').unwrap_or(0);
1025 response[start..start + end]
1026 .split(',')
1027 .filter_map(|s| s.trim().parse().ok())
1028 .collect()
1029 })
1030 .unwrap_or_default();
1031
1032 Ok((generation, assignments))
1033 }
1034
1035 fn parse_heartbeat_response(response: &str) -> Result<u64> {
1037 response
1038 .find("generation: ")
1039 .and_then(|i| {
1040 let start = i + 12;
1041 let end = response[start..]
1042 .find([',', ' ', '}'])
1043 .unwrap_or(response.len() - start);
1044 response[start..start + end].parse().ok()
1045 })
1046 .ok_or_else(|| ClientError::ProtocolError("Invalid heartbeat response".to_string()))
1047 }
1048
1049 fn parse_assignments_response(response: &str) -> Result<(u64, Vec<u32>)> {
1051 Self::parse_join_response(response)
1052 }
1053}
1054
1055impl std::fmt::Debug for GroupedConsumer {
1056 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1057 f.debug_struct("GroupedConsumer")
1058 .field("worker_id", &self.config.worker_id)
1059 .field("generation", &self.generation)
1060 .field("assignments", &self.assignments)
1061 .field("running", &self.running)
1062 .finish()
1063 }
1064}
1065
1066#[cfg(test)]
1067#[allow(clippy::unwrap_used)]
1068mod tests {
1069 use super::*;
1070
1071 #[test]
1072 fn test_group_config_defaults() {
1073 let config = GroupConfig::new("test-group");
1074
1075 assert_eq!(config.group_id, "test-group");
1076 assert!(config.topics.is_empty());
1077 assert_eq!(config.assignment_strategy, AssignmentStrategy::RoundRobin);
1078 }
1079
1080 #[test]
1081 fn test_worker_config_defaults() {
1082 let config = WorkerConfig::new("worker-1");
1083
1084 assert_eq!(config.worker_id, "worker-1");
1085 assert_eq!(config.max_fetch_bytes, 1_048_576);
1086 }
1087
1088 #[test]
1089 fn test_round_robin_assignment() {
1090 let topics = vec![1, 2, 3, 4, 5, 6];
1091 let workers = vec!["w1".to_string(), "w2".to_string(), "w3".to_string()];
1092
1093 let assignments = GroupCoordinator::assign_round_robin(&topics, &workers);
1094
1095 assert_eq!(assignments.get("w1"), Some(&vec![1, 4]));
1096 assert_eq!(assignments.get("w2"), Some(&vec![2, 5]));
1097 assert_eq!(assignments.get("w3"), Some(&vec![3, 6]));
1098 }
1099
1100 #[test]
1101 fn test_range_assignment() {
1102 let topics = vec![1, 2, 3, 4, 5, 6, 7];
1103 let workers = vec!["w1".to_string(), "w2".to_string(), "w3".to_string()];
1104
1105 let assignments = GroupCoordinator::assign_range(&topics, &workers);
1106
1107 assert_eq!(assignments.get("w1").map(|v| v.len()), Some(3));
1110 assert_eq!(assignments.get("w2").map(|v| v.len()), Some(2));
1111 assert_eq!(assignments.get("w3").map(|v| v.len()), Some(2));
1112 }
1113
1114 #[test]
1115 fn test_sticky_assignment_preserves_existing() {
1116 let topics = vec![1, 2, 3, 4];
1117 let workers = vec!["w1".to_string(), "w2".to_string()];
1118
1119 let mut existing = HashMap::new();
1120 existing.insert("w1".to_string(), vec![1, 2]);
1121 existing.insert("w2".to_string(), vec![3, 4]);
1122
1123 let assignments = GroupCoordinator::assign_sticky(&topics, &workers, &existing);
1124
1125 assert_eq!(assignments.get("w1"), Some(&vec![1, 2]));
1127 assert_eq!(assignments.get("w2"), Some(&vec![3, 4]));
1128 }
1129
1130 #[test]
1131 fn test_assignment_strategy_default() {
1132 assert_eq!(
1133 AssignmentStrategy::default(),
1134 AssignmentStrategy::RoundRobin
1135 );
1136 }
1137}