capnweb_core/protocol/
message.rs

1use super::expression::Expression;
2use super::ids::{ExportId, ImportId};
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5
6/// Cap'n Web protocol messages
7/// Messages are represented as JSON arrays with the message type as the first element
8#[derive(Debug, Clone, PartialEq)]
9pub enum Message {
10    /// ["push", expression] - Evaluate an expression and assign it an import ID
11    Push(Expression),
12
13    /// ["pull", importId] - Request resolution of an import
14    Pull(ImportId),
15
16    /// ["resolve", exportId, expression] - Resolve an export with a value
17    Resolve(ExportId, Expression),
18
19    /// ["reject", exportId, expression] - Reject an export with an error
20    Reject(ExportId, Expression),
21
22    /// ["release", importId, refcount] - Release an import
23    Release(ImportId, u32),
24
25    /// ["abort", expression] - Terminate the session with an error
26    Abort(Expression),
27}
28
29impl Message {
30    /// Parse a message from a JSON value
31    pub fn from_json(value: &JsonValue) -> Result<Self, MessageError> {
32        let arr = value.as_array().ok_or(MessageError::NotAnArray)?;
33
34        if arr.is_empty() {
35            return Err(MessageError::EmptyMessage);
36        }
37
38        let msg_type = arr[0].as_str().ok_or(MessageError::InvalidMessageType)?;
39
40        match msg_type {
41            "push" => {
42                if arr.len() != 2 {
43                    return Err(MessageError::InvalidPush);
44                }
45                let expr = Expression::from_json(&arr[1])?;
46                Ok(Message::Push(expr))
47            }
48
49            "pull" => {
50                if arr.len() != 2 {
51                    return Err(MessageError::InvalidPull);
52                }
53                let import_id = arr[1].as_i64().ok_or(MessageError::InvalidImportId)?;
54                Ok(Message::Pull(ImportId(import_id)))
55            }
56
57            "resolve" => {
58                if arr.len() != 3 {
59                    return Err(MessageError::InvalidResolve);
60                }
61                let export_id = arr[1].as_i64().ok_or(MessageError::InvalidExportId)?;
62                let expr = Expression::from_json(&arr[2])?;
63                Ok(Message::Resolve(ExportId(export_id), expr))
64            }
65
66            "reject" => {
67                if arr.len() != 3 {
68                    return Err(MessageError::InvalidReject);
69                }
70                let export_id = arr[1].as_i64().ok_or(MessageError::InvalidExportId)?;
71                let expr = Expression::from_json(&arr[2])?;
72                Ok(Message::Reject(ExportId(export_id), expr))
73            }
74
75            "release" => {
76                if arr.len() != 3 {
77                    return Err(MessageError::InvalidRelease);
78                }
79                let import_id = arr[1].as_i64().ok_or(MessageError::InvalidImportId)?;
80                let refcount = arr[2].as_u64().ok_or(MessageError::InvalidRefcount)? as u32;
81                Ok(Message::Release(ImportId(import_id), refcount))
82            }
83
84            "abort" => {
85                if arr.len() != 2 {
86                    return Err(MessageError::InvalidAbort);
87                }
88                let expr = Expression::from_json(&arr[1])?;
89                Ok(Message::Abort(expr))
90            }
91
92            _ => Err(MessageError::UnknownMessageType(msg_type.to_string())),
93        }
94    }
95
96    /// Convert the message to a JSON value
97    pub fn to_json(&self) -> JsonValue {
98        match self {
99            Message::Push(expr) => {
100                serde_json::json!(["push", expr.to_json()])
101            }
102            Message::Pull(import_id) => {
103                serde_json::json!(["pull", import_id.0])
104            }
105            Message::Resolve(export_id, expr) => {
106                serde_json::json!(["resolve", export_id.0, expr.to_json()])
107            }
108            Message::Reject(export_id, expr) => {
109                serde_json::json!(["reject", export_id.0, expr.to_json()])
110            }
111            Message::Release(import_id, refcount) => {
112                serde_json::json!(["release", import_id.0, refcount])
113            }
114            Message::Abort(expr) => {
115                serde_json::json!(["abort", expr.to_json()])
116            }
117        }
118    }
119}
120
121/// Custom serialization for Message
122impl Serialize for Message {
123    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
124    where
125        S: serde::Serializer,
126    {
127        self.to_json().serialize(serializer)
128    }
129}
130
131/// Custom deserialization for Message
132impl<'de> Deserialize<'de> for Message {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: serde::Deserializer<'de>,
136    {
137        let value = JsonValue::deserialize(deserializer)?;
138        Message::from_json(&value).map_err(|e| serde::de::Error::custom(e.to_string()))
139    }
140}
141
142#[derive(Debug, thiserror::Error)]
143pub enum MessageError {
144    #[error("Message must be a JSON array")]
145    NotAnArray,
146
147    #[error("Message array cannot be empty")]
148    EmptyMessage,
149
150    #[error("Message type must be a string")]
151    InvalidMessageType,
152
153    #[error("Invalid push message format")]
154    InvalidPush,
155
156    #[error("Invalid pull message format")]
157    InvalidPull,
158
159    #[error("Invalid resolve message format")]
160    InvalidResolve,
161
162    #[error("Invalid reject message format")]
163    InvalidReject,
164
165    #[error("Invalid release message format")]
166    InvalidRelease,
167
168    #[error("Invalid abort message format")]
169    InvalidAbort,
170
171    #[error("Invalid import ID")]
172    InvalidImportId,
173
174    #[error("Invalid export ID")]
175    InvalidExportId,
176
177    #[error("Invalid refcount")]
178    InvalidRefcount,
179
180    #[error("Unknown message type: {0}")]
181    UnknownMessageType(String),
182
183    #[error("Expression error: {0}")]
184    ExpressionError(#[from] super::expression::ExpressionError),
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use serde_json::json;
191
192    #[test]
193    fn test_push_message() {
194        let json = json!(["push", "hello"]);
195        let msg = Message::from_json(&json).unwrap();
196
197        match &msg {
198            Message::Push(expr) => {
199                assert_eq!(expr, &Expression::String("hello".to_string()));
200            }
201            _ => panic!("Expected Push message"),
202        }
203
204        assert_eq!(msg.to_json(), json);
205    }
206
207    #[test]
208    fn test_pull_message() {
209        let json = json!(["pull", 42]);
210        let msg = Message::from_json(&json).unwrap();
211
212        match msg {
213            Message::Pull(id) => {
214                assert_eq!(id, ImportId(42));
215            }
216            _ => panic!("Expected Pull message"),
217        }
218
219        assert_eq!(msg.to_json(), json);
220    }
221
222    #[test]
223    fn test_resolve_message() {
224        let json = json!(["resolve", -1, "result"]);
225        let msg = Message::from_json(&json).unwrap();
226
227        match &msg {
228            Message::Resolve(id, expr) => {
229                assert_eq!(id, &ExportId(-1));
230                assert_eq!(expr, &Expression::String("result".to_string()));
231            }
232            _ => panic!("Expected Resolve message"),
233        }
234
235        assert_eq!(msg.to_json(), json);
236    }
237
238    #[test]
239    fn test_serialization_roundtrip() {
240        let original = Message::Push(Expression::Number(serde_json::Number::from(42)));
241        let json = serde_json::to_value(&original).unwrap();
242        let deserialized: Message = serde_json::from_value(json.clone()).unwrap();
243
244        assert_eq!(original, deserialized);
245        assert_eq!(json, json!(["push", 42]));
246    }
247}