Skip to main content

rivetkit_client/protocol/
codec.rs

1use anyhow::{anyhow, Context, Result};
2use rivetkit_client_protocol as wire;
3use serde::Serialize;
4use serde_json::{json, Value as JsonValue};
5use vbare::OwnedVersionedData;
6
7use crate::EncodingKind;
8
9use super::{to_client, to_server};
10
11pub fn encode_to_server(encoding: EncodingKind, value: &to_server::ToServer) -> Result<Vec<u8>> {
12	match encoding {
13		EncodingKind::Json => Ok(serde_json::to_vec(&to_server_json_value(value)?)?),
14		EncodingKind::Cbor => Ok(serde_cbor::to_vec(&to_server_json_value(value)?)?),
15		EncodingKind::Bare => encode_to_server_bare(value),
16	}
17}
18
19pub fn decode_to_client(encoding: EncodingKind, payload: &[u8]) -> Result<to_client::ToClient> {
20	match encoding {
21		EncodingKind::Json => {
22			let value: JsonValue =
23				serde_json::from_slice(payload).context("decode actor websocket json response")?;
24			to_client_from_json_value(&value)
25		}
26		EncodingKind::Cbor => {
27			let value: JsonValue =
28				serde_cbor::from_slice(payload).context("decode actor websocket cbor response")?;
29			to_client_from_json_value(&value)
30		}
31		EncodingKind::Bare => decode_to_client_bare(payload),
32	}
33}
34
35pub fn encode_http_action_request(encoding: EncodingKind, args: &[JsonValue]) -> Result<Vec<u8>> {
36	match encoding {
37		EncodingKind::Json => Ok(serde_json::to_vec(&json!({ "args": args }))?),
38		EncodingKind::Cbor => Ok(serde_cbor::to_vec(&json!({ "args": args }))?),
39		EncodingKind::Bare => {
40			wire::versioned::HttpActionRequest::wrap_latest(wire::HttpActionRequest {
41				args: serde_cbor::to_vec(&args.to_vec())?,
42			})
43			.serialize_with_embedded_version(wire::PROTOCOL_VERSION)
44		}
45	}
46}
47
48pub fn decode_http_action_response(encoding: EncodingKind, payload: &[u8]) -> Result<JsonValue> {
49	match encoding {
50		EncodingKind::Json => {
51			let value: JsonValue = serde_json::from_slice(payload)?;
52			value
53				.get("output")
54				.cloned()
55				.ok_or_else(|| anyhow!("action response missing output"))
56		}
57		EncodingKind::Cbor => {
58			let value: JsonValue = serde_cbor::from_slice(payload)?;
59			value
60				.get("output")
61				.cloned()
62				.ok_or_else(|| anyhow!("action response missing output"))
63		}
64		EncodingKind::Bare => {
65			let response =
66                <wire::versioned::HttpActionResponse as OwnedVersionedData>::deserialize_with_embedded_version(
67                    payload,
68                )
69                .context("decode bare action response")?;
70			Ok(serde_cbor::from_slice(&response.output)?)
71		}
72	}
73}
74
75pub fn encode_http_queue_request<T: Serialize>(
76	encoding: EncodingKind,
77	name: &str,
78	body: &T,
79	wait: bool,
80	timeout: Option<u64>,
81) -> Result<Vec<u8>> {
82	#[derive(Serialize)]
83	struct JsonQueueRequest<'a, T: Serialize + ?Sized> {
84		name: &'a str,
85		body: &'a T,
86		wait: bool,
87		#[serde(skip_serializing_if = "Option::is_none")]
88		timeout: Option<u64>,
89	}
90
91	let request = JsonQueueRequest {
92		name,
93		body,
94		wait,
95		timeout,
96	};
97
98	match encoding {
99		EncodingKind::Json => Ok(serde_json::to_vec(&request)?),
100		EncodingKind::Cbor => Ok(serde_cbor::to_vec(&request)?),
101		EncodingKind::Bare => {
102			wire::versioned::HttpQueueSendRequest::wrap_latest(wire::HttpQueueSendRequest {
103				body: serde_cbor::to_vec(body)?,
104				name: Some(name.to_owned()),
105				wait: Some(wait),
106				timeout,
107			})
108			.serialize_with_embedded_version(wire::PROTOCOL_VERSION)
109		}
110	}
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum QueueSendStatus {
115	Completed,
116	TimedOut,
117	Other(String),
118}
119
120#[derive(Debug, Clone)]
121pub struct QueueSendResult {
122	pub status: QueueSendStatus,
123	pub response: Option<JsonValue>,
124}
125
126pub fn decode_http_queue_response(
127	encoding: EncodingKind,
128	payload: &[u8],
129) -> Result<QueueSendResult> {
130	let (status, response) = match encoding {
131		EncodingKind::Json => {
132			let value: JsonValue = serde_json::from_slice(payload)?;
133			let status = value
134				.get("status")
135				.and_then(JsonValue::as_str)
136				.ok_or_else(|| anyhow!("queue response missing status"))?
137				.to_owned();
138			let response = value.get("response").cloned();
139			(status, response)
140		}
141		EncodingKind::Cbor => {
142			let value: JsonValue = serde_cbor::from_slice(payload)?;
143			let status = value
144				.get("status")
145				.and_then(JsonValue::as_str)
146				.ok_or_else(|| anyhow!("queue response missing status"))?
147				.to_owned();
148			let response = value.get("response").cloned();
149			(status, response)
150		}
151		EncodingKind::Bare => {
152			let response =
153                <wire::versioned::HttpQueueSendResponse as OwnedVersionedData>::deserialize_with_embedded_version(
154                    payload,
155                )
156                .context("decode bare queue response")?;
157			let body = response
158				.response
159				.map(|payload| serde_cbor::from_slice(&payload))
160				.transpose()?;
161			(response.status, body)
162		}
163	};
164
165	let status = match status.as_str() {
166		"completed" => QueueSendStatus::Completed,
167		"timedOut" => QueueSendStatus::TimedOut,
168		_ => QueueSendStatus::Other(status),
169	};
170
171	Ok(QueueSendResult { status, response })
172}
173
174pub fn decode_http_error(
175	encoding: EncodingKind,
176	payload: &[u8],
177) -> Result<(String, String, String, Option<JsonValue>)> {
178	match encoding {
179		EncodingKind::Json => {
180			let value: JsonValue = serde_json::from_slice(payload)?;
181			error_from_json_value(&value)
182		}
183		EncodingKind::Cbor => {
184			let value: JsonValue = serde_cbor::from_slice(payload)?;
185			error_from_json_value(&value)
186		}
187		EncodingKind::Bare => {
188			let error =
189                <wire::versioned::HttpResponseError as OwnedVersionedData>::deserialize_with_embedded_version(
190                    payload,
191                )
192                .context("decode bare http error")?;
193			let metadata = error
194				.metadata
195				.map(|payload| serde_cbor::from_slice(&payload))
196				.transpose()?;
197			Ok((error.group, error.code, error.message, metadata))
198		}
199	}
200}
201
202fn to_server_json_value(value: &to_server::ToServer) -> Result<JsonValue> {
203	let body = match &value.body {
204		to_server::ToServerBody::ActionRequest(request) => json!({
205			"tag": "ActionRequest",
206			"val": {
207				"id": request.id,
208				"name": request.name,
209				"args": serde_cbor::from_slice::<JsonValue>(&request.args)
210					.context("decode websocket action args for json/cbor transport")?,
211			},
212		}),
213		to_server::ToServerBody::SubscriptionRequest(request) => json!({
214			"tag": "SubscriptionRequest",
215			"val": {
216				"eventName": request.event_name,
217				"subscribe": request.subscribe,
218			},
219		}),
220	};
221	Ok(json!({ "body": body }))
222}
223
224fn to_client_from_json_value(value: &JsonValue) -> Result<to_client::ToClient> {
225	let body = value
226		.get("body")
227		.and_then(JsonValue::as_object)
228		.ok_or_else(|| anyhow!("actor websocket response missing body"))?;
229	let tag = body
230		.get("tag")
231		.and_then(JsonValue::as_str)
232		.ok_or_else(|| anyhow!("actor websocket response missing tag"))?;
233	let value = body
234		.get("val")
235		.and_then(JsonValue::as_object)
236		.ok_or_else(|| anyhow!("actor websocket response missing val"))?;
237
238	let body = match tag {
239		"Init" => to_client::ToClientBody::Init(to_client::Init {
240			actor_id: json_string(value, "actorId")?,
241			connection_id: json_string(value, "connectionId")?,
242			connection_token: value
243				.get("connectionToken")
244				.and_then(JsonValue::as_str)
245				.map(ToOwned::to_owned),
246		}),
247		"Error" => to_client::ToClientBody::Error(to_client::Error {
248			group: json_string(value, "group")?,
249			code: json_string(value, "code")?,
250			message: json_string(value, "message")?,
251			metadata: value.get("metadata").map(serde_cbor::to_vec).transpose()?,
252			action_id: value.get("actionId").map(parse_json_u64).transpose()?,
253		}),
254		"ActionResponse" => to_client::ToClientBody::ActionResponse(to_client::ActionResponse {
255			id: parse_json_u64(
256				value
257					.get("id")
258					.ok_or_else(|| anyhow!("action response missing id"))?,
259			)?,
260			output: serde_cbor::to_vec(
261				value
262					.get("output")
263					.ok_or_else(|| anyhow!("action response missing output"))?,
264			)?,
265		}),
266		"Event" => to_client::ToClientBody::Event(to_client::Event {
267			name: json_string(value, "name")?,
268			args: serde_cbor::to_vec(
269				value
270					.get("args")
271					.ok_or_else(|| anyhow!("event response missing args"))?,
272			)?,
273		}),
274		other => return Err(anyhow!("unknown actor websocket response tag `{other}`")),
275	};
276
277	Ok(to_client::ToClient { body })
278}
279
280fn encode_to_server_bare(value: &to_server::ToServer) -> Result<Vec<u8>> {
281	let body = match &value.body {
282		to_server::ToServerBody::ActionRequest(request) => {
283			wire::ToServerBody::ActionRequest(wire::ActionRequest {
284				id: serde_bare::Uint(request.id),
285				name: request.name.clone(),
286				args: request.args.clone(),
287			})
288		}
289		to_server::ToServerBody::SubscriptionRequest(request) => {
290			wire::ToServerBody::SubscriptionRequest(wire::SubscriptionRequest {
291				event_name: request.event_name.clone(),
292				subscribe: request.subscribe,
293			})
294		}
295	};
296
297	wire::versioned::ToServer::wrap_latest(wire::ToServer { body })
298		.serialize_with_embedded_version(wire::PROTOCOL_VERSION)
299}
300
301fn decode_to_client_bare(payload: &[u8]) -> Result<to_client::ToClient> {
302	let message =
303		<wire::versioned::ToClient as OwnedVersionedData>::deserialize_with_embedded_version(
304			payload,
305		)
306		.context("decode bare actor websocket response")?;
307
308	let body = match message.body {
309		wire::ToClientBody::Init(init) => to_client::ToClientBody::Init(to_client::Init {
310			actor_id: init.actor_id,
311			connection_id: init.connection_id,
312			connection_token: None,
313		}),
314		wire::ToClientBody::Error(error) => to_client::ToClientBody::Error(to_client::Error {
315			group: error.group,
316			code: error.code,
317			message: error.message,
318			metadata: error.metadata,
319			action_id: error.action_id.map(|id| id.0),
320		}),
321		wire::ToClientBody::ActionResponse(response) => {
322			to_client::ToClientBody::ActionResponse(to_client::ActionResponse {
323				id: response.id.0,
324				output: response.output,
325			})
326		}
327		wire::ToClientBody::Event(event) => to_client::ToClientBody::Event(to_client::Event {
328			name: event.name,
329			args: event.args,
330		}),
331	};
332
333	Ok(to_client::ToClient { body })
334}
335
336fn json_string(value: &serde_json::Map<String, JsonValue>, key: &str) -> Result<String> {
337	value
338		.get(key)
339		.and_then(JsonValue::as_str)
340		.map(ToOwned::to_owned)
341		.ok_or_else(|| anyhow!("json object missing string field `{key}`"))
342}
343
344fn parse_json_u64(value: &JsonValue) -> Result<u64> {
345	match value {
346		JsonValue::Number(number) => number
347			.as_u64()
348			.ok_or_else(|| anyhow!("json number is not an unsigned integer")),
349		JsonValue::Array(values) if values.len() == 2 => {
350			let tag = values[0]
351				.as_str()
352				.ok_or_else(|| anyhow!("json bigint tag is not a string"))?;
353			let raw = values[1]
354				.as_str()
355				.ok_or_else(|| anyhow!("json bigint value is not a string"))?;
356			if tag != "$BigInt" {
357				return Err(anyhow!("unsupported json bigint tag `{tag}`"));
358			}
359			raw.parse::<u64>().context("parse json bigint")
360		}
361		_ => Err(anyhow!("invalid json unsigned integer")),
362	}
363}
364
365fn error_from_json_value(value: &JsonValue) -> Result<(String, String, String, Option<JsonValue>)> {
366	let value = value
367		.as_object()
368		.ok_or_else(|| anyhow!("http error response is not an object"))?;
369	Ok((
370		json_string(value, "group")?,
371		json_string(value, "code")?,
372		json_string(value, "message")?,
373		value.get("metadata").cloned(),
374	))
375}
376
377#[cfg(test)]
378mod tests {
379	use serde_json::json;
380
381	use super::*;
382
383	#[test]
384	fn bare_action_response_round_trips() {
385		let payload = wire::versioned::HttpActionResponse::wrap_latest(wire::HttpActionResponse {
386			output: serde_cbor::to_vec(&json!({ "ok": true })).unwrap(),
387		})
388		.serialize_with_embedded_version(wire::PROTOCOL_VERSION)
389		.unwrap();
390
391		let output = decode_http_action_response(EncodingKind::Bare, &payload).unwrap();
392		assert_eq!(output, json!({ "ok": true }));
393	}
394
395	#[test]
396	fn bare_queue_request_has_embedded_version() {
397		let payload = encode_http_queue_request(
398			EncodingKind::Bare,
399			"jobs",
400			&json!({ "id": 1 }),
401			true,
402			Some(50),
403		)
404		.unwrap();
405		assert_eq!(
406			u16::from_le_bytes([payload[0], payload[1]]),
407			wire::PROTOCOL_VERSION
408		);
409	}
410}