Skip to main content

rivetkit_client/
handle.rs

1use crate::{
2	common::{EncodingKind, RawWebSocket, TransportKind, HEADER_CONN_PARAMS, HEADER_ENCODING},
3	connection::{start_connection, ActorConnection, ActorConnectionInner},
4	protocol::{codec, query::*},
5	remote_manager::RemoteManager,
6};
7use anyhow::{anyhow, Result};
8use bytes::Bytes;
9use reqwest::{
10	header::{HeaderMap, HeaderValue},
11	Method, Response,
12};
13use serde::Serialize;
14use serde_json::Value as JsonValue;
15use std::{
16	ops::Deref,
17	sync::{Arc, Mutex},
18	time::Duration,
19};
20
21pub use crate::protocol::codec::{QueueSendResult, QueueSendStatus};
22
23#[derive(Debug, Clone, Copy, Default)]
24pub struct SendOpts {}
25
26#[derive(Debug, Clone, Copy, Default)]
27pub struct SendAndWaitOpts {
28	pub timeout: Option<Duration>,
29}
30
31pub type QueueSendOptions = SendAndWaitOpts;
32
33pub struct ActorHandleStateless {
34	remote_manager: RemoteManager,
35	params: Option<JsonValue>,
36	encoding_kind: EncodingKind,
37	// Mutex (not RefCell) so the handle is `Sync` and `&handle` futures
38	// remain `Send` — required to call `.action(...)` from within axum
39	// middleware that needs `Send` futures.
40	query: Mutex<ActorQuery>,
41}
42
43impl ActorHandleStateless {
44	pub fn new(
45		remote_manager: RemoteManager,
46		params: Option<JsonValue>,
47		encoding_kind: EncodingKind,
48		query: ActorQuery,
49	) -> Self {
50		Self {
51			remote_manager,
52			params,
53			encoding_kind,
54			query: Mutex::new(query),
55		}
56	}
57
58	pub async fn action(&self, name: &str, args: Vec<JsonValue>) -> Result<JsonValue> {
59		// Resolve actor ID
60		let query = self.query.lock().expect("query lock poisoned").clone();
61		let actor_id = self.remote_manager.resolve_actor_id(&query).await?;
62
63		let body = codec::encode_http_action_request(self.encoding_kind, &args)?;
64
65		let headers = self.protocol_headers()?;
66
67		// Send request via gateway
68		let path = format!("/action/{}", urlencoding::encode(name));
69		let res = self
70			.remote_manager
71			.send_request(
72				&actor_id,
73				&path,
74				Method::POST,
75				headers,
76				Some(Bytes::from(body)),
77			)
78			.await?;
79
80		if !res.status().is_success() {
81			let status = res.status();
82			let body = res.bytes().await?;
83			if let Ok((group, code, message, metadata)) =
84				codec::decode_http_error(self.encoding_kind, &body)
85			{
86				return Err(anyhow!(
87					"action failed ({group}/{code}): {message}, metadata={metadata:?}"
88				));
89			}
90			return Err(anyhow!("action failed: {status}"));
91		}
92
93		// Decode response
94		let output = res.bytes().await?;
95		codec::decode_http_action_response(self.encoding_kind, &output)
96	}
97
98	pub async fn send(&self, name: &str, body: impl Serialize, _opts: SendOpts) -> Result<()> {
99		self.send_queue(name, &body, false, None).await.map(|_| ())
100	}
101
102	pub async fn send_and_wait(
103		&self,
104		name: &str,
105		body: impl Serialize,
106		opts: SendAndWaitOpts,
107	) -> Result<QueueSendResult> {
108		let result = self.send_queue(name, &body, true, opts.timeout).await?;
109		result.ok_or_else(|| anyhow!("queue wait response missing"))
110	}
111
112	async fn send_queue<T: Serialize>(
113		&self,
114		name: &str,
115		body: &T,
116		wait: bool,
117		timeout: Option<Duration>,
118	) -> Result<Option<QueueSendResult>> {
119		let query = self.query.lock().expect("query lock poisoned").clone();
120		let actor_id = self.remote_manager.resolve_actor_id(&query).await?;
121		let timeout_ms =
122			timeout.map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX));
123		let request_body =
124			codec::encode_http_queue_request(self.encoding_kind, name, body, wait, timeout_ms)?;
125
126		let headers = self.protocol_headers()?;
127
128		let path = format!("/queue/{}", urlencoding::encode(name));
129		let res = self
130			.remote_manager
131			.send_request(
132				&actor_id,
133				&path,
134				Method::POST,
135				headers,
136				Some(Bytes::from(request_body)),
137			)
138			.await?;
139
140		if !res.status().is_success() {
141			let status = res.status();
142			let body = res.bytes().await?;
143			if let Ok((group, code, message, metadata)) =
144				codec::decode_http_error(self.encoding_kind, &body)
145			{
146				return Err(anyhow!(
147					"queue send failed ({group}/{code}): {message}, metadata={metadata:?}"
148				));
149			}
150			return Err(anyhow!("queue send failed: {status}"));
151		}
152
153		let body = res.bytes().await?;
154		let result = codec::decode_http_queue_response(self.encoding_kind, &body)?;
155		Ok(wait.then_some(result))
156	}
157
158	pub async fn fetch(
159		&self,
160		path: &str,
161		method: Method,
162		headers: HeaderMap,
163		body: Option<Bytes>,
164	) -> Result<Response> {
165		let query = self.query.lock().expect("query lock poisoned").clone();
166		let actor_id = self.remote_manager.resolve_actor_id(&query).await?;
167		let path = normalize_fetch_path(path);
168		self.remote_manager
169			.send_request(&actor_id, &path, method, headers, body)
170			.await
171	}
172
173	pub async fn web_socket(
174		&self,
175		path: &str,
176		protocols: Option<Vec<String>>,
177	) -> Result<RawWebSocket> {
178		let query = self.query.lock().expect("query lock poisoned").clone();
179		let actor_id = self.remote_manager.resolve_actor_id(&query).await?;
180		self.remote_manager
181			.open_raw_websocket(&actor_id, path, self.params.clone(), protocols)
182			.await
183	}
184
185	pub fn gateway_url(&self) -> Result<String> {
186		let query = self.query.lock().expect("query lock poisoned").clone();
187		self.remote_manager.gateway_url(&query)
188	}
189
190	pub fn get_gateway_url(&self) -> Result<String> {
191		self.gateway_url()
192	}
193
194	pub async fn reload(&self) -> Result<()> {
195		let query = self.query.lock().expect("query lock poisoned").clone();
196		let actor_id = self.remote_manager.resolve_actor_id(&query).await?;
197		let res = self
198			.remote_manager
199			.send_request(
200				&actor_id,
201				"/dynamic/reload",
202				Method::PUT,
203				HeaderMap::new(),
204				None,
205			)
206			.await?;
207		if !res.status().is_success() {
208			let status = res.status();
209			let body = res.text().await.unwrap_or_default();
210			return Err(anyhow!("reload failed with status {status}: {body}"));
211		}
212		Ok(())
213	}
214
215	pub async fn resolve(&self) -> Result<String> {
216		let query = {
217			let Ok(query) = self.query.lock() else {
218				return Err(anyhow!("Failed to lock actor query"));
219			};
220			query.clone()
221		};
222
223		match query {
224			ActorQuery::Create { .. } => Err(anyhow!("actor query cannot be create")),
225			ActorQuery::GetForId { get_for_id } => Ok(get_for_id.actor_id.clone()),
226			_ => {
227				let actor_id = self.remote_manager.resolve_actor_id(&query).await?;
228
229				// Get name from the original query
230				let name = match &query {
231					ActorQuery::GetForKey { get_for_key } => get_for_key.name.clone(),
232					ActorQuery::GetOrCreateForKey {
233						get_or_create_for_key,
234					} => get_or_create_for_key.name.clone(),
235					_ => return Err(anyhow!("unexpected query type")),
236				};
237
238				{
239					let Ok(mut query_mut) = self.query.lock() else {
240						return Err(anyhow!("Failed to lock actor query mutably"));
241					};
242
243					*query_mut = ActorQuery::GetForId {
244						get_for_id: GetForIdRequest {
245							name,
246							actor_id: actor_id.clone(),
247						},
248					};
249				}
250
251				Ok(actor_id)
252			}
253		}
254	}
255
256	fn protocol_headers(&self) -> Result<HeaderMap> {
257		let mut headers = HeaderMap::new();
258		headers.insert(
259			HEADER_ENCODING,
260			HeaderValue::from_str(self.encoding_kind.as_str())?,
261		);
262
263		if let Some(params) = &self.params {
264			headers.insert(
265				HEADER_CONN_PARAMS,
266				HeaderValue::from_str(&serde_json::to_string(params)?)?,
267			);
268		}
269
270		Ok(headers)
271	}
272}
273
274fn normalize_fetch_path(path: &str) -> String {
275	let path = path.trim_start_matches('/');
276	if path.is_empty() {
277		"/request".to_string()
278	} else {
279		format!("/request/{path}")
280	}
281}
282
283pub struct ActorHandle {
284	handle: ActorHandleStateless,
285	remote_manager: RemoteManager,
286	params: Option<JsonValue>,
287	query: ActorQuery,
288	client_shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
289	transport_kind: crate::TransportKind,
290	encoding_kind: EncodingKind,
291}
292
293impl ActorHandle {
294	pub fn new(
295		remote_manager: RemoteManager,
296		params: Option<JsonValue>,
297		query: ActorQuery,
298		client_shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
299		transport_kind: TransportKind,
300		encoding_kind: EncodingKind,
301	) -> Self {
302		let handle = ActorHandleStateless::new(
303			remote_manager.clone(),
304			params.clone(),
305			encoding_kind,
306			query.clone(),
307		);
308
309		Self {
310			handle,
311			remote_manager,
312			params,
313			query,
314			client_shutdown_tx,
315			transport_kind,
316			encoding_kind,
317		}
318	}
319
320	pub fn connect(&self) -> ActorConnection {
321		let conn = ActorConnectionInner::new(
322			self.remote_manager.clone(),
323			self.query.clone(),
324			self.transport_kind,
325			self.encoding_kind,
326			self.params.clone(),
327		);
328
329		let rx = self.client_shutdown_tx.subscribe();
330		start_connection(&conn, rx);
331
332		conn
333	}
334}
335
336impl Deref for ActorHandle {
337	type Target = ActorHandleStateless;
338
339	fn deref(&self) -> &Self::Target {
340		&self.handle
341	}
342}