1use std::{
38 collections::{HashMap, VecDeque},
39 io,
40 marker::PhantomData,
41 pin::Pin,
42 task::{Context, Poll},
43};
44
45use futures::{
46 sink::SinkExt,
47 stream::{Stream, StreamExt},
48};
49use serde::{Deserialize, Serialize};
50use tokio_util::codec::Framed;
51
52use stream_codec::StreamCodec;
53
54mod stream_codec;
55
56pub struct Client<T> {
58 inner: Framed<T, StreamCodec>,
59 id: usize,
60}
61
62impl<T> Client<T>
63where
64 T: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin,
65{
66 pub fn new(io: T) -> Client<T> {
68 let inner = Framed::new(io, StreamCodec::stream_incoming());
69 Client { inner, id: 0 }
70 }
71
72 pub async fn subscribe<F: for<'de> serde::de::Deserialize<'de>>(
74 mut self,
75 name: &str,
76 ) -> io::Result<Handle<T, F>> {
77 let mut topic_list = HashMap::default();
78 let mut pending_recv = VecDeque::new();
79
80 subscribe(
81 &mut self.inner,
82 self.id,
83 name,
84 &mut topic_list,
85 &mut pending_recv,
86 )
87 .await?;
88 self.id = self.id.wrapping_add(1);
89
90 Ok(Handle {
91 inner: self.inner,
92 topic_list,
93 output: PhantomData,
94 rpc_id: self.id,
95 pending_recv,
96 })
97 }
98
99 pub async fn subscribe_list<
101 F: for<'de> serde::de::Deserialize<'de>,
102 I: Iterator<Item = H>,
103 H: AsRef<str>,
104 >(
105 mut self,
106 name_list: I,
107 ) -> io::Result<Handle<T, F>> {
108 let mut topic_list = HashMap::default();
109 let mut pending_recv = VecDeque::new();
110
111 for topic in name_list {
112 subscribe(
113 &mut self.inner,
114 self.id,
115 topic,
116 &mut topic_list,
117 &mut pending_recv,
118 )
119 .await?;
120 self.id = self.id.wrapping_add(1);
121 }
122
123 Ok(Handle {
124 inner: self.inner,
125 topic_list,
126 output: PhantomData,
127 rpc_id: self.id,
128 pending_recv,
129 })
130 }
131}
132
133pub struct Handle<T, F> {
135 inner: Framed<T, StreamCodec>,
136 topic_list: HashMap<String, String>,
137 output: PhantomData<F>,
138 rpc_id: usize,
139 pending_recv: VecDeque<bytes::BytesMut>,
140}
141
142impl<T, F> Handle<T, F>
143where
144 T: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin,
145{
146 pub fn ids(&self) -> impl Iterator<Item = &String> {
148 self.topic_list.keys()
149 }
150
151 pub fn topics(&self) -> impl Iterator<Item = &String> {
153 self.topic_list.values()
154 }
155
156 #[allow(clippy::result_large_err)]
158 pub fn try_into(self) -> Result<Client<T>, Self> {
159 if self.topic_list.is_empty() {
160 Ok(Client {
161 inner: self.inner,
162 id: self.rpc_id,
163 })
164 } else {
165 Err(self)
166 }
167 }
168
169 pub async fn subscribe(mut self, topic: &str) -> io::Result<Self> {
170 if self.topic_list.iter().any(|(_, v)| *v == topic) {
171 return Ok(self);
172 }
173
174 subscribe(
175 &mut self.inner,
176 self.rpc_id,
177 topic,
178 &mut self.topic_list,
179 &mut self.pending_recv,
180 )
181 .await?;
182 self.rpc_id = self.rpc_id.wrapping_add(1);
183
184 Ok(self)
185 }
186
187 pub async fn unsubscribe(&mut self, topic: &str) -> io::Result<()> {
189 let id = {
190 let id = self
191 .topic_list
192 .iter()
193 .find_map(|(k, v)| if v == topic { Some(k) } else { None })
194 .cloned();
195 if id.is_none() {
196 return Ok(());
197 }
198 id.unwrap()
199 };
200 let req_json = format!(
201 r#"{{"id": {}, "jsonrpc": "2.0", "method": "unsubscribe", "params": ["{}"]}}"#,
202 self.rpc_id, id
203 );
204 self.rpc_id = self.rpc_id.wrapping_add(1);
205
206 self.inner.send(req_json).await?;
207
208 let output = loop {
209 let resp = self.inner.next().await;
210
211 let resp = resp.ok_or_else::<io::Error, _>(|| io::ErrorKind::BrokenPipe.into())??;
212
213 match serde_json::from_slice::<jsonrpc_core::response::Output>(&resp) {
216 Ok(output) => break output,
217 Err(_) => self.pending_recv.push_back(resp),
218 }
219 };
220
221 match output {
222 jsonrpc_core::response::Output::Success(_) => {
223 self.topic_list.remove(&id);
224 Ok(())
225 }
226 jsonrpc_core::response::Output::Failure(e) => {
227 Err(io::Error::new(io::ErrorKind::InvalidData, e.error))
228 }
229 }
230 }
231
232 pub async fn unsubscribe_all(mut self) -> io::Result<Client<T>> {
234 for topic in self.topic_list.clone().values() {
235 self.unsubscribe(topic).await?
236 }
237 Ok(Client {
238 inner: self.inner,
239 id: self.rpc_id,
240 })
241 }
242}
243
244impl<T, F> Stream for Handle<T, F>
245where
246 F: for<'de> serde::de::Deserialize<'de> + Unpin,
247 T: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin,
248{
249 type Item = io::Result<(String, F)>;
250
251 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
252 let parse = |data: bytes::BytesMut,
253 topic_list: &HashMap<String, String>|
254 -> io::Result<(String, F)> {
255 let output = serde_json::from_slice::<jsonrpc_core::request::Notification>(&data)
256 .expect("must parse to notification");
257 let message = output
258 .params
259 .parse::<Message>()
260 .expect("must parse to message");
261 serde_json::from_str::<F>(&message.result)
262 .map(|r| (topic_list.get(&message.subscription).cloned().unwrap(), r))
263 .map_err(|_| io::ErrorKind::InvalidData.into())
264 };
265
266 if let Some(data) = self.pending_recv.pop_front() {
267 return Poll::Ready(Some(parse(data, &self.topic_list)));
268 }
269 match self.inner.poll_next_unpin(cx) {
270 Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(parse(frame, &self.topic_list))),
271 Poll::Ready(None) => Poll::Ready(None),
272 Poll::Pending => Poll::Pending,
273 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
274 }
275 }
276}
277
278#[derive(Deserialize, Serialize, Debug)]
279struct Message {
280 result: String,
281 subscription: String,
282}
283
284async fn subscribe<T: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin>(
285 io: &mut Framed<T, StreamCodec>,
286 id: usize,
287 topic: impl AsRef<str>,
288 topic_list: &mut HashMap<String, String>,
289 pending_recv: &mut VecDeque<bytes::BytesMut>,
290) -> io::Result<()> {
291 let req_json = format!(
302 r#"{{"id": {}, "jsonrpc": "2.0", "method": "subscribe", "params": ["{}"]}}"#,
303 id,
304 topic.as_ref()
305 );
306
307 io.send(req_json).await?;
308
309 loop {
311 let resp = io.next().await;
312 let resp = resp.ok_or_else::<io::Error, _>(|| io::ErrorKind::BrokenPipe.into())??;
313 match serde_json::from_slice::<jsonrpc_core::response::Output>(&resp) {
314 Ok(output) => match output {
315 jsonrpc_core::response::Output::Success(success) => {
316 let res = serde_json::from_value::<String>(success.result).unwrap();
317 topic_list.insert(res, topic.as_ref().to_owned());
318 break Ok(());
319 }
320 jsonrpc_core::response::Output::Failure(e) => {
321 return Err(io::Error::new(io::ErrorKind::InvalidData, e.error))
322 }
323 },
324 Err(_) => pending_recv.push_back(resp),
326 }
327 }
328}