Skip to main content

aeron_rpc/
client.rs

1use std::{
2    sync::{Arc, atomic::AtomicU32},
3    time::Duration,
4};
5
6use tokio::sync::oneshot;
7
8use crate::{
9    FromBytes, ToBusinessId,
10    err::{ReceiveError, SendError},
11    protocol::{Client2MultiplexerSender, Request, Response, SendPacket},
12};
13
14pub struct RpcClient {
15    sender: Client2MultiplexerSender,
16    pub request_id: Arc<AtomicU32>,
17}
18
19impl RpcClient {
20    pub fn new(sender: Client2MultiplexerSender) -> Self {
21        Self {
22            sender,
23            request_id: Arc::new(AtomicU32::new(1)),
24        }
25    }
26
27    fn fetch_request_id(&self) -> u64 {
28        self.request_id
29            .fetch_add(1, std::sync::atomic::Ordering::Relaxed) as u64
30    }
31
32    /// Send a request without waiting for a response.
33    pub async fn report(
34        &self,
35        business_id: &impl ToBusinessId,
36        data: impl Into<Vec<u8>>,
37    ) -> Result<(), SendError> {
38        let request_id = self.fetch_request_id();
39        let req = Request::new(request_id, business_id.to_business_id(), data.into());
40        let (tx, rx) = oneshot::channel();
41        self.sender
42            .send(SendPacket {
43                request: req,
44                resp_sender: None,
45                timeout: Duration::from_secs(60),
46                send_signal: tx,
47            })
48            .await
49            .expect("Channel closed unexpectedly");
50
51        rx.await.expect("Channel closed unexpectedly")?;
52
53        Ok(())
54    }
55
56    /// Send a request and wait for a response.
57    pub async fn send<T>(
58        &self,
59        business_id: &impl ToBusinessId,
60        data: impl Into<Vec<u8>>,
61        timeout: Duration,
62    ) -> Result<T, SendError>
63    where
64        T: FromBytes,
65    {
66        let request_id = self.fetch_request_id();
67        let req = Request::new(request_id, business_id.to_business_id(), data.into());
68
69        let (tx, mut rx) = tokio::sync::mpsc::channel(1 << 10);
70
71        let (signal_tx, signal_rx) = oneshot::channel();
72        self.sender
73            .send(SendPacket {
74                request: req,
75                resp_sender: Some(tx),
76                timeout,
77                send_signal: signal_tx,
78            })
79            .await
80            .expect("Channel closed unexpectedly");
81
82        signal_rx.await.expect("Channel closed unexpectedly")?;
83
84        match tokio::time::timeout(timeout, rx.recv()).await {
85            Ok(resp) => {
86                let resp = resp.expect("Channel closed unexpectedly");
87                match resp {
88                    Ok(r) => Ok(T::from_bytes(r.data)?),
89                    Err(_) => Err(SendError::Timeout),
90                }
91            }
92            Err(_) => Err(SendError::Timeout),
93        }
94    }
95
96    /// Send a request and receive a stream of responses.
97    pub async fn send_stream<T>(
98        &self,
99        business_id: &impl ToBusinessId,
100        data: impl Into<Vec<u8>>,
101    ) -> Result<Stream<T>, SendError> {
102        let request_id = self.fetch_request_id();
103        let req = Request::new(request_id, business_id.to_business_id(), data.into());
104
105        let (tx, rx) = tokio::sync::mpsc::channel(1 << 10);
106
107        let (signal_tx, signal_rx) = oneshot::channel();
108
109        self.sender
110            .send(SendPacket {
111                request: req,
112                timeout: Duration::from_secs(60),
113                resp_sender: Some(tx),
114                send_signal: signal_tx,
115            })
116            .await
117            .expect("Channel closed unexpectedly");
118
119        signal_rx.await.expect("Channel closed unexpectedly")?;
120
121        Ok(Stream::new(rx))
122    }
123}
124
125pub struct Stream<T> {
126    rx: tokio::sync::mpsc::Receiver<Result<Response, ReceiveError>>,
127    _marker: std::marker::PhantomData<T>,
128}
129
130impl<T> Stream<T> {
131    pub fn new(rx: tokio::sync::mpsc::Receiver<Result<Response, ReceiveError>>) -> Self {
132        Self {
133            rx,
134            _marker: std::marker::PhantomData,
135        }
136    }
137}
138
139impl<T> Stream<T>
140where
141    T: FromBytes,
142{
143    pub async fn next(&mut self) -> Option<Result<T, ReceiveError>> {
144        self.rx.recv().await.map(|r| {
145            log::trace!("Received response: {:?}", r);
146            Ok(T::from_bytes(r?.data)?)
147        })
148    }
149}