Skip to main content

rrq_protocol/
lib.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fmt;
6
7pub const PROTOCOL_VERSION: &str = "2";
8pub const FRAME_HEADER_LEN: usize = 4;
9
10fn default_protocol_version() -> String {
11    PROTOCOL_VERSION.to_string()
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionContext {
16    pub job_id: String,
17    pub attempt: u32,
18    pub enqueue_time: DateTime<Utc>,
19    pub queue_name: String,
20    pub deadline: Option<DateTime<Utc>>,
21    #[serde(default)]
22    pub trace_context: Option<HashMap<String, String>>,
23    pub worker_id: Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum RunnerMessage {
29    Request { payload: ExecutionRequest },
30    Response { payload: ExecutionOutcome },
31    Cancel { payload: CancelRequest },
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CancelRequest {
36    #[serde(default = "default_protocol_version")]
37    pub protocol_version: String,
38    pub job_id: String,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub request_id: Option<String>,
41    #[serde(default)]
42    pub hard_kill: bool,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ExecutionRequest {
47    #[serde(default = "default_protocol_version")]
48    pub protocol_version: String,
49    pub request_id: String,
50    pub job_id: String,
51    pub function_name: String,
52    #[serde(default)]
53    pub params: HashMap<String, Value>,
54    pub context: ExecutionContext,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58#[serde(rename_all = "snake_case")]
59pub enum OutcomeStatus {
60    Success,
61    Retry,
62    Timeout,
63    Error,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ExecutionOutcome {
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub job_id: Option<String>,
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub request_id: Option<String>,
72    pub status: OutcomeStatus,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub result: Option<Value>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub error: Option<ExecutionError>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub retry_after_seconds: Option<f64>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ExecutionError {
83    pub message: String,
84    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
85    pub error_type: Option<String>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub code: Option<String>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub details: Option<Value>,
90}
91
92impl ExecutionOutcome {
93    pub fn success<T: Serialize>(
94        job_id: impl Into<String>,
95        request_id: impl Into<String>,
96        result: T,
97    ) -> Self {
98        let value = serde_json::to_value(result).unwrap_or(Value::Null);
99        Self {
100            job_id: Some(job_id.into()),
101            request_id: Some(request_id.into()),
102            status: OutcomeStatus::Success,
103            result: Some(value),
104            error: None,
105            retry_after_seconds: None,
106        }
107    }
108
109    pub fn retry(
110        job_id: impl Into<String>,
111        request_id: impl Into<String>,
112        message: impl Into<String>,
113        retry_after_seconds: Option<f64>,
114    ) -> Self {
115        Self {
116            job_id: Some(job_id.into()),
117            request_id: Some(request_id.into()),
118            status: OutcomeStatus::Retry,
119            result: None,
120            error: Some(ExecutionError {
121                message: message.into(),
122                error_type: None,
123                code: None,
124                details: None,
125            }),
126            retry_after_seconds,
127        }
128    }
129
130    pub fn timeout(
131        job_id: impl Into<String>,
132        request_id: impl Into<String>,
133        message: impl Into<String>,
134    ) -> Self {
135        Self {
136            job_id: Some(job_id.into()),
137            request_id: Some(request_id.into()),
138            status: OutcomeStatus::Timeout,
139            result: None,
140            error: Some(ExecutionError {
141                message: message.into(),
142                error_type: None,
143                code: None,
144                details: None,
145            }),
146            retry_after_seconds: None,
147        }
148    }
149
150    pub fn error(
151        job_id: impl Into<String>,
152        request_id: impl Into<String>,
153        message: impl Into<String>,
154    ) -> Self {
155        Self {
156            job_id: Some(job_id.into()),
157            request_id: Some(request_id.into()),
158            status: OutcomeStatus::Error,
159            result: None,
160            error: Some(ExecutionError {
161                message: message.into(),
162                error_type: None,
163                code: None,
164                details: None,
165            }),
166            retry_after_seconds: None,
167        }
168    }
169
170    pub fn handler_not_found(
171        job_id: impl Into<String>,
172        request_id: impl Into<String>,
173        message: impl Into<String>,
174    ) -> Self {
175        Self {
176            job_id: Some(job_id.into()),
177            request_id: Some(request_id.into()),
178            status: OutcomeStatus::Error,
179            result: None,
180            error: Some(ExecutionError {
181                message: message.into(),
182                error_type: Some("handler_not_found".to_string()),
183                code: None,
184                details: None,
185            }),
186            retry_after_seconds: None,
187        }
188    }
189}
190
191#[derive(Debug)]
192pub enum FrameError {
193    InvalidLength,
194    Json(serde_json::Error),
195}
196
197impl fmt::Display for FrameError {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        match self {
200            Self::InvalidLength => write!(f, "invalid frame length"),
201            Self::Json(err) => write!(f, "json decode error: {err}"),
202        }
203    }
204}
205
206impl std::error::Error for FrameError {}
207
208impl From<serde_json::Error> for FrameError {
209    fn from(err: serde_json::Error) -> Self {
210        Self::Json(err)
211    }
212}
213
214pub fn encode_frame(message: &RunnerMessage) -> Result<Vec<u8>, FrameError> {
215    let payload = serde_json::to_vec(message)?;
216    let length = u32::try_from(payload.len()).map_err(|_| FrameError::InvalidLength)?;
217    let mut framed = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
218    framed.extend_from_slice(&length.to_be_bytes());
219    framed.extend_from_slice(&payload);
220    Ok(framed)
221}
222
223pub fn decode_frame(frame: &[u8]) -> Result<RunnerMessage, FrameError> {
224    if frame.len() < FRAME_HEADER_LEN {
225        return Err(FrameError::InvalidLength);
226    }
227    let mut header = [0u8; FRAME_HEADER_LEN];
228    header.copy_from_slice(&frame[..FRAME_HEADER_LEN]);
229    let length = u32::from_be_bytes(header) as usize;
230    if frame.len() - FRAME_HEADER_LEN != length {
231        return Err(FrameError::InvalidLength);
232    }
233    Ok(serde_json::from_slice(&frame[FRAME_HEADER_LEN..])?)
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use serde_json::json;
240
241    #[test]
242    fn execution_request_defaults_protocol_version() {
243        let payload = json!({
244            "job_id": "job-1",
245            "request_id": "req-1",
246            "function_name": "echo",
247            "params": {},
248            "context": {
249                "job_id": "job-1",
250                "attempt": 1,
251                "enqueue_time": "2024-01-01T00:00:00Z",
252                "queue_name": "default",
253                "deadline": null,
254                "trace_context": null,
255                "worker_id": null
256            }
257        });
258        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
259        assert_eq!(request.protocol_version, PROTOCOL_VERSION);
260    }
261
262    #[test]
263    fn handler_not_found_sets_error_type() {
264        let outcome = ExecutionOutcome::handler_not_found("job-1", "req-1", "missing handler");
265        assert_eq!(outcome.status, OutcomeStatus::Error);
266        assert_eq!(
267            outcome
268                .error
269                .as_ref()
270                .and_then(|err| err.error_type.as_deref()),
271            Some("handler_not_found")
272        );
273    }
274
275    #[test]
276    fn runner_message_round_trip() {
277        let request = ExecutionRequest {
278            protocol_version: PROTOCOL_VERSION.to_string(),
279            request_id: "req-1".to_string(),
280            job_id: "job-1".to_string(),
281            function_name: "echo".to_string(),
282            params: HashMap::new(),
283            context: ExecutionContext {
284                job_id: "job-1".to_string(),
285                attempt: 1,
286                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
287                queue_name: "default".to_string(),
288                deadline: None,
289                trace_context: None,
290                worker_id: None,
291            },
292        };
293        let msg = RunnerMessage::Request { payload: request };
294        let serialized = serde_json::to_string(&msg).unwrap();
295        let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
296        let RunnerMessage::Request { payload } = decoded else {
297            panic!("unexpected message type")
298        };
299        assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
300        assert_eq!(payload.request_id, "req-1");
301    }
302
303    #[test]
304    fn cancel_request_round_trip() {
305        let cancel = CancelRequest {
306            protocol_version: PROTOCOL_VERSION.to_string(),
307            job_id: "job-1".to_string(),
308            request_id: Some("req-1".to_string()),
309            hard_kill: true,
310        };
311        let msg = RunnerMessage::Cancel { payload: cancel };
312        let serialized = serde_json::to_string(&msg).unwrap();
313        let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
314        let RunnerMessage::Cancel { payload } = decoded else {
315            panic!("unexpected message type")
316        };
317        assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
318        assert_eq!(payload.request_id.as_deref(), Some("req-1"));
319        assert!(payload.hard_kill);
320    }
321
322    #[test]
323    fn frame_round_trip() {
324        let outcome = ExecutionOutcome::success("job-1", "req-1", json!({"ok": true}));
325        let message = RunnerMessage::Response { payload: outcome };
326        let framed = encode_frame(&message).expect("frame encode failed");
327        let decoded = decode_frame(&framed).expect("frame decode failed");
328        let RunnerMessage::Response { payload } = decoded else {
329            panic!("unexpected message variant")
330        };
331        assert_eq!(payload.status, OutcomeStatus::Success);
332        assert_eq!(payload.job_id.as_deref(), Some("job-1"));
333    }
334}