1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fmt;
6
7pub const PROTOCOL_VERSION: &str = "2";
8pub const FRAME_HEADER_LEN: usize = 4;
9
10fn default_protocol_version() -> String {
11 PROTOCOL_VERSION.to_string()
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionContext {
16 pub job_id: String,
17 pub attempt: u32,
18 pub enqueue_time: DateTime<Utc>,
19 pub queue_name: String,
20 pub deadline: Option<DateTime<Utc>>,
21 #[serde(default)]
22 pub trace_context: Option<HashMap<String, String>>,
23 #[serde(default)]
24 pub correlation_context: Option<HashMap<String, String>>,
25 pub worker_id: Option<String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29#[serde(tag = "type", rename_all = "snake_case")]
30pub enum RunnerMessage {
31 Request { payload: ExecutionRequest },
32 Response { payload: ExecutionOutcome },
33 Cancel { payload: CancelRequest },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CancelRequest {
38 #[serde(default = "default_protocol_version")]
39 pub protocol_version: String,
40 pub job_id: String,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub request_id: Option<String>,
43 #[serde(default)]
44 pub hard_kill: bool,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ExecutionRequest {
49 #[serde(default = "default_protocol_version")]
50 pub protocol_version: String,
51 pub request_id: String,
52 pub job_id: String,
53 pub function_name: String,
54 #[serde(default)]
55 pub params: HashMap<String, Value>,
56 pub context: ExecutionContext,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
60#[serde(rename_all = "snake_case")]
61pub enum OutcomeStatus {
62 Success,
63 Retry,
64 Timeout,
65 Error,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExecutionOutcome {
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub job_id: Option<String>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub request_id: Option<String>,
74 pub status: OutcomeStatus,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub result: Option<Value>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub error: Option<ExecutionError>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub retry_after_seconds: Option<f64>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ExecutionError {
85 pub message: String,
86 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
87 pub error_type: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub code: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub details: Option<Value>,
92}
93
94impl ExecutionOutcome {
95 pub fn success<T: Serialize>(
96 job_id: impl Into<String>,
97 request_id: impl Into<String>,
98 result: T,
99 ) -> Self {
100 let value = serde_json::to_value(result).unwrap_or(Value::Null);
101 Self {
102 job_id: Some(job_id.into()),
103 request_id: Some(request_id.into()),
104 status: OutcomeStatus::Success,
105 result: Some(value),
106 error: None,
107 retry_after_seconds: None,
108 }
109 }
110
111 pub fn retry(
112 job_id: impl Into<String>,
113 request_id: impl Into<String>,
114 message: impl Into<String>,
115 retry_after_seconds: Option<f64>,
116 ) -> Self {
117 Self {
118 job_id: Some(job_id.into()),
119 request_id: Some(request_id.into()),
120 status: OutcomeStatus::Retry,
121 result: None,
122 error: Some(ExecutionError {
123 message: message.into(),
124 error_type: None,
125 code: None,
126 details: None,
127 }),
128 retry_after_seconds,
129 }
130 }
131
132 pub fn timeout(
133 job_id: impl Into<String>,
134 request_id: impl Into<String>,
135 message: impl Into<String>,
136 ) -> Self {
137 Self {
138 job_id: Some(job_id.into()),
139 request_id: Some(request_id.into()),
140 status: OutcomeStatus::Timeout,
141 result: None,
142 error: Some(ExecutionError {
143 message: message.into(),
144 error_type: None,
145 code: None,
146 details: None,
147 }),
148 retry_after_seconds: None,
149 }
150 }
151
152 pub fn error(
153 job_id: impl Into<String>,
154 request_id: impl Into<String>,
155 message: impl Into<String>,
156 ) -> Self {
157 Self {
158 job_id: Some(job_id.into()),
159 request_id: Some(request_id.into()),
160 status: OutcomeStatus::Error,
161 result: None,
162 error: Some(ExecutionError {
163 message: message.into(),
164 error_type: None,
165 code: None,
166 details: None,
167 }),
168 retry_after_seconds: None,
169 }
170 }
171
172 pub fn handler_not_found(
173 job_id: impl Into<String>,
174 request_id: impl Into<String>,
175 message: impl Into<String>,
176 ) -> Self {
177 Self {
178 job_id: Some(job_id.into()),
179 request_id: Some(request_id.into()),
180 status: OutcomeStatus::Error,
181 result: None,
182 error: Some(ExecutionError {
183 message: message.into(),
184 error_type: Some("handler_not_found".to_string()),
185 code: None,
186 details: None,
187 }),
188 retry_after_seconds: None,
189 }
190 }
191}
192
193#[derive(Debug)]
194pub enum FrameError {
195 InvalidLength,
196 Json(serde_json::Error),
197}
198
199impl fmt::Display for FrameError {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 match self {
202 Self::InvalidLength => write!(f, "invalid frame length"),
203 Self::Json(err) => write!(f, "json decode error: {err}"),
204 }
205 }
206}
207
208impl std::error::Error for FrameError {}
209
210impl From<serde_json::Error> for FrameError {
211 fn from(err: serde_json::Error) -> Self {
212 Self::Json(err)
213 }
214}
215
216pub fn encode_frame(message: &RunnerMessage) -> Result<Vec<u8>, FrameError> {
217 let payload = serde_json::to_vec(message)?;
218 let length = u32::try_from(payload.len()).map_err(|_| FrameError::InvalidLength)?;
219 let mut framed = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
220 framed.extend_from_slice(&length.to_be_bytes());
221 framed.extend_from_slice(&payload);
222 Ok(framed)
223}
224
225pub fn decode_frame(frame: &[u8]) -> Result<RunnerMessage, FrameError> {
226 if frame.len() < FRAME_HEADER_LEN {
227 return Err(FrameError::InvalidLength);
228 }
229 let mut header = [0u8; FRAME_HEADER_LEN];
230 header.copy_from_slice(&frame[..FRAME_HEADER_LEN]);
231 let length = u32::from_be_bytes(header) as usize;
232 if frame.len() - FRAME_HEADER_LEN != length {
233 return Err(FrameError::InvalidLength);
234 }
235 Ok(serde_json::from_slice(&frame[FRAME_HEADER_LEN..])?)
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use serde_json::json;
242
243 #[test]
244 fn execution_request_defaults_protocol_version() {
245 let payload = json!({
246 "job_id": "job-1",
247 "request_id": "req-1",
248 "function_name": "echo",
249 "params": {},
250 "context": {
251 "job_id": "job-1",
252 "attempt": 1,
253 "enqueue_time": "2024-01-01T00:00:00Z",
254 "queue_name": "default",
255 "deadline": null,
256 "trace_context": null,
257 "worker_id": null
258 }
259 });
260 let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
261 assert_eq!(request.protocol_version, PROTOCOL_VERSION);
262 }
263
264 #[test]
265 fn handler_not_found_sets_error_type() {
266 let outcome = ExecutionOutcome::handler_not_found("job-1", "req-1", "missing handler");
267 assert_eq!(outcome.status, OutcomeStatus::Error);
268 assert_eq!(
269 outcome
270 .error
271 .as_ref()
272 .and_then(|err| err.error_type.as_deref()),
273 Some("handler_not_found")
274 );
275 }
276
277 #[test]
278 fn runner_message_round_trip() {
279 let request = ExecutionRequest {
280 protocol_version: PROTOCOL_VERSION.to_string(),
281 request_id: "req-1".to_string(),
282 job_id: "job-1".to_string(),
283 function_name: "echo".to_string(),
284 params: HashMap::new(),
285 context: ExecutionContext {
286 job_id: "job-1".to_string(),
287 attempt: 1,
288 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
289 queue_name: "default".to_string(),
290 deadline: None,
291 trace_context: None,
292 correlation_context: None,
293 worker_id: None,
294 },
295 };
296 let msg = RunnerMessage::Request { payload: request };
297 let serialized = serde_json::to_string(&msg).unwrap();
298 let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
299 let RunnerMessage::Request { payload } = decoded else {
300 panic!("unexpected message type")
301 };
302 assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
303 assert_eq!(payload.request_id, "req-1");
304 }
305
306 #[test]
307 fn cancel_request_round_trip() {
308 let cancel = CancelRequest {
309 protocol_version: PROTOCOL_VERSION.to_string(),
310 job_id: "job-1".to_string(),
311 request_id: Some("req-1".to_string()),
312 hard_kill: true,
313 };
314 let msg = RunnerMessage::Cancel { payload: cancel };
315 let serialized = serde_json::to_string(&msg).unwrap();
316 let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
317 let RunnerMessage::Cancel { payload } = decoded else {
318 panic!("unexpected message type")
319 };
320 assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
321 assert_eq!(payload.request_id.as_deref(), Some("req-1"));
322 assert!(payload.hard_kill);
323 }
324
325 #[test]
326 fn frame_round_trip() {
327 let outcome = ExecutionOutcome::success("job-1", "req-1", json!({"ok": true}));
328 let message = RunnerMessage::Response { payload: outcome };
329 let framed = encode_frame(&message).expect("frame encode failed");
330 let decoded = decode_frame(&framed).expect("frame decode failed");
331 let RunnerMessage::Response { payload } = decoded else {
332 panic!("unexpected message variant")
333 };
334 assert_eq!(payload.status, OutcomeStatus::Success);
335 assert_eq!(payload.job_id.as_deref(), Some("job-1"));
336 }
337}