1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fmt;
6
7pub const PROTOCOL_VERSION: &str = "2";
8pub const FRAME_HEADER_LEN: usize = 4;
9
10fn default_protocol_version() -> String {
11 PROTOCOL_VERSION.to_string()
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionContext {
16 pub job_id: String,
17 pub attempt: u32,
18 pub enqueue_time: DateTime<Utc>,
19 pub queue_name: String,
20 pub deadline: Option<DateTime<Utc>>,
21 #[serde(default)]
22 pub trace_context: Option<HashMap<String, String>>,
23 pub worker_id: Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum RunnerMessage {
29 Request { payload: ExecutionRequest },
30 Response { payload: ExecutionOutcome },
31 Cancel { payload: CancelRequest },
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CancelRequest {
36 #[serde(default = "default_protocol_version")]
37 pub protocol_version: String,
38 pub job_id: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub request_id: Option<String>,
41 #[serde(default)]
42 pub hard_kill: bool,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ExecutionRequest {
47 #[serde(default = "default_protocol_version")]
48 pub protocol_version: String,
49 pub request_id: String,
50 pub job_id: String,
51 pub function_name: String,
52 #[serde(default)]
53 pub params: HashMap<String, Value>,
54 pub context: ExecutionContext,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58#[serde(rename_all = "snake_case")]
59pub enum OutcomeStatus {
60 Success,
61 Retry,
62 Timeout,
63 Error,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ExecutionOutcome {
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub job_id: Option<String>,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub request_id: Option<String>,
72 pub status: OutcomeStatus,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub result: Option<Value>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub error: Option<ExecutionError>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub retry_after_seconds: Option<f64>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ExecutionError {
83 pub message: String,
84 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
85 pub error_type: Option<String>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub code: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub details: Option<Value>,
90}
91
92impl ExecutionOutcome {
93 pub fn success<T: Serialize>(
94 job_id: impl Into<String>,
95 request_id: impl Into<String>,
96 result: T,
97 ) -> Self {
98 let value = serde_json::to_value(result).unwrap_or(Value::Null);
99 Self {
100 job_id: Some(job_id.into()),
101 request_id: Some(request_id.into()),
102 status: OutcomeStatus::Success,
103 result: Some(value),
104 error: None,
105 retry_after_seconds: None,
106 }
107 }
108
109 pub fn retry(
110 job_id: impl Into<String>,
111 request_id: impl Into<String>,
112 message: impl Into<String>,
113 retry_after_seconds: Option<f64>,
114 ) -> Self {
115 Self {
116 job_id: Some(job_id.into()),
117 request_id: Some(request_id.into()),
118 status: OutcomeStatus::Retry,
119 result: None,
120 error: Some(ExecutionError {
121 message: message.into(),
122 error_type: None,
123 code: None,
124 details: None,
125 }),
126 retry_after_seconds,
127 }
128 }
129
130 pub fn timeout(
131 job_id: impl Into<String>,
132 request_id: impl Into<String>,
133 message: impl Into<String>,
134 ) -> Self {
135 Self {
136 job_id: Some(job_id.into()),
137 request_id: Some(request_id.into()),
138 status: OutcomeStatus::Timeout,
139 result: None,
140 error: Some(ExecutionError {
141 message: message.into(),
142 error_type: None,
143 code: None,
144 details: None,
145 }),
146 retry_after_seconds: None,
147 }
148 }
149
150 pub fn error(
151 job_id: impl Into<String>,
152 request_id: impl Into<String>,
153 message: impl Into<String>,
154 ) -> Self {
155 Self {
156 job_id: Some(job_id.into()),
157 request_id: Some(request_id.into()),
158 status: OutcomeStatus::Error,
159 result: None,
160 error: Some(ExecutionError {
161 message: message.into(),
162 error_type: None,
163 code: None,
164 details: None,
165 }),
166 retry_after_seconds: None,
167 }
168 }
169
170 pub fn handler_not_found(
171 job_id: impl Into<String>,
172 request_id: impl Into<String>,
173 message: impl Into<String>,
174 ) -> Self {
175 Self {
176 job_id: Some(job_id.into()),
177 request_id: Some(request_id.into()),
178 status: OutcomeStatus::Error,
179 result: None,
180 error: Some(ExecutionError {
181 message: message.into(),
182 error_type: Some("handler_not_found".to_string()),
183 code: None,
184 details: None,
185 }),
186 retry_after_seconds: None,
187 }
188 }
189}
190
191#[derive(Debug)]
192pub enum FrameError {
193 InvalidLength,
194 Json(serde_json::Error),
195}
196
197impl fmt::Display for FrameError {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match self {
200 Self::InvalidLength => write!(f, "invalid frame length"),
201 Self::Json(err) => write!(f, "json decode error: {err}"),
202 }
203 }
204}
205
206impl std::error::Error for FrameError {}
207
208impl From<serde_json::Error> for FrameError {
209 fn from(err: serde_json::Error) -> Self {
210 Self::Json(err)
211 }
212}
213
214pub fn encode_frame(message: &RunnerMessage) -> Result<Vec<u8>, FrameError> {
215 let payload = serde_json::to_vec(message)?;
216 let length = u32::try_from(payload.len()).map_err(|_| FrameError::InvalidLength)?;
217 let mut framed = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
218 framed.extend_from_slice(&length.to_be_bytes());
219 framed.extend_from_slice(&payload);
220 Ok(framed)
221}
222
223pub fn decode_frame(frame: &[u8]) -> Result<RunnerMessage, FrameError> {
224 if frame.len() < FRAME_HEADER_LEN {
225 return Err(FrameError::InvalidLength);
226 }
227 let mut header = [0u8; FRAME_HEADER_LEN];
228 header.copy_from_slice(&frame[..FRAME_HEADER_LEN]);
229 let length = u32::from_be_bytes(header) as usize;
230 if frame.len() - FRAME_HEADER_LEN != length {
231 return Err(FrameError::InvalidLength);
232 }
233 Ok(serde_json::from_slice(&frame[FRAME_HEADER_LEN..])?)
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use serde_json::json;
240
241 #[test]
242 fn execution_request_defaults_protocol_version() {
243 let payload = json!({
244 "job_id": "job-1",
245 "request_id": "req-1",
246 "function_name": "echo",
247 "params": {},
248 "context": {
249 "job_id": "job-1",
250 "attempt": 1,
251 "enqueue_time": "2024-01-01T00:00:00Z",
252 "queue_name": "default",
253 "deadline": null,
254 "trace_context": null,
255 "worker_id": null
256 }
257 });
258 let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
259 assert_eq!(request.protocol_version, PROTOCOL_VERSION);
260 }
261
262 #[test]
263 fn handler_not_found_sets_error_type() {
264 let outcome = ExecutionOutcome::handler_not_found("job-1", "req-1", "missing handler");
265 assert_eq!(outcome.status, OutcomeStatus::Error);
266 assert_eq!(
267 outcome
268 .error
269 .as_ref()
270 .and_then(|err| err.error_type.as_deref()),
271 Some("handler_not_found")
272 );
273 }
274
275 #[test]
276 fn runner_message_round_trip() {
277 let request = ExecutionRequest {
278 protocol_version: PROTOCOL_VERSION.to_string(),
279 request_id: "req-1".to_string(),
280 job_id: "job-1".to_string(),
281 function_name: "echo".to_string(),
282 params: HashMap::new(),
283 context: ExecutionContext {
284 job_id: "job-1".to_string(),
285 attempt: 1,
286 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
287 queue_name: "default".to_string(),
288 deadline: None,
289 trace_context: None,
290 worker_id: None,
291 },
292 };
293 let msg = RunnerMessage::Request { payload: request };
294 let serialized = serde_json::to_string(&msg).unwrap();
295 let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
296 let RunnerMessage::Request { payload } = decoded else {
297 panic!("unexpected message type")
298 };
299 assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
300 assert_eq!(payload.request_id, "req-1");
301 }
302
303 #[test]
304 fn cancel_request_round_trip() {
305 let cancel = CancelRequest {
306 protocol_version: PROTOCOL_VERSION.to_string(),
307 job_id: "job-1".to_string(),
308 request_id: Some("req-1".to_string()),
309 hard_kill: true,
310 };
311 let msg = RunnerMessage::Cancel { payload: cancel };
312 let serialized = serde_json::to_string(&msg).unwrap();
313 let decoded: RunnerMessage = serde_json::from_str(&serialized).unwrap();
314 let RunnerMessage::Cancel { payload } = decoded else {
315 panic!("unexpected message type")
316 };
317 assert_eq!(payload.protocol_version, PROTOCOL_VERSION);
318 assert_eq!(payload.request_id.as_deref(), Some("req-1"));
319 assert!(payload.hard_kill);
320 }
321
322 #[test]
323 fn frame_round_trip() {
324 let outcome = ExecutionOutcome::success("job-1", "req-1", json!({"ok": true}));
325 let message = RunnerMessage::Response { payload: outcome };
326 let framed = encode_frame(&message).expect("frame encode failed");
327 let decoded = decode_frame(&framed).expect("frame decode failed");
328 let RunnerMessage::Response { payload } = decoded else {
329 panic!("unexpected message variant")
330 };
331 assert_eq!(payload.status, OutcomeStatus::Success);
332 assert_eq!(payload.job_id.as_deref(), Some("job-1"));
333 }
334}