leaf_rpc_client/
lib.rs

1use std::{
2    collections::HashMap,
3    convert::Infallible,
4    future::Future,
5    sync::{
6        atomic::{AtomicU64, Ordering::SeqCst},
7        Arc,
8    },
9};
10
11use fastwebsockets::{FragmentCollectorRead, Frame};
12use futures::{future::Either, pin_mut, StreamExt};
13use hyper::{
14    header::{CONNECTION, UPGRADE},
15    Request,
16};
17use leaf_rpc_proto::{Req, ReqKind, Resp, RespKind};
18use tokio::{
19    net::TcpStream,
20    sync::{mpsc, oneshot, Mutex},
21};
22
23pub use hyper::Uri;
24pub use leaf_protocol;
25
26use leaf_protocol::prelude::*;
27use tokio_stream::wrappers::ReceiverStream;
28
29#[derive(Clone)]
30pub struct RpcClient {
31    index: Arc<AtomicU64>,
32    frame_writer: mpsc::Sender<Frame<'static>>,
33    pending_reqs: Arc<Mutex<HashMap<u64, oneshot::Sender<Resp>>>>,
34}
35
36// TODO: Implement graceful shutdown of RPC client.
37impl Drop for RpcClient {
38    fn drop(&mut self) {
39        tracing::warn!("TODO: implement graceful shutdown of RPC client.");
40    }
41}
42
43impl RpcClient {
44    pub async fn connect(uri: Uri, auth_token: Option<&str>) -> anyhow::Result<Self> {
45        let host = uri.host().unwrap();
46        let port = uri.port().unwrap();
47        let socket = format!("{host}:{port}");
48        let stream = TcpStream::connect(socket).await?;
49
50        let req = Request::builder()
51            .method("GET")
52            .uri(&uri)
53            .header("Host", host)
54            .header(UPGRADE, "websocket")
55            .header(CONNECTION, "upgrade")
56            .header(
57                "Sec-Websocket-Key",
58                fastwebsockets::handshake::generate_key(),
59            )
60            .header("Sec-Websocket-Version", "13")
61            .body(String::new())?;
62
63        let pending_reqs = Arc::new(Mutex::new(HashMap::<u64, oneshot::Sender<Resp>>::default()));
64        let pending_reqs_ = pending_reqs.clone();
65
66        let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await?;
67        let (ws_read, mut ws_write) = ws.split(tokio::io::split);
68        let mut ws_read = FragmentCollectorRead::new(ws_read);
69
70        let (client_frame_send, client_frame_recv) = mpsc::channel(10);
71
72        tokio::spawn(async move {
73            let read_frame_from_server = async_stream::stream! {
74                loop {
75                    yield ws_read.read_frame::<_, Infallible>(&mut |_| async { panic!("obligated send not implemented") }).await;
76                }
77            }
78            .map(Either::Left);
79            let recv_frame_to_send = ReceiverStream::new(client_frame_recv).map(Either::Right);
80
81            let stream = futures::stream::select(read_frame_from_server, recv_frame_to_send);
82            pin_mut!(stream);
83
84            loop {
85                let Some(event) = stream.next().await else {
86                    break;
87                };
88
89                match event {
90                    Either::Left(frame_from_server) => match frame_from_server {
91                        Ok(frame) => {
92                            if frame.opcode == fastwebsockets::OpCode::Binary {
93                                let mut data = &frame.payload[..];
94                                let resp = Resp::deserialize(&mut data);
95                                match resp {
96                                    Ok(resp) => {
97                                        let mut pending_reqs = pending_reqs_.lock().await;
98                                        let Some(sender) = pending_reqs.remove(&resp.id) else {
99                                            tracing::warn!(
100                                                "Got response for request that is not pending"
101                                            );
102                                            continue;
103                                        };
104                                        sender.send(resp).ok();
105                                    }
106                                    Err(e) => tracing::error!(
107                                        "Error deserializing response from server: {e}"
108                                    ),
109                                }
110                            }
111                        }
112                        Err(e) => tracing::error!("Error reading message from server: {e}"),
113                    },
114                    Either::Right(frame_to_send) => {
115                        if let Err(e) = ws_write.write_frame(frame_to_send).await {
116                            tracing::warn!("Could not send request to server: {e}");
117                        }
118                    }
119                }
120            }
121        });
122
123        let client = RpcClient {
124            index: Arc::new(0.into()),
125            frame_writer: client_frame_send,
126            pending_reqs,
127        };
128
129        if let Some(auth_token) = auth_token {
130            let resp = client
131                .send_req(ReqKind::Authenticate(auth_token.into()))
132                .await?;
133            match resp.result {
134                Ok(RespKind::Authenticated) => (),
135                Ok(_) => anyhow::bail!("Unexpected response when authenticating"),
136                Err(e) => anyhow::bail!("Authentication error: {e}"),
137            }
138        }
139
140        Ok(client)
141    }
142
143    async fn send_req(&self, kind: ReqKind) -> anyhow::Result<Resp> {
144        let id = self.index.fetch_add(1, SeqCst);
145        let req = Req { id, kind };
146
147        let mut req_bytes = Vec::new();
148        req.serialize(&mut req_bytes)?;
149
150        let (resp_sender, resp_receiver) = oneshot::channel();
151        {
152            let mut pending_reqs = self.pending_reqs.lock().await;
153            pending_reqs.insert(id, resp_sender);
154        }
155
156        self.frame_writer
157            .send(Frame::binary(fastwebsockets::Payload::Owned(req_bytes)))
158            .await?;
159
160        let resp = resp_receiver.await?;
161        assert_eq!(resp.id, id, "Invalid RPC ID in response");
162
163        Ok(resp)
164    }
165
166    pub async fn read_entity<L: Into<ExactLink>>(
167        &self,
168        link: L,
169    ) -> anyhow::Result<Option<(Digest, Entity)>> {
170        let link = link.into();
171        let resp = self.send_req(ReqKind::ReadEntity(link)).await?;
172        let RespKind::ReadEntity(entity) = resp
173            .result
174            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
175        else {
176            anyhow::bail!(INVALID_RPC_RESP_MSG);
177        };
178        Ok(entity)
179    }
180
181    pub async fn del_entity<L: Into<ExactLink>>(&self, link: L) -> anyhow::Result<()> {
182        let link = link.into();
183        let resp = self.send_req(ReqKind::DelEntity(link)).await?;
184        let RespKind::DelEntity = resp
185            .result
186            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
187        else {
188            anyhow::bail!(INVALID_RPC_RESP_MSG);
189        };
190        Ok(())
191    }
192
193    pub async fn list_entities<L: Into<ExactLink>>(
194        &self,
195        link: L,
196    ) -> anyhow::Result<Vec<ExactLink>> {
197        let link = link.into();
198        let resp = self.send_req(ReqKind::ListEntities(link)).await?;
199        let RespKind::ListEntities(entities) = resp
200            .result
201            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
202        else {
203            anyhow::bail!(INVALID_RPC_RESP_MSG);
204        };
205        Ok(entities)
206    }
207
208    // TODO: Support Operating on Multiple Components at a Time.
209    pub async fn del_components<C: Component, L: Into<ExactLink>>(
210        &self,
211        link: L,
212    ) -> anyhow::Result<Option<Digest>> {
213        let link = link.into();
214
215        let resp = self
216            .send_req(ReqKind::DelComponentsBySchema {
217                link,
218                schemas: vec![C::schema_id()],
219            })
220            .await?;
221        let RespKind::DelComponentBySchema(new_digest) = resp
222            .result
223            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
224        else {
225            anyhow::bail!(INVALID_RPC_RESP_MSG);
226        };
227        Ok(new_digest)
228    }
229
230    // TODO: Support Operating on Multiple Components at a Time.
231    pub async fn add_component<C: Component, L: Into<ExactLink>>(
232        &self,
233        link: L,
234        component: C,
235        replace_existing: bool,
236    ) -> anyhow::Result<Digest> {
237        let link = link.into();
238        let component_data = component.make_data()?;
239
240        let resp = self
241            .send_req(ReqKind::AddComponents {
242                link,
243                components: vec![component_data],
244                replace_existing,
245            })
246            .await?;
247        let RespKind::AddComponents(entity_id) = resp
248            .result
249            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
250        else {
251            anyhow::bail!(INVALID_RPC_RESP_MSG);
252        };
253        Ok(entity_id)
254    }
255
256    // TODO: implement way to get multiple components at a time.
257    pub async fn get_components<C: Component, L: Into<ExactLink>>(
258        &self,
259        _link: L,
260    ) -> anyhow::Result<Option<(Digest, Vec<C>)>> {
261        unimplemented!("get_components() needs a better way to get multiple components at a time.");
262        // let link = link.into();
263        // let schema = C::schema_id();
264
265        // let resp = self
266        //     .send_req(ReqKind::GetComponentsBySchema {
267        //         link,
268        //         schemas: vec![schema],
269        //     })
270        //     .await?;
271        // let RespKind::GetComponentBySchema(components) = resp
272        //     .result
273        //     .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
274        // else {
275        //     anyhow::bail!(INVALID_RPC_RESP_MSG);
276        // };
277        // let components = components.map(|data| {
278        //     let components = data
279        //         .components
280        //         .into_iter()
281        //         .map(|(schema, components_data)| {
282        //             let ComponentKind::Unencrypted(comp) =
283        //                 ComponentKind::deserialize(&mut &data[..])?
284        //             else {
285        //                 anyhow::bail!("Encrypted components not supported.");
286        //             };
287        //             assert_eq!(comp.schema, schema);
288        //             let data = C::deserialize(&mut &comp.data[..])?;
289        //             Ok::<_, anyhow::Error>(data)
290        //         })
291        //         .collect::<Result<Vec<C>, _>>();
292
293        //     (digest, components)
294        // });
295        // match components {
296        //     Some((digest, components)) => Ok(Some((digest, components?))),
297        //     None => Ok(None),
298        // }
299    }
300
301    pub async fn create_namespace(&self) -> anyhow::Result<NamespaceId> {
302        let resp = self.send_req(ReqKind::CreateNamespace).await?;
303        let RespKind::CreateNamespace(id) = resp
304            .result
305            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
306        else {
307            anyhow::bail!(INVALID_RPC_RESP_MSG);
308        };
309        Ok(id)
310    }
311    pub async fn import_namespace_secret(
312        &self,
313        namespace: NamespaceSecretKey,
314    ) -> anyhow::Result<NamespaceId> {
315        let resp = self
316            .send_req(ReqKind::ImportNamespaceSecret(namespace))
317            .await?;
318        let RespKind::ImportNamespaceSecret(id) = resp
319            .result
320            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
321        else {
322            anyhow::bail!(INVALID_RPC_RESP_MSG);
323        };
324        Ok(id)
325    }
326    pub async fn get_namespace_secret(
327        &self,
328        namespace: NamespaceSecretKey,
329    ) -> anyhow::Result<Option<NamespaceSecretKey>> {
330        let resp = self
331            .send_req(ReqKind::GetNamespaceSecret(namespace))
332            .await?;
333        let RespKind::GetNamespaceSecret(id) = resp
334            .result
335            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
336        else {
337            anyhow::bail!(INVALID_RPC_RESP_MSG);
338        };
339        Ok(id)
340    }
341
342    pub async fn create_subspace(&self) -> anyhow::Result<SubspaceId> {
343        let resp = self.send_req(ReqKind::CreateSubspace).await?;
344        let RespKind::CreateSubspace(id) = resp
345            .result
346            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
347        else {
348            anyhow::bail!(INVALID_RPC_RESP_MSG);
349        };
350        Ok(id)
351    }
352    pub async fn import_subspace_secret(
353        &self,
354        subspace: SubspaceSecretKey,
355    ) -> anyhow::Result<SubspaceId> {
356        let resp = self
357            .send_req(ReqKind::ImportSubspaceSecret(subspace))
358            .await?;
359        let RespKind::ImportSubspaceSecret(id) = resp
360            .result
361            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
362        else {
363            anyhow::bail!(INVALID_RPC_RESP_MSG);
364        };
365        Ok(id)
366    }
367    pub async fn get_subspace_secret(
368        &self,
369        subspace: SubspaceSecretKey,
370    ) -> anyhow::Result<Option<SubspaceSecretKey>> {
371        let resp = self.send_req(ReqKind::GetSubspaceSecret(subspace)).await?;
372        let RespKind::GetSubspaceSecret(id) = resp
373            .result
374            .map_err(|s| anyhow::format_err!("Error from Leaf RPC endpoint: {s}"))?
375        else {
376            anyhow::bail!(INVALID_RPC_RESP_MSG);
377        };
378        Ok(id)
379    }
380}
381const INVALID_RPC_RESP_MSG: &str = "Invalid response kind from RPC endpoint";
382
383struct SpawnExecutor;
384
385impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
386where
387    Fut: Future + Send + 'static,
388    Fut::Output: Send + 'static,
389{
390    fn execute(&self, fut: Fut) {
391        tokio::task::spawn(fut);
392    }
393}