use bytes::{Bytes, BytesMut};
use kafka_protocol::messages::fetch_request::FetchRequest;
use kafka_protocol::messages::produce_request::ProduceRequest;
use kafka_protocol::messages::request_header::RequestHeader;
use kafka_protocol::messages::ApiKey;
use kafka_protocol::protocol::buf::ByteBuf;
use kafka_protocol::protocol::Decodable;
use kafka_protocol::records::{
Compression, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions,
};
use minicbor::encode::Encoder;
use ockam_node::Context;
use std::convert::TryFrom;
use std::io::{Error, ErrorKind};
use tracing::warn;
use crate::kafka::portal_worker::InterceptError;
use crate::kafka::protocol_aware::utils::{decode_body, encode_request};
use crate::kafka::protocol_aware::{InletInterceptorImpl, MessageWrapper, RequestInfo};
impl InletInterceptorImpl {
pub(crate) async fn intercept_request_impl(
&self,
context: &mut Context,
mut original: BytesMut,
) -> Result<BytesMut, InterceptError> {
let mut buffer = original.peek_bytes(0..original.len());
let api_key_num = buffer
.peek_bytes(0..2)
.try_get_i16()
.map_err(|_| InterceptError::Io(Error::from(ErrorKind::InvalidData)))?;
let api_key = ApiKey::try_from(api_key_num).map_err(|_| {
warn!("unknown request api: {api_key_num}");
InterceptError::Io(Error::from(ErrorKind::InvalidData))
})?;
let version = buffer
.peek_bytes(2..4)
.try_get_i16()
.map_err(|_| InterceptError::Io(Error::from(ErrorKind::InvalidData)))?;
let result = RequestHeader::decode(&mut buffer, api_key.request_header_version(version));
let header = match result {
Ok(header) => header,
Err(_) => {
warn!("cannot decode request kafka header");
return Err(InterceptError::Io(Error::from(ErrorKind::InvalidData)));
}
};
debug!(
"request: length: {}, correlation {}, version {}, api {:?}",
buffer.len(),
header.correlation_id,
header.request_api_version,
api_key
);
match api_key {
ApiKey::ProduceKey => {
return self
.handle_produce_request(context, &mut buffer, &header)
.await;
}
ApiKey::FetchKey => {
self.handle_fetch_request(context, &mut buffer, &header)
.await?;
}
ApiKey::MetadataKey | ApiKey::FindCoordinatorKey => {
self.request_map.lock().unwrap().insert(
header.correlation_id,
RequestInfo {
request_api_key: api_key,
request_api_version: header.request_api_version,
},
);
}
ApiKey::LeaderAndIsrKey => {
warn!("leader and isr key not supported! closing connection");
return Err(InterceptError::Io(Error::from(ErrorKind::InvalidData)));
}
ApiKey::UpdateMetadataKey => {
warn!("update metadata not supported! closing connection");
return Err(InterceptError::Io(Error::from(ErrorKind::InvalidData)));
}
_ => {}
}
Ok(original)
}
async fn handle_fetch_request(
&self,
context: &mut Context,
buffer: &mut Bytes,
header: &RequestHeader,
) -> Result<(), InterceptError> {
let request: FetchRequest = decode_body(buffer, header.request_api_version)?;
for topic in &request.topics {
let topic_id = if header.request_api_version <= 12 {
topic.topic.0.to_string()
} else {
let topic_id = topic.topic_id.to_string();
self.uuid_to_name
.lock()
.unwrap()
.get(&topic_id)
.cloned()
.ok_or_else(|| {
warn!("missing map from uuid {topic_id} to name");
InterceptError::Io(Error::from(ErrorKind::InvalidData))
})?
};
let partitions: Vec<i32> = topic
.partitions
.iter()
.map(|partition| partition.partition)
.collect();
self.secure_channel_controller
.start_relays_for(context, &topic_id, partitions)
.await
.map_err(InterceptError::Ockam)?
}
self.request_map.lock().unwrap().insert(
header.correlation_id,
RequestInfo {
request_api_key: ApiKey::FetchKey,
request_api_version: header.request_api_version,
},
);
Ok(())
}
async fn handle_produce_request(
&self,
context: &mut Context,
buffer: &mut Bytes,
header: &RequestHeader,
) -> Result<BytesMut, InterceptError> {
let mut request: ProduceRequest = decode_body(buffer, header.request_api_version)?;
for (topic_name, topic) in request.topic_data.iter_mut() {
for data in &mut topic.partition_data {
if let Some(content) = data.records.take() {
let mut content = BytesMut::from(content.as_ref());
let mut records = RecordBatchDecoder::decode(&mut content)
.map_err(|_| InterceptError::Io(Error::from(ErrorKind::InvalidData)))?;
for record in records.iter_mut() {
if let Some(record_value) = record.value.take() {
let encrypted_content = self
.secure_channel_controller
.encrypt_content_for(
context,
topic_name,
data.index,
record_value.to_vec(),
)
.await
.map_err(InterceptError::Ockam)?;
let wrapper = MessageWrapper {
consumer_decryptor_address: encrypted_content
.consumer_decryptor_address,
content: encrypted_content.content,
};
let mut write_buffer = Vec::with_capacity(1024);
let mut encoder = Encoder::new(&mut write_buffer);
encoder.encode(wrapper).map_err(|_err| {
InterceptError::Io(Error::from(ErrorKind::InvalidData))
})?;
record.value = Some(write_buffer.into());
}
}
let mut encoded = BytesMut::new();
RecordBatchEncoder::encode(
&mut encoded,
records.iter(),
&RecordEncodeOptions {
version: 2,
compression: Compression::None,
},
)
.map_err(|_| InterceptError::Io(Error::from(ErrorKind::InvalidData)))?;
data.records = Some(encoded.freeze());
}
}
}
encode_request(
header,
&request,
header.request_api_version,
ApiKey::ProduceKey,
)
}
}