trapeze/client/
mod.rs

1use std::future::Future;
2use std::io::Result as IoResult;
3use std::ops::{Deref, DerefMut};
4use std::sync::Arc;
5
6use futures::FutureExt;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9use tokio::sync::oneshot;
10use tokio::task::JoinSet;
11
12use crate::context::metadata::Metadata;
13use crate::context::timeout::Timeout;
14use crate::context::Context;
15use crate::io::{MessageIo, SendResult, StreamIo};
16use crate::transport::connect;
17use crate::types::encoding::Encodeable;
18use crate::types::frame::StreamFrame;
19use crate::types::message::Message;
20use crate::{Result, Status};
21
22pub mod request_handlers;
23
24type RequestFnBox = Box<dyn FnOnce(StreamIo, &mut JoinSet<IoResult<()>>) + Send>;
25
26#[derive(Clone)]
27pub struct Client {
28    tx: UnboundedSender<RequestFnBox>,
29    _tasks: Arc<JoinSet<IoResult<()>>>,
30    context: Context,
31}
32
33struct ClientInner {
34    next_id: u32,
35    io: MessageIo,
36    tasks: JoinSet<IoResult<()>>,
37}
38
39impl Deref for Client {
40    type Target = Context;
41    fn deref(&self) -> &Self::Target {
42        &self.context
43    }
44}
45
46impl DerefMut for Client {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        &mut self.context
49    }
50}
51
52impl ClientInner {
53    pub fn new<C: AsyncRead + AsyncWrite + Send + 'static>(connection: C) -> Self {
54        let mut tasks = JoinSet::<IoResult<()>>::new();
55        let io = MessageIo::new(&mut tasks, connection);
56        let next_id = 1;
57
58        Self { next_id, io, tasks }
59    }
60
61    pub async fn start(&mut self, mut req_rx: UnboundedReceiver<RequestFnBox>) -> IoResult<()> {
62        loop {
63            tokio::select! {
64                Some(res) = self.tasks.join_next() => {
65                    res??;
66                },
67                Some(fcn) = req_rx.recv() => {
68                    let id = self.next_id;
69                    self.next_id += 2;
70                    let Some(stream) = self.io.stream(id) else {
71                        log::error!("Ran out of stream ids");
72                        continue;
73                    };
74                    fcn(stream, &mut self.tasks);
75                },
76                Some((id, _)) = self.io.rx.recv() => {
77                    log::error!("Received a message with an invalid stream id `{id}`");
78                },
79                else => {
80                    // no more messages to read, and no more taks to process
81                    // we are done
82                    break;
83                },
84            }
85        }
86        Ok(())
87    }
88}
89
90impl Client {
91    pub fn new<C: AsyncRead + AsyncWrite + Send + 'static>(connection: C) -> Self {
92        let (tx, rx) = unbounded_channel();
93        let mut tasks = JoinSet::<IoResult<()>>::new();
94        let context = Context::default();
95
96        let mut inner = ClientInner::new(connection);
97        tasks.spawn(async move { inner.start(rx).await });
98
99        let tasks = Arc::new(tasks);
100
101        Self {
102            tx,
103            _tasks: tasks,
104            context,
105        }
106    }
107
108    pub async fn connect(address: impl AsRef<str>) -> IoResult<Self> {
109        let conn = connect(address).await?;
110        Ok(Self::new(conn))
111    }
112
113    fn spawn_stream<Fut: Future<Output = Result<()>> + Send, Msg: Message + Encodeable>(
114        &self,
115        frame: impl Into<StreamFrame<Msg>> + Send + 'static,
116        f: impl FnOnce(SendResult, StreamIo) -> Fut + Send + 'static,
117    ) -> impl Future<Output = Result<()>> + Send {
118        let (tx, rx) = oneshot::channel();
119        let _ = self.tx.send(Box::new(move |stream, tasks| {
120            let res = stream.tx.send(frame);
121            tasks.spawn(async move {
122                let _ = tx.send(f(res, stream).await);
123                Ok(())
124            });
125        }));
126
127        async move {
128            let Ok(result) = rx.await else {
129                return Err(Status::channel_closed());
130            };
131            result
132        }
133        .fuse()
134    }
135}
136
137pub trait ClientExt: Clone + Deref<Target = Context> + DerefMut {
138    #[must_use]
139    fn with_metadata(&self, metadata: impl Into<Metadata>) -> Self {
140        let mut this = self.clone();
141        this.metadata = metadata.into();
142        this
143    }
144
145    #[must_use]
146    fn with_timeout(&self, timeout: impl Into<Timeout>) -> Self {
147        let mut this = self.clone();
148        this.timeout = timeout.into();
149        this
150    }
151
152    #[must_use]
153    fn with_context(&self, context: impl Into<Context>) -> Self {
154        let mut this = self.clone();
155        *this = context.into();
156        this
157    }
158}
159
160impl<T: Clone + Deref<Target = Context> + DerefMut> ClientExt for T {}
161
162impl AsRef<Client> for Client {
163    fn as_ref(&self) -> &Client {
164        self
165    }
166}