use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::RwLock;
use crate::consumer_groups::ConsumerGroupManager;
use crate::fixture_executor::FixtureRuntime;
use crate::metrics::KafkaMetrics;
use crate::partitions::KafkaMessage;
use crate::protocol::{KafkaProtocolHandler, KafkaRequest, KafkaRequestType, KafkaResponse};
use crate::spec_registry::KafkaSpecRegistry;
use crate::topics::Topic;
use mockforge_core::config::KafkaConfig;
use mockforge_core::Result;
use std::sync::OnceLock;
pub(crate) fn resolve_advertised_endpoint(config: &KafkaConfig) -> (String, i32) {
let host = config.advertised_host.clone().unwrap_or_else(|| config.host.clone());
let port = config.advertised_port.map(|p| p as i32).unwrap_or(config.port as i32);
(host, port)
}
#[derive(Clone)]
#[allow(dead_code)]
pub struct KafkaMockBroker {
config: KafkaConfig,
pub topics: Arc<RwLock<HashMap<String, Topic>>>,
pub consumer_groups: Arc<RwLock<ConsumerGroupManager>>,
pub group_coordinator: Arc<RwLock<crate::group_coordinator::GroupCoordinator>>,
spec_registry: Arc<KafkaSpecRegistry>,
fixture_runtime: Arc<OnceLock<Arc<FixtureRuntime>>>,
metrics: Arc<KafkaMetrics>,
}
impl KafkaMockBroker {
pub async fn new(config: KafkaConfig) -> Result<Self> {
let topics = Arc::new(RwLock::new(HashMap::new()));
let consumer_groups = Arc::new(RwLock::new(ConsumerGroupManager::new()));
let spec_registry = KafkaSpecRegistry::new(config.clone(), Arc::clone(&topics)).await?;
let metrics = Arc::new(KafkaMetrics::new());
Ok(Self {
config,
topics,
consumer_groups,
group_coordinator: Arc::new(RwLock::new(
crate::group_coordinator::GroupCoordinator::new(),
)),
spec_registry: Arc::new(spec_registry),
fixture_runtime: Arc::new(OnceLock::new()),
metrics,
})
}
pub async fn start(&self) -> Result<()> {
let addr = format!("{}:{}", self.config.host, self.config.port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!("Starting Kafka mock broker on {}", addr);
self.install_fixture_runtime().await;
loop {
let (socket, _) = listener.accept().await?;
let broker = Arc::new(self.clone());
tokio::spawn(async move {
if let Err(e) = broker.handle_connection(socket).await {
tracing::error!("Error handling connection: {}", e);
}
});
}
}
pub async fn install_fixture_runtime(&self) {
if self.fixture_runtime.get().is_some() {
return;
}
let fixtures = self.spec_registry.all_fixtures().to_vec();
let relationships = self.spec_registry.relationships().to_vec();
let state_machines = self.spec_registry.state_machines().to_vec();
let scenarios = self.spec_registry.scenarios().to_vec();
let broker_arc = Arc::new(self.clone());
let runtime = crate::fixture_executor::install(
Arc::clone(&broker_arc),
&fixtures,
&state_machines,
&scenarios,
&relationships,
)
.await;
let _ = self.fixture_runtime.set(runtime);
}
pub fn fixture_runtime(&self) -> Option<Arc<FixtureRuntime>> {
self.fixture_runtime.get().cloned()
}
async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> {
let topics: Vec<crate::protocol::TopicMetadata> = {
let guard = self.topics.read().await;
guard
.iter()
.map(|(name, topic)| crate::protocol::TopicMetadata {
name: name.clone(),
partitions: (topic.partitions.len() as i32).max(1),
})
.collect()
};
let (advertised_host, advertised_port) = resolve_advertised_endpoint(&self.config);
let protocol_handler =
KafkaProtocolHandler::with_topology(advertised_host, advertised_port, topics);
self.metrics.record_connection();
let _guard = ConnectionGuard {
metrics: Arc::clone(&self.metrics),
};
loop {
let mut size_buf = [0u8; 4];
match tokio::time::timeout(
std::time::Duration::from_secs(30),
socket.read_exact(&mut size_buf),
)
.await
{
Ok(Ok(_)) => {
let message_size = i32::from_be_bytes(size_buf) as usize;
if message_size > 10 * 1024 * 1024 {
self.metrics.record_error();
tracing::warn!("Message size too large: {} bytes", message_size);
continue;
}
let mut message_buf = vec![0u8; message_size];
if let Err(e) = tokio::time::timeout(
std::time::Duration::from_secs(10),
socket.read_exact(&mut message_buf),
)
.await
{
self.metrics.record_error();
tracing::error!("Timeout reading message: {}", e);
break;
}
let request = match protocol_handler.parse_request(&message_buf) {
Ok(req) => req,
Err(e) => {
self.metrics.record_error();
tracing::error!("Failed to parse request: {}", e);
continue;
}
};
let correlation_id = request.correlation_id;
let request_api_version = request.api_version;
self.metrics.record_request(get_api_key_from_request(&request));
let start_time = std::time::Instant::now();
let response = match self.handle_request(&message_buf, request).await {
Ok(resp) => resp,
Err(e) => {
self.metrics.record_error();
tracing::error!("Failed to handle request: {}", e);
continue;
}
};
let latency = start_time.elapsed().as_micros() as u64;
self.metrics.record_request_latency(latency);
self.metrics.record_response();
let response_data = match protocol_handler.serialize_response(
&response,
correlation_id,
request_api_version,
) {
Ok(data) => data,
Err(e) => {
self.metrics.record_error();
tracing::error!("Failed to serialize response: {}", e);
continue;
}
};
let response_size = (response_data.len() as i32).to_be_bytes();
if let Err(e) = socket.write_all(&response_size).await {
self.metrics.record_error();
tracing::error!("Failed to write response size: {}", e);
break;
}
if let Err(e) = socket.write_all(&response_data).await {
self.metrics.record_error();
tracing::error!("Failed to write response: {}", e);
break;
}
}
Ok(Err(e)) => {
self.metrics.record_error();
tracing::error!("Failed to read message size: {}", e);
break;
}
Err(_) => {
continue;
}
}
}
Ok(())
}
async fn handle_request(
&self,
message_buf: &[u8],
request: KafkaRequest,
) -> Result<KafkaResponse> {
match request.request_type {
KafkaRequestType::Metadata => self.handle_metadata().await,
KafkaRequestType::Produce => self.handle_produce(message_buf, &request).await,
KafkaRequestType::Fetch => self.handle_fetch(message_buf, &request).await,
KafkaRequestType::ListOffsets => self.handle_list_offsets(message_buf, &request).await,
KafkaRequestType::FindCoordinator => {
self.handle_find_coordinator(message_buf, &request).await
}
KafkaRequestType::JoinGroup => self.handle_join_group(message_buf, &request).await,
KafkaRequestType::SyncGroup => self.handle_sync_group(message_buf, &request).await,
KafkaRequestType::Heartbeat => self.handle_heartbeat(message_buf, &request).await,
KafkaRequestType::LeaveGroup => self.handle_leave_group(message_buf, &request).await,
KafkaRequestType::OffsetCommit => {
self.handle_offset_commit(message_buf, &request).await
}
KafkaRequestType::OffsetFetch => self.handle_offset_fetch(message_buf, &request).await,
KafkaRequestType::ListGroups => self.handle_list_groups().await,
KafkaRequestType::DescribeGroups => self.handle_describe_groups().await,
KafkaRequestType::ApiVersions => self.handle_api_versions().await,
KafkaRequestType::CreateTopics => self.handle_create_topics().await,
KafkaRequestType::DeleteTopics => self.handle_delete_topics().await,
KafkaRequestType::DescribeConfigs => self.handle_describe_configs().await,
}
}
async fn handle_metadata(&self) -> Result<KafkaResponse> {
Ok(KafkaResponse::Metadata)
}
async fn handle_produce(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::produce_codec::{
parse_produce_v9, serialize_produce_v9_response, PartitionProduceResult,
TopicProduceResult,
};
use crate::produce_nonflex::{parse_produce_v3_v8, serialize_produce_v3_v8_response};
const ERR_UNKNOWN_TOPIC_OR_PARTITION: i16 = 3;
let version = request.api_version;
let is_flexible = version == 9;
let is_nonflex = (3..=8).contains(&version);
if !is_flexible && !is_nonflex {
let body = serialize_produce_v9_response(request.correlation_id, &[]);
tracing::warn!("rejecting Produce v{version} (supported: 3..=9)");
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("produce request body_offset past end of buffer")
})?;
let parsed = if is_flexible {
parse_produce_v9(body_slice).map_err(|e| {
mockforge_core::Error::internal(format!("failed to parse Produce v9: {e}"))
})?
} else {
parse_produce_v3_v8(body_slice).map_err(|e| {
mockforge_core::Error::internal(format!("failed to parse Produce v{version}: {e}"))
})?
};
let append_time_ms = chrono::Utc::now().timestamp_millis();
let mut topic_results = Vec::with_capacity(parsed.topics.len());
for topic_data in parsed.topics {
let mut partition_results = Vec::with_capacity(topic_data.partitions.len());
let mut accepted_for_relationships: Vec<KafkaMessage> = Vec::new();
let topic_name = topic_data.name.clone();
for part in topic_data.partitions {
let mut topics_guard = self.topics.write().await;
let topic_entry =
topics_guard.entry(topic_data.name.clone()).or_insert_with(|| {
Topic::new(topic_data.name.clone(), crate::topics::TopicConfig::default())
});
if part.records.is_empty() {
partition_results.push(PartitionProduceResult {
partition_index: part.partition_index,
error_code: 0,
base_offset: -1,
log_append_time_ms: append_time_ms,
log_start_offset: 0,
});
continue;
}
if topic_entry.get_partition(part.partition_index).is_none() {
partition_results.push(PartitionProduceResult {
partition_index: part.partition_index,
error_code: ERR_UNKNOWN_TOPIC_OR_PARTITION,
base_offset: -1,
log_append_time_ms: -1,
log_start_offset: 0,
});
continue;
}
let mut base_offset: i64 = -1;
for (i, rec) in part.records.into_iter().enumerate() {
let msg = KafkaMessage {
offset: 0, timestamp: rec.timestamp_ms,
key: rec.key,
value: rec.value,
headers: rec.headers,
};
accepted_for_relationships.push(msg.clone());
let offset = topic_entry.produce(part.partition_index, msg).await?;
if i == 0 {
base_offset = offset;
}
}
partition_results.push(PartitionProduceResult {
partition_index: part.partition_index,
error_code: 0,
base_offset,
log_append_time_ms: append_time_ms,
log_start_offset: 0,
});
}
topic_results.push(TopicProduceResult {
name: topic_data.name,
partitions: partition_results,
});
if !accepted_for_relationships.is_empty() {
if let Some(runtime) = self.fixture_runtime() {
let broker_arc = Arc::new(self.clone());
crate::fixture_executor::on_produced_records(
&broker_arc,
&runtime,
&topic_name,
&accepted_for_relationships,
)
.await;
}
}
}
let body = if is_flexible {
serialize_produce_v9_response(request.correlation_id, &topic_results)
} else {
serialize_produce_v3_v8_response(request.correlation_id, version, &topic_results)
};
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_fetch(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::fetch_codec::{
parse_fetch_v12, serialize_fetch_v12_response, serialize_record_batch_v2,
FetchPartitionResponse, FetchTopicResponse,
};
use crate::fetch_nonflex::{parse_fetch_v4_v11, serialize_fetch_v4_v11_response};
const ERR_UNKNOWN_TOPIC_OR_PARTITION: i16 = 3;
const ERR_OFFSET_OUT_OF_RANGE: i16 = 1;
let version = request.api_version;
let is_flexible = version == 12;
let is_nonflex = (4..=11).contains(&version);
if !is_flexible && !is_nonflex {
let body = serialize_fetch_v12_response(request.correlation_id, 0, &[]);
tracing::warn!("rejecting Fetch v{version} (supported: 4..=12)");
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("fetch request body_offset past end of buffer")
})?;
let parsed = if is_flexible {
parse_fetch_v12(body_slice).map_err(|e| {
mockforge_core::Error::internal(format!("failed to parse Fetch v12: {e}"))
})?
} else {
parse_fetch_v4_v11(version, body_slice).map_err(|e| {
mockforge_core::Error::internal(format!("failed to parse Fetch v{version}: {e}"))
})?
};
let topics_guard = self.topics.read().await;
let mut topic_responses = Vec::with_capacity(parsed.topics.len());
for t in &parsed.topics {
let mut partition_responses = Vec::with_capacity(t.partitions.len());
let topic = topics_guard.get(&t.topic);
for p in &t.partitions {
let Some(topic) = topic else {
partition_responses.push(FetchPartitionResponse {
partition_index: p.partition_index,
error_code: ERR_UNKNOWN_TOPIC_OR_PARTITION,
high_watermark: -1,
log_start_offset: -1,
records: Vec::new(),
});
continue;
};
let Some(part) = topic.get_partition(p.partition_index) else {
partition_responses.push(FetchPartitionResponse {
partition_index: p.partition_index,
error_code: ERR_UNKNOWN_TOPIC_OR_PARTITION,
high_watermark: -1,
log_start_offset: -1,
records: Vec::new(),
});
continue;
};
if p.fetch_offset > part.high_watermark {
partition_responses.push(FetchPartitionResponse {
partition_index: p.partition_index,
error_code: ERR_OFFSET_OUT_OF_RANGE,
high_watermark: part.high_watermark,
log_start_offset: part.log_start_offset,
records: Vec::new(),
});
continue;
}
let max_bytes = p.partition_max_bytes.max(0) as usize;
let mut selected: Vec<&crate::partitions::KafkaMessage> = Vec::new();
let mut estimated_size: usize = 0;
for msg in &part.messages {
if msg.offset < p.fetch_offset {
continue;
}
let headers_size: usize =
msg.headers.iter().map(|(k, v)| k.len() + v.len() + 8).sum();
let record_size = msg.key.as_ref().map_or(0, |k| k.len())
+ msg.value.len()
+ headers_size
+ 16;
if !selected.is_empty() && estimated_size + record_size > max_bytes {
break;
}
estimated_size += record_size;
selected.push(msg);
}
let records_blob = if selected.is_empty() {
Vec::new()
} else {
serialize_record_batch_v2(&selected)
};
partition_responses.push(FetchPartitionResponse {
partition_index: p.partition_index,
error_code: 0,
high_watermark: part.high_watermark,
log_start_offset: part.log_start_offset,
records: records_blob,
});
}
topic_responses.push(FetchTopicResponse {
topic: t.topic.clone(),
partitions: partition_responses,
});
}
let body = if is_flexible {
serialize_fetch_v12_response(
request.correlation_id,
parsed.session_id,
&topic_responses,
)
} else {
serialize_fetch_v4_v11_response(
request.correlation_id,
version,
parsed.session_id,
&topic_responses,
)
};
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_list_offsets(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::listoffsets_codec::{
parse_listoffsets_v7, serialize_listoffsets_v7_response, ListOffsetsPartitionResponse,
ListOffsetsTopicResponse,
};
const ERR_UNKNOWN_TOPIC_OR_PARTITION: i16 = 3;
const TS_EARLIEST: i64 = -2;
const TS_LATEST: i64 = -1;
if request.api_version != 7 {
let body = serialize_listoffsets_v7_response(request.correlation_id, &[]);
tracing::warn!("rejecting ListOffsets v{} (only v7 supported)", request.api_version);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("listoffsets body_offset past end of buffer")
})?;
let parsed = parse_listoffsets_v7(body_slice).map_err(|e| {
mockforge_core::Error::internal(format!("failed to parse ListOffsets v7: {e}"))
})?;
let topics_guard = self.topics.read().await;
let mut topic_responses = Vec::with_capacity(parsed.topics.len());
for t in &parsed.topics {
let mut partition_responses = Vec::with_capacity(t.partitions.len());
let topic = topics_guard.get(&t.topic);
for p in &t.partitions {
let Some(topic) = topic else {
partition_responses.push(ListOffsetsPartitionResponse {
partition_index: p.partition_index,
error_code: ERR_UNKNOWN_TOPIC_OR_PARTITION,
timestamp: -1,
offset: -1,
});
continue;
};
let Some(part) = topic.get_partition(p.partition_index) else {
partition_responses.push(ListOffsetsPartitionResponse {
partition_index: p.partition_index,
error_code: ERR_UNKNOWN_TOPIC_OR_PARTITION,
timestamp: -1,
offset: -1,
});
continue;
};
let (offset, ts) = match p.timestamp {
TS_EARLIEST => (part.log_start_offset, -1),
TS_LATEST => (part.high_watermark, -1),
needle => {
let found = part.messages.iter().find(|m| m.timestamp >= needle);
match found {
Some(m) => (m.offset, m.timestamp),
None => (part.high_watermark, -1),
}
}
};
partition_responses.push(ListOffsetsPartitionResponse {
partition_index: p.partition_index,
error_code: 0,
timestamp: ts,
offset,
});
}
topic_responses.push(ListOffsetsTopicResponse {
topic: t.topic.clone(),
partitions: partition_responses,
});
}
let body = serialize_listoffsets_v7_response(request.correlation_id, &topic_responses);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_find_coordinator(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{
parse_find_coordinator_v2, serialize_find_coordinator_v2_response,
};
if request.api_version != 2 {
let body = serialize_find_coordinator_v2_response(
request.correlation_id,
&self.config.host,
self.config.port as i32,
);
tracing::warn!(
"rejecting FindCoordinator v{} (only v2 supported)",
request.api_version
);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("find_coordinator body_offset past buffer end")
})?;
let _parsed = parse_find_coordinator_v2(body_slice).map_err(|e| {
mockforge_core::Error::internal(format!("FindCoordinator v2 parse: {e}"))
})?;
let body = serialize_find_coordinator_v2_response(
request.correlation_id,
&self.config.host,
self.config.port as i32,
);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_join_group(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{
parse_join_group_v5, serialize_join_group_v5_response, JoinGroupResponseMember,
};
if request.api_version != 5 {
let body =
serialize_join_group_v5_response(request.correlation_id, 0, "range", "", "", &[]);
tracing::warn!("rejecting JoinGroup v{} (only v5 supported)", request.api_version);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("join_group body_offset past buffer end")
})?;
let parsed = parse_join_group_v5(body_slice)
.map_err(|e| mockforge_core::Error::internal(format!("JoinGroup v5 parse: {e}")))?;
let protocols: Vec<(String, Vec<u8>)> =
parsed.protocols.iter().map(|p| (p.name.clone(), p.metadata.clone())).collect();
let outcome = self.group_coordinator.write().await.join_group(
&parsed.group_id,
&parsed.member_id,
&protocols,
);
let members: Vec<JoinGroupResponseMember> = outcome
.members
.iter()
.map(|m| JoinGroupResponseMember {
member_id: m.member_id.clone(),
metadata: m.metadata.clone(),
})
.collect();
let body = serialize_join_group_v5_response(
request.correlation_id,
outcome.generation_id,
&outcome.protocol_name,
&outcome.leader_id,
&outcome.member_id,
&members,
);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_sync_group(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{parse_sync_group_v3, serialize_sync_group_v3_response};
if request.api_version != 3 {
let body = serialize_sync_group_v3_response(request.correlation_id, &[]);
tracing::warn!("rejecting SyncGroup v{} (only v3 supported)", request.api_version);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("sync_group body_offset past buffer end")
})?;
let parsed = parse_sync_group_v3(body_slice)
.map_err(|e| mockforge_core::Error::internal(format!("SyncGroup v3 parse: {e}")))?;
let pairs: Vec<(String, Vec<u8>)> =
parsed.assignments.into_iter().map(|a| (a.member_id, a.assignment)).collect();
let assignment = self
.group_coordinator
.write()
.await
.sync_group(&parsed.group_id, &parsed.member_id, pairs)
.unwrap_or_default();
let body = serialize_sync_group_v3_response(request.correlation_id, &assignment);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_heartbeat(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{parse_heartbeat_v3, serialize_heartbeat_v3_response};
if request.api_version != 3 {
let body = serialize_heartbeat_v3_response(request.correlation_id, 0);
tracing::warn!("rejecting Heartbeat v{} (only v3 supported)", request.api_version);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("heartbeat body_offset past buffer end")
})?;
let parsed = parse_heartbeat_v3(body_slice)
.map_err(|e| mockforge_core::Error::internal(format!("Heartbeat v3 parse: {e}")))?;
let err = self
.group_coordinator
.write()
.await
.heartbeat(&parsed.group_id, parsed.generation_id, &parsed.member_id)
.err()
.unwrap_or(0);
let body = serialize_heartbeat_v3_response(request.correlation_id, err);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_leave_group(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{parse_leave_group, serialize_leave_group_response};
if !(0..=3).contains(&request.api_version) {
let body = serialize_leave_group_response(3, request.correlation_id, &[]);
tracing::warn!(
"rejecting LeaveGroup v{} (only v0..=v3 supported)",
request.api_version
);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("leave_group body_offset past buffer end")
})?;
let parsed = parse_leave_group(request.api_version, body_slice).map_err(|e| {
mockforge_core::Error::internal(format!(
"LeaveGroup v{} parse: {e}",
request.api_version
))
})?;
{
let mut coord = self.group_coordinator.write().await;
for m in &parsed.members {
coord.leave_group(&parsed.group_id, &m.member_id);
}
}
let body = serialize_leave_group_response(
request.api_version,
request.correlation_id,
&parsed.members,
);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_offset_commit(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{parse_offset_commit_v7, serialize_offset_commit_v7_response};
if request.api_version != 7 {
let body = serialize_offset_commit_v7_response(request.correlation_id, &[]);
tracing::warn!("rejecting OffsetCommit v{} (only v7 supported)", request.api_version);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("offset_commit body_offset past buffer end")
})?;
let parsed = parse_offset_commit_v7(body_slice)
.map_err(|e| mockforge_core::Error::internal(format!("OffsetCommit v7 parse: {e}")))?;
{
let mut coord = self.group_coordinator.write().await;
for topic in &parsed.topics {
for p in &topic.partitions {
coord.commit_offset(
&parsed.group_id,
&topic.name,
p.partition_index,
p.committed_offset,
p.committed_metadata.clone(),
);
}
}
}
let body = serialize_offset_commit_v7_response(request.correlation_id, &parsed.topics);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_offset_fetch(
&self,
message_buf: &[u8],
request: &KafkaRequest,
) -> Result<KafkaResponse> {
use crate::group_codec::{
parse_offset_fetch_v5, serialize_offset_fetch_v5_response,
OffsetFetchPartitionResponse, OffsetFetchTopicResponse,
};
if request.api_version != 5 {
let body = serialize_offset_fetch_v5_response(request.correlation_id, &[]);
tracing::warn!("rejecting OffsetFetch v{} (only v5 supported)", request.api_version);
return Ok(KafkaResponse::Preserialized(body));
}
let body_slice = message_buf.get(request.body_offset..).ok_or_else(|| {
mockforge_core::Error::internal("offset_fetch body_offset past buffer end")
})?;
let parsed = parse_offset_fetch_v5(body_slice)
.map_err(|e| mockforge_core::Error::internal(format!("OffsetFetch v5 parse: {e}")))?;
let coord = self.group_coordinator.read().await;
let topic_responses: Vec<OffsetFetchTopicResponse> = parsed
.topics
.iter()
.map(|t| {
let partitions = t
.partition_indexes
.iter()
.map(|&idx| match coord.fetch_offset(&parsed.group_id, &t.name, idx) {
Some(committed) => OffsetFetchPartitionResponse {
partition_index: idx,
committed_offset: committed.offset,
committed_metadata: committed.metadata,
},
None => OffsetFetchPartitionResponse {
partition_index: idx,
committed_offset: -1,
committed_metadata: None,
},
})
.collect();
OffsetFetchTopicResponse {
name: t.name.clone(),
partitions,
}
})
.collect();
let body = serialize_offset_fetch_v5_response(request.correlation_id, &topic_responses);
Ok(KafkaResponse::Preserialized(body))
}
async fn handle_api_versions(&self) -> Result<KafkaResponse> {
Ok(KafkaResponse::ApiVersions)
}
async fn handle_list_groups(&self) -> Result<KafkaResponse> {
Ok(KafkaResponse::ListGroups)
}
async fn handle_describe_groups(&self) -> Result<KafkaResponse> {
Ok(KafkaResponse::DescribeGroups)
}
async fn handle_create_topics(&self) -> Result<KafkaResponse> {
let mut topics = self.topics.write().await;
let topic_name = if topics.contains_key("default-topic") {
format!("topic-{}", topics.len() + 1)
} else {
"default-topic".to_string()
};
let topic_config = crate::topics::TopicConfig::default();
let topic = Topic::new(topic_name.clone(), topic_config);
topics.insert(topic_name, topic);
Ok(KafkaResponse::CreateTopics)
}
async fn handle_delete_topics(&self) -> Result<KafkaResponse> {
Ok(KafkaResponse::DeleteTopics)
}
async fn handle_describe_configs(&self) -> Result<KafkaResponse> {
Ok(KafkaResponse::DescribeConfigs)
}
pub async fn test_commit_offsets(
&self,
group_id: &str,
offsets: HashMap<(String, i32), i64>,
) -> Result<()> {
let mut consumer_groups = self.consumer_groups.write().await;
consumer_groups
.commit_offsets(group_id, offsets)
.await
.map_err(|e| mockforge_core::Error::from(e.to_string()))
}
pub async fn test_get_committed_offsets(&self, group_id: &str) -> HashMap<(String, i32), i64> {
let consumer_groups = self.consumer_groups.read().await;
consumer_groups.get_committed_offsets(group_id)
}
pub async fn test_create_topic(&self, name: &str, config: crate::topics::TopicConfig) {
use crate::topics::Topic;
let topic = Topic::new(name.to_string(), config);
let mut topics = self.topics.write().await;
topics.insert(name.to_string(), topic);
}
pub async fn test_join_group(
&self,
group_id: &str,
member_id: &str,
client_id: &str,
) -> Result<()> {
let mut consumer_groups = self.consumer_groups.write().await;
consumer_groups
.join_group(group_id, member_id, client_id)
.await
.map_err(|e| mockforge_core::Error::from(e.to_string()))?;
Ok(())
}
pub async fn test_sync_group(
&self,
group_id: &str,
assignments: Vec<crate::consumer_groups::PartitionAssignment>,
) -> Result<()> {
let topics = self.topics.read().await;
let mut consumer_groups = self.consumer_groups.write().await;
consumer_groups
.sync_group(group_id, assignments, &topics)
.await
.map_err(|e| mockforge_core::Error::from(e.to_string()))?;
Ok(())
}
pub async fn test_get_assignments(
&self,
group_id: &str,
member_id: &str,
) -> Vec<crate::consumer_groups::PartitionAssignment> {
let consumer_groups = self.consumer_groups.read().await;
if let Some(group) = consumer_groups.groups().get(group_id) {
if let Some(member) = group.members.get(member_id) {
return member.assignment.clone();
}
}
vec![]
}
pub async fn test_simulate_lag(&self, group_id: &str, topic: &str, lag: i64) -> Result<()> {
let topics = self.topics.read().await;
let mut consumer_groups = self.consumer_groups.write().await;
consumer_groups.simulate_lag(group_id, topic, lag, &topics).await;
Ok(())
}
pub async fn test_reset_offsets(&self, group_id: &str, topic: &str, to: &str) -> Result<()> {
let topics = self.topics.read().await;
let mut consumer_groups = self.consumer_groups.write().await;
consumer_groups.reset_offsets(group_id, topic, to, &topics).await;
Ok(())
}
pub fn metrics(&self) -> &Arc<KafkaMetrics> {
&self.metrics
}
}
#[derive(Debug, Clone)]
pub struct Record {
pub key: Option<Vec<u8>>,
pub value: Vec<u8>,
pub headers: Vec<(String, Vec<u8>)>,
}
#[derive(Debug)]
pub struct ProduceResponse {
pub partition: i32,
pub error_code: i16,
pub offset: i64,
}
#[derive(Debug)]
pub struct FetchResponse {
pub partition: i32,
pub error_code: i16,
pub high_watermark: i64,
pub records: Vec<Record>,
}
struct ConnectionGuard {
metrics: Arc<KafkaMetrics>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.metrics.record_connection_closed();
}
}
fn get_api_key_from_request(request: &KafkaRequest) -> i16 {
request.api_key
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn advertised_falls_back_to_bind_when_unset() {
let config = KafkaConfig {
host: "127.0.0.1".to_string(),
port: 19092,
advertised_host: None,
advertised_port: None,
..KafkaConfig::default()
};
assert_eq!(resolve_advertised_endpoint(&config), ("127.0.0.1".to_string(), 19092));
}
#[test]
fn advertised_overrides_bind_when_host_set() {
let config = KafkaConfig {
host: "0.0.0.0".to_string(),
port: 9092,
advertised_host: Some("hosted-mock-abc.fly.dev".to_string()),
advertised_port: None,
..KafkaConfig::default()
};
assert_eq!(
resolve_advertised_endpoint(&config),
("hosted-mock-abc.fly.dev".to_string(), 9092)
);
}
#[test]
fn advertised_overrides_bind_when_port_set() {
let config = KafkaConfig {
host: "0.0.0.0".to_string(),
port: 9092,
advertised_host: None,
advertised_port: Some(443),
..KafkaConfig::default()
};
assert_eq!(resolve_advertised_endpoint(&config), ("0.0.0.0".to_string(), 443));
}
#[test]
fn test_record_creation_with_all_fields() {
let record = Record {
key: Some(b"test-key".to_vec()),
value: b"test-value".to_vec(),
headers: vec![("header1".to_string(), b"value1".to_vec())],
};
assert_eq!(record.key, Some(b"test-key".to_vec()));
assert_eq!(record.value, b"test-value".to_vec());
assert_eq!(record.headers.len(), 1);
assert_eq!(record.headers[0].0, "header1");
}
#[test]
fn test_record_creation_without_key() {
let record = Record {
key: None,
value: b"message body".to_vec(),
headers: vec![],
};
assert!(record.key.is_none());
assert_eq!(record.value, b"message body".to_vec());
assert!(record.headers.is_empty());
}
#[test]
fn test_record_with_multiple_headers() {
let record = Record {
key: Some(b"key".to_vec()),
value: b"value".to_vec(),
headers: vec![
("content-type".to_string(), b"application/json".to_vec()),
("correlation-id".to_string(), b"12345".to_vec()),
("source".to_string(), b"test-producer".to_vec()),
],
};
assert_eq!(record.headers.len(), 3);
assert_eq!(record.headers[0].0, "content-type");
assert_eq!(record.headers[1].0, "correlation-id");
assert_eq!(record.headers[2].0, "source");
}
#[test]
fn test_record_clone() {
let original = Record {
key: Some(b"key".to_vec()),
value: b"value".to_vec(),
headers: vec![("h".to_string(), b"v".to_vec())],
};
let cloned = original.clone();
assert_eq!(original.key, cloned.key);
assert_eq!(original.value, cloned.value);
assert_eq!(original.headers, cloned.headers);
}
#[test]
fn test_record_debug() {
let record = Record {
key: Some(b"key".to_vec()),
value: b"value".to_vec(),
headers: vec![],
};
let debug_str = format!("{:?}", record);
assert!(debug_str.contains("Record"));
assert!(debug_str.contains("key"));
assert!(debug_str.contains("value"));
}
#[test]
fn test_record_empty_value() {
let record = Record {
key: None,
value: vec![],
headers: vec![],
};
assert!(record.key.is_none());
assert!(record.value.is_empty());
assert!(record.headers.is_empty());
}
#[test]
fn test_record_binary_data() {
let binary_data: Vec<u8> = vec![0x00, 0xFF, 0x80, 0x7F, 0xFE];
let record = Record {
key: Some(binary_data.clone()),
value: binary_data.clone(),
headers: vec![],
};
assert_eq!(record.key.as_ref().unwrap().len(), 5);
assert_eq!(record.value.len(), 5);
assert_eq!(record.value[0], 0x00);
assert_eq!(record.value[1], 0xFF);
}
#[test]
fn test_produce_response_success() {
let response = ProduceResponse {
partition: 0,
error_code: 0,
offset: 100,
};
assert_eq!(response.partition, 0);
assert_eq!(response.error_code, 0);
assert_eq!(response.offset, 100);
}
#[test]
fn test_produce_response_with_error() {
let response = ProduceResponse {
partition: 1,
error_code: 3, offset: -1,
};
assert_eq!(response.partition, 1);
assert_eq!(response.error_code, 3);
assert_eq!(response.offset, -1);
}
#[test]
fn test_produce_response_high_offset() {
let response = ProduceResponse {
partition: 5,
error_code: 0,
offset: i64::MAX,
};
assert_eq!(response.partition, 5);
assert_eq!(response.offset, i64::MAX);
}
#[test]
fn test_produce_response_debug() {
let response = ProduceResponse {
partition: 0,
error_code: 0,
offset: 42,
};
let debug_str = format!("{:?}", response);
assert!(debug_str.contains("ProduceResponse"));
assert!(debug_str.contains("partition"));
assert!(debug_str.contains("error_code"));
assert!(debug_str.contains("offset"));
}
#[test]
fn test_fetch_response_empty() {
let response = FetchResponse {
partition: 0,
error_code: 0,
high_watermark: 100,
records: vec![],
};
assert_eq!(response.partition, 0);
assert_eq!(response.error_code, 0);
assert_eq!(response.high_watermark, 100);
assert!(response.records.is_empty());
}
#[test]
fn test_fetch_response_with_records() {
let records = vec![
Record {
key: Some(b"key1".to_vec()),
value: b"value1".to_vec(),
headers: vec![],
},
Record {
key: Some(b"key2".to_vec()),
value: b"value2".to_vec(),
headers: vec![],
},
];
let response = FetchResponse {
partition: 0,
error_code: 0,
high_watermark: 50,
records,
};
assert_eq!(response.records.len(), 2);
assert_eq!(response.records[0].key, Some(b"key1".to_vec()));
assert_eq!(response.records[1].value, b"value2".to_vec());
}
#[test]
fn test_fetch_response_with_error() {
let response = FetchResponse {
partition: 0,
error_code: 1, high_watermark: 0,
records: vec![],
};
assert_eq!(response.error_code, 1);
assert_eq!(response.high_watermark, 0);
}
#[test]
fn test_fetch_response_debug() {
let response = FetchResponse {
partition: 2,
error_code: 0,
high_watermark: 1000,
records: vec![],
};
let debug_str = format!("{:?}", response);
assert!(debug_str.contains("FetchResponse"));
assert!(debug_str.contains("high_watermark"));
}
#[test]
fn test_get_api_key_produce() {
let request = KafkaRequest {
api_key: 0, api_version: 7,
correlation_id: 1,
client_id: "test-client".to_string(),
request_type: KafkaRequestType::Produce,
body_offset: 0,
};
assert_eq!(get_api_key_from_request(&request), 0);
}
#[test]
fn test_get_api_key_fetch() {
let request = KafkaRequest {
api_key: 1, api_version: 11,
correlation_id: 2,
client_id: "consumer".to_string(),
request_type: KafkaRequestType::Fetch,
body_offset: 0,
};
assert_eq!(get_api_key_from_request(&request), 1);
}
#[test]
fn test_get_api_key_metadata() {
let request = KafkaRequest {
api_key: 3, api_version: 9,
correlation_id: 3,
client_id: "admin".to_string(),
request_type: KafkaRequestType::Metadata,
body_offset: 0,
};
assert_eq!(get_api_key_from_request(&request), 3);
}
#[test]
fn test_get_api_key_api_versions() {
let request = KafkaRequest {
api_key: 18, api_version: 3,
correlation_id: 100,
client_id: "client".to_string(),
request_type: KafkaRequestType::ApiVersions,
body_offset: 0,
};
assert_eq!(get_api_key_from_request(&request), 18);
}
#[test]
fn test_get_api_key_list_groups() {
let request = KafkaRequest {
api_key: 16, api_version: 4,
correlation_id: 5,
client_id: "admin-client".to_string(),
request_type: KafkaRequestType::ListGroups,
body_offset: 0,
};
assert_eq!(get_api_key_from_request(&request), 16);
}
#[test]
fn test_get_api_key_create_topics() {
let request = KafkaRequest {
api_key: 19, api_version: 5,
correlation_id: 10,
client_id: "admin".to_string(),
request_type: KafkaRequestType::CreateTopics,
body_offset: 0,
};
assert_eq!(get_api_key_from_request(&request), 19);
}
#[test]
fn test_kafka_request_fields() {
let request = KafkaRequest {
api_key: 0,
api_version: 8,
correlation_id: 12345,
client_id: "my-producer".to_string(),
request_type: KafkaRequestType::Produce,
body_offset: 0,
};
assert_eq!(request.api_key, 0);
assert_eq!(request.api_version, 8);
assert_eq!(request.correlation_id, 12345);
assert_eq!(request.client_id, "my-producer");
}
#[test]
fn test_kafka_request_empty_client_id() {
let request = KafkaRequest {
api_key: 3,
api_version: 9,
correlation_id: 1,
client_id: String::new(),
request_type: KafkaRequestType::Metadata,
body_offset: 0,
};
assert!(request.client_id.is_empty());
}
#[test]
fn test_kafka_request_max_correlation_id() {
let request = KafkaRequest {
api_key: 0,
api_version: 0,
correlation_id: i32::MAX,
client_id: "test".to_string(),
request_type: KafkaRequestType::Produce,
body_offset: 0,
};
assert_eq!(request.correlation_id, i32::MAX);
}
#[test]
fn test_request_type_variants() {
let metadata = KafkaRequestType::Metadata;
let produce = KafkaRequestType::Produce;
let fetch = KafkaRequestType::Fetch;
let list_groups = KafkaRequestType::ListGroups;
let describe_groups = KafkaRequestType::DescribeGroups;
let api_versions = KafkaRequestType::ApiVersions;
let create_topics = KafkaRequestType::CreateTopics;
let delete_topics = KafkaRequestType::DeleteTopics;
let describe_configs = KafkaRequestType::DescribeConfigs;
assert!(matches!(metadata, KafkaRequestType::Metadata));
assert!(matches!(produce, KafkaRequestType::Produce));
assert!(matches!(fetch, KafkaRequestType::Fetch));
assert!(matches!(list_groups, KafkaRequestType::ListGroups));
assert!(matches!(describe_groups, KafkaRequestType::DescribeGroups));
assert!(matches!(api_versions, KafkaRequestType::ApiVersions));
assert!(matches!(create_topics, KafkaRequestType::CreateTopics));
assert!(matches!(delete_topics, KafkaRequestType::DeleteTopics));
assert!(matches!(describe_configs, KafkaRequestType::DescribeConfigs));
}
#[test]
fn test_message_size_limit_constant() {
let max_message_size: usize = 10 * 1024 * 1024;
assert_eq!(max_message_size, 10_485_760);
}
#[test]
fn test_message_size_under_limit() {
let message_size: usize = 1024 * 1024; let limit: usize = 10 * 1024 * 1024; assert!(message_size <= limit);
}
#[test]
fn test_message_size_over_limit() {
let message_size: usize = 11 * 1024 * 1024; let limit: usize = 10 * 1024 * 1024; assert!(message_size > limit);
}
#[test]
fn test_response_size_serialization() {
let response_len: i32 = 1000;
let size_bytes = response_len.to_be_bytes();
assert_eq!(size_bytes.len(), 4);
assert_eq!(i32::from_be_bytes(size_bytes), 1000);
}
#[test]
fn test_response_size_max_value() {
let response_len: i32 = i32::MAX;
let size_bytes = response_len.to_be_bytes();
assert_eq!(size_bytes.len(), 4);
assert_eq!(i32::from_be_bytes(size_bytes), i32::MAX);
}
#[test]
fn test_response_size_zero() {
let response_len: i32 = 0;
let size_bytes = response_len.to_be_bytes();
assert_eq!(size_bytes, [0, 0, 0, 0]);
}
#[tokio::test]
async fn test_handle_produce_v9_writes_records_to_topic() {
use crate::produce_codec::{parse_produce_v9, PartitionProduceData, TopicProduceData};
let broker = KafkaMockBroker::new(KafkaConfig::default()).await.expect("broker");
let record_batch =
crate::produce_codec::one_record_batch_for_testing(Some(b"key-1"), b"hello-produce");
let mut msg = Vec::new();
msg.extend_from_slice(&0i16.to_be_bytes());
msg.extend_from_slice(&9i16.to_be_bytes());
msg.extend_from_slice(&777i32.to_be_bytes());
msg.extend_from_slice(&1i16.to_be_bytes());
msg.push(b't');
msg.push(0);
msg.push(0); msg.extend_from_slice(&(-1i16).to_be_bytes()); msg.extend_from_slice(&30_000i32.to_be_bytes());
msg.push(2);
let topic_name = b"prod-target";
msg.push((topic_name.len() as u8) + 1);
msg.extend_from_slice(topic_name);
msg.push(2);
msg.extend_from_slice(&0i32.to_be_bytes());
let rb_len_plus_one = (record_batch.len() as u32) + 1;
if rb_len_plus_one < 128 {
msg.push(rb_len_plus_one as u8);
} else {
let mut v = rb_len_plus_one;
while (v & !0x7F) != 0 {
msg.push(((v & 0x7F) | 0x80) as u8);
v >>= 7;
}
msg.push(v as u8);
}
msg.extend_from_slice(&record_batch);
msg.push(0); msg.push(0); msg.push(0);
let body_offset = 10 + 1 + 1 ;
let parsed = parse_produce_v9(&msg[body_offset..]).expect("codec parse");
assert_eq!(parsed.topics[0].name, "prod-target");
assert_eq!(parsed.topics[0].partitions[0].records[0].value, b"hello-produce");
let handler = crate::protocol::KafkaProtocolHandler::new();
let request = handler.parse_request(&msg).expect("parse header");
assert_eq!(request.api_key, 0);
assert_eq!(request.api_version, 9);
assert_eq!(request.body_offset, body_offset);
let response = broker.handle_produce(&msg, &request).await.expect("produce");
match response {
KafkaResponse::Preserialized(bytes) => {
assert_eq!(&bytes[0..4], &777i32.to_be_bytes());
}
other => panic!("unexpected response variant: {other:?}"),
}
let topics = broker.topics.read().await;
let topic = topics.get("prod-target").expect("auto-created topic");
let record_count: usize = topic.partitions.iter().map(|p| p.messages.len()).sum();
assert_eq!(record_count, 1);
let stored = topic.partitions[0].messages.front().unwrap();
assert_eq!(stored.value, b"hello-produce");
assert_eq!(stored.key.as_deref(), Some(b"key-1".as_ref()));
let _ = TopicProduceData {
name: String::new(),
partitions: vec![],
};
let _ = PartitionProduceData {
partition_index: 0,
records: vec![],
compression_codec: 0,
};
}
#[tokio::test]
async fn test_handle_create_topics_creates_unique_topic_names() {
let broker = KafkaMockBroker::new(KafkaConfig::default()).await.expect("broker");
let _ = broker.handle_create_topics().await.expect("create1");
let _ = broker.handle_create_topics().await.expect("create2");
let topics = broker.topics.read().await;
assert!(topics.contains_key("default-topic"));
assert!(topics.keys().any(|name| name.starts_with("topic-")));
}
}