Skip to main content

cargo_lambda_watch/
eventstream.rs

1use aws_smithy_eventstream::frame::write_message_to;
2use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
3use axum::{body::Body, http::response::Builder, response::Response};
4use bytes::{Bytes, BytesMut};
5use http_body_util::BodyExt;
6use serde::Serialize;
7
8use crate::error::ServerError;
9
10/// Encodes a chunk of data as an EventStream PayloadChunk event
11pub fn encode_payload_chunk(chunk_data: Bytes) -> Result<Bytes, ServerError> {
12    let message = Message::new(chunk_data)
13        .add_header(Header::new(
14            ":event-type",
15            HeaderValue::String("PayloadChunk".into()),
16        ))
17        .add_header(Header::new(
18            ":content-type",
19            HeaderValue::String("application/octet-stream".into()),
20        ));
21
22    let mut buf = BytesMut::new();
23    write_message_to(&message, &mut buf).map_err(ServerError::EventStreamEncodingError)?;
24
25    Ok(buf.freeze())
26}
27
28/// Encodes an InvokeComplete event with optional error information
29pub fn encode_invoke_complete(
30    error_code: Option<String>,
31    error_details: Option<String>,
32) -> Result<Bytes, ServerError> {
33    #[derive(Serialize)]
34    #[serde(rename_all = "PascalCase")]
35    struct InvokeCompletePayload {
36        #[serde(skip_serializing_if = "Option::is_none")]
37        error_code: Option<String>,
38        #[serde(skip_serializing_if = "Option::is_none")]
39        error_details: Option<String>,
40        #[serde(skip_serializing_if = "Option::is_none")]
41        log_result: Option<String>,
42    }
43
44    let payload = InvokeCompletePayload {
45        error_code,
46        error_details,
47        log_result: None,
48    };
49
50    let payload_json = serde_json::to_vec(&payload).map_err(ServerError::SerializationError)?;
51
52    let message = Message::new(Bytes::from(payload_json))
53        .add_header(Header::new(
54            ":event-type",
55            HeaderValue::String("InvokeComplete".into()),
56        ))
57        .add_header(Header::new(
58            ":content-type",
59            HeaderValue::String("application/json".into()),
60        ));
61
62    let mut buf = BytesMut::new();
63    write_message_to(&message, &mut buf).map_err(ServerError::EventStreamEncodingError)?;
64
65    Ok(buf.freeze())
66}
67
68/// Transforms a Lambda streaming response into an EventStream response
69pub async fn create_eventstream_response(
70    builder: Builder,
71    body: &mut Body,
72) -> Result<Response<Body>, ServerError> {
73    // Collect all frames from the body
74    let mut eventstream_chunks = Vec::new();
75
76    // Process each chunk and convert to EventStream PayloadChunk events
77    while let Some(frame) = body.frame().await {
78        let frame = frame.map_err(ServerError::DataDeserialization)?;
79
80        if let Ok(data) = frame.into_data() {
81            if !data.is_empty() {
82                let eventstream_chunk = encode_payload_chunk(data)?;
83                eventstream_chunks.push(eventstream_chunk);
84            }
85        }
86    }
87
88    // Add InvokeComplete event at the end
89    let invoke_complete = encode_invoke_complete(None, None)?;
90    eventstream_chunks.push(invoke_complete);
91
92    // Combine all chunks into a single body
93    let combined_body = eventstream_chunks
94        .into_iter()
95        .flat_map(|chunk| chunk.to_vec())
96        .collect::<Vec<u8>>();
97
98    let response = builder
99        .header("Content-Type", "application/vnd.amazon.eventstream")
100        .body(Body::from(combined_body))
101        .map_err(ServerError::ResponseBuild)?;
102
103    Ok(response)
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use aws_smithy_eventstream::frame::read_message_from;
110
111    // Helper function to decode EventStream messages for testing
112    fn decode_eventstream_message(
113        data: &[u8],
114    ) -> Result<(String, Bytes), Box<dyn std::error::Error>> {
115        let message = read_message_from(data)?;
116
117        let event_type = message
118            .headers()
119            .iter()
120            .find(|h| h.name().as_str() == ":event-type")
121            .and_then(|h| {
122                if let aws_smithy_types::event_stream::HeaderValue::String(s) = h.value() {
123                    Some(s.as_str().to_string())
124                } else {
125                    None
126                }
127            })
128            .ok_or("Missing :event-type header")?;
129
130        Ok((event_type, message.payload().clone()))
131    }
132
133    #[test]
134    fn test_encode_payload_chunk() {
135        let test_data = Bytes::from("Hello, streaming world!");
136
137        let encoded =
138            encode_payload_chunk(test_data.clone()).expect("Failed to encode payload chunk");
139
140        // Verify the encoded message can be decoded
141        let (event_type, payload) =
142            decode_eventstream_message(&encoded).expect("Failed to decode EventStream message");
143
144        assert_eq!(event_type, "PayloadChunk");
145        assert_eq!(payload, test_data);
146    }
147
148    #[test]
149    fn test_encode_invoke_complete_success() {
150        let encoded = encode_invoke_complete(None, None).expect("Failed to encode InvokeComplete");
151
152        // Verify the encoded message can be decoded
153        let (event_type, payload) =
154            decode_eventstream_message(&encoded).expect("Failed to decode EventStream message");
155
156        assert_eq!(event_type, "InvokeComplete");
157
158        // Parse the JSON payload
159        let payload_json: serde_json::Value =
160            serde_json::from_slice(&payload).expect("Failed to parse InvokeComplete payload");
161
162        // Verify no error fields are present (or they are null)
163        assert!(payload_json.get("ErrorCode").is_none() || payload_json["ErrorCode"].is_null());
164        assert!(
165            payload_json.get("ErrorDetails").is_none() || payload_json["ErrorDetails"].is_null()
166        );
167    }
168
169    #[test]
170    fn test_encode_invoke_complete_with_error() {
171        let error_code = Some("RuntimeError".to_string());
172        let error_details = Some("Function execution failed".to_string());
173
174        let encoded = encode_invoke_complete(error_code.clone(), error_details.clone())
175            .expect("Failed to encode InvokeComplete with error");
176
177        // Verify the encoded message can be decoded
178        let (event_type, payload) =
179            decode_eventstream_message(&encoded).expect("Failed to decode EventStream message");
180
181        assert_eq!(event_type, "InvokeComplete");
182
183        // Parse the JSON payload
184        let payload_json: serde_json::Value =
185            serde_json::from_slice(&payload).expect("Failed to parse InvokeComplete payload");
186
187        // Verify error fields are present
188        assert_eq!(payload_json["ErrorCode"].as_str(), error_code.as_deref());
189        assert_eq!(
190            payload_json["ErrorDetails"].as_str(),
191            error_details.as_deref()
192        );
193    }
194
195    #[test]
196    fn test_eventstream_message_structure() {
197        // Test that the encoded messages have the correct EventStream structure
198        let test_data = Bytes::from("test data");
199        let encoded = encode_payload_chunk(test_data).expect("Failed to encode payload chunk");
200
201        // EventStream messages should have a specific binary format
202        // The first 12 bytes are the prelude (total length, headers length, prelude CRC)
203        assert!(
204            encoded.len() >= 16,
205            "Message too short to be valid EventStream"
206        );
207
208        // Read the message length from the first 4 bytes (big-endian)
209        let total_length = u32::from_be_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
210
211        // The encoded length should match the total message length
212        assert_eq!(
213            total_length as usize,
214            encoded.len(),
215            "Message length mismatch"
216        );
217    }
218
219    #[test]
220    fn test_multiple_payload_chunks() {
221        // Test encoding multiple chunks as would happen in a real stream
222        let chunks = vec![
223            Bytes::from("chunk 1"),
224            Bytes::from("chunk 2"),
225            Bytes::from("chunk 3"),
226        ];
227
228        let mut encoded_messages = Vec::new();
229
230        for chunk in &chunks {
231            let encoded =
232                encode_payload_chunk(chunk.clone()).expect("Failed to encode payload chunk");
233            encoded_messages.push(encoded);
234        }
235
236        // Add InvokeComplete at the end
237        let invoke_complete =
238            encode_invoke_complete(None, None).expect("Failed to encode InvokeComplete");
239        encoded_messages.push(invoke_complete);
240
241        // Verify we have the right number of messages
242        assert_eq!(encoded_messages.len(), 4); // 3 chunks + 1 InvokeComplete
243
244        // Verify each chunk can be decoded
245        for (i, encoded) in encoded_messages.iter().enumerate() {
246            let (event_type, _payload) =
247                decode_eventstream_message(encoded).expect("Failed to decode EventStream message");
248
249            if i < chunks.len() {
250                assert_eq!(event_type, "PayloadChunk");
251            } else {
252                assert_eq!(event_type, "InvokeComplete");
253            }
254        }
255    }
256}