1use std::{
2 collections::HashMap,
3 convert::Infallible,
4 future::Future,
5 sync::{
6 atomic::{AtomicU64, Ordering::SeqCst},
7 Arc,
8 },
9};
10
11use fastwebsockets::{FragmentCollectorRead, Frame};
12use futures::{future::Either, pin_mut, StreamExt};
13use hyper::{
14 header::{CONNECTION, UPGRADE},
15 Request,
16};
17use leaf_rpc_proto::{Req, ReqKind, Resp, RespKind};
18use tokio::{
19 net::TcpStream,
20 sync::{mpsc, oneshot, Mutex},
21};
22
23pub use hyper::Uri;
24pub use leaf_protocol;
25
26use leaf_protocol::prelude::*;
27use tokio_stream::wrappers::ReceiverStream;
28
29#[derive(Clone)]
30pub struct RpcClient {
31 index: Arc<AtomicU64>,
32 frame_writer: mpsc::Sender<Frame<'static>>,
33 pending_reqs: Arc<Mutex<HashMap<u64, oneshot::Sender<Resp>>>>,
34}
35
36impl Drop for RpcClient {
38 fn drop(&mut self) {
39 tracing::warn!("TODO: implement graceful shutdown of RPC client.");
40 }
41}
42
43impl RpcClient {
44 pub async fn connect(uri: Uri, auth_token: Option<&str>) -> anyhow::Result<Self> {
45 let host = uri.host().unwrap();
46 let port = uri.port().unwrap();
47 let socket = format!("{host}:{port}");
48 let stream = TcpStream::connect(socket).await?;
49
50 let req = Request::builder()
51 .method("GET")
52 .uri(&uri)
53 .header("Host", host)
54 .header(UPGRADE, "websocket")
55 .header(CONNECTION, "upgrade")
56 .header(
57 "Sec-Websocket-Key",
58 fastwebsockets::handshake::generate_key(),
59 )
60 .header("Sec-Websocket-Version", "13")
61 .body(String::new())?;
62
63 let pending_reqs = Arc::new(Mutex::new(HashMap::<u64, oneshot::Sender<Resp>>::default()));
64 let pending_reqs_ = pending_reqs.clone();
65
66 let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await?;
67 let (ws_read, mut ws_write) = ws.split(tokio::io::split);
68 let mut ws_read = FragmentCollectorRead::new(ws_read);
69
70 let (client_frame_send, client_frame_recv) = mpsc::channel(10);
71
72 tokio::spawn(async move {
73 let read_frame_from_server = async_stream::stream! {
74 loop {
75 yield ws_read.read_frame::<_, Infallible>(&mut |_| async { panic!("obligated send not implemented") }).await;
76 }
77 }
78 .map(Either::Left);
79 let recv_frame_to_send = ReceiverStream::new(client_frame_recv).map(Either::Right);
80
81 let stream = futures::stream::select(read_frame_from_server, recv_frame_to_send);
82 pin_mut!(stream);
83
84 loop {
85 let Some(event) = stream.next().await else {
86 break;
87 };
88
89 match event {
90 Either::Left(frame_from_server) => match frame_from_server {
91 Ok(frame) => {
92 if frame.opcode == fastwebsockets::OpCode::Binary {
93 let mut data = &frame.payload[..];
94 let resp = Resp::deserialize(&mut data);
95 match resp {
96 Ok(resp) => {
97 let mut pending_reqs = pending_reqs_.lock().await;
98 let Some(sender) = pending_reqs.remove(&resp.id) else {
99 tracing::warn!(
100 "Got response for request that is not pending"
101 );
102 continue;
103 };
104 sender.send(resp).ok();
105 }
106 Err(e) => tracing::error!(
107 "Error deserializing response from server: {e}"
108 ),
109 }
110 }
111 }
112 Err(e) => tracing::error!("Error reading message from server: {e}"),
113 },
114 Either::Right(frame_to_send) => {
115 if let Err(e) = ws_write.write_frame(frame_to_send).await {
116 tracing::warn!("Could not send request to server: {e}");
117 }
118 }
119 }
120 }
121 });
122
123 let client = RpcClient {
124 index: Arc::new(0.into()),
125 frame_writer: client_frame_send,
126 pending_reqs,
127 };
128
129 if let Some(auth_token) = auth_token {
130 let resp = client
131 .send_req(ReqKind::Authenticate(auth_token.into()))
132 .await?;
133 match resp.result {
134 Ok(RespKind::Authenticated) => (),
135 Ok(_) => anyhow::bail!("Unexpected response when authenticating"),
136 Err(e) => anyhow::bail!("Authentication error: {e}"),
137 }
138 }
139
140 Ok(client)
141 }
142
143 async fn send_req(&self, kind: ReqKind) -> anyhow::Result<Resp> {
144 let id = self.index.fetch_add(1, SeqCst);
145 let req = Req { id, kind };
146
147 let mut req_bytes = Vec::new();
148 req.serialize(&mut req_bytes)?;
149
150 let (resp_sender, resp_receiver) = oneshot::channel();
151 {
152 let mut pending_reqs = self.pending_reqs.lock().await;
153 pending_reqs.insert(id, resp_sender);
154 }
155
156 self.frame_writer
157 .send(Frame::binary(fastwebsockets::Payload::Owned(req_bytes)))
158 .await?;
159
160 let resp = resp_receiver.await?;
161 assert_eq!(resp.id, id, "Invalid RPC ID in response");
162
163 Ok(resp)
164 }
165
166 pub async fn read_entity<L: Into<ExactLink>>(
167 &self,
168 link: L,
169 ) -> anyhow::Result<Option<(Digest, Entity)>> {
170 let link = link.into();
171 let resp = self.send_req(ReqKind::ReadEntity(link)).await?;
172 let RespKind::ReadEntity(entity) = resp
173 .result
174 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
175 else {
176 anyhow::bail!(INVALID_RPC_RESP_MSG);
177 };
178 Ok(entity)
179 }
180
181 pub async fn del_entity<L: Into<ExactLink>>(&self, link: L) -> anyhow::Result<()> {
182 let link = link.into();
183 let resp = self.send_req(ReqKind::DelEntity(link)).await?;
184 let RespKind::DelEntity = resp
185 .result
186 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
187 else {
188 anyhow::bail!(INVALID_RPC_RESP_MSG);
189 };
190 Ok(())
191 }
192
193 pub async fn list_entities<L: Into<ExactLink>>(
194 &self,
195 link: L,
196 ) -> anyhow::Result<Vec<ExactLink>> {
197 let link = link.into();
198 let resp = self.send_req(ReqKind::ListEntities(link)).await?;
199 let RespKind::ListEntities(entities) = resp
200 .result
201 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
202 else {
203 anyhow::bail!(INVALID_RPC_RESP_MSG);
204 };
205 Ok(entities)
206 }
207
208 pub async fn del_components<C: Component, L: Into<ExactLink>>(
210 &self,
211 link: L,
212 ) -> anyhow::Result<Option<Digest>> {
213 let link = link.into();
214
215 let resp = self
216 .send_req(ReqKind::DelComponentsBySchema {
217 link,
218 schemas: vec![C::schema_id()],
219 })
220 .await?;
221 let RespKind::DelComponentBySchema(new_digest) = resp
222 .result
223 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
224 else {
225 anyhow::bail!(INVALID_RPC_RESP_MSG);
226 };
227 Ok(new_digest)
228 }
229
230 pub async fn add_component<C: Component, L: Into<ExactLink>>(
232 &self,
233 link: L,
234 component: C,
235 replace_existing: bool,
236 ) -> anyhow::Result<Digest> {
237 let link = link.into();
238 let component_data = component.make_data()?;
239
240 let resp = self
241 .send_req(ReqKind::AddComponents {
242 link,
243 components: vec![component_data],
244 replace_existing,
245 })
246 .await?;
247 let RespKind::AddComponents(entity_id) = resp
248 .result
249 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
250 else {
251 anyhow::bail!(INVALID_RPC_RESP_MSG);
252 };
253 Ok(entity_id)
254 }
255
256 pub async fn get_components<C: Component, L: Into<ExactLink>>(
258 &self,
259 _link: L,
260 ) -> anyhow::Result<Option<(Digest, Vec<C>)>> {
261 unimplemented!("get_components() needs a better way to get multiple components at a time.");
262 }
300
301 pub async fn create_namespace(&self) -> anyhow::Result<NamespaceId> {
302 let resp = self.send_req(ReqKind::CreateNamespace).await?;
303 let RespKind::CreateNamespace(id) = resp
304 .result
305 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
306 else {
307 anyhow::bail!(INVALID_RPC_RESP_MSG);
308 };
309 Ok(id)
310 }
311 pub async fn import_namespace_secret(
312 &self,
313 namespace: NamespaceSecretKey,
314 ) -> anyhow::Result<NamespaceId> {
315 let resp = self
316 .send_req(ReqKind::ImportNamespaceSecret(namespace))
317 .await?;
318 let RespKind::ImportNamespaceSecret(id) = resp
319 .result
320 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
321 else {
322 anyhow::bail!(INVALID_RPC_RESP_MSG);
323 };
324 Ok(id)
325 }
326 pub async fn get_namespace_secret(
327 &self,
328 namespace: NamespaceSecretKey,
329 ) -> anyhow::Result<Option<NamespaceSecretKey>> {
330 let resp = self
331 .send_req(ReqKind::GetNamespaceSecret(namespace))
332 .await?;
333 let RespKind::GetNamespaceSecret(id) = resp
334 .result
335 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
336 else {
337 anyhow::bail!(INVALID_RPC_RESP_MSG);
338 };
339 Ok(id)
340 }
341
342 pub async fn create_subspace(&self) -> anyhow::Result<SubspaceId> {
343 let resp = self.send_req(ReqKind::CreateSubspace).await?;
344 let RespKind::CreateSubspace(id) = resp
345 .result
346 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
347 else {
348 anyhow::bail!(INVALID_RPC_RESP_MSG);
349 };
350 Ok(id)
351 }
352 pub async fn import_subspace_secret(
353 &self,
354 subspace: SubspaceSecretKey,
355 ) -> anyhow::Result<SubspaceId> {
356 let resp = self
357 .send_req(ReqKind::ImportSubspaceSecret(subspace))
358 .await?;
359 let RespKind::ImportSubspaceSecret(id) = resp
360 .result
361 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
362 else {
363 anyhow::bail!(INVALID_RPC_RESP_MSG);
364 };
365 Ok(id)
366 }
367 pub async fn get_subspace_secret(
368 &self,
369 subspace: SubspaceSecretKey,
370 ) -> anyhow::Result<Option<SubspaceSecretKey>> {
371 let resp = self.send_req(ReqKind::GetSubspaceSecret(subspace)).await?;
372 let RespKind::GetSubspaceSecret(id) = resp
373 .result
374 .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
375 else {
376 anyhow::bail!(INVALID_RPC_RESP_MSG);
377 };
378 Ok(id)
379 }
380}
381const INVALID_RPC_RESP_MSG: &str = "Invalid response kind from RPC endpoint";
382
383struct SpawnExecutor;
384
385impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
386where
387 Fut: Future + Send + 'static,
388 Fut::Output: Send + 'static,
389{
390 fn execute(&self, fut: Fut) {
391 tokio::task::spawn(fut);
392 }
393}