Skip to main content

rivetkit_client/
client.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::Result;
4use serde_json::Value as JsonValue;
5
6use crate::{
7	common::{ActorKey, EncodingKind, TransportKind},
8	handle::ActorHandle,
9	protocol::query::*,
10	remote_manager::RemoteManager,
11};
12
13#[derive(Default)]
14pub struct GetWithIdOptions {
15	pub params: Option<JsonValue>,
16}
17
18#[derive(Default)]
19pub struct GetOptions {
20	pub params: Option<JsonValue>,
21}
22
23#[derive(Default)]
24pub struct GetOrCreateOptions {
25	pub params: Option<JsonValue>,
26	pub create_in_region: Option<String>,
27	pub create_with_input: Option<JsonValue>,
28}
29
30#[derive(Default)]
31pub struct CreateOptions {
32	pub params: Option<JsonValue>,
33	pub region: Option<String>,
34	pub input: Option<JsonValue>,
35}
36
37pub struct ClientConfig {
38	pub endpoint: String,
39	pub token: Option<String>,
40	pub namespace: Option<String>,
41	pub pool_name: Option<String>,
42	pub encoding: EncodingKind,
43	pub transport: TransportKind,
44	pub headers: Option<HashMap<String, String>>,
45	pub max_input_size: Option<usize>,
46	pub disable_metadata_lookup: bool,
47}
48
49impl ClientConfig {
50	pub fn new(endpoint: impl Into<String>) -> Self {
51		Self {
52			endpoint: endpoint.into(),
53			token: None,
54			namespace: None,
55			pool_name: None,
56			encoding: EncodingKind::Bare,
57			transport: TransportKind::WebSocket,
58			headers: None,
59			max_input_size: None,
60			disable_metadata_lookup: false,
61		}
62	}
63
64	pub fn token(mut self, token: impl Into<String>) -> Self {
65		self.token = Some(token.into());
66		self
67	}
68
69	pub fn token_opt(mut self, token: Option<String>) -> Self {
70		self.token = token;
71		self
72	}
73
74	pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
75		self.namespace = Some(namespace.into());
76		self
77	}
78
79	pub fn pool_name(mut self, pool_name: impl Into<String>) -> Self {
80		self.pool_name = Some(pool_name.into());
81		self
82	}
83
84	pub fn encoding(mut self, encoding: EncodingKind) -> Self {
85		self.encoding = encoding;
86		self
87	}
88
89	pub fn transport(mut self, transport: TransportKind) -> Self {
90		self.transport = transport;
91		self
92	}
93
94	pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
95		self.headers
96			.get_or_insert_with(HashMap::new)
97			.insert(key.into(), value.into());
98		self
99	}
100
101	pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
102		self.headers = Some(headers);
103		self
104	}
105
106	pub fn max_input_size(mut self, max_input_size: usize) -> Self {
107		self.max_input_size = Some(max_input_size);
108		self
109	}
110
111	pub fn disable_metadata_lookup(mut self, disable: bool) -> Self {
112		self.disable_metadata_lookup = disable;
113		self
114	}
115}
116
117pub struct Client {
118	remote_manager: RemoteManager,
119	encoding_kind: EncodingKind,
120	transport_kind: TransportKind,
121	shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
122}
123
124impl Clone for Client {
125	fn clone(&self) -> Self {
126		Self {
127			remote_manager: self.remote_manager.clone(),
128			encoding_kind: self.encoding_kind,
129			transport_kind: self.transport_kind,
130			shutdown_tx: self.shutdown_tx.clone(),
131		}
132	}
133}
134
135impl std::fmt::Debug for Client {
136	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137		f.debug_struct("Client")
138			.field("encoding_kind", &self.encoding_kind)
139			.field("transport_kind", &self.transport_kind)
140			.finish_non_exhaustive()
141	}
142}
143
144impl Client {
145	pub fn new(config: ClientConfig) -> Self {
146		let remote_manager = RemoteManager::from_config(
147			config.endpoint,
148			config.token,
149			config.namespace,
150			config.pool_name,
151			config.headers,
152			config.max_input_size,
153			config.disable_metadata_lookup,
154		);
155
156		Self {
157			remote_manager,
158			encoding_kind: config.encoding,
159			transport_kind: config.transport,
160			shutdown_tx: Arc::new(tokio::sync::broadcast::channel(1).0),
161		}
162	}
163
164	pub fn from_endpoint(endpoint: impl Into<String>) -> Self {
165		Self::new(ClientConfig::new(endpoint))
166	}
167
168	fn create_handle(&self, params: Option<JsonValue>, query: ActorQuery) -> ActorHandle {
169		let handle = ActorHandle::new(
170			self.remote_manager.clone(),
171			params,
172			query,
173			self.shutdown_tx.clone(),
174			self.transport_kind,
175			self.encoding_kind,
176		);
177
178		handle
179	}
180
181	pub fn get(&self, name: &str, key: ActorKey, opts: GetOptions) -> Result<ActorHandle> {
182		let actor_query = ActorQuery::GetForKey {
183			get_for_key: GetForKeyRequest {
184				name: name.to_string(),
185				key,
186			},
187		};
188
189		let handle = self.create_handle(opts.params, actor_query);
190
191		Ok(handle)
192	}
193
194	pub fn get_for_id(&self, name: &str, actor_id: &str, opts: GetOptions) -> Result<ActorHandle> {
195		let actor_query = ActorQuery::GetForId {
196			get_for_id: GetForIdRequest {
197				name: name.to_string(),
198				actor_id: actor_id.to_string(),
199			},
200		};
201
202		let handle = self.create_handle(opts.params, actor_query);
203
204		Ok(handle)
205	}
206
207	pub fn get_or_create(
208		&self,
209		name: &str,
210		key: ActorKey,
211		opts: GetOrCreateOptions,
212	) -> Result<ActorHandle> {
213		let input = opts.create_with_input;
214		let region = opts.create_in_region;
215
216		let actor_query = ActorQuery::GetOrCreateForKey {
217			get_or_create_for_key: GetOrCreateRequest {
218				name: name.to_string(),
219				key: key,
220				input,
221				region,
222			},
223		};
224
225		let handle = self.create_handle(opts.params, actor_query);
226
227		Ok(handle)
228	}
229
230	pub async fn create(
231		&self,
232		name: &str,
233		key: ActorKey,
234		opts: CreateOptions,
235	) -> Result<ActorHandle> {
236		let input = opts.input;
237		let _region = opts.region;
238
239		let actor_id = self.remote_manager.create_actor(name, &key, input).await?;
240
241		let get_query = ActorQuery::GetForId {
242			get_for_id: GetForIdRequest {
243				name: name.to_string(),
244				actor_id,
245			},
246		};
247
248		let handle = self.create_handle(opts.params, get_query);
249
250		Ok(handle)
251	}
252
253	pub fn disconnect(self) {
254		drop(self)
255	}
256
257	pub fn dispose(self) {
258		self.disconnect()
259	}
260}
261
262impl Drop for Client {
263	fn drop(&mut self) {
264		// Notify all subscribers to shutdown
265		let _ = self.shutdown_tx.send(());
266	}
267}