1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::{TcpListener, TcpStream};
5use tokio::sync::RwLock;
6
7use crate::consumer_groups::ConsumerGroupManager;
8use crate::metrics::KafkaMetrics;
9use crate::protocol::{KafkaProtocolHandler, KafkaRequest, KafkaRequestType, KafkaResponse};
10use crate::spec_registry::KafkaSpecRegistry;
11use crate::topics::Topic;
12use mockforge_core::config::KafkaConfig;
13use mockforge_core::Result;
14
15#[derive(Clone)]
57#[allow(dead_code)]
58pub struct KafkaMockBroker {
59 config: KafkaConfig,
61 pub topics: Arc<RwLock<HashMap<String, Topic>>>,
63 pub consumer_groups: Arc<RwLock<ConsumerGroupManager>>,
65 spec_registry: Arc<KafkaSpecRegistry>,
67 metrics: Arc<KafkaMetrics>,
69}
70
71impl KafkaMockBroker {
72 pub async fn new(config: KafkaConfig) -> Result<Self> {
98 let topics = Arc::new(RwLock::new(HashMap::new()));
99 let consumer_groups = Arc::new(RwLock::new(ConsumerGroupManager::new()));
100 let spec_registry = KafkaSpecRegistry::new(config.clone(), Arc::clone(&topics)).await?;
101 let metrics = Arc::new(KafkaMetrics::new());
102
103 Ok(Self {
104 config,
105 topics,
106 consumer_groups,
107 spec_registry: Arc::new(spec_registry),
108 metrics,
109 })
110 }
111
112 pub async fn start(&self) -> Result<()> {
140 let addr = format!("{}:{}", self.config.host, self.config.port);
141 let listener = TcpListener::bind(&addr).await?;
142
143 tracing::info!("Starting Kafka mock broker on {}", addr);
144
145 loop {
146 let (socket, _) = listener.accept().await?;
147 let broker = Arc::new(self.clone());
148
149 tokio::spawn(async move {
150 if let Err(e) = broker.handle_connection(socket).await {
151 tracing::error!("Error handling connection: {}", e);
152 }
153 });
154 }
155 }
156
157 async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> {
159 let protocol_handler = KafkaProtocolHandler::new();
160 self.metrics.record_connection();
161
162 let _guard = ConnectionGuard {
164 metrics: Arc::clone(&self.metrics),
165 };
166
167 loop {
168 let mut size_buf = [0u8; 4];
170 match tokio::time::timeout(
171 std::time::Duration::from_secs(30),
172 socket.read_exact(&mut size_buf),
173 )
174 .await
175 {
176 Ok(Ok(_)) => {
177 let message_size = i32::from_be_bytes(size_buf) as usize;
178
179 if message_size > 10 * 1024 * 1024 {
181 self.metrics.record_error();
183 tracing::warn!("Message size too large: {} bytes", message_size);
184 continue;
185 }
186
187 let mut message_buf = vec![0u8; message_size];
189 if let Err(e) = tokio::time::timeout(
190 std::time::Duration::from_secs(10),
191 socket.read_exact(&mut message_buf),
192 )
193 .await
194 {
195 self.metrics.record_error();
196 tracing::error!("Timeout reading message: {}", e);
197 break;
198 }
199
200 let request = match protocol_handler.parse_request(&message_buf) {
202 Ok(req) => req,
203 Err(e) => {
204 self.metrics.record_error();
205 tracing::error!("Failed to parse request: {}", e);
206 continue;
207 }
208 };
209
210 self.metrics.record_request(get_api_key_from_request(&request));
212
213 let start_time = std::time::Instant::now();
214
215 let response = match self.handle_request(request).await {
217 Ok(resp) => resp,
218 Err(e) => {
219 self.metrics.record_error();
220 tracing::error!("Failed to handle request: {}", e);
221 continue;
223 }
224 };
225
226 let latency = start_time.elapsed().as_micros() as u64;
227 self.metrics.record_request_latency(latency);
228 self.metrics.record_response();
229
230 let response_data = match protocol_handler.serialize_response(&response, 0) {
232 Ok(data) => data,
233 Err(e) => {
234 self.metrics.record_error();
235 tracing::error!("Failed to serialize response: {}", e);
236 continue;
237 }
238 };
239
240 let response_size = (response_data.len() as i32).to_be_bytes();
242 if let Err(e) = socket.write_all(&response_size).await {
243 self.metrics.record_error();
244 tracing::error!("Failed to write response size: {}", e);
245 break;
246 }
247
248 if let Err(e) = socket.write_all(&response_data).await {
250 self.metrics.record_error();
251 tracing::error!("Failed to write response: {}", e);
252 break;
253 }
254 }
255 Ok(Err(e)) => {
256 self.metrics.record_error();
257 tracing::error!("Failed to read message size: {}", e);
258 break;
259 }
260 Err(_) => {
261 continue;
263 }
264 }
265 }
266
267 Ok(())
268 }
269
270 async fn handle_request(&self, request: KafkaRequest) -> Result<KafkaResponse> {
272 match request.request_type {
273 KafkaRequestType::Metadata => self.handle_metadata().await,
274 KafkaRequestType::Produce => self.handle_produce().await,
275 KafkaRequestType::Fetch => self.handle_fetch().await,
276 KafkaRequestType::ListGroups => self.handle_list_groups().await,
277 KafkaRequestType::DescribeGroups => self.handle_describe_groups().await,
278 KafkaRequestType::ApiVersions => self.handle_api_versions().await,
279 KafkaRequestType::CreateTopics => self.handle_create_topics().await,
280 KafkaRequestType::DeleteTopics => self.handle_delete_topics().await,
281 KafkaRequestType::DescribeConfigs => self.handle_describe_configs().await,
282 }
283 }
284
285 async fn handle_metadata(&self) -> Result<KafkaResponse> {
286 Ok(KafkaResponse::Metadata)
288 }
289
290 async fn handle_produce(&self) -> Result<KafkaResponse> {
291 Ok(KafkaResponse::Produce)
293 }
294
295 async fn handle_fetch(&self) -> Result<KafkaResponse> {
296 Ok(KafkaResponse::Fetch)
298 }
299
300 async fn handle_api_versions(&self) -> Result<KafkaResponse> {
301 Ok(KafkaResponse::ApiVersions)
302 }
303
304 async fn handle_list_groups(&self) -> Result<KafkaResponse> {
305 Ok(KafkaResponse::ListGroups)
306 }
307
308 async fn handle_describe_groups(&self) -> Result<KafkaResponse> {
309 Ok(KafkaResponse::DescribeGroups)
310 }
311
312 async fn handle_create_topics(&self) -> Result<KafkaResponse> {
313 let topic_name = "default-topic".to_string();
316 let topic_config = crate::topics::TopicConfig::default();
317 let topic = crate::topics::Topic::new(topic_name.clone(), topic_config);
318
319 let mut topics = self.topics.write().await;
321 topics.insert(topic_name, topic);
322
323 Ok(KafkaResponse::CreateTopics)
324 }
325
326 async fn handle_delete_topics(&self) -> Result<KafkaResponse> {
327 Ok(KafkaResponse::DeleteTopics)
328 }
329
330 async fn handle_describe_configs(&self) -> Result<KafkaResponse> {
331 Ok(KafkaResponse::DescribeConfigs)
332 }
333
334 pub async fn test_commit_offsets(
336 &self,
337 group_id: &str,
338 offsets: std::collections::HashMap<(String, i32), i64>,
339 ) -> Result<()> {
340 let mut consumer_groups = self.consumer_groups.write().await;
341 consumer_groups
342 .commit_offsets(group_id, offsets)
343 .await
344 .map_err(|e| mockforge_core::Error::from(e.to_string()))
345 }
346
347 pub async fn test_get_committed_offsets(
349 &self,
350 group_id: &str,
351 ) -> std::collections::HashMap<(String, i32), i64> {
352 let consumer_groups = self.consumer_groups.read().await;
353 consumer_groups.get_committed_offsets(group_id)
354 }
355
356 pub async fn test_create_topic(&self, name: &str, config: crate::topics::TopicConfig) {
358 use crate::topics::Topic;
359 let topic = Topic::new(name.to_string(), config);
360 let mut topics = self.topics.write().await;
361 topics.insert(name.to_string(), topic);
362 }
363
364 pub async fn test_join_group(
366 &self,
367 group_id: &str,
368 member_id: &str,
369 client_id: &str,
370 ) -> Result<()> {
371 let mut consumer_groups = self.consumer_groups.write().await;
372 consumer_groups
373 .join_group(group_id, member_id, client_id)
374 .await
375 .map_err(|e| mockforge_core::Error::from(e.to_string()))?;
376 Ok(())
377 }
378
379 pub async fn test_sync_group(
381 &self,
382 group_id: &str,
383 assignments: Vec<crate::consumer_groups::PartitionAssignment>,
384 ) -> Result<()> {
385 let topics = self.topics.read().await;
386 let mut consumer_groups = self.consumer_groups.write().await;
387 consumer_groups
388 .sync_group(group_id, assignments, &topics)
389 .await
390 .map_err(|e| mockforge_core::Error::from(e.to_string()))?;
391 Ok(())
392 }
393
394 pub async fn test_get_assignments(
396 &self,
397 group_id: &str,
398 member_id: &str,
399 ) -> Vec<crate::consumer_groups::PartitionAssignment> {
400 let consumer_groups = self.consumer_groups.read().await;
401 if let Some(group) = consumer_groups.groups().get(group_id) {
402 if let Some(member) = group.members.get(member_id) {
403 return member.assignment.clone();
404 }
405 }
406 vec![]
407 }
408
409 pub async fn test_simulate_lag(&self, group_id: &str, topic: &str, lag: i64) -> Result<()> {
411 let topics = self.topics.read().await;
412 let mut consumer_groups = self.consumer_groups.write().await;
413 consumer_groups.simulate_lag(group_id, topic, lag, &topics).await;
414 Ok(())
415 }
416
417 pub async fn test_reset_offsets(&self, group_id: &str, topic: &str, to: &str) -> Result<()> {
419 let topics = self.topics.read().await;
420 let mut consumer_groups = self.consumer_groups.write().await;
421 consumer_groups.reset_offsets(group_id, topic, to, &topics).await;
422 Ok(())
423 }
424
425 pub fn metrics(&self) -> &Arc<KafkaMetrics> {
444 &self.metrics
445 }
446}
447
448#[derive(Debug, Clone)]
450pub struct Record {
451 pub key: Option<Vec<u8>>,
452 pub value: Vec<u8>,
453 pub headers: Vec<(String, Vec<u8>)>,
454}
455
456#[derive(Debug)]
458pub struct ProduceResponse {
459 pub partition: i32,
460 pub error_code: i16,
461 pub offset: i64,
462}
463
464#[derive(Debug)]
466pub struct FetchResponse {
467 pub partition: i32,
468 pub error_code: i16,
469 pub high_watermark: i64,
470 pub records: Vec<Record>,
471}
472
473struct ConnectionGuard {
475 metrics: Arc<KafkaMetrics>,
476}
477
478impl Drop for ConnectionGuard {
479 fn drop(&mut self) {
480 self.metrics.record_connection_closed();
481 }
482}
483
484fn get_api_key_from_request(request: &KafkaRequest) -> i16 {
486 request.api_key
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
496 fn test_record_creation_with_all_fields() {
497 let record = Record {
498 key: Some(b"test-key".to_vec()),
499 value: b"test-value".to_vec(),
500 headers: vec![("header1".to_string(), b"value1".to_vec())],
501 };
502
503 assert_eq!(record.key, Some(b"test-key".to_vec()));
504 assert_eq!(record.value, b"test-value".to_vec());
505 assert_eq!(record.headers.len(), 1);
506 assert_eq!(record.headers[0].0, "header1");
507 }
508
509 #[test]
510 fn test_record_creation_without_key() {
511 let record = Record {
512 key: None,
513 value: b"message body".to_vec(),
514 headers: vec![],
515 };
516
517 assert!(record.key.is_none());
518 assert_eq!(record.value, b"message body".to_vec());
519 assert!(record.headers.is_empty());
520 }
521
522 #[test]
523 fn test_record_with_multiple_headers() {
524 let record = Record {
525 key: Some(b"key".to_vec()),
526 value: b"value".to_vec(),
527 headers: vec![
528 ("content-type".to_string(), b"application/json".to_vec()),
529 ("correlation-id".to_string(), b"12345".to_vec()),
530 ("source".to_string(), b"test-producer".to_vec()),
531 ],
532 };
533
534 assert_eq!(record.headers.len(), 3);
535 assert_eq!(record.headers[0].0, "content-type");
536 assert_eq!(record.headers[1].0, "correlation-id");
537 assert_eq!(record.headers[2].0, "source");
538 }
539
540 #[test]
541 fn test_record_clone() {
542 let original = Record {
543 key: Some(b"key".to_vec()),
544 value: b"value".to_vec(),
545 headers: vec![("h".to_string(), b"v".to_vec())],
546 };
547
548 let cloned = original.clone();
549
550 assert_eq!(original.key, cloned.key);
551 assert_eq!(original.value, cloned.value);
552 assert_eq!(original.headers, cloned.headers);
553 }
554
555 #[test]
556 fn test_record_debug() {
557 let record = Record {
558 key: Some(b"key".to_vec()),
559 value: b"value".to_vec(),
560 headers: vec![],
561 };
562
563 let debug_str = format!("{:?}", record);
564 assert!(debug_str.contains("Record"));
565 assert!(debug_str.contains("key"));
566 assert!(debug_str.contains("value"));
567 }
568
569 #[test]
570 fn test_record_empty_value() {
571 let record = Record {
572 key: None,
573 value: vec![],
574 headers: vec![],
575 };
576
577 assert!(record.key.is_none());
578 assert!(record.value.is_empty());
579 assert!(record.headers.is_empty());
580 }
581
582 #[test]
583 fn test_record_binary_data() {
584 let binary_data: Vec<u8> = vec![0x00, 0xFF, 0x80, 0x7F, 0xFE];
586 let record = Record {
587 key: Some(binary_data.clone()),
588 value: binary_data.clone(),
589 headers: vec![],
590 };
591
592 assert_eq!(record.key.as_ref().unwrap().len(), 5);
593 assert_eq!(record.value.len(), 5);
594 assert_eq!(record.value[0], 0x00);
595 assert_eq!(record.value[1], 0xFF);
596 }
597
598 #[test]
601 fn test_produce_response_success() {
602 let response = ProduceResponse {
603 partition: 0,
604 error_code: 0,
605 offset: 100,
606 };
607
608 assert_eq!(response.partition, 0);
609 assert_eq!(response.error_code, 0);
610 assert_eq!(response.offset, 100);
611 }
612
613 #[test]
614 fn test_produce_response_with_error() {
615 let response = ProduceResponse {
616 partition: 1,
617 error_code: 3, offset: -1,
619 };
620
621 assert_eq!(response.partition, 1);
622 assert_eq!(response.error_code, 3);
623 assert_eq!(response.offset, -1);
624 }
625
626 #[test]
627 fn test_produce_response_high_offset() {
628 let response = ProduceResponse {
629 partition: 5,
630 error_code: 0,
631 offset: i64::MAX,
632 };
633
634 assert_eq!(response.partition, 5);
635 assert_eq!(response.offset, i64::MAX);
636 }
637
638 #[test]
639 fn test_produce_response_debug() {
640 let response = ProduceResponse {
641 partition: 0,
642 error_code: 0,
643 offset: 42,
644 };
645
646 let debug_str = format!("{:?}", response);
647 assert!(debug_str.contains("ProduceResponse"));
648 assert!(debug_str.contains("partition"));
649 assert!(debug_str.contains("error_code"));
650 assert!(debug_str.contains("offset"));
651 }
652
653 #[test]
656 fn test_fetch_response_empty() {
657 let response = FetchResponse {
658 partition: 0,
659 error_code: 0,
660 high_watermark: 100,
661 records: vec![],
662 };
663
664 assert_eq!(response.partition, 0);
665 assert_eq!(response.error_code, 0);
666 assert_eq!(response.high_watermark, 100);
667 assert!(response.records.is_empty());
668 }
669
670 #[test]
671 fn test_fetch_response_with_records() {
672 let records = vec![
673 Record {
674 key: Some(b"key1".to_vec()),
675 value: b"value1".to_vec(),
676 headers: vec![],
677 },
678 Record {
679 key: Some(b"key2".to_vec()),
680 value: b"value2".to_vec(),
681 headers: vec![],
682 },
683 ];
684
685 let response = FetchResponse {
686 partition: 0,
687 error_code: 0,
688 high_watermark: 50,
689 records,
690 };
691
692 assert_eq!(response.records.len(), 2);
693 assert_eq!(response.records[0].key, Some(b"key1".to_vec()));
694 assert_eq!(response.records[1].value, b"value2".to_vec());
695 }
696
697 #[test]
698 fn test_fetch_response_with_error() {
699 let response = FetchResponse {
700 partition: 0,
701 error_code: 1, high_watermark: 0,
703 records: vec![],
704 };
705
706 assert_eq!(response.error_code, 1);
707 assert_eq!(response.high_watermark, 0);
708 }
709
710 #[test]
711 fn test_fetch_response_debug() {
712 let response = FetchResponse {
713 partition: 2,
714 error_code: 0,
715 high_watermark: 1000,
716 records: vec![],
717 };
718
719 let debug_str = format!("{:?}", response);
720 assert!(debug_str.contains("FetchResponse"));
721 assert!(debug_str.contains("high_watermark"));
722 }
723
724 #[test]
727 fn test_get_api_key_produce() {
728 let request = KafkaRequest {
729 api_key: 0, api_version: 7,
731 correlation_id: 1,
732 client_id: "test-client".to_string(),
733 request_type: KafkaRequestType::Produce,
734 };
735
736 assert_eq!(get_api_key_from_request(&request), 0);
737 }
738
739 #[test]
740 fn test_get_api_key_fetch() {
741 let request = KafkaRequest {
742 api_key: 1, api_version: 11,
744 correlation_id: 2,
745 client_id: "consumer".to_string(),
746 request_type: KafkaRequestType::Fetch,
747 };
748
749 assert_eq!(get_api_key_from_request(&request), 1);
750 }
751
752 #[test]
753 fn test_get_api_key_metadata() {
754 let request = KafkaRequest {
755 api_key: 3, api_version: 9,
757 correlation_id: 3,
758 client_id: "admin".to_string(),
759 request_type: KafkaRequestType::Metadata,
760 };
761
762 assert_eq!(get_api_key_from_request(&request), 3);
763 }
764
765 #[test]
766 fn test_get_api_key_api_versions() {
767 let request = KafkaRequest {
768 api_key: 18, api_version: 3,
770 correlation_id: 100,
771 client_id: "client".to_string(),
772 request_type: KafkaRequestType::ApiVersions,
773 };
774
775 assert_eq!(get_api_key_from_request(&request), 18);
776 }
777
778 #[test]
779 fn test_get_api_key_list_groups() {
780 let request = KafkaRequest {
781 api_key: 16, api_version: 4,
783 correlation_id: 5,
784 client_id: "admin-client".to_string(),
785 request_type: KafkaRequestType::ListGroups,
786 };
787
788 assert_eq!(get_api_key_from_request(&request), 16);
789 }
790
791 #[test]
792 fn test_get_api_key_create_topics() {
793 let request = KafkaRequest {
794 api_key: 19, api_version: 5,
796 correlation_id: 10,
797 client_id: "admin".to_string(),
798 request_type: KafkaRequestType::CreateTopics,
799 };
800
801 assert_eq!(get_api_key_from_request(&request), 19);
802 }
803
804 #[test]
807 fn test_kafka_request_fields() {
808 let request = KafkaRequest {
809 api_key: 0,
810 api_version: 8,
811 correlation_id: 12345,
812 client_id: "my-producer".to_string(),
813 request_type: KafkaRequestType::Produce,
814 };
815
816 assert_eq!(request.api_key, 0);
817 assert_eq!(request.api_version, 8);
818 assert_eq!(request.correlation_id, 12345);
819 assert_eq!(request.client_id, "my-producer");
820 }
821
822 #[test]
823 fn test_kafka_request_empty_client_id() {
824 let request = KafkaRequest {
825 api_key: 3,
826 api_version: 9,
827 correlation_id: 1,
828 client_id: String::new(),
829 request_type: KafkaRequestType::Metadata,
830 };
831
832 assert!(request.client_id.is_empty());
833 }
834
835 #[test]
836 fn test_kafka_request_max_correlation_id() {
837 let request = KafkaRequest {
838 api_key: 0,
839 api_version: 0,
840 correlation_id: i32::MAX,
841 client_id: "test".to_string(),
842 request_type: KafkaRequestType::Produce,
843 };
844
845 assert_eq!(request.correlation_id, i32::MAX);
846 }
847
848 #[test]
851 fn test_request_type_variants() {
852 let metadata = KafkaRequestType::Metadata;
853 let produce = KafkaRequestType::Produce;
854 let fetch = KafkaRequestType::Fetch;
855 let list_groups = KafkaRequestType::ListGroups;
856 let describe_groups = KafkaRequestType::DescribeGroups;
857 let api_versions = KafkaRequestType::ApiVersions;
858 let create_topics = KafkaRequestType::CreateTopics;
859 let delete_topics = KafkaRequestType::DeleteTopics;
860 let describe_configs = KafkaRequestType::DescribeConfigs;
861
862 assert!(matches!(metadata, KafkaRequestType::Metadata));
864 assert!(matches!(produce, KafkaRequestType::Produce));
865 assert!(matches!(fetch, KafkaRequestType::Fetch));
866 assert!(matches!(list_groups, KafkaRequestType::ListGroups));
867 assert!(matches!(describe_groups, KafkaRequestType::DescribeGroups));
868 assert!(matches!(api_versions, KafkaRequestType::ApiVersions));
869 assert!(matches!(create_topics, KafkaRequestType::CreateTopics));
870 assert!(matches!(delete_topics, KafkaRequestType::DeleteTopics));
871 assert!(matches!(describe_configs, KafkaRequestType::DescribeConfigs));
872 }
873
874 #[test]
877 fn test_message_size_limit_constant() {
878 let max_message_size: usize = 10 * 1024 * 1024;
880 assert_eq!(max_message_size, 10_485_760);
881 }
882
883 #[test]
884 fn test_message_size_under_limit() {
885 let message_size: usize = 1024 * 1024; let limit: usize = 10 * 1024 * 1024; assert!(message_size <= limit);
888 }
889
890 #[test]
891 fn test_message_size_over_limit() {
892 let message_size: usize = 11 * 1024 * 1024; let limit: usize = 10 * 1024 * 1024; assert!(message_size > limit);
895 }
896
897 #[test]
900 fn test_response_size_serialization() {
901 let response_len: i32 = 1000;
902 let size_bytes = response_len.to_be_bytes();
903
904 assert_eq!(size_bytes.len(), 4);
905 assert_eq!(i32::from_be_bytes(size_bytes), 1000);
906 }
907
908 #[test]
909 fn test_response_size_max_value() {
910 let response_len: i32 = i32::MAX;
911 let size_bytes = response_len.to_be_bytes();
912
913 assert_eq!(size_bytes.len(), 4);
914 assert_eq!(i32::from_be_bytes(size_bytes), i32::MAX);
915 }
916
917 #[test]
918 fn test_response_size_zero() {
919 let response_len: i32 = 0;
920 let size_bytes = response_len.to_be_bytes();
921
922 assert_eq!(size_bytes, [0, 0, 0, 0]);
923 }
924}