Skip to main content

hermod/protocol/
messages.rs

1//! Protocol messages for trace-forward protocol
2//!
3//! The protocol has three messages:
4//! - MsgTraceObjectsRequest: Acceptor requests N trace objects
5//! - MsgTraceObjectsReply: Forwarder replies with trace objects
6//! - MsgDone: Acceptor terminates the session
7
8use super::types::TraceObject;
9use pallas_codec::minicbor::{Decode, Decoder, Encode, Encoder, decode, encode};
10
11/// Request for trace objects from the acceptor
12///
13/// Wire format: `array(3)[1, blocking: bool, array(2)[0, count: u16]]`
14#[derive(Debug, Clone)]
15pub struct MsgTraceObjectsRequest {
16    /// Whether this is a blocking request
17    pub blocking: bool,
18    /// Number of trace objects requested
19    pub number_of_trace_objects: u16,
20}
21
22/// Reply with trace objects from the forwarder
23///
24/// Wire format: `array(2)[3, trace_objects: [TraceObject]]`
25///
26/// Note: For blocking requests, the list must be non-empty
27#[derive(Debug, Clone)]
28pub struct MsgTraceObjectsReply {
29    /// The trace objects being sent
30    /// For blocking requests, this must be non-empty
31    pub trace_objects: Vec<TraceObject>,
32}
33
34/// Termination message from acceptor
35///
36/// Wire format: `array(1)[2]`
37#[derive(Debug, Clone, Copy)]
38pub struct MsgDone;
39
40/// All possible messages in the protocol
41#[derive(Debug, Clone)]
42pub enum Message {
43    /// Request for trace objects
44    TraceObjectsRequest(MsgTraceObjectsRequest),
45    /// Reply with trace objects
46    TraceObjectsReply(MsgTraceObjectsReply),
47    /// Termination
48    Done,
49}
50
51/// Protocol state machine states
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum State {
54    /// Idle state - acceptor can send request or done
55    Idle,
56    /// Busy state - forwarder must send reply
57    /// The bool indicates if the request was blocking
58    Busy(bool),
59    /// Terminal state
60    Done,
61}
62
63// CBOR encoding/decoding implementations
64impl Encode<()> for Message {
65    fn encode<W: encode::Write>(
66        &self,
67        e: &mut Encoder<W>,
68        ctx: &mut (),
69    ) -> Result<(), encode::Error<W::Error>> {
70        match self {
71            Message::TraceObjectsRequest(req) => {
72                // NumberOfTraceObjects is a Haskell newtype, encoded as [constructor_index, value]
73                e.array(3)?.u16(1)?.bool(req.blocking)?;
74                e.array(2)?.u16(0)?.u16(req.number_of_trace_objects)?;
75            }
76            Message::TraceObjectsReply(reply) => {
77                e.array(2)?.u16(3)?;
78                e.array(reply.trace_objects.len() as u64)?;
79                for trace_obj in &reply.trace_objects {
80                    e.encode_with(trace_obj, ctx)?;
81                }
82            }
83            Message::Done => {
84                e.array(1)?.u16(2)?;
85            }
86        }
87        Ok(())
88    }
89}
90
91impl<'b> Decode<'b, ()> for Message {
92    fn decode(d: &mut Decoder<'b>, ctx: &mut ()) -> Result<Self, decode::Error> {
93        d.array()?;
94        let tag = d.u16()?;
95
96        match tag {
97            1 => {
98                // MsgTraceObjectsRequest
99                let blocking = d.bool()?;
100                // NumberOfTraceObjects is a Haskell newtype; Generic Serialise encodes it as
101                // array(2)[constructor_index=0, value]
102                d.array()?;
103                let _constructor_idx = d.u16()?;
104                let number_of_trace_objects = d.u16()?;
105                Ok(Message::TraceObjectsRequest(MsgTraceObjectsRequest {
106                    blocking,
107                    number_of_trace_objects,
108                }))
109            }
110            2 => {
111                // MsgDone
112                Ok(Message::Done)
113            }
114            3 => {
115                // MsgTraceObjectsReply
116                // Haskell's Serialise [a] uses indefinite-length encoding for non-empty lists
117                let mut trace_objects = Vec::new();
118                for item in d.array_iter_with::<(), TraceObject>(ctx)? {
119                    trace_objects.push(item?);
120                }
121                Ok(Message::TraceObjectsReply(MsgTraceObjectsReply {
122                    trace_objects,
123                }))
124            }
125            _ => Err(decode::Error::message("unknown message tag")),
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::protocol::types::{DetailLevel, Severity, TraceObject};
134    use chrono::Utc;
135    use pallas_codec::minicbor;
136
137    fn encode_msg(msg: &Message) -> Vec<u8> {
138        let mut buf = Vec::new();
139        minicbor::encode_with(msg, &mut buf, &mut ()).unwrap();
140        buf
141    }
142
143    fn decode_msg(buf: &[u8]) -> Message {
144        minicbor::decode_with(buf, &mut ()).unwrap()
145    }
146
147    fn make_trace() -> TraceObject {
148        TraceObject {
149            to_human: Some("hello".to_string()),
150            to_machine: r#"{"msg":"hello"}"#.to_string(),
151            to_namespace: vec!["Test".to_string(), "Message".to_string()],
152            to_severity: Severity::Info,
153            to_details: DetailLevel::DNormal,
154            to_timestamp: Utc::now(),
155            to_hostname: "localhost".to_string(),
156            to_thread_id: "42".to_string(),
157        }
158    }
159
160    // --- MsgDone ---
161
162    #[test]
163    fn done_round_trip() {
164        let buf = encode_msg(&Message::Done);
165        assert!(matches!(decode_msg(&buf), Message::Done));
166    }
167
168    #[test]
169    fn done_exact_bytes() {
170        // array(1)[2] = 0x81, 0x02
171        assert_eq!(encode_msg(&Message::Done), &[0x81, 0x02]);
172    }
173
174    // --- MsgTraceObjectsRequest ---
175
176    #[test]
177    fn request_blocking_round_trip() {
178        let req = Message::TraceObjectsRequest(MsgTraceObjectsRequest {
179            blocking: true,
180            number_of_trace_objects: 100,
181        });
182        match decode_msg(&encode_msg(&req)) {
183            Message::TraceObjectsRequest(r) => {
184                assert!(r.blocking);
185                assert_eq!(r.number_of_trace_objects, 100);
186            }
187            _ => panic!("wrong message type"),
188        }
189    }
190
191    #[test]
192    fn request_non_blocking_round_trip() {
193        let req = Message::TraceObjectsRequest(MsgTraceObjectsRequest {
194            blocking: false,
195            number_of_trace_objects: 10,
196        });
197        match decode_msg(&encode_msg(&req)) {
198            Message::TraceObjectsRequest(r) => {
199                assert!(!r.blocking);
200                assert_eq!(r.number_of_trace_objects, 10);
201            }
202            _ => panic!("wrong message type"),
203        }
204    }
205
206    #[test]
207    fn request_zero_count_round_trip() {
208        let req = Message::TraceObjectsRequest(MsgTraceObjectsRequest {
209            blocking: false,
210            number_of_trace_objects: 0,
211        });
212        match decode_msg(&encode_msg(&req)) {
213            Message::TraceObjectsRequest(r) => assert_eq!(r.number_of_trace_objects, 0),
214            _ => panic!("wrong message type"),
215        }
216    }
217
218    // --- MsgTraceObjectsReply ---
219
220    #[test]
221    fn reply_empty_round_trip() {
222        let reply = Message::TraceObjectsReply(MsgTraceObjectsReply {
223            trace_objects: vec![],
224        });
225        match decode_msg(&encode_msg(&reply)) {
226            Message::TraceObjectsReply(r) => assert!(r.trace_objects.is_empty()),
227            _ => panic!("wrong message type"),
228        }
229    }
230
231    #[test]
232    fn reply_with_trace_round_trip() {
233        let trace = make_trace();
234        let reply = Message::TraceObjectsReply(MsgTraceObjectsReply {
235            trace_objects: vec![trace.clone()],
236        });
237        match decode_msg(&encode_msg(&reply)) {
238            Message::TraceObjectsReply(r) => {
239                assert_eq!(r.trace_objects.len(), 1);
240                assert_eq!(r.trace_objects[0].to_machine, trace.to_machine);
241                assert_eq!(r.trace_objects[0].to_namespace, trace.to_namespace);
242                assert_eq!(r.trace_objects[0].to_human, trace.to_human);
243                assert_eq!(r.trace_objects[0].to_severity, trace.to_severity);
244                assert_eq!(r.trace_objects[0].to_hostname, trace.to_hostname);
245                assert_eq!(r.trace_objects[0].to_thread_id, trace.to_thread_id);
246            }
247            _ => panic!("wrong message type"),
248        }
249    }
250
251    #[test]
252    fn reply_with_multiple_traces_round_trip() {
253        let traces: Vec<TraceObject> = (0..5).map(|_| make_trace()).collect();
254        let reply = Message::TraceObjectsReply(MsgTraceObjectsReply {
255            trace_objects: traces,
256        });
257        match decode_msg(&encode_msg(&reply)) {
258            Message::TraceObjectsReply(r) => assert_eq!(r.trace_objects.len(), 5),
259            _ => panic!("wrong message type"),
260        }
261    }
262
263    #[test]
264    fn reply_trace_with_no_human_round_trip() {
265        let mut trace = make_trace();
266        trace.to_human = None;
267        let reply = Message::TraceObjectsReply(MsgTraceObjectsReply {
268            trace_objects: vec![trace],
269        });
270        match decode_msg(&encode_msg(&reply)) {
271            Message::TraceObjectsReply(r) => assert!(r.trace_objects[0].to_human.is_none()),
272            _ => panic!("wrong message type"),
273        }
274    }
275}