Skip to main content

rrq_protocol/
lib.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fmt;
6
7pub const PROTOCOL_VERSION: &str = "2";
8pub const FRAME_HEADER_LEN: usize = 4;
9
10fn default_protocol_version() -> String {
11    PROTOCOL_VERSION.to_string()
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionContext {
16    pub job_id: String,
17    pub attempt: u32,
18    pub enqueue_time: DateTime<Utc>,
19    pub queue_name: String,
20    pub deadline: Option<DateTime<Utc>>,
21    #[serde(default)]
22    pub trace_context: Option<HashMap<String, String>>,
23    #[serde(default)]
24    pub correlation_context: Option<HashMap<String, String>>,
25    pub worker_id: Option<String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29#[serde(tag = "type", rename_all = "snake_case")]
30pub enum RunnerMessage {
31    Request { payload: ExecutionRequest },
32    Response { payload: ExecutionOutcome },
33    Cancel { payload: CancelRequest },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CancelRequest {
38    #[serde(default = "default_protocol_version")]
39    pub protocol_version: String,
40    pub job_id: String,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub request_id: Option<String>,
43    #[serde(default)]
44    pub hard_kill: bool,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ExecutionRequest {
49    #[serde(default = "default_protocol_version")]
50    pub protocol_version: String,
51    pub request_id: String,
52    pub job_id: String,
53    pub function_name: String,
54    #[serde(default)]
55    pub params: HashMap<String, Value>,
56    pub context: ExecutionContext,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
60#[serde(rename_all = "snake_case")]
61pub enum OutcomeStatus {
62    Success,
63    Retry,
64    Timeout,
65    Error,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExecutionOutcome {
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub job_id: Option<String>,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub request_id: Option<String>,
74    pub status: OutcomeStatus,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub result: Option<Value>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub error: Option<ExecutionError>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub retry_after_seconds: Option<f64>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ExecutionError {
85    pub message: String,
86    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
87    pub error_type: Option<String>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub code: Option<String>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub details: Option<Value>,
92}
93
94impl ExecutionOutcome {
95    pub fn success<T: Serialize>(
96        job_id: impl Into<String>,
97        request_id: impl Into<String>,
98        result: T,
99    ) -> Self {
100        let value = serde_json::to_value(result).unwrap_or(Value::Null);
101        Self {
102            job_id: Some(job_id.into()),
103            request_id: Some(request_id.into()),
104            status: OutcomeStatus::Success,
105            result: Some(value),
106            error: None,
107            retry_after_seconds: None,
108        }
109    }
110
111    pub fn retry(
112        job_id: impl Into<String>,
113        request_id: impl Into<String>,
114        message: impl Into<String>,
115        retry_after_seconds: Option<f64>,
116    ) -> Self {
117        Self {
118            job_id: Some(job_id.into()),
119            request_id: Some(request_id.into()),
120            status: OutcomeStatus::Retry,
121            result: None,
122            error: Some(ExecutionError {
123                message: message.into(),
124                error_type: None,
125                code: None,
126                details: None,
127            }),
128            retry_after_seconds,
129        }
130    }
131
132    pub fn timeout(
133        job_id: impl Into<String>,
134        request_id: impl Into<String>,
135        message: impl Into<String>,
136    ) -> Self {
137        Self {
138            job_id: Some(job_id.into()),
139            request_id: Some(request_id.into()),
140            status: OutcomeStatus::Timeout,
141            result: None,
142            error: Some(ExecutionError {
143                message: message.into(),
144                error_type: None,
145                code: None,
146                details: None,
147            }),
148            retry_after_seconds: None,
149        }
150    }
151
152    pub fn error(
153        job_id: impl Into<String>,
154        request_id: impl Into<String>,
155        message: impl Into<String>,
156    ) -> Self {
157        Self {
158            job_id: Some(job_id.into()),
159            request_id: Some(request_id.into()),
160            status: OutcomeStatus::Error,
161            result: None,
162            error: Some(ExecutionError {
163                message: message.into(),
164                error_type: None,
165                code: None,
166                details: None,
167            }),
168            retry_after_seconds: None,
169        }
170    }
171
172    pub fn handler_not_found(
173        job_id: impl Into<String>,
174        request_id: impl Into<String>,
175        message: impl Into<String>,
176    ) -> Self {
177        Self {
178            job_id: Some(job_id.into()),
179            request_id: Some(request_id.into()),
180            status: OutcomeStatus::Error,
181            result: None,
182            error: Some(ExecutionError {
183                message: message.into(),
184                error_type: Some("handler_not_found".to_string()),
185                code: None,
186                details: None,
187            }),
188            retry_after_seconds: None,
189        }
190    }
191}
192
193#[derive(Debug)]
194pub enum FrameError {
195    InvalidLength,
196    Json(serde_json::Error),
197}
198
199impl fmt::Display for FrameError {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        match self {
202            Self::InvalidLength => write!(f, "invalid frame length"),
203            Self::Json(err) => write!(f, "json decode error: {err}"),
204        }
205    }
206}
207
208impl std::error::Error for FrameError {}
209
210impl From<serde_json::Error> for FrameError {
211    fn from(err: serde_json::Error) -> Self {
212        Self::Json(err)
213    }
214}
215
216pub fn encode_frame(message: &RunnerMessage) -> Result<Vec<u8>, FrameError> {
217    let payload = serde_json::to_vec(message)?;
218    let length = u32::try_from(payload.len()).map_err(|_| FrameError::InvalidLength)?;
219    let mut framed = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
220    framed.extend_from_slice(&length.to_be_bytes());
221    framed.extend_from_slice(&payload);
222    Ok(framed)
223}
224
225pub fn decode_frame(frame: &[u8]) -> Result<RunnerMessage, FrameError> {
226    if frame.len() < FRAME_HEADER_LEN {
227        return Err(FrameError::InvalidLength);
228    }
229    let mut header = [0u8; FRAME_HEADER_LEN];
230    header.copy_from_slice(&frame[..FRAME_HEADER_LEN]);
231    let length = u32::from_be_bytes(header) as usize;
232    if frame.len() - FRAME_HEADER_LEN != length {
233        return Err(FrameError::InvalidLength);
234    }
235    Ok(serde_json::from_slice(&frame[FRAME_HEADER_LEN..])?)
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use serde_json::json;
242
243    #[test]
244    fn execution_request_defaults_protocol_version() {
245        let payload = json!({
246            "job_id": "job-1",
247            "request_id": "req-1",
248            "function_name": "echo",
249            "params": {},
250            "context": {
251                "job_id": "job-1",
252                "attempt": 1,
253                "enqueue_time": "2024-01-01T00:00:00Z",
254                "queue_name": "default",
255                "deadline": null,
256                "trace_context": null,
257                "worker_id": null
258            }
259        });
260        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
261        assert_eq!(request.protocol_version, PROTOCOL_VERSION);
262    }
263
264    #[test]
265    fn handler_not_found_sets_error_type() {
266        let outcome = ExecutionOutcome::handler_not_found("job-1", "req-1", "missing handler");
267        assert_eq!(outcome.status, OutcomeStatus::Error);
268        assert_eq!(
269            outcome
270                .error
271                .as_ref()
272                .and_then(|err| err.error_type.as_deref()),
273            Some("handler_not_found")
274        );
275    }
276
277    #[test]
278    fn runner_message_round_trip() {
279        let request = ExecutionRequest {
280            protocol_version: PROTOCOL_VERSION.to_string(),
281            request_id: "req-1".to_string(),
282            job_id: "job-1".to_string(),
283            function_name: "echo".to_string(),
284            params: HashMap::new(),
285            context: ExecutionContext {
286                job_id: "job-1".to_string(),
287                attempt: 1,
288                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
289                queue_name: "default".to_string(),
290                deadline: None,
291                trace_context: None,
292                correlation_context: None,
293                worker_id: None,
294            },
295        };
296        let msg = RunnerMessage::Request { payload: request };
297        let serialized = serde_json::to_string(&msg).unwrap();
298        let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
299        let RunnerMessage::Request { payload } = decoded else {
300            panic!("unexpected message type")
301        };
302        assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
303        assert_eq!(payload.request_id, "req-1");
304    }
305
306    #[test]
307    fn cancel_request_round_trip() {
308        let cancel = CancelRequest {
309            protocol_version: PROTOCOL_VERSION.to_string(),
310            job_id: "job-1".to_string(),
311            request_id: Some("req-1".to_string()),
312            hard_kill: true,
313        };
314        let msg = RunnerMessage::Cancel { payload: cancel };
315        let serialized = serde_json::to_string(&msg).unwrap();
316        let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
317        let RunnerMessage::Cancel { payload } = decoded else {
318            panic!("unexpected message type")
319        };
320        assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
321        assert_eq!(payload.request_id.as_deref(), Some("req-1"));
322        assert!(payload.hard_kill);
323    }
324
325    #[test]
326    fn frame_round_trip() {
327        let outcome = ExecutionOutcome::success("job-1", "req-1", json!({"ok": true}));
328        let message = RunnerMessage::Response { payload: outcome };
329        let framed = encode_frame(&message).expect("frame encode failed");
330        let decoded = decode_frame(&framed).expect("frame decode failed");
331        let RunnerMessage::Response { payload } = decoded else {
332            panic!("unexpected message variant")
333        };
334        assert_eq!(payload.status, OutcomeStatus::Success);
335        assert_eq!(payload.job_id.as_deref(), Some("job-1"));
336    }
337}