use aws_smithy_eventstream::frame::write_message_to;
use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
use axum::{body::Body, http::response::Builder, response::Response};
use bytes::{Bytes, BytesMut};
use http_body_util::BodyExt;
use serde::Serialize;
use crate::error::ServerError;
pub fn encode_payload_chunk(chunk_data: Bytes) -> Result<Bytes, ServerError> {
let message = Message::new(chunk_data)
.add_header(Header::new(
":event-type",
HeaderValue::String("PayloadChunk".into()),
))
.add_header(Header::new(
":content-type",
HeaderValue::String("application/octet-stream".into()),
));
let mut buf = BytesMut::new();
write_message_to(&message, &mut buf).map_err(ServerError::EventStreamEncodingError)?;
Ok(buf.freeze())
}
pub fn encode_invoke_complete(
error_code: Option<String>,
error_details: Option<String>,
) -> Result<Bytes, ServerError> {
#[derive(Serialize)]
#[serde(rename_all = "PascalCase")]
struct InvokeCompletePayload {
#[serde(skip_serializing_if = "Option::is_none")]
error_code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
error_details: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
log_result: Option<String>,
}
let payload = InvokeCompletePayload {
error_code,
error_details,
log_result: None,
};
let payload_json = serde_json::to_vec(&payload).map_err(ServerError::SerializationError)?;
let message = Message::new(Bytes::from(payload_json))
.add_header(Header::new(
":event-type",
HeaderValue::String("InvokeComplete".into()),
))
.add_header(Header::new(
":content-type",
HeaderValue::String("application/json".into()),
));
let mut buf = BytesMut::new();
write_message_to(&message, &mut buf).map_err(ServerError::EventStreamEncodingError)?;
Ok(buf.freeze())
}
pub async fn create_eventstream_response(
builder: Builder,
body: &mut Body,
) -> Result<Response<Body>, ServerError> {
let mut eventstream_chunks = Vec::new();
while let Some(frame) = body.frame().await {
let frame = frame.map_err(ServerError::DataDeserialization)?;
if let Ok(data) = frame.into_data() {
if !data.is_empty() {
let eventstream_chunk = encode_payload_chunk(data)?;
eventstream_chunks.push(eventstream_chunk);
}
}
}
let invoke_complete = encode_invoke_complete(None, None)?;
eventstream_chunks.push(invoke_complete);
let combined_body = eventstream_chunks
.into_iter()
.flat_map(|chunk| chunk.to_vec())
.collect::<Vec<u8>>();
let response = builder
.header("Content-Type", "application/vnd.amazon.eventstream")
.body(Body::from(combined_body))
.map_err(ServerError::ResponseBuild)?;
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_eventstream::frame::read_message_from;
fn decode_eventstream_message(
data: &[u8],
) -> Result<(String, Bytes), Box<dyn std::error::Error>> {
let message = read_message_from(data)?;
let event_type = message
.headers()
.iter()
.find(|h| h.name().as_str() == ":event-type")
.and_then(|h| {
if let aws_smithy_types::event_stream::HeaderValue::String(s) = h.value() {
Some(s.as_str().to_string())
} else {
None
}
})
.ok_or("Missing :event-type header")?;
Ok((event_type, message.payload().clone()))
}
#[test]
fn test_encode_payload_chunk() {
let test_data = Bytes::from("Hello, streaming world!");
let encoded =
encode_payload_chunk(test_data.clone()).expect("Failed to encode payload chunk");
let (event_type, payload) =
decode_eventstream_message(&encoded).expect("Failed to decode EventStream message");
assert_eq!(event_type, "PayloadChunk");
assert_eq!(payload, test_data);
}
#[test]
fn test_encode_invoke_complete_success() {
let encoded = encode_invoke_complete(None, None).expect("Failed to encode InvokeComplete");
let (event_type, payload) =
decode_eventstream_message(&encoded).expect("Failed to decode EventStream message");
assert_eq!(event_type, "InvokeComplete");
let payload_json: serde_json::Value =
serde_json::from_slice(&payload).expect("Failed to parse InvokeComplete payload");
assert!(payload_json.get("ErrorCode").is_none() || payload_json["ErrorCode"].is_null());
assert!(
payload_json.get("ErrorDetails").is_none() || payload_json["ErrorDetails"].is_null()
);
}
#[test]
fn test_encode_invoke_complete_with_error() {
let error_code = Some("RuntimeError".to_string());
let error_details = Some("Function execution failed".to_string());
let encoded = encode_invoke_complete(error_code.clone(), error_details.clone())
.expect("Failed to encode InvokeComplete with error");
let (event_type, payload) =
decode_eventstream_message(&encoded).expect("Failed to decode EventStream message");
assert_eq!(event_type, "InvokeComplete");
let payload_json: serde_json::Value =
serde_json::from_slice(&payload).expect("Failed to parse InvokeComplete payload");
assert_eq!(payload_json["ErrorCode"].as_str(), error_code.as_deref());
assert_eq!(
payload_json["ErrorDetails"].as_str(),
error_details.as_deref()
);
}
#[test]
fn test_eventstream_message_structure() {
let test_data = Bytes::from("test data");
let encoded = encode_payload_chunk(test_data).expect("Failed to encode payload chunk");
assert!(
encoded.len() >= 16,
"Message too short to be valid EventStream"
);
let total_length = u32::from_be_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
assert_eq!(
total_length as usize,
encoded.len(),
"Message length mismatch"
);
}
#[test]
fn test_multiple_payload_chunks() {
let chunks = vec![
Bytes::from("chunk 1"),
Bytes::from("chunk 2"),
Bytes::from("chunk 3"),
];
let mut encoded_messages = Vec::new();
for chunk in &chunks {
let encoded =
encode_payload_chunk(chunk.clone()).expect("Failed to encode payload chunk");
encoded_messages.push(encoded);
}
let invoke_complete =
encode_invoke_complete(None, None).expect("Failed to encode InvokeComplete");
encoded_messages.push(invoke_complete);
assert_eq!(encoded_messages.len(), 4);
for (i, encoded) in encoded_messages.iter().enumerate() {
let (event_type, _payload) =
decode_eventstream_message(encoded).expect("Failed to decode EventStream message");
if i < chunks.len() {
assert_eq!(event_type, "PayloadChunk");
} else {
assert_eq!(event_type, "InvokeComplete");
}
}
}
}