1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fmt;
6
7pub const PROTOCOL_VERSION: &str = "1";
8pub const FRAME_HEADER_LEN: usize = 4;
9
10fn default_protocol_version() -> String {
11 PROTOCOL_VERSION.to_string()
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionContext {
16 pub job_id: String,
17 pub attempt: u32,
18 pub enqueue_time: DateTime<Utc>,
19 pub queue_name: String,
20 pub deadline: Option<DateTime<Utc>>,
21 #[serde(default)]
22 pub trace_context: Option<HashMap<String, String>>,
23 pub worker_id: Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum RunnerMessage {
29 Request { payload: ExecutionRequest },
30 Response { payload: ExecutionOutcome },
31 Cancel { payload: CancelRequest },
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CancelRequest {
36 #[serde(default = "default_protocol_version")]
37 pub protocol_version: String,
38 pub job_id: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub request_id: Option<String>,
41 #[serde(default)]
42 pub hard_kill: bool,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ExecutionRequest {
47 #[serde(default = "default_protocol_version")]
48 pub protocol_version: String,
49 pub request_id: String,
50 pub job_id: String,
51 pub function_name: String,
52 #[serde(default)]
53 pub args: Vec<Value>,
54 #[serde(default)]
55 pub kwargs: HashMap<String, Value>,
56 pub context: ExecutionContext,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
60#[serde(rename_all = "snake_case")]
61pub enum OutcomeStatus {
62 Success,
63 Retry,
64 Timeout,
65 Error,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExecutionOutcome {
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub job_id: Option<String>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub request_id: Option<String>,
74 pub status: OutcomeStatus,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub result: Option<Value>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub error: Option<ExecutionError>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub retry_after_seconds: Option<f64>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ExecutionError {
85 pub message: String,
86 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
87 pub error_type: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub code: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub details: Option<Value>,
92}
93
94impl ExecutionOutcome {
95 pub fn success<T: Serialize>(
96 job_id: impl Into<String>,
97 request_id: impl Into<String>,
98 result: T,
99 ) -> Self {
100 let value = serde_json::to_value(result).unwrap_or(Value::Null);
101 Self {
102 job_id: Some(job_id.into()),
103 request_id: Some(request_id.into()),
104 status: OutcomeStatus::Success,
105 result: Some(value),
106 error: None,
107 retry_after_seconds: None,
108 }
109 }
110
111 pub fn retry(
112 job_id: impl Into<String>,
113 request_id: impl Into<String>,
114 message: impl Into<String>,
115 retry_after_seconds: Option<f64>,
116 ) -> Self {
117 Self {
118 job_id: Some(job_id.into()),
119 request_id: Some(request_id.into()),
120 status: OutcomeStatus::Retry,
121 result: None,
122 error: Some(ExecutionError {
123 message: message.into(),
124 error_type: None,
125 code: None,
126 details: None,
127 }),
128 retry_after_seconds,
129 }
130 }
131
132 pub fn timeout(
133 job_id: impl Into<String>,
134 request_id: impl Into<String>,
135 message: impl Into<String>,
136 ) -> Self {
137 Self {
138 job_id: Some(job_id.into()),
139 request_id: Some(request_id.into()),
140 status: OutcomeStatus::Timeout,
141 result: None,
142 error: Some(ExecutionError {
143 message: message.into(),
144 error_type: None,
145 code: None,
146 details: None,
147 }),
148 retry_after_seconds: None,
149 }
150 }
151
152 pub fn error(
153 job_id: impl Into<String>,
154 request_id: impl Into<String>,
155 message: impl Into<String>,
156 ) -> Self {
157 Self {
158 job_id: Some(job_id.into()),
159 request_id: Some(request_id.into()),
160 status: OutcomeStatus::Error,
161 result: None,
162 error: Some(ExecutionError {
163 message: message.into(),
164 error_type: None,
165 code: None,
166 details: None,
167 }),
168 retry_after_seconds: None,
169 }
170 }
171
172 pub fn handler_not_found(
173 job_id: impl Into<String>,
174 request_id: impl Into<String>,
175 message: impl Into<String>,
176 ) -> Self {
177 Self {
178 job_id: Some(job_id.into()),
179 request_id: Some(request_id.into()),
180 status: OutcomeStatus::Error,
181 result: None,
182 error: Some(ExecutionError {
183 message: message.into(),
184 error_type: Some("handler_not_found".to_string()),
185 code: None,
186 details: None,
187 }),
188 retry_after_seconds: None,
189 }
190 }
191}
192
193#[derive(Debug)]
194pub enum FrameError {
195 InvalidLength,
196 Json(serde_json::Error),
197}
198
199impl fmt::Display for FrameError {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 match self {
202 Self::InvalidLength => write!(f, "invalid frame length"),
203 Self::Json(err) => write!(f, "json decode error: {err}"),
204 }
205 }
206}
207
208impl std::error::Error for FrameError {}
209
210impl From<serde_json::Error> for FrameError {
211 fn from(err: serde_json::Error) -> Self {
212 Self::Json(err)
213 }
214}
215
216pub fn encode_frame(message: &RunnerMessage) -> Result<Vec<u8>, FrameError> {
217 let payload = serde_json::to_vec(message)?;
218 let length = u32::try_from(payload.len()).map_err(|_| FrameError::InvalidLength)?;
219 let mut framed = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
220 framed.extend_from_slice(&length.to_be_bytes());
221 framed.extend_from_slice(&payload);
222 Ok(framed)
223}
224
225pub fn decode_frame(frame: &[u8]) -> Result<RunnerMessage, FrameError> {
226 if frame.len() < FRAME_HEADER_LEN {
227 return Err(FrameError::InvalidLength);
228 }
229 let mut header = [0u8; FRAME_HEADER_LEN];
230 header.copy_from_slice(&frame[..FRAME_HEADER_LEN]);
231 let length = u32::from_be_bytes(header) as usize;
232 if frame.len() - FRAME_HEADER_LEN != length {
233 return Err(FrameError::InvalidLength);
234 }
235 Ok(serde_json::from_slice(&frame[FRAME_HEADER_LEN..])?)
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use serde_json::json;
242
243 #[test]
244 fn execution_request_defaults_protocol_version() {
245 let payload = json!({
246 "job_id": "job-1",
247 "request_id": "req-1",
248 "function_name": "echo",
249 "args": [],
250 "kwargs": {},
251 "context": {
252 "job_id": "job-1",
253 "attempt": 1,
254 "enqueue_time": "2024-01-01T00:00:00Z",
255 "queue_name": "default",
256 "deadline": null,
257 "trace_context": null,
258 "worker_id": null
259 }
260 });
261 let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
262 assert_eq!(request.protocol_version, PROTOCOL_VERSION);
263 }
264
265 #[test]
266 fn handler_not_found_sets_error_type() {
267 let outcome = ExecutionOutcome::handler_not_found("job-1", "req-1", "missing handler");
268 assert_eq!(outcome.status, OutcomeStatus::Error);
269 assert_eq!(
270 outcome
271 .error
272 .as_ref()
273 .and_then(|err| err.error_type.as_deref()),
274 Some("handler_not_found")
275 );
276 }
277
278 #[test]
279 fn runner_message_round_trip() {
280 let request = ExecutionRequest {
281 protocol_version: PROTOCOL_VERSION.to_string(),
282 request_id: "req-1".to_string(),
283 job_id: "job-1".to_string(),
284 function_name: "echo".to_string(),
285 args: Vec::new(),
286 kwargs: HashMap::new(),
287 context: ExecutionContext {
288 job_id: "job-1".to_string(),
289 attempt: 1,
290 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
291 queue_name: "default".to_string(),
292 deadline: None,
293 trace_context: None,
294 worker_id: None,
295 },
296 };
297 let msg = RunnerMessage::Request { payload: request };
298 let serialized = serde_json::to_string(&msg).unwrap();
299 let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
300 let RunnerMessage::Request { payload } = decoded else {
301 panic!("unexpected message type")
302 };
303 assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
304 assert_eq!(payload.request_id, "req-1");
305 }
306
307 #[test]
308 fn cancel_request_round_trip() {
309 let cancel = CancelRequest {
310 protocol_version: PROTOCOL_VERSION.to_string(),
311 job_id: "job-1".to_string(),
312 request_id: Some("req-1".to_string()),
313 hard_kill: true,
314 };
315 let msg = RunnerMessage::Cancel { payload: cancel };
316 let serialized = serde_json::to_string(&msg).unwrap();
317 let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
318 let RunnerMessage::Cancel { payload } = decoded else {
319 panic!("unexpected message type")
320 };
321 assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
322 assert_eq!(payload.request_id.as_deref(), Some("req-1"));
323 assert!(payload.hard_kill);
324 }
325
326 #[test]
327 fn frame_round_trip() {
328 let outcome = ExecutionOutcome::success("job-1", "req-1", json!({"ok": true}));
329 let message = RunnerMessage::Response { payload: outcome };
330 let framed = encode_frame(&message).expect("frame encode failed");
331 let decoded = decode_frame(&framed).expect("frame decode failed");
332 let RunnerMessage::Response { payload } = decoded else {
333 panic!("unexpected message variant")
334 };
335 assert_eq!(payload.status, OutcomeStatus::Success);
336 assert_eq!(payload.job_id.as_deref(), Some("job-1"));
337 }
338}