1use bytes::{Buf, BufMut, Bytes, BytesMut};
7use serde::{Deserialize, Serialize};
8use std::fmt;
9use uuid::Uuid;
10
11use crate::errors::{ProtocolError, Result};
12
13pub const PROTOCOL_VERSION: &str = "1.0";
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum MessageType {
20 InputStreamData,
22 OutputStreamData,
24 Acknowledge,
26 ChannelClosed,
28 StartPublication,
30 PausePublication,
32}
33
34impl MessageType {
35 pub fn as_str(&self) -> &'static str {
37 match self {
38 MessageType::InputStreamData => "input_stream_data",
39 MessageType::OutputStreamData => "output_stream_data",
40 MessageType::Acknowledge => "acknowledge",
41 MessageType::ChannelClosed => "channel_closed",
42 MessageType::StartPublication => "start_publication",
43 MessageType::PausePublication => "pause_publication",
44 }
45 }
46}
47
48impl fmt::Display for MessageType {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(f, "{}", self.as_str())
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "PascalCase")]
57#[derive(Default)]
58pub enum SessionType {
59 #[serde(rename = "Standard_Stream")]
61 #[default]
62 StandardStream,
63 #[serde(rename = "Port")]
65 Port,
66 #[serde(rename = "InteractiveCommands")]
68 InteractiveCommands,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub enum ChannelType {
74 Stdin = 0,
76 Stdout = 1,
78 Stderr = 2,
80 Control = 3,
82}
83
84impl TryFrom<u32> for ChannelType {
85 type Error = crate::errors::Error;
86
87 fn try_from(value: u32) -> Result<Self> {
88 match value {
89 0 => Ok(ChannelType::Stdin),
90 1 => Ok(ChannelType::Stdout),
91 2 => Ok(ChannelType::Stderr),
92 3 => Ok(ChannelType::Control),
93 _ => Err(
94 ProtocolError::InvalidMessage(format!("Invalid channel type: {}", value)).into(),
95 ),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102#[serde(rename_all = "PascalCase")]
103pub struct AgentMessage {
104 pub message_type: MessageType,
106 pub schema_version: u32,
108 pub created_date: u64,
110 pub sequence_number: i64,
112 pub flags: u64,
114 pub message_id: Uuid,
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub payload_digest: Option<String>,
119 pub payload_type: u32,
121 pub payload_length: u32,
123 #[serde(skip)]
125 pub payload: Bytes,
126}
127
128impl AgentMessage {
129 pub fn new(message_type: MessageType, sequence_number: i64, payload: Bytes) -> Self {
131 Self {
132 message_type,
133 schema_version: 1,
134 created_date: std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .unwrap_or_default()
138 .as_millis() as u64,
139 sequence_number,
140 flags: 0,
141 message_id: Uuid::new_v4(),
142 payload_digest: None,
143 payload_type: 1,
144 payload_length: payload.len() as u32,
145 payload,
146 }
147 }
148
149 pub fn to_bytes(&self) -> Result<Bytes> {
151 let header_json = serde_json::to_vec(self)?;
153 let header_len = header_json.len() as u32;
154
155 let total_len = 4 + header_len as usize + self.payload.len();
157 let mut buf = BytesMut::with_capacity(total_len);
158
159 buf.put_u32(header_len);
160 buf.put_slice(&header_json);
161 buf.put_slice(&self.payload);
162
163 Ok(buf.freeze())
164 }
165
166 pub fn from_bytes(mut data: Bytes) -> Result<Self> {
168 if data.len() < 4 {
169 return Err(ProtocolError::Framing("Message too short".to_string()).into());
170 }
171
172 let header_len = data.get_u32() as usize;
174
175 const MAX_HEADER_SIZE: usize = 1024 * 1024;
177 if header_len > MAX_HEADER_SIZE {
178 return Err(ProtocolError::Framing(format!(
179 "Header length {} exceeds maximum {}",
180 header_len, MAX_HEADER_SIZE
181 ))
182 .into());
183 }
184
185 if data.len() < header_len {
186 return Err(ProtocolError::Framing(format!(
187 "Incomplete header: expected {}, got {}",
188 header_len,
189 data.len()
190 ))
191 .into());
192 }
193
194 let header_bytes = data.split_to(header_len);
196 let mut msg: AgentMessage = serde_json::from_slice(&header_bytes)?;
197
198 if data.len() != msg.payload_length as usize {
200 return Err(ProtocolError::Framing(format!(
201 "Payload length mismatch: expected {}, got {}",
202 msg.payload_length,
203 data.len()
204 ))
205 .into());
206 }
207
208 msg.payload = data;
210
211 Ok(msg)
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217#[serde(rename_all = "PascalCase")]
218pub struct StreamDataPayload {
219 pub data: String,
221}
222
223impl StreamDataPayload {
224 pub fn new(data: &[u8]) -> Self {
226 Self {
227 data: base64::Engine::encode(&base64::engine::general_purpose::STANDARD, data),
228 }
229 }
230
231 pub fn decode(&self) -> Result<Vec<u8>> {
233 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &self.data).map_err(
234 |e| ProtocolError::InvalidMessage(format!("Base64 decode error: {}", e)).into(),
235 )
236 }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241#[serde(rename_all = "PascalCase")]
242pub struct AcknowledgePayload {
243 pub acknowledged_message_id: Uuid,
245 pub sequence_number: i64,
247 pub is_sequential_message: bool,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253#[serde(rename_all = "PascalCase")]
254pub struct ChannelClosedPayload {
255 #[serde(skip_serializing_if = "Option::is_none")]
257 pub output: Option<String>,
258 pub session_id: String,
260 pub message_id: Uuid,
262 #[serde(skip_serializing_if = "Option::is_none")]
264 pub exit_code: Option<i32>,
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_message_roundtrip() {
273 let payload = Bytes::from("test payload");
274 let msg = AgentMessage::new(MessageType::InputStreamData, 1, payload.clone());
275
276 let bytes = msg.to_bytes().unwrap();
277 let decoded = AgentMessage::from_bytes(bytes).unwrap();
278
279 assert_eq!(msg.message_type, decoded.message_type);
280 assert_eq!(msg.sequence_number, decoded.sequence_number);
281 assert_eq!(msg.payload, decoded.payload);
282 }
283
284 #[test]
285 fn test_stream_data_payload() {
286 let data = b"hello world";
287 let payload = StreamDataPayload::new(data);
288 let decoded = payload.decode().unwrap();
289
290 assert_eq!(data, decoded.as_slice());
291 }
292
293 #[test]
294 fn test_channel_type_conversion() {
295 assert_eq!(ChannelType::try_from(0).unwrap(), ChannelType::Stdin);
296 assert_eq!(ChannelType::try_from(1).unwrap(), ChannelType::Stdout);
297 assert!(ChannelType::try_from(99).is_err());
298 }
299}