Skip to main content

aws_ssm_bridge/
protocol.rs

1//! SSM Protocol message types and framing
2//!
3//! This module implements the AWS SSM Session Manager WebSocket protocol,
4//! including message framing, sequencing, and channel multiplexing.
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use serde::{Deserialize, Serialize};
8use std::fmt;
9use uuid::Uuid;
10
11use crate::errors::{ProtocolError, Result};
12
13/// Protocol version
14pub const PROTOCOL_VERSION: &str = "1.0";
15
16/// Message type identifiers
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum MessageType {
20    /// Input data (stdin)
21    InputStreamData,
22    /// Output data (stdout)
23    OutputStreamData,
24    /// Acknowledge message
25    Acknowledge,
26    /// Channel closed
27    ChannelClosed,
28    /// Start publication
29    StartPublication,
30    /// Pause publication
31    PausePublication,
32}
33
34impl MessageType {
35    /// Returns the string representation of the message type
36    pub fn as_str(&self) -> &'static str {
37        match self {
38            MessageType::InputStreamData => "input_stream_data",
39            MessageType::OutputStreamData => "output_stream_data",
40            MessageType::Acknowledge => "acknowledge",
41            MessageType::ChannelClosed => "channel_closed",
42            MessageType::StartPublication => "start_publication",
43            MessageType::PausePublication => "pause_publication",
44        }
45    }
46}
47
48impl fmt::Display for MessageType {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "{}", self.as_str())
51    }
52}
53
54/// Session type
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "PascalCase")]
57#[derive(Default)]
58pub enum SessionType {
59    /// Standard shell session
60    #[serde(rename = "Standard_Stream")]
61    #[default]
62    StandardStream,
63    /// Port forwarding session
64    #[serde(rename = "Port")]
65    Port,
66    /// Interactive commands (AWS-StartInteractiveCommand)
67    #[serde(rename = "InteractiveCommands")]
68    InteractiveCommands,
69}
70
71/// Channel type for multiplexed streams
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub enum ChannelType {
74    /// Standard input
75    Stdin = 0,
76    /// Standard output
77    Stdout = 1,
78    /// Standard error
79    Stderr = 2,
80    /// Control channel
81    Control = 3,
82}
83
84impl TryFrom<u32> for ChannelType {
85    type Error = crate::errors::Error;
86
87    fn try_from(value: u32) -> Result<Self> {
88        match value {
89            0 => Ok(ChannelType::Stdin),
90            1 => Ok(ChannelType::Stdout),
91            2 => Ok(ChannelType::Stderr),
92            3 => Ok(ChannelType::Control),
93            _ => Err(
94                ProtocolError::InvalidMessage(format!("Invalid channel type: {}", value)).into(),
95            ),
96        }
97    }
98}
99
100/// Agent message sent over WebSocket
101#[derive(Debug, Clone, Serialize, Deserialize)]
102#[serde(rename_all = "PascalCase")]
103pub struct AgentMessage {
104    /// Message type
105    pub message_type: MessageType,
106    /// Schema version
107    pub schema_version: u32,
108    /// Message creation timestamp (Unix milliseconds)
109    pub created_date: u64,
110    /// Sequence number
111    pub sequence_number: i64,
112    /// Flags
113    pub flags: u64,
114    /// Message ID
115    pub message_id: Uuid,
116    /// Payload digest (SHA-256)
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub payload_digest: Option<String>,
119    /// Payload type
120    pub payload_type: u32,
121    /// Payload length
122    pub payload_length: u32,
123    /// Payload data
124    #[serde(skip)]
125    pub payload: Bytes,
126}
127
128impl AgentMessage {
129    /// Create a new agent message
130    pub fn new(message_type: MessageType, sequence_number: i64, payload: Bytes) -> Self {
131        Self {
132            message_type,
133            schema_version: 1,
134            // Use unwrap_or_default to handle edge case of system time before UNIX epoch
135            created_date: std::time::SystemTime::now()
136                .duration_since(std::time::UNIX_EPOCH)
137                .unwrap_or_default()
138                .as_millis() as u64,
139            sequence_number,
140            flags: 0,
141            message_id: Uuid::new_v4(),
142            payload_digest: None,
143            payload_type: 1,
144            payload_length: payload.len() as u32,
145            payload,
146        }
147    }
148
149    /// Serialize message to bytes (header + payload)
150    pub fn to_bytes(&self) -> Result<Bytes> {
151        // Serialize header to JSON
152        let header_json = serde_json::to_vec(self)?;
153        let header_len = header_json.len() as u32;
154
155        // Build frame: [header_len(4)][header][payload]
156        let total_len = 4 + header_len as usize + self.payload.len();
157        let mut buf = BytesMut::with_capacity(total_len);
158
159        buf.put_u32(header_len);
160        buf.put_slice(&header_json);
161        buf.put_slice(&self.payload);
162
163        Ok(buf.freeze())
164    }
165
166    /// Deserialize message from bytes
167    pub fn from_bytes(mut data: Bytes) -> Result<Self> {
168        if data.len() < 4 {
169            return Err(ProtocolError::Framing("Message too short".to_string()).into());
170        }
171
172        // Read header length with validation
173        let header_len = data.get_u32() as usize;
174
175        // Security: Validate header length is reasonable (max 1MB for JSON header)
176        const MAX_HEADER_SIZE: usize = 1024 * 1024;
177        if header_len > MAX_HEADER_SIZE {
178            return Err(ProtocolError::Framing(format!(
179                "Header length {} exceeds maximum {}",
180                header_len, MAX_HEADER_SIZE
181            ))
182            .into());
183        }
184
185        if data.len() < header_len {
186            return Err(ProtocolError::Framing(format!(
187                "Incomplete header: expected {}, got {}",
188                header_len,
189                data.len()
190            ))
191            .into());
192        }
193
194        // Parse header
195        let header_bytes = data.split_to(header_len);
196        let mut msg: AgentMessage = serde_json::from_slice(&header_bytes)?;
197
198        // Validate payload length
199        if data.len() != msg.payload_length as usize {
200            return Err(ProtocolError::Framing(format!(
201                "Payload length mismatch: expected {}, got {}",
202                msg.payload_length,
203                data.len()
204            ))
205            .into());
206        }
207
208        // Attach payload
209        msg.payload = data;
210
211        Ok(msg)
212    }
213}
214
215/// Payload for input/output stream data
216#[derive(Debug, Clone, Serialize, Deserialize)]
217#[serde(rename_all = "PascalCase")]
218pub struct StreamDataPayload {
219    /// Data content (base64 encoded)
220    pub data: String,
221}
222
223impl StreamDataPayload {
224    /// Create new stream data payload
225    pub fn new(data: &[u8]) -> Self {
226        Self {
227            data: base64::Engine::encode(&base64::engine::general_purpose::STANDARD, data),
228        }
229    }
230
231    /// Decode data from base64
232    pub fn decode(&self) -> Result<Vec<u8>> {
233        base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &self.data).map_err(
234            |e| ProtocolError::InvalidMessage(format!("Base64 decode error: {}", e)).into(),
235        )
236    }
237}
238
239/// Acknowledge payload
240#[derive(Debug, Clone, Serialize, Deserialize)]
241#[serde(rename_all = "PascalCase")]
242pub struct AcknowledgePayload {
243    /// Message ID being acknowledged
244    pub acknowledged_message_id: Uuid,
245    /// Sequence number
246    pub sequence_number: i64,
247    /// Is sequence number valid
248    pub is_sequential_message: bool,
249}
250
251/// Channel closed payload
252#[derive(Debug, Clone, Serialize, Deserialize)]
253#[serde(rename_all = "PascalCase")]
254pub struct ChannelClosedPayload {
255    /// Output from the session (base64 encoded)
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub output: Option<String>,
258    /// Session ID
259    pub session_id: String,
260    /// Message ID
261    pub message_id: Uuid,
262    /// Exit code
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub exit_code: Option<i32>,
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_message_roundtrip() {
273        let payload = Bytes::from("test payload");
274        let msg = AgentMessage::new(MessageType::InputStreamData, 1, payload.clone());
275
276        let bytes = msg.to_bytes().unwrap();
277        let decoded = AgentMessage::from_bytes(bytes).unwrap();
278
279        assert_eq!(msg.message_type, decoded.message_type);
280        assert_eq!(msg.sequence_number, decoded.sequence_number);
281        assert_eq!(msg.payload, decoded.payload);
282    }
283
284    #[test]
285    fn test_stream_data_payload() {
286        let data = b"hello world";
287        let payload = StreamDataPayload::new(data);
288        let decoded = payload.decode().unwrap();
289
290        assert_eq!(data, decoded.as_slice());
291    }
292
293    #[test]
294    fn test_channel_type_conversion() {
295        assert_eq!(ChannelType::try_from(0).unwrap(), ChannelType::Stdin);
296        assert_eq!(ChannelType::try_from(1).unwrap(), ChannelType::Stdout);
297        assert!(ChannelType::try_from(99).is_err());
298    }
299}