futures_loco_protocol/
session.rs1use std::{
2 fmt::{self, Debug, Display},
3 io, mem,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use flume::{r#async::RecvStream, Receiver, Sender};
9use futures_core::{ready, Future, Stream};
10use futures_io::{AsyncRead, AsyncWrite};
11use loco_protocol::command::Method;
12use nohash_hasher::IntMap;
13
14use crate::{BoxedCommand, LocoClient};
15
16#[derive(Debug, Clone)]
17pub struct LocoSession {
18 sender: Sender<Request>,
19}
20
21impl LocoSession {
22 pub fn new<T>(client: LocoClient<T>) -> (Self, LocoSessionStream<T>) {
23 let (sender, receiver) = flume::bounded(16);
24
25 (Self { sender }, LocoSessionStream::new(receiver, client))
26 }
27
28 pub async fn request(&self, method: Method, data: Vec<u8>) -> Result<CommandRequest, Error> {
29 let (sender, receiver) = oneshot::channel();
30
31 self.sender
32 .send_async(Request {
33 method,
34 data,
35 response_sender: sender,
36 })
37 .await
38 .map_err(|_| Error::SessionClosed)?;
39
40 Ok(CommandRequest { inner: receiver })
41 }
42}
43
44pin_project_lite::pin_project!(
45 pub struct LocoSessionStream<T> {
46 #[pin]
47 request_stream: RecvStream<'static, Request>,
48
49 response_map: IntMap<u32, oneshot::Sender<BoxedCommand>>,
50
51 state: SessionState,
52
53 #[pin]
54 client: LocoClient<T>,
55 }
56);
57
58impl<T> LocoSessionStream<T> {
59 fn new(request_receiver: Receiver<Request>, client: LocoClient<T>) -> Self {
60 Self {
61 request_stream: request_receiver.into_stream(),
62 response_map: IntMap::default(),
63
64 state: SessionState::Pending,
65
66 client,
67 }
68 }
69}
70
71impl<T: AsyncRead + AsyncWrite> Stream for LocoSessionStream<T> {
72 type Item = io::Result<BoxedCommand>;
73
74 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
75 let mut this = self.project();
76
77 loop {
78 match mem::replace(this.state, SessionState::Done) {
79 SessionState::Pending => {
80 while let Poll::Ready(read) = this.client.as_mut().poll_read(cx) {
81 let read = read?;
82
83 if let Some(sender) = this.response_map.remove(&read.header.id) {
84 let _ = sender.send(read);
85 } else {
86 *this.state = SessionState::Pending;
87 return Poll::Ready(Some(Ok(read)));
88 }
89 }
90
91 let mut receiver_read = false;
92 while let Poll::Ready(Some(request)) =
93 this.request_stream.as_mut().poll_next(cx)
94 {
95 let id = this.client.as_mut().write(request.method, &request.data);
96 this.response_map.insert(id, request.response_sender);
97
98 if !receiver_read {
99 receiver_read = true;
100 }
101 }
102
103 if receiver_read {
104 *this.state = SessionState::Write;
105 } else {
106 *this.state = SessionState::Pending;
107 return Poll::Pending;
108 }
109 }
110
111 SessionState::Write => {
112 if this.client.as_mut().poll_flush(cx)?.is_ready() {
113 *this.state = SessionState::Pending;
114 } else {
115 *this.state = SessionState::Write;
116 return Poll::Pending;
117 };
118 }
119
120 SessionState::Done => return Poll::Ready(None),
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy)]
127enum SessionState {
128 Pending,
129 Write,
130 Done,
131}
132
133#[derive(Debug)]
134struct Request {
135 method: Method,
136 data: Vec<u8>,
137 response_sender: oneshot::Sender<BoxedCommand>,
138}
139
140pin_project_lite::pin_project! {
141 #[derive(Debug)]
142 pub struct CommandRequest {
143 #[pin]
144 inner: oneshot::Receiver<BoxedCommand>,
145 }
146}
147
148impl Future for CommandRequest {
149 type Output = Result<BoxedCommand, Error>;
150
151 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152 let command = ready!(self
153 .project()
154 .inner
155 .poll(cx)
156 .map_err(|_| Error::SessionClosed))?;
157
158 Poll::Ready(Ok(command))
159 }
160}
161
162#[derive(Debug)]
163pub enum Error {
164 SessionClosed,
165}
166
167impl Display for Error {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 f.write_str("session closed")
170 }
171}