1use serde::{Deserialize, Serialize};
6
7use super::Capabilities;
8use crate::codec::Algorithm;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "UPPERCASE")]
13pub enum MessageType {
14 Hello,
16 Accept,
18 Reject,
20 Data,
22 Ping,
24 Pong,
26 Close,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct Message {
33 #[serde(rename = "type")]
35 pub msg_type: MessageType,
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub session_id: Option<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub payload: Option<MessagePayload>,
42 pub timestamp: u64,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(untagged)]
49pub enum MessagePayload {
50 Capabilities(Capabilities),
52 Rejection(RejectionInfo),
54 Data(DataPayload),
56 Empty {},
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RejectionInfo {
63 pub code: RejectionCode,
65 pub message: String,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
72pub enum RejectionCode {
73 VersionMismatch,
75 NoCommonAlgorithm,
77 SecurityPolicy,
79 RateLimited,
81 Unknown,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct DataPayload {
88 pub algorithm: Algorithm,
90 pub content: String,
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub original_size: Option<usize>,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub security_status: Option<SecurityStatus>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SecurityStatus {
103 pub scanned: bool,
105 pub safe: bool,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub threat_type: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub confidence: Option<f32>,
113}
114
115impl Message {
116 pub fn hello(capabilities: Capabilities) -> Self {
118 Self {
119 msg_type: MessageType::Hello,
120 session_id: None,
121 payload: Some(MessagePayload::Capabilities(capabilities)),
122 timestamp: current_timestamp(),
123 }
124 }
125
126 pub fn accept(session_id: &str, capabilities: Capabilities) -> Self {
128 Self {
129 msg_type: MessageType::Accept,
130 session_id: Some(session_id.to_string()),
131 payload: Some(MessagePayload::Capabilities(capabilities)),
132 timestamp: current_timestamp(),
133 }
134 }
135
136 pub fn reject(code: RejectionCode, message: &str) -> Self {
138 Self {
139 msg_type: MessageType::Reject,
140 session_id: None,
141 payload: Some(MessagePayload::Rejection(RejectionInfo {
142 code,
143 message: message.to_string(),
144 })),
145 timestamp: current_timestamp(),
146 }
147 }
148
149 pub fn data(session_id: &str, algorithm: Algorithm, content: String) -> Self {
151 Self {
152 msg_type: MessageType::Data,
153 session_id: Some(session_id.to_string()),
154 payload: Some(MessagePayload::Data(DataPayload {
155 algorithm,
156 content,
157 original_size: None,
158 security_status: None,
159 })),
160 timestamp: current_timestamp(),
161 }
162 }
163
164 pub fn data_with_security(
166 session_id: &str,
167 algorithm: Algorithm,
168 content: String,
169 security: SecurityStatus,
170 ) -> Self {
171 Self {
172 msg_type: MessageType::Data,
173 session_id: Some(session_id.to_string()),
174 payload: Some(MessagePayload::Data(DataPayload {
175 algorithm,
176 content,
177 original_size: None,
178 security_status: Some(security),
179 })),
180 timestamp: current_timestamp(),
181 }
182 }
183
184 pub fn ping(session_id: &str) -> Self {
186 Self {
187 msg_type: MessageType::Ping,
188 session_id: Some(session_id.to_string()),
189 payload: Some(MessagePayload::Empty {}),
190 timestamp: current_timestamp(),
191 }
192 }
193
194 pub fn pong(session_id: &str) -> Self {
196 Self {
197 msg_type: MessageType::Pong,
198 session_id: Some(session_id.to_string()),
199 payload: Some(MessagePayload::Empty {}),
200 timestamp: current_timestamp(),
201 }
202 }
203
204 pub fn close(session_id: &str) -> Self {
206 Self {
207 msg_type: MessageType::Close,
208 session_id: Some(session_id.to_string()),
209 payload: Some(MessagePayload::Empty {}),
210 timestamp: current_timestamp(),
211 }
212 }
213
214 pub fn to_json(&self) -> Result<String, serde_json::Error> {
216 serde_json::to_string(self)
217 }
218
219 pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
221 serde_json::to_string(self)
222 }
223
224 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
226 serde_json::from_str(json)
227 }
228
229 pub fn get_capabilities(&self) -> Option<&Capabilities> {
231 match &self.payload {
232 Some(MessagePayload::Capabilities(caps)) => Some(caps),
233 _ => None,
234 }
235 }
236
237 pub fn get_data(&self) -> Option<&DataPayload> {
239 match &self.payload {
240 Some(MessagePayload::Data(data)) => Some(data),
241 _ => None,
242 }
243 }
244
245 pub fn get_rejection(&self) -> Option<&RejectionInfo> {
247 match &self.payload {
248 Some(MessagePayload::Rejection(info)) => Some(info),
249 _ => None,
250 }
251 }
252}
253
254fn current_timestamp() -> u64 {
256 std::time::SystemTime::now()
257 .duration_since(std::time::UNIX_EPOCH)
258 .map(|d| d.as_millis() as u64)
259 .unwrap_or(0)
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_hello_message() {
268 let caps = Capabilities::default();
269 let msg = Message::hello(caps);
270
271 assert_eq!(msg.msg_type, MessageType::Hello);
272 assert!(msg.session_id.is_none());
273 assert!(msg.get_capabilities().is_some());
274
275 let json = msg.to_json().unwrap();
276 let parsed = Message::from_json(&json).unwrap();
277 assert_eq!(parsed.msg_type, MessageType::Hello);
278 }
279
280 #[test]
281 fn test_accept_message() {
282 let caps = Capabilities::default();
283 let msg = Message::accept("session-123", caps);
284
285 assert_eq!(msg.msg_type, MessageType::Accept);
286 assert_eq!(msg.session_id, Some("session-123".to_string()));
287 }
288
289 #[test]
290 fn test_reject_message() {
291 let msg = Message::reject(RejectionCode::VersionMismatch, "Version 4.0 not supported");
292
293 assert_eq!(msg.msg_type, MessageType::Reject);
294 let rejection = msg.get_rejection().unwrap();
295 assert_eq!(rejection.code, RejectionCode::VersionMismatch);
296 }
297
298 #[test]
299 fn test_data_message() {
300 let msg = Message::data("session-123", Algorithm::M2M, "#M2M|1|...".to_string());
301
302 assert_eq!(msg.msg_type, MessageType::Data);
303 let data = msg.get_data().unwrap();
304 assert_eq!(data.algorithm, Algorithm::M2M);
305 }
306
307 #[test]
308 fn test_serialization_roundtrip() {
309 let caps = Capabilities::new("test-agent").with_extension("custom", "value");
310 let msg = Message::hello(caps);
311
312 let json = msg.to_json().unwrap();
313 let parsed = Message::from_json(&json).unwrap();
314
315 let caps = parsed.get_capabilities().unwrap();
316 assert_eq!(caps.agent_type, "test-agent");
317 assert_eq!(caps.extensions.get("custom"), Some(&"value".to_string()));
318 }
319}