1use std::{
2 collections::HashMap,
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use byteorder::{BigEndian, ByteOrder};
9use bytes::Bytes;
10use futures_util::FutureExt;
11use protobuf::Message;
12use tokio::sync::{mpsc, oneshot};
13
14use crate::{Error, packet::PacketType, protocol, util::SeqGenerator};
15
16mod types;
17pub use self::types::*;
18
19mod sender;
20pub use self::sender::MercurySender;
21
22component! {
23 MercuryManager : MercuryManagerInner {
24 sequence: SeqGenerator<u64> = SeqGenerator::new(0),
25 pending: HashMap<Vec<u8>, MercuryPending> = HashMap::new(),
26 subscriptions: Vec<(String, mpsc::UnboundedSender<MercuryResponse>)> = Vec::new(),
27 invalid: bool = false,
28 }
29}
30
31pub struct MercuryPending {
32 parts: Vec<Vec<u8>>,
33 partial: Option<Vec<u8>>,
34 callback: Option<oneshot::Sender<Result<MercuryResponse, Error>>>,
35}
36
37pub struct MercuryFuture<T> {
38 receiver: oneshot::Receiver<Result<T, Error>>,
39}
40
41impl<T> Future for MercuryFuture<T> {
42 type Output = Result<T, Error>;
43
44 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45 self.receiver.poll_unpin(cx)?
46 }
47}
48
49impl MercuryManager {
50 fn next_seq(&self) -> Vec<u8> {
51 let mut seq = vec![0u8; 8];
52 BigEndian::write_u64(&mut seq, self.lock(|inner| inner.sequence.get()));
53 seq
54 }
55
56 fn request(&self, req: MercuryRequest) -> Result<MercuryFuture<MercuryResponse>, Error> {
57 let (tx, rx) = oneshot::channel();
58
59 let pending = MercuryPending {
60 parts: Vec::new(),
61 partial: None,
62 callback: Some(tx),
63 };
64
65 let seq = self.next_seq();
66 self.lock(|inner| {
67 if !inner.invalid {
68 inner.pending.insert(seq.clone(), pending);
69 }
70 });
71
72 let cmd = req.method.command();
73 let data = req.encode(&seq)?;
74
75 self.session().send_packet(cmd, data)?;
76 Ok(MercuryFuture { receiver: rx })
77 }
78
79 pub fn get<T: Into<String>>(&self, uri: T) -> Result<MercuryFuture<MercuryResponse>, Error> {
80 self.request(MercuryRequest {
81 method: MercuryMethod::Get,
82 uri: uri.into(),
83 content_type: None,
84 payload: Vec::new(),
85 })
86 }
87
88 pub fn send<T: Into<String>>(
89 &self,
90 uri: T,
91 data: Vec<u8>,
92 ) -> Result<MercuryFuture<MercuryResponse>, Error> {
93 self.request(MercuryRequest {
94 method: MercuryMethod::Send,
95 uri: uri.into(),
96 content_type: None,
97 payload: vec![data],
98 })
99 }
100
101 pub fn sender<T: Into<String>>(&self, uri: T) -> MercurySender {
102 MercurySender::new(self.clone(), uri.into())
103 }
104
105 pub fn subscribe<T: Into<String>>(
106 &self,
107 uri: T,
108 ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<MercuryResponse>, Error>> + 'static
109 {
110 let uri = uri.into();
111 let request = self.request(MercuryRequest {
112 method: MercuryMethod::Sub,
113 uri: uri.clone(),
114 content_type: None,
115 payload: Vec::new(),
116 });
117
118 let manager = self.clone();
119 async move {
120 let response = request?.await?;
121
122 let (tx, rx) = mpsc::unbounded_channel();
123
124 manager.lock(move |inner| {
125 if !inner.invalid {
126 debug!("subscribed uri={} count={}", uri, response.payload.len());
127 if !response.payload.is_empty() {
128 for sub in response.payload {
130 match protocol::pubsub::Subscription::parse_from_bytes(&sub) {
131 Ok(mut sub) => {
132 let sub_uri = sub.take_uri();
133
134 debug!("subscribed sub_uri={sub_uri}");
135
136 inner.subscriptions.push((sub_uri, tx.clone()));
137 }
138 Err(e) => {
139 error!("could not subscribe to {uri}: {e}");
140 }
141 }
142 }
143 } else {
144 inner.subscriptions.push((uri, tx));
146 }
147 }
148 });
149
150 Ok(rx)
151 }
152 }
153
154 pub fn listen_for<T: Into<String>>(
155 &self,
156 uri: T,
157 ) -> impl Future<Output = mpsc::UnboundedReceiver<MercuryResponse>> + 'static {
158 let uri = uri.into();
159
160 let manager = self.clone();
161 async move {
162 let (tx, rx) = mpsc::unbounded_channel();
163
164 manager.lock(move |inner| {
165 if !inner.invalid {
166 debug!("listening to uri={uri}");
167 inner.subscriptions.push((uri, tx));
168 }
169 });
170
171 rx
172 }
173 }
174
175 pub(crate) fn dispatch(&self, cmd: PacketType, mut data: Bytes) -> Result<(), Error> {
176 let seq_len = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
177 let seq = data.split_to(seq_len).as_ref().to_owned();
178
179 let flags = data.split_to(1).as_ref()[0];
180 let count = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
181
182 let pending = self.lock(|inner| inner.pending.remove(&seq));
183
184 let mut pending = match pending {
185 Some(pending) => pending,
186 None => {
187 if let PacketType::MercuryEvent = cmd {
188 MercuryPending {
189 parts: Vec::new(),
190 partial: None,
191 callback: None,
192 }
193 } else {
194 warn!("Ignore seq {:?} cmd {:x}", seq, cmd as u8);
195 return Err(MercuryError::Command(cmd).into());
196 }
197 }
198 };
199
200 for i in 0..count {
201 let mut part = Self::parse_part(&mut data);
202 if let Some(mut partial) = pending.partial.take() {
203 partial.extend_from_slice(&part);
204 part = partial;
205 }
206
207 if i == count - 1 && (flags == 2) {
208 pending.partial = Some(part)
209 } else {
210 pending.parts.push(part);
211 }
212 }
213
214 if flags == 0x1 {
215 self.complete_request(cmd, pending)?;
216 } else {
217 self.lock(move |inner| inner.pending.insert(seq, pending));
218 }
219
220 Ok(())
221 }
222
223 fn parse_part(data: &mut Bytes) -> Vec<u8> {
224 let size = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
225 data.split_to(size).as_ref().to_owned()
226 }
227
228 fn complete_request(&self, cmd: PacketType, mut pending: MercuryPending) -> Result<(), Error> {
229 let header_data = pending.parts.remove(0);
230 let header = protocol::mercury::Header::parse_from_bytes(&header_data)?;
231
232 let response = MercuryResponse {
233 uri: header.uri().to_string(),
234 status_code: header.status_code(),
235 payload: pending.parts,
236 };
237
238 let status_code = response.status_code;
239 if status_code >= 500 {
240 error!("error {} for uri {}", status_code, &response.uri);
241 Err(MercuryError::Response(response).into())
242 } else if status_code >= 400 {
243 error!("error {} for uri {}", status_code, &response.uri);
244 if let Some(cb) = pending.callback {
245 cb.send(Err(MercuryError::Response(response.clone()).into()))
246 .map_err(|_| MercuryError::Channel)?;
247 }
248 Err(MercuryError::Response(response).into())
249 } else if let PacketType::MercuryEvent = cmd {
250 let mut uri_split = response.uri.split('/');
254
255 let encoded_uri = std::iter::once(uri_split.next().unwrap_or_default().to_string())
256 .chain(uri_split.map(|component| {
257 form_urlencoded::byte_serialize(component.as_bytes()).collect::<String>()
258 }))
259 .collect::<Vec<String>>()
260 .join("/");
261
262 let mut found = false;
263
264 self.lock(|inner| {
265 inner.subscriptions.retain(|(prefix, sub)| {
266 if encoded_uri.starts_with(prefix) {
267 found = true;
268
269 sub.send(response.clone()).is_ok()
272 } else {
273 true
275 }
276 });
277 });
278
279 if found {
280 Ok(())
281 } else if self.session().dealer().handles(&response.uri) {
282 trace!("mercury response <{}> is handled by dealer", response.uri);
283 Ok(())
284 } else {
285 debug!("unknown subscription uri={}", &response.uri);
286 trace!("response pushed over Mercury: {response:?}");
287 Err(MercuryError::Response(response).into())
288 }
289 } else if let Some(cb) = pending.callback {
290 cb.send(Ok(response)).map_err(|_| MercuryError::Channel)?;
291 Ok(())
292 } else {
293 error!("can't handle Mercury response: {response:?}");
294 Err(MercuryError::Response(response).into())
295 }
296 }
297
298 pub(crate) fn shutdown(&self) {
299 self.lock(|inner| {
300 inner.invalid = true;
301 inner.pending.clear();
303 inner.subscriptions.clear();
304 });
305 }
306}