1use crate::asynchronous::delayable_stream::ResettableTimer;
2use crate::client::{Client, ClientFactory};
3use crate::destinations::{
4 DestinationId, Destinations, InboundMessage, MessageId, OutboundMessage, Sender, Subscriber,
5 SubscriptionId,
6};
7use crate::error::StomperError;
8
9use either::Either;
10use futures::future::BoxFuture;
11use futures::sink::Sink;
12use futures::stream::{once, Stream};
13use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
14
15use futures::stream::select_all::select_all;
16use log::info;
17use std::convert::TryFrom;
18use std::future::ready;
19use std::future::Future;
20use std::pin::Pin;
21use std::time::Duration;
22use stomp_parser::client::*;
23use stomp_parser::headers::*;
24use stomp_parser::server::*;
25use tokio::time::sleep;
26use tokio_stream::wrappers::UnboundedReceiverStream;
27
28use tokio::sync::mpsc::{self, UnboundedSender};
29
30use std::collections::HashMap;
31
32use super::delayable_stream::ResettableTimerResetter;
33
34const EOL: &[u8; 1] = b"\n";
35const LINGER_TIME: u64 = 1000;
36const HEARTBEAT_BUFFER_PERCENT: u32 = 20;
37
38type ServerMessage = Either<ServerFrame, Vec<u8>>;
39
40enum ClientState {
42 Alive,
43 Dead,
44}
45
46#[derive(Debug)]
48enum ClientEvent {
49 Connected(HeartBeatIntervalls),
50
51 ClientFrame(Result<ClientFrame, StomperError>),
53
54 ClientHeartbeat,
55
56 ServerMessage(SubscriptionId, OutboundMessage),
58
59 Subscribed(
61 DestinationId,
62 SubscriptionId,
63 Result<SubscriptionId, StomperError>,
64 ),
65
66 Unsubscribed(SubscriptionId, Result<SubscriptionId, StomperError>),
68
69 Error(String),
71
72 Heartbeat,
74
75 Close,
77}
78
79#[derive(Debug, Clone)]
96pub struct AsyncStompClient {
97 sender: UnboundedSender<ClientEvent>,
98}
99
100impl AsyncStompClient {
101 fn send_event(&self, event: ClientEvent) {
102 if self.sender.send(event).is_err() {
103 info!("Unable to send ClientEvent, channel closed?");
104 }
105 }
106
107 fn unwrap_subscriber_sub_id(subscriber_sub_id: Option<SubscriptionId>) -> SubscriptionId {
108 subscriber_sub_id
109 .expect("STOMP requires subscriptions to have a client-provided identifier")
110 }
111}
112
113impl Subscriber for AsyncStompClient {
114 fn subscribe_callback(
115 &self,
116 destination_id: DestinationId,
117 client_subscription_id: Option<SubscriptionId>,
118 subscribe_result: Result<SubscriptionId, StomperError>,
119 ) {
120 self.send_event(ClientEvent::Subscribed(
121 destination_id,
122 AsyncStompClient::unwrap_subscriber_sub_id(client_subscription_id),
123 subscribe_result,
124 ));
125 }
126 fn unsubscribe_callback(
127 &self,
128 client_subscription_id: Option<SubscriptionId>,
129 unsubscribe_result: std::result::Result<SubscriptionId, StomperError>,
130 ) {
131 self.send_event(ClientEvent::Unsubscribed(
132 AsyncStompClient::unwrap_subscriber_sub_id(client_subscription_id),
133 unsubscribe_result,
134 ));
135 }
136
137 fn send(
138 &self,
139 _: SubscriptionId,
140 client_subscription_id: Option<SubscriptionId>,
141 message: OutboundMessage,
142 ) -> Result<(), StomperError> {
143 self.sender
144 .send(ClientEvent::ServerMessage(
145 AsyncStompClient::unwrap_subscriber_sub_id(client_subscription_id),
146 message,
147 ))
148 .map_err(|_| StomperError::new("Unable to send message, client channel closed"))
149 }
150}
151
152impl Sender for AsyncStompClient {
153 fn send_callback(&self, _: Option<MessageId>, _: Result<MessageId, StomperError>) {
154 }
156}
157
158impl AsyncStompClient {
176 fn create(sender: UnboundedSender<ClientEvent>) -> Self {
177 AsyncStompClient { sender }
178 }
179}
180
181type ResultType = Pin<
182 Box<
183 dyn Future<
184 Output = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>,
185 > + Send
186 + 'static,
187 >,
188>;
189
190trait ResultStream:
191 Stream<Item = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>>
192 + Send
193 + 'static
194{
195}
196
197impl<
198 T: Stream<Item = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>>
199 + Send
200 + 'static,
201 > ResultStream for T
202{
203}
204
205type RawClientStream = Pin<Box<dyn Stream<Item = Result<Vec<u8>, StomperError>> + Send + 'static>>;
206trait ClientStream: Stream<Item = ClientEvent> + Send + Unpin + 'static {}
207
208impl<T: Stream<Item = ClientEvent> + Send + Unpin + 'static> ClientStream for T {}
209
210fn frame_result(frame: ServerFrame) -> ResultType {
211 state_frame_result(ClientState::Alive, frame)
212}
213
214fn state_frame_result(
215 state: ClientState,
216 frame: ServerFrame,
217) -> BoxFuture<'static, Result<(ClientState, Option<ServerMessage>), StomperError>> {
218 ready(Ok((state, Some(Either::Left(frame))))).boxed()
219}
220
221pub struct ClientSession<T>
222where
223 T: Destinations + 'static,
224{
225 destinations: T,
226 client_proxy: AsyncStompClient,
227 active_subscriptions_by_client_id: HashMap<SubscriptionId, (DestinationId, SubscriptionId)>,
228 server_heartbeat_resetter: ResettableTimerResetter,
229 client_heartbeat_resetter: ResettableTimerResetter,
230 client: T::Client,
231}
232
233impl<T> ClientSession<T>
234where
235 T: Destinations + 'static,
236{
237 fn new(
238 destinations: T,
239 client_proxy: AsyncStompClient,
240 server_heartbeat_resetter: ResettableTimerResetter,
241
242 client_heartbeat_resetter: ResettableTimerResetter,
243 client: T::Client,
244 ) -> ClientSession<T> {
245 ClientSession {
246 destinations,
247 client_proxy,
248 active_subscriptions_by_client_id: HashMap::new(),
249 server_heartbeat_resetter,
250 client_heartbeat_resetter,
251 client,
252 }
253 }
254
255 fn unsubscribe(&mut self, client_subscription_id: SubscriptionId) -> ResultType {
256 match self
257 .active_subscriptions_by_client_id
258 .get(&client_subscription_id)
259 {
260 None => self.error(&format!(
261 "Attempt to unsubscribe from unknown subscription: {}",
262 client_subscription_id
263 )),
264 Some((destination_id, destination_sub_id)) => {
265 self.destinations.unsubscribe(
266 destination_id.clone(),
267 destination_sub_id.clone(),
268 Box::new(self.client_proxy.clone()),
269 &self.client,
270 );
271 ready(Ok((ClientState::Alive, None))).boxed()
272 }
273 }
274 }
275
276 fn client_frame(&mut self, frame: Result<ClientFrame, StomperError>) -> ResultType {
277 match frame {
278 Err(err) => self.error(&format!("Error processing client message: {:?}", err)),
279 Ok(frame) => self.handle(frame).boxed(),
280 }
281 }
282
283 fn subscribed(
284 &mut self,
285 destination: DestinationId,
286 client_subscription_id: SubscriptionId,
287 result: Result<SubscriptionId, StomperError>,
288 ) -> ResultType {
289 if let Ok(destination_sub_id) = result {
290 self.active_subscriptions_by_client_id
291 .insert(client_subscription_id, (destination, destination_sub_id));
292 }
293 ready(Ok((ClientState::Alive, None))).boxed()
294 }
295
296 fn unsubscribed(
297 &mut self,
298 client_subscription_id: SubscriptionId,
299 result: Result<SubscriptionId, StomperError>,
300 ) -> ResultType {
301 if result.is_ok() {
302 self.active_subscriptions_by_client_id
303 .remove(&client_subscription_id);
304 }
305 ready(Ok((ClientState::Alive, None))).boxed()
306 }
307
308 fn server_message(
309 &mut self,
310 client_subscription_id: SubscriptionId,
311 message: OutboundMessage,
312 ) -> ResultType {
313 let raw_body = message.body;
314
315 let message_frame = MessageFrameBuilder::new(
316 message.message_id.into(),
317 message.destination.into(),
318 client_subscription_id.into(),
319 )
320 .content_type("text/plain".to_owned())
321 .content_length(raw_body.len() as u32)
322 .body(raw_body)
323 .build();
324
325 frame_result(ServerFrame::Message(message_frame))
326 }
327
328 fn error(&mut self, message: &str) -> ResultType {
329 let client = self.client_proxy.clone();
330
331 frame_result(ServerFrame::Error(ErrorFrame::from_message(message)))
332 .inspect(move |_| client.send_event(ClientEvent::Close))
333 .boxed()
334 }
335
336 fn send_heartbeat(&self) -> ResultType {
337 println!("Sending heartbeat");
338 ready(Ok((ClientState::Alive, Some(Either::Right(EOL.to_vec()))))).boxed()
339 }
340
341 fn client_message_received(&mut self) {
342 if let Err(err) = self.client_heartbeat_resetter.reset() {
343 log::error!("Error resetting client heartbeat timeout: {:?}", err);
344 }
345 }
346
347 fn handle_event(&mut self, event: ClientEvent) -> ResultType {
348 match event {
349 ClientEvent::Connected(heartbeat) => {
350 let mut builder =
351 ConnectedFrameBuilder::new(StompVersion::V1_2).heartbeat(heartbeat);
352
353 if let Some(session) = self.client.session() {
354 builder = builder.session(session);
355 }
356
357 if let Some(server) = self.client.server() {
358 builder = builder.server(server);
359 }
360
361 let frame = builder.build();
362
363 frame_result(ServerFrame::Connected(frame))
364 }
365 ClientEvent::Close => ready(Ok((ClientState::Dead, None))).boxed(),
366 ClientEvent::ClientFrame(result) => {
367 self.client_message_received();
368 self.client_frame(result)
369 }
370 ClientEvent::ClientHeartbeat => {
371 self.client_message_received();
372 ready(Ok((ClientState::Alive, None))).boxed()
373 }
374 ClientEvent::ServerMessage(client_subscription_id, message) => {
375 self.server_heartbeat_resetter
376 .reset()
377 .expect("Unexpected error");
378 self.server_message(client_subscription_id, message).boxed()
379 }
380 ClientEvent::Subscribed(destination, client_subscription_id, result) => {
381 self.subscribed(destination, client_subscription_id, result)
382 }
383 ClientEvent::Unsubscribed(client_subscription_id, result) => {
384 self.unsubscribed(client_subscription_id, result)
385 }
386 ClientEvent::Error(message) => self.error(&message),
387 ClientEvent::Heartbeat => self.send_heartbeat(),
388 }
389 .boxed()
390 }
391
392 async fn parse_client_message(bytes: Vec<u8>) -> Result<Option<ClientFrame>, StomperError> {
393 if is_heartbeat(&*bytes) {
394 Ok(None)
395 } else {
396 Some(ClientFrame::try_from(bytes).map_err(|err| err.into())).transpose()
397 }
398 }
399
400 fn log_error(error: &StomperError) {
401 log::error!("Error handling event: {}", error);
402 }
403
404 fn not_dead<Q>(result: &Result<(ClientState, Q), StomperError>) -> impl Future<Output = bool> {
405 ready(!matches!(result, Ok((ClientState::Dead, _))))
406 }
407
408 fn into_opt_ok_of_bytes(
409 result: Result<(ClientState, Option<ServerMessage>), StomperError>,
410 ) -> impl Future<Output = Option<Result<Vec<u8>, StomperError>>> {
411 ready(
412 result
414 .map(|(_, opt_frame)| {
415 opt_frame.map(|either| match either {
417 Either::Left(frame) => frame.into(),
418 Either::Right(bytes) => bytes,
419 })
420 })
421 .or(Ok(None))
423 .transpose(),
425 )
426 }
427 pub fn process_stream<F: ClientFactory<T::Client> + 'static>(
428 stream: RawClientStream,
429 server_frame_sink: Pin<
430 Box<dyn Sink<Vec<u8>, Error = StomperError> + Sync + Send + 'static>,
431 >,
432 destinations: T,
433 client_factory: F,
434 ) -> impl Future<Output = Result<(), StomperError>> + Send + 'static {
435 let close_stream = futures::stream::once(async { ClientEvent::Close }).boxed();
437
438 let stream_from_client = stream
439 .and_then(|bytes| Self::parse_client_message(bytes).boxed())
440 .inspect(|frame| log::debug!("Frame: {:?}", frame))
441 .map(|opt_frame| {
442 opt_frame
443 .transpose()
444 .map(ClientEvent::ClientFrame)
445 .unwrap_or(ClientEvent::ClientHeartbeat)
446 })
447 .chain(close_stream);
448
449 tokio::task::spawn(
451 stream_from_client
452 .into_future() .then(|(first_message, stream_from_client)| {
454 Self::validate_and_connect(first_message, client_factory).map(
455 move |validation_result| {
456 Self::handle_connection_validation_result(
457 validation_result,
458 destinations,
459 stream_from_client,
460 )
461 },
462 )
463 })
464 .then(move |stream| Self::process_response_stream(stream, server_frame_sink)),
465 )
466 .inspect(|_| info!("Client completing"))
467 .map_ok(|_| ()) .map_err(|_| StomperError::new("Unable to join response task"))
469 }
470
471 fn process_response_stream<S: ResultStream>(
472 response_stream: S,
473 server_frame_sink: Pin<
474 Box<dyn Sink<Vec<u8>, Error = StomperError> + Sync + Send + 'static>,
475 >,
476 ) -> impl Future<Output = Result<(), StomperError>> {
477 response_stream
478 .chain(futures::stream::once(async {
479 sleep(Duration::from_millis(LINGER_TIME)).await;
480 Err(StomperError::new("Closing stream"))
481 }))
482 .filter_map(Self::into_opt_ok_of_bytes)
483 .forward(server_frame_sink)
484 }
485
486 fn validate_and_connect<F: ClientFactory<T::Client> + 'static>(
487 first_message: Option<ClientEvent>,
488 client_factory: F,
489 ) -> BoxFuture<'static, Result<(HeartBeatIntervalls, T::Client), StomperError>> {
490 match first_message {
491 Some(ClientEvent::ClientFrame(Ok(ClientFrame::Connect(connect_frame)))) => {
492 if !connect_frame
493 .accept_version()
494 .value()
495 .contains(&StompVersion::V1_2)
496 {
497 ready(Err(StomperError::new("Only STOMP 1.2 is supported"))).boxed()
498 } else {
499 let login: Option<String> = connect_frame
500 .login()
501 .map(|login_value| login_value.value().to_owned());
502 let passcode: Option<String> = connect_frame
503 .passcode()
504 .map(|passcode_value| passcode_value.value().to_owned());
505 let heartbeat = connect_frame.heartbeat().value().clone();
506
507 client_factory
508 .create(login, passcode.as_ref())
509 .map_ok(move |client| (heartbeat, client))
510 .boxed()
511 }
512 }
513 _ => ready(Err(StomperError::new(
514 "First message must be a CONNECT frame",
515 )))
516 .boxed(),
517 }
518 }
519
520 fn handle_connection_validation_result<S: ClientStream>(
521 first_result: Result<(HeartBeatIntervalls, T::Client), StomperError>,
522 destinations: T,
523 stream_from_client: S,
524 ) -> impl Stream<Item = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>>
525 + Send {
526 if let Err(error) = first_result {
527 once(ready(Ok((
529 ClientState::Dead,
530 Some(Either::Left(ServerFrame::Error(ErrorFrame::from_message(
531 &error.message,
532 )))),
533 ))))
534 .left_stream()
535 } else {
536 let (tx, rx) = mpsc::unbounded_channel();
537 let client_proxy = AsyncStompClient::create(tx);
538
539 let (heartbeat_requested, client) = first_result.unwrap();
540
541 let stream_from_server = UnboundedReceiverStream::new(rx);
543
544 let (server_heartbeat_stream, server_heartbeat_resetter) = if heartbeat_requested
545 .expected
546 > 0
547 {
548 ResettableTimer::create(Duration::from_millis(heartbeat_requested.expected as u64))
549 } else {
550 ResettableTimer::default()
551 };
552
553 let all_events = once(ready(ClientEvent::Connected(HeartBeatIntervalls::new(
554 heartbeat_requested.expected,
555 heartbeat_requested.supplied,
556 ))))
557 .chain(select_all(vec![
558 stream_from_client.boxed(),
559 stream_from_server.boxed(),
560 server_heartbeat_stream
561 .map(|_| ClientEvent::Heartbeat)
562 .boxed(),
563 ]))
564 .inspect(|event| log::debug!("ClientEvent: {:?}", event));
565
566 let (client_heartbeat_stream, client_heartbeat_resetter) =
567 if heartbeat_requested.supplied > 0 {
568 let heartbeat_with_buffer =
569 heartbeat_requested.supplied * (HEARTBEAT_BUFFER_PERCENT + 100) / 100;
570 ResettableTimer::create(Duration::from_millis(heartbeat_with_buffer as u64))
571 } else {
572 ResettableTimer::default()
573 };
574
575 let event_handler = {
576 let mut client_session = ClientSession::new(
577 destinations,
578 client_proxy,
579 server_heartbeat_resetter,
580 client_heartbeat_resetter,
581 client,
582 );
583
584 client_session.start_heartbeat_listener(client_heartbeat_stream);
585
586 move |event| client_session.handle_event(event)
587 };
588
589 all_events
590 .then(event_handler)
591 .inspect_ok(|(_, message)| {
592 log::debug!("Message to client: {:?}", message);
593 })
594 .inspect_err(Self::log_error)
595 .take_while(Self::not_dead)
596 .right_stream()
597 }
598 }
599
600 fn start_heartbeat_listener(&mut self, mut timer: ResettableTimer) {
601 tokio::task::spawn({
602 let client = self.client_proxy.clone();
603
604 async move {
605 timer
606 .next()
607 .inspect(|_| {
608 client.send_event(ClientEvent::Error("Missed heartbeat".to_owned()));
609 })
610 .await
611 }
612 });
613 }
614
615 fn handle(&mut self, frame: ClientFrame) -> ResultType {
616 match frame {
617 ClientFrame::Connect(_) => self.error("Already connected."),
618
619 ClientFrame::Subscribe(frame) => {
620 self.destinations.subscribe(
621 DestinationId(frame.destination().value().to_owned()),
622 Some(SubscriptionId::from(frame.id().value())),
623 Box::new(self.client_proxy.clone()),
624 &self.client,
625 );
626 ready(Ok((ClientState::Alive, None))).boxed()
627 }
628
629 ClientFrame::Send(frame) => {
630 self.destinations.send(
631 DestinationId(frame.destination().value().to_owned()),
632 InboundMessage {
633 sender_message_id: None,
634 body: frame.body().unwrap().to_owned(),
635 },
636 Box::new(self.client_proxy.clone()),
637 &self.client,
638 );
639 ready(Ok((ClientState::Alive, None))).boxed()
640 }
641
642 ClientFrame::Disconnect(_frame) => {
643 info!("Client Disconnecting");
644 ready(Ok((ClientState::Dead, None))).boxed()
645 }
646 ClientFrame::Unsubscribe(frame) => {
647 self.unsubscribe(SubscriptionId(frame.id().value().to_owned()))
648 }
649
650 ClientFrame::Abort(_frame) => {
651 todo!()
652 }
653
654 ClientFrame::Ack(_frame) => {
655 todo!()
656 }
657
658 ClientFrame::Begin(_frame) => {
659 todo!()
660 }
661
662 ClientFrame::Commit(_frame) => {
663 todo!()
664 }
665
666 ClientFrame::Nack(_frame) => {
667 todo!()
668 }
669 }
670 }
671}
672
673fn is_heartbeat(bytes: &[u8]) -> bool {
674 matches!(bytes, b"\n" | b"\r\n")
675}
676
677#[cfg(test)]
678mod tests {
679 use super::{AsyncStompClient, ClientEvent};
680 use crate::destinations::{
681 DestinationId, MessageId, OutboundMessage, Subscriber, SubscriptionId,
682 };
683 use tokio::sync::mpsc;
684
685 #[tokio::test]
686 async fn it_calls_sender() {
687 let (tx, mut rx) = mpsc::unbounded_channel();
688
689 let client = AsyncStompClient::create(tx);
690
691 let result = client.send(
692 SubscriptionId::from("Arbitrary"),
693 Some(SubscriptionId::from("sub-1")),
694 OutboundMessage {
695 message_id: MessageId::from("1"),
696 destination: DestinationId::from("somedest"),
697 body: "Hello, World".as_bytes().to_owned(),
698 },
699 );
700
701 if result.is_err() {
702 panic!("Send failed");
703 }
704
705 if let Some(ClientEvent::ServerMessage(_, message)) = rx.recv().await {
706 assert_eq!("Hello, World", std::str::from_utf8(&message.body).unwrap());
707 } else {
708 panic!("No, or incorrect, message received");
709 }
710 }
711
712 #[tokio::test]
713 async fn returns_error_on_failure() {
714 let (tx, mut rx) = mpsc::unbounded_channel();
715
716 let client = AsyncStompClient::create(tx);
717
718 rx.close();
719
720 let result = client.send(
721 SubscriptionId::from("Arbitrary"),
722 Some(SubscriptionId::from("sub-1")),
723 OutboundMessage {
724 message_id: MessageId::from("1"),
725 destination: DestinationId::from("somedest"),
726 body: "Hello, World".as_bytes().to_owned(),
727 },
728 );
729
730 if let Err(error) = result {
731 assert_eq!(
732 "Unable to send message, client channel closed",
733 error.message
734 )
735 } else {
736 panic!("No, or incorrect, error message received");
737 }
738 }
739}