futures_loco_protocol/
session.rs

1use std::{
2    fmt::{self, Debug, Display},
3    io, mem,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use flume::{r#async::RecvStream, Receiver, Sender};
9use futures_core::{ready, Future, Stream};
10use futures_io::{AsyncRead, AsyncWrite};
11use loco_protocol::command::Method;
12use nohash_hasher::IntMap;
13
14use crate::{BoxedCommand, LocoClient};
15
16#[derive(Debug, Clone)]
17pub struct LocoSession {
18    sender: Sender<Request>,
19}
20
21impl LocoSession {
22    pub fn new<T>(client: LocoClient<T>) -> (Self, LocoSessionStream<T>) {
23        let (sender, receiver) = flume::bounded(16);
24
25        (Self { sender }, LocoSessionStream::new(receiver, client))
26    }
27
28    pub async fn request(&self, method: Method, data: Vec<u8>) -> Result<CommandRequest, Error> {
29        let (sender, receiver) = oneshot::channel();
30
31        self.sender
32            .send_async(Request {
33                method,
34                data,
35                response_sender: sender,
36            })
37            .await
38            .map_err(|_| Error::SessionClosed)?;
39
40        Ok(CommandRequest { inner: receiver })
41    }
42}
43
44pin_project_lite::pin_project!(
45    pub struct LocoSessionStream<T> {
46        #[pin]
47        request_stream: RecvStream<'static, Request>,
48
49        response_map: IntMap<u32, oneshot::Sender<BoxedCommand>>,
50
51        state: SessionState,
52
53        #[pin]
54        client: LocoClient<T>,
55    }
56);
57
58impl<T> LocoSessionStream<T> {
59    fn new(request_receiver: Receiver<Request>, client: LocoClient<T>) -> Self {
60        Self {
61            request_stream: request_receiver.into_stream(),
62            response_map: IntMap::default(),
63
64            state: SessionState::Pending,
65
66            client,
67        }
68    }
69}
70
71impl<T: AsyncRead + AsyncWrite> Stream for LocoSessionStream<T> {
72    type Item = io::Result<BoxedCommand>;
73
74    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
75        let mut this = self.project();
76
77        loop {
78            match mem::replace(this.state, SessionState::Done) {
79                SessionState::Pending => {
80                    while let Poll::Ready(read) = this.client.as_mut().poll_read(cx) {
81                        let read = read?;
82
83                        if let Some(sender) = this.response_map.remove(&read.header.id) {
84                            let _ = sender.send(read);
85                        } else {
86                            *this.state = SessionState::Pending;
87                            return Poll::Ready(Some(Ok(read)));
88                        }
89                    }
90
91                    let mut receiver_read = false;
92                    while let Poll::Ready(Some(request)) =
93                        this.request_stream.as_mut().poll_next(cx)
94                    {
95                        let id = this.client.as_mut().write(request.method, &request.data);
96                        this.response_map.insert(id, request.response_sender);
97
98                        if !receiver_read {
99                            receiver_read = true;
100                        }
101                    }
102
103                    if receiver_read {
104                        *this.state = SessionState::Write;
105                    } else {
106                        *this.state = SessionState::Pending;
107                        return Poll::Pending;
108                    }
109                }
110
111                SessionState::Write => {
112                    if this.client.as_mut().poll_flush(cx)?.is_ready() {
113                        *this.state = SessionState::Pending;
114                    } else {
115                        *this.state = SessionState::Write;
116                        return Poll::Pending;
117                    };
118                }
119
120                SessionState::Done => return Poll::Ready(None),
121            }
122        }
123    }
124}
125
126#[derive(Debug, Clone, Copy)]
127enum SessionState {
128    Pending,
129    Write,
130    Done,
131}
132
133#[derive(Debug)]
134struct Request {
135    method: Method,
136    data: Vec<u8>,
137    response_sender: oneshot::Sender<BoxedCommand>,
138}
139
140pin_project_lite::pin_project! {
141    #[derive(Debug)]
142    pub struct CommandRequest {
143        #[pin]
144        inner: oneshot::Receiver<BoxedCommand>,
145    }
146}
147
148impl Future for CommandRequest {
149    type Output = Result<BoxedCommand, Error>;
150
151    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152        let command = ready!(self
153            .project()
154            .inner
155            .poll(cx)
156            .map_err(|_| Error::SessionClosed))?;
157
158        Poll::Ready(Ok(command))
159    }
160}
161
162#[derive(Debug)]
163pub enum Error {
164    SessionClosed,
165}
166
167impl Display for Error {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        f.write_str("session closed")
170    }
171}