rivetkit_client/protocol/
codec.rs1use anyhow::{anyhow, Context, Result};
2use rivetkit_client_protocol as wire;
3use serde::Serialize;
4use serde_json::{json, Value as JsonValue};
5use vbare::OwnedVersionedData;
6
7use crate::EncodingKind;
8
9use super::{to_client, to_server};
10
11pub fn encode_to_server(encoding: EncodingKind, value: &to_server::ToServer) -> Result<Vec<u8>> {
12 match encoding {
13 EncodingKind::Json => Ok(serde_json::to_vec(&to_server_json_value(value)?)?),
14 EncodingKind::Cbor => Ok(serde_cbor::to_vec(&to_server_json_value(value)?)?),
15 EncodingKind::Bare => encode_to_server_bare(value),
16 }
17}
18
19pub fn decode_to_client(encoding: EncodingKind, payload: &[u8]) -> Result<to_client::ToClient> {
20 match encoding {
21 EncodingKind::Json => {
22 let value: JsonValue =
23 serde_json::from_slice(payload).context("decode actor websocket json response")?;
24 to_client_from_json_value(&value)
25 }
26 EncodingKind::Cbor => {
27 let value: JsonValue =
28 serde_cbor::from_slice(payload).context("decode actor websocket cbor response")?;
29 to_client_from_json_value(&value)
30 }
31 EncodingKind::Bare => decode_to_client_bare(payload),
32 }
33}
34
35pub fn encode_http_action_request(encoding: EncodingKind, args: &[JsonValue]) -> Result<Vec<u8>> {
36 match encoding {
37 EncodingKind::Json => Ok(serde_json::to_vec(&json!({ "args": args }))?),
38 EncodingKind::Cbor => Ok(serde_cbor::to_vec(&json!({ "args": args }))?),
39 EncodingKind::Bare => {
40 wire::versioned::HttpActionRequest::wrap_latest(wire::HttpActionRequest {
41 args: serde_cbor::to_vec(&args.to_vec())?,
42 })
43 .serialize_with_embedded_version(wire::PROTOCOL_VERSION)
44 }
45 }
46}
47
48pub fn decode_http_action_response(encoding: EncodingKind, payload: &[u8]) -> Result<JsonValue> {
49 match encoding {
50 EncodingKind::Json => {
51 let value: JsonValue = serde_json::from_slice(payload)?;
52 value
53 .get("output")
54 .cloned()
55 .ok_or_else(|| anyhow!("action response missing output"))
56 }
57 EncodingKind::Cbor => {
58 let value: JsonValue = serde_cbor::from_slice(payload)?;
59 value
60 .get("output")
61 .cloned()
62 .ok_or_else(|| anyhow!("action response missing output"))
63 }
64 EncodingKind::Bare => {
65 let response =
66 <wire::versioned::HttpActionResponse as OwnedVersionedData>::deserialize_with_embedded_version(
67 payload,
68 )
69 .context("decode bare action response")?;
70 Ok(serde_cbor::from_slice(&response.output)?)
71 }
72 }
73}
74
75pub fn encode_http_queue_request<T: Serialize>(
76 encoding: EncodingKind,
77 name: &str,
78 body: &T,
79 wait: bool,
80 timeout: Option<u64>,
81) -> Result<Vec<u8>> {
82 #[derive(Serialize)]
83 struct JsonQueueRequest<'a, T: Serialize + ?Sized> {
84 name: &'a str,
85 body: &'a T,
86 wait: bool,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 timeout: Option<u64>,
89 }
90
91 let request = JsonQueueRequest {
92 name,
93 body,
94 wait,
95 timeout,
96 };
97
98 match encoding {
99 EncodingKind::Json => Ok(serde_json::to_vec(&request)?),
100 EncodingKind::Cbor => Ok(serde_cbor::to_vec(&request)?),
101 EncodingKind::Bare => {
102 wire::versioned::HttpQueueSendRequest::wrap_latest(wire::HttpQueueSendRequest {
103 body: serde_cbor::to_vec(body)?,
104 name: Some(name.to_owned()),
105 wait: Some(wait),
106 timeout,
107 })
108 .serialize_with_embedded_version(wire::PROTOCOL_VERSION)
109 }
110 }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum QueueSendStatus {
115 Completed,
116 TimedOut,
117 Other(String),
118}
119
120#[derive(Debug, Clone)]
121pub struct QueueSendResult {
122 pub status: QueueSendStatus,
123 pub response: Option<JsonValue>,
124}
125
126pub fn decode_http_queue_response(
127 encoding: EncodingKind,
128 payload: &[u8],
129) -> Result<QueueSendResult> {
130 let (status, response) = match encoding {
131 EncodingKind::Json => {
132 let value: JsonValue = serde_json::from_slice(payload)?;
133 let status = value
134 .get("status")
135 .and_then(JsonValue::as_str)
136 .ok_or_else(|| anyhow!("queue response missing status"))?
137 .to_owned();
138 let response = value.get("response").cloned();
139 (status, response)
140 }
141 EncodingKind::Cbor => {
142 let value: JsonValue = serde_cbor::from_slice(payload)?;
143 let status = value
144 .get("status")
145 .and_then(JsonValue::as_str)
146 .ok_or_else(|| anyhow!("queue response missing status"))?
147 .to_owned();
148 let response = value.get("response").cloned();
149 (status, response)
150 }
151 EncodingKind::Bare => {
152 let response =
153 <wire::versioned::HttpQueueSendResponse as OwnedVersionedData>::deserialize_with_embedded_version(
154 payload,
155 )
156 .context("decode bare queue response")?;
157 let body = response
158 .response
159 .map(|payload| serde_cbor::from_slice(&payload))
160 .transpose()?;
161 (response.status, body)
162 }
163 };
164
165 let status = match status.as_str() {
166 "completed" => QueueSendStatus::Completed,
167 "timedOut" => QueueSendStatus::TimedOut,
168 _ => QueueSendStatus::Other(status),
169 };
170
171 Ok(QueueSendResult { status, response })
172}
173
174pub fn decode_http_error(
175 encoding: EncodingKind,
176 payload: &[u8],
177) -> Result<(String, String, String, Option<JsonValue>)> {
178 match encoding {
179 EncodingKind::Json => {
180 let value: JsonValue = serde_json::from_slice(payload)?;
181 error_from_json_value(&value)
182 }
183 EncodingKind::Cbor => {
184 let value: JsonValue = serde_cbor::from_slice(payload)?;
185 error_from_json_value(&value)
186 }
187 EncodingKind::Bare => {
188 let error =
189 <wire::versioned::HttpResponseError as OwnedVersionedData>::deserialize_with_embedded_version(
190 payload,
191 )
192 .context("decode bare http error")?;
193 let metadata = error
194 .metadata
195 .map(|payload| serde_cbor::from_slice(&payload))
196 .transpose()?;
197 Ok((error.group, error.code, error.message, metadata))
198 }
199 }
200}
201
202fn to_server_json_value(value: &to_server::ToServer) -> Result<JsonValue> {
203 let body = match &value.body {
204 to_server::ToServerBody::ActionRequest(request) => json!({
205 "tag": "ActionRequest",
206 "val": {
207 "id": request.id,
208 "name": request.name,
209 "args": serde_cbor::from_slice::<JsonValue>(&request.args)
210 .context("decode websocket action args for json/cbor transport")?,
211 },
212 }),
213 to_server::ToServerBody::SubscriptionRequest(request) => json!({
214 "tag": "SubscriptionRequest",
215 "val": {
216 "eventName": request.event_name,
217 "subscribe": request.subscribe,
218 },
219 }),
220 };
221 Ok(json!({ "body": body }))
222}
223
224fn to_client_from_json_value(value: &JsonValue) -> Result<to_client::ToClient> {
225 let body = value
226 .get("body")
227 .and_then(JsonValue::as_object)
228 .ok_or_else(|| anyhow!("actor websocket response missing body"))?;
229 let tag = body
230 .get("tag")
231 .and_then(JsonValue::as_str)
232 .ok_or_else(|| anyhow!("actor websocket response missing tag"))?;
233 let value = body
234 .get("val")
235 .and_then(JsonValue::as_object)
236 .ok_or_else(|| anyhow!("actor websocket response missing val"))?;
237
238 let body = match tag {
239 "Init" => to_client::ToClientBody::Init(to_client::Init {
240 actor_id: json_string(value, "actorId")?,
241 connection_id: json_string(value, "connectionId")?,
242 connection_token: value
243 .get("connectionToken")
244 .and_then(JsonValue::as_str)
245 .map(ToOwned::to_owned),
246 }),
247 "Error" => to_client::ToClientBody::Error(to_client::Error {
248 group: json_string(value, "group")?,
249 code: json_string(value, "code")?,
250 message: json_string(value, "message")?,
251 metadata: value.get("metadata").map(serde_cbor::to_vec).transpose()?,
252 action_id: value.get("actionId").map(parse_json_u64).transpose()?,
253 }),
254 "ActionResponse" => to_client::ToClientBody::ActionResponse(to_client::ActionResponse {
255 id: parse_json_u64(
256 value
257 .get("id")
258 .ok_or_else(|| anyhow!("action response missing id"))?,
259 )?,
260 output: serde_cbor::to_vec(
261 value
262 .get("output")
263 .ok_or_else(|| anyhow!("action response missing output"))?,
264 )?,
265 }),
266 "Event" => to_client::ToClientBody::Event(to_client::Event {
267 name: json_string(value, "name")?,
268 args: serde_cbor::to_vec(
269 value
270 .get("args")
271 .ok_or_else(|| anyhow!("event response missing args"))?,
272 )?,
273 }),
274 other => return Err(anyhow!("unknown actor websocket response tag `{other}`")),
275 };
276
277 Ok(to_client::ToClient { body })
278}
279
280fn encode_to_server_bare(value: &to_server::ToServer) -> Result<Vec<u8>> {
281 let body = match &value.body {
282 to_server::ToServerBody::ActionRequest(request) => {
283 wire::ToServerBody::ActionRequest(wire::ActionRequest {
284 id: serde_bare::Uint(request.id),
285 name: request.name.clone(),
286 args: request.args.clone(),
287 })
288 }
289 to_server::ToServerBody::SubscriptionRequest(request) => {
290 wire::ToServerBody::SubscriptionRequest(wire::SubscriptionRequest {
291 event_name: request.event_name.clone(),
292 subscribe: request.subscribe,
293 })
294 }
295 };
296
297 wire::versioned::ToServer::wrap_latest(wire::ToServer { body })
298 .serialize_with_embedded_version(wire::PROTOCOL_VERSION)
299}
300
301fn decode_to_client_bare(payload: &[u8]) -> Result<to_client::ToClient> {
302 let message =
303 <wire::versioned::ToClient as OwnedVersionedData>::deserialize_with_embedded_version(
304 payload,
305 )
306 .context("decode bare actor websocket response")?;
307
308 let body = match message.body {
309 wire::ToClientBody::Init(init) => to_client::ToClientBody::Init(to_client::Init {
310 actor_id: init.actor_id,
311 connection_id: init.connection_id,
312 connection_token: None,
313 }),
314 wire::ToClientBody::Error(error) => to_client::ToClientBody::Error(to_client::Error {
315 group: error.group,
316 code: error.code,
317 message: error.message,
318 metadata: error.metadata,
319 action_id: error.action_id.map(|id| id.0),
320 }),
321 wire::ToClientBody::ActionResponse(response) => {
322 to_client::ToClientBody::ActionResponse(to_client::ActionResponse {
323 id: response.id.0,
324 output: response.output,
325 })
326 }
327 wire::ToClientBody::Event(event) => to_client::ToClientBody::Event(to_client::Event {
328 name: event.name,
329 args: event.args,
330 }),
331 };
332
333 Ok(to_client::ToClient { body })
334}
335
336fn json_string(value: &serde_json::Map<String, JsonValue>, key: &str) -> Result<String> {
337 value
338 .get(key)
339 .and_then(JsonValue::as_str)
340 .map(ToOwned::to_owned)
341 .ok_or_else(|| anyhow!("json object missing string field `{key}`"))
342}
343
344fn parse_json_u64(value: &JsonValue) -> Result<u64> {
345 match value {
346 JsonValue::Number(number) => number
347 .as_u64()
348 .ok_or_else(|| anyhow!("json number is not an unsigned integer")),
349 JsonValue::Array(values) if values.len() == 2 => {
350 let tag = values[0]
351 .as_str()
352 .ok_or_else(|| anyhow!("json bigint tag is not a string"))?;
353 let raw = values[1]
354 .as_str()
355 .ok_or_else(|| anyhow!("json bigint value is not a string"))?;
356 if tag != "$BigInt" {
357 return Err(anyhow!("unsupported json bigint tag `{tag}`"));
358 }
359 raw.parse::<u64>().context("parse json bigint")
360 }
361 _ => Err(anyhow!("invalid json unsigned integer")),
362 }
363}
364
365fn error_from_json_value(value: &JsonValue) -> Result<(String, String, String, Option<JsonValue>)> {
366 let value = value
367 .as_object()
368 .ok_or_else(|| anyhow!("http error response is not an object"))?;
369 Ok((
370 json_string(value, "group")?,
371 json_string(value, "code")?,
372 json_string(value, "message")?,
373 value.get("metadata").cloned(),
374 ))
375}
376
377#[cfg(test)]
378mod tests {
379 use serde_json::json;
380
381 use super::*;
382
383 #[test]
384 fn bare_action_response_round_trips() {
385 let payload = wire::versioned::HttpActionResponse::wrap_latest(wire::HttpActionResponse {
386 output: serde_cbor::to_vec(&json!({ "ok": true })).unwrap(),
387 })
388 .serialize_with_embedded_version(wire::PROTOCOL_VERSION)
389 .unwrap();
390
391 let output = decode_http_action_response(EncodingKind::Bare, &payload).unwrap();
392 assert_eq!(output, json!({ "ok": true }));
393 }
394
395 #[test]
396 fn bare_queue_request_has_embedded_version() {
397 let payload = encode_http_queue_request(
398 EncodingKind::Bare,
399 "jobs",
400 &json!({ "id": 1 }),
401 true,
402 Some(50),
403 )
404 .unwrap();
405 assert_eq!(
406 u16::from_le_bytes([payload[0], payload[1]]),
407 wire::PROTOCOL_VERSION
408 );
409 }
410}