little_stomper/asynchronous/
client.rs

1use crate::asynchronous::delayable_stream::ResettableTimer;
2use crate::client::{Client, ClientFactory};
3use crate::destinations::{
4    DestinationId, Destinations, InboundMessage, MessageId, OutboundMessage, Sender, Subscriber,
5    SubscriptionId,
6};
7use crate::error::StomperError;
8
9use either::Either;
10use futures::future::BoxFuture;
11use futures::sink::Sink;
12use futures::stream::{once, Stream};
13use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
14
15use futures::stream::select_all::select_all;
16use log::info;
17use std::convert::TryFrom;
18use std::future::ready;
19use std::future::Future;
20use std::pin::Pin;
21use std::time::Duration;
22use stomp_parser::client::*;
23use stomp_parser::headers::*;
24use stomp_parser::server::*;
25use tokio::time::sleep;
26use tokio_stream::wrappers::UnboundedReceiverStream;
27
28use tokio::sync::mpsc::{self, UnboundedSender};
29
30use std::collections::HashMap;
31
32use super::delayable_stream::ResettableTimerResetter;
33
34const EOL: &[u8; 1] = b"\n";
35const LINGER_TIME: u64 = 1000;
36const HEARTBEAT_BUFFER_PERCENT: u32 = 20;
37
38type ServerMessage = Either<ServerFrame, Vec<u8>>;
39
40/// Indicates or changes the current state of the client
41enum ClientState {
42    Alive,
43    Dead,
44}
45
46/// And Event which a AsyncStompClient can receive
47#[derive(Debug)]
48enum ClientEvent {
49    Connected(HeartBeatIntervalls),
50
51    /// A Frame from the client if received and parsed correctly, or Err if there was an error
52    ClientFrame(Result<ClientFrame, StomperError>),
53
54    ClientHeartbeat,
55
56    /// A Message the server wishes to send to the client, specifying the (client's) subscription id, as well as the message itself
57    ServerMessage(SubscriptionId, OutboundMessage),
58
59    /// A callback indicating the result of an attempt to subscribe the client to a destination
60    Subscribed(
61        DestinationId,
62        SubscriptionId,
63        Result<SubscriptionId, StomperError>,
64    ),
65
66    /// A callback indicating the result of an attempt to unsubscribe the client from a destination
67    Unsubscribed(SubscriptionId, Result<SubscriptionId, StomperError>),
68
69    /// An error that should be communicated to the client
70    Error(String),
71
72    /// Send a heartbeat to the client
73    Heartbeat,
74
75    /// An event indicating the client connection was closed.
76    Close,
77}
78
79/// A proxy for a client which can subscribe to destinations, receive messages and send messages.
80///
81/// Note that a client must also implement [destinations::Subscriber](crate::destinations::Subscriber) and [destinations::Sender](crate::destinations::Sender),
82/// which define the bulk of the API.
83// trait ClientProxy: Subscriber + Sender + Sync + Send {
84//     /// Allows error messages to be send to the client
85//     fn error(&self, message: &str);
86
87//     /// Exposes self as a Sender.
88//     fn into_sender(self: Arc<Self>) -> Arc<dyn Sender>;
89
90//     /// Exposes self as a Subscriber.
91//     fn into_subscriber(self: Arc<Self>) -> Arc<dyn Subscriber>;
92
93//     fn send_heartbeat(&self);
94// }
95#[derive(Debug, Clone)]
96pub struct AsyncStompClient {
97    sender: UnboundedSender<ClientEvent>,
98}
99
100impl AsyncStompClient {
101    fn send_event(&self, event: ClientEvent) {
102        if self.sender.send(event).is_err() {
103            info!("Unable to send ClientEvent, channel closed?");
104        }
105    }
106
107    fn unwrap_subscriber_sub_id(subscriber_sub_id: Option<SubscriptionId>) -> SubscriptionId {
108        subscriber_sub_id
109            .expect("STOMP requires subscriptions to have a client-provided identifier")
110    }
111}
112
113impl Subscriber for AsyncStompClient {
114    fn subscribe_callback(
115        &self,
116        destination_id: DestinationId,
117        client_subscription_id: Option<SubscriptionId>,
118        subscribe_result: Result<SubscriptionId, StomperError>,
119    ) {
120        self.send_event(ClientEvent::Subscribed(
121            destination_id,
122            AsyncStompClient::unwrap_subscriber_sub_id(client_subscription_id),
123            subscribe_result,
124        ));
125    }
126    fn unsubscribe_callback(
127        &self,
128        client_subscription_id: Option<SubscriptionId>,
129        unsubscribe_result: std::result::Result<SubscriptionId, StomperError>,
130    ) {
131        self.send_event(ClientEvent::Unsubscribed(
132            AsyncStompClient::unwrap_subscriber_sub_id(client_subscription_id),
133            unsubscribe_result,
134        ));
135    }
136
137    fn send(
138        &self,
139        _: SubscriptionId,
140        client_subscription_id: Option<SubscriptionId>,
141        message: OutboundMessage,
142    ) -> Result<(), StomperError> {
143        self.sender
144            .send(ClientEvent::ServerMessage(
145                AsyncStompClient::unwrap_subscriber_sub_id(client_subscription_id),
146                message,
147            ))
148            .map_err(|_| StomperError::new("Unable to send message, client channel closed"))
149    }
150}
151
152impl Sender for AsyncStompClient {
153    fn send_callback(&self, _: Option<MessageId>, _: Result<MessageId, StomperError>) {
154        //don't really care (for now?)
155    }
156}
157
158// impl ClientProxy for AsyncStompClient {
159//     fn into_sender(self: Arc<Self>) -> Arc<(dyn Sender + 'static)> {
160//         self
161//     }
162//     fn into_subscriber(self: Arc<Self>) -> Arc<(dyn Subscriber + 'static)> {
163//         self
164//     }
165
166//     fn error(&self, message: &str) {
167//         self.send_event(ClientEvent::Error(message.to_owned()));
168//     }
169
170//     fn send_heartbeat(&self) {
171//         self.send_event(ClientEvent::Heartbeat)
172//     }
173// }
174
175impl AsyncStompClient {
176    fn create(sender: UnboundedSender<ClientEvent>) -> Self {
177        AsyncStompClient { sender }
178    }
179}
180
181type ResultType = Pin<
182    Box<
183        dyn Future<
184                Output = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>,
185            > + Send
186            + 'static,
187    >,
188>;
189
190trait ResultStream:
191    Stream<Item = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>>
192    + Send
193    + 'static
194{
195}
196
197impl<
198        T: Stream<Item = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>>
199            + Send
200            + 'static,
201    > ResultStream for T
202{
203}
204
205type RawClientStream = Pin<Box<dyn Stream<Item = Result<Vec<u8>, StomperError>> + Send + 'static>>;
206trait ClientStream: Stream<Item = ClientEvent> + Send + Unpin + 'static {}
207
208impl<T: Stream<Item = ClientEvent> + Send + Unpin + 'static> ClientStream for T {}
209
210fn frame_result(frame: ServerFrame) -> ResultType {
211    state_frame_result(ClientState::Alive, frame)
212}
213
214fn state_frame_result(
215    state: ClientState,
216    frame: ServerFrame,
217) -> BoxFuture<'static, Result<(ClientState, Option<ServerMessage>), StomperError>> {
218    ready(Ok((state, Some(Either::Left(frame))))).boxed()
219}
220
221pub struct ClientSession<T>
222where
223    T: Destinations + 'static,
224{
225    destinations: T,
226    client_proxy: AsyncStompClient,
227    active_subscriptions_by_client_id: HashMap<SubscriptionId, (DestinationId, SubscriptionId)>,
228    server_heartbeat_resetter: ResettableTimerResetter,
229    client_heartbeat_resetter: ResettableTimerResetter,
230    client: T::Client,
231}
232
233impl<T> ClientSession<T>
234where
235    T: Destinations + 'static,
236{
237    fn new(
238        destinations: T,
239        client_proxy: AsyncStompClient,
240        server_heartbeat_resetter: ResettableTimerResetter,
241
242        client_heartbeat_resetter: ResettableTimerResetter,
243        client: T::Client,
244    ) -> ClientSession<T> {
245        ClientSession {
246            destinations,
247            client_proxy,
248            active_subscriptions_by_client_id: HashMap::new(),
249            server_heartbeat_resetter,
250            client_heartbeat_resetter,
251            client,
252        }
253    }
254
255    fn unsubscribe(&mut self, client_subscription_id: SubscriptionId) -> ResultType {
256        match self
257            .active_subscriptions_by_client_id
258            .get(&client_subscription_id)
259        {
260            None => self.error(&format!(
261                "Attempt to unsubscribe from unknown subscription: {}",
262                client_subscription_id
263            )),
264            Some((destination_id, destination_sub_id)) => {
265                self.destinations.unsubscribe(
266                    destination_id.clone(),
267                    destination_sub_id.clone(),
268                    Box::new(self.client_proxy.clone()),
269                    &self.client,
270                );
271                ready(Ok((ClientState::Alive, None))).boxed()
272            }
273        }
274    }
275
276    fn client_frame(&mut self, frame: Result<ClientFrame, StomperError>) -> ResultType {
277        match frame {
278            Err(err) => self.error(&format!("Error processing client message: {:?}", err)),
279            Ok(frame) => self.handle(frame).boxed(),
280        }
281    }
282
283    fn subscribed(
284        &mut self,
285        destination: DestinationId,
286        client_subscription_id: SubscriptionId,
287        result: Result<SubscriptionId, StomperError>,
288    ) -> ResultType {
289        if let Ok(destination_sub_id) = result {
290            self.active_subscriptions_by_client_id
291                .insert(client_subscription_id, (destination, destination_sub_id));
292        }
293        ready(Ok((ClientState::Alive, None))).boxed()
294    }
295
296    fn unsubscribed(
297        &mut self,
298        client_subscription_id: SubscriptionId,
299        result: Result<SubscriptionId, StomperError>,
300    ) -> ResultType {
301        if result.is_ok() {
302            self.active_subscriptions_by_client_id
303                .remove(&client_subscription_id);
304        }
305        ready(Ok((ClientState::Alive, None))).boxed()
306    }
307
308    fn server_message(
309        &mut self,
310        client_subscription_id: SubscriptionId,
311        message: OutboundMessage,
312    ) -> ResultType {
313        let raw_body = message.body;
314
315        let message_frame = MessageFrameBuilder::new(
316            message.message_id.into(),
317            message.destination.into(),
318            client_subscription_id.into(),
319        )
320        .content_type("text/plain".to_owned())
321        .content_length(raw_body.len() as u32)
322        .body(raw_body)
323        .build();
324
325        frame_result(ServerFrame::Message(message_frame))
326    }
327
328    fn error(&mut self, message: &str) -> ResultType {
329        let client = self.client_proxy.clone();
330
331        frame_result(ServerFrame::Error(ErrorFrame::from_message(message)))
332            .inspect(move |_| client.send_event(ClientEvent::Close))
333            .boxed()
334    }
335
336    fn send_heartbeat(&self) -> ResultType {
337        println!("Sending heartbeat");
338        ready(Ok((ClientState::Alive, Some(Either::Right(EOL.to_vec()))))).boxed()
339    }
340
341    fn client_message_received(&mut self) {
342        if let Err(err) = self.client_heartbeat_resetter.reset() {
343            log::error!("Error resetting client heartbeat timeout: {:?}", err);
344        }
345    }
346
347    fn handle_event(&mut self, event: ClientEvent) -> ResultType {
348        match event {
349            ClientEvent::Connected(heartbeat) => {
350                let mut builder =
351                    ConnectedFrameBuilder::new(StompVersion::V1_2).heartbeat(heartbeat);
352
353                if let Some(session) = self.client.session() {
354                    builder = builder.session(session);
355                }
356
357                if let Some(server) = self.client.server() {
358                    builder = builder.server(server);
359                }
360
361                let frame = builder.build();
362
363                frame_result(ServerFrame::Connected(frame))
364            }
365            ClientEvent::Close => ready(Ok((ClientState::Dead, None))).boxed(),
366            ClientEvent::ClientFrame(result) => {
367                self.client_message_received();
368                self.client_frame(result)
369            }
370            ClientEvent::ClientHeartbeat => {
371                self.client_message_received();
372                ready(Ok((ClientState::Alive, None))).boxed()
373            }
374            ClientEvent::ServerMessage(client_subscription_id, message) => {
375                self.server_heartbeat_resetter
376                    .reset()
377                    .expect("Unexpected error");
378                self.server_message(client_subscription_id, message).boxed()
379            }
380            ClientEvent::Subscribed(destination, client_subscription_id, result) => {
381                self.subscribed(destination, client_subscription_id, result)
382            }
383            ClientEvent::Unsubscribed(client_subscription_id, result) => {
384                self.unsubscribed(client_subscription_id, result)
385            }
386            ClientEvent::Error(message) => self.error(&message),
387            ClientEvent::Heartbeat => self.send_heartbeat(),
388        }
389        .boxed()
390    }
391
392    async fn parse_client_message(bytes: Vec<u8>) -> Result<Option<ClientFrame>, StomperError> {
393        if is_heartbeat(&*bytes) {
394            Ok(None)
395        } else {
396            Some(ClientFrame::try_from(bytes).map_err(|err| err.into())).transpose()
397        }
398    }
399
400    fn log_error(error: &StomperError) {
401        log::error!("Error handling event: {}", error);
402    }
403
404    fn not_dead<Q>(result: &Result<(ClientState, Q), StomperError>) -> impl Future<Output = bool> {
405        ready(!matches!(result, Ok((ClientState::Dead, _))))
406    }
407
408    fn into_opt_ok_of_bytes(
409        result: Result<(ClientState, Option<ServerMessage>), StomperError>,
410    ) -> impl Future<Output = Option<Result<Vec<u8>, StomperError>>> {
411        ready(
412            // Drop the ClientState, already handled
413            result
414                .map(|(_, opt_frame)| {
415                    // serialize the frame
416                    opt_frame.map(|either| match either {
417                        Either::Left(frame) => frame.into(),
418                        Either::Right(bytes) => bytes,
419                    })
420                })
421                // drop errors
422                .or(Ok(None))
423                // cause only Some(Ok(bytes)) values to be passed on
424                .transpose(),
425        )
426    }
427    pub fn process_stream<F: ClientFactory<T::Client> + 'static>(
428        stream: RawClientStream,
429        server_frame_sink: Pin<
430            Box<dyn Sink<Vec<u8>, Error = StomperError> + Sync + Send + 'static>,
431        >,
432        destinations: T,
433        client_factory: F,
434    ) -> impl Future<Output = Result<(), StomperError>> + Send + 'static {
435        // Closes this session; will be chained to client stream to run after that ends
436        let close_stream = futures::stream::once(async { ClientEvent::Close }).boxed();
437
438        let stream_from_client = stream
439            .and_then(|bytes| Self::parse_client_message(bytes).boxed())
440            .inspect(|frame| log::debug!("Frame: {:?}", frame))
441            .map(|opt_frame| {
442                opt_frame
443                    .transpose()
444                    .map(ClientEvent::ClientFrame)
445                    .unwrap_or(ClientEvent::ClientHeartbeat)
446            })
447            .chain(close_stream);
448
449        // the first message must be a connect frame
450        tokio::task::spawn(
451            stream_from_client
452                .into_future() // Split off the first message for individual handling
453                .then(|(first_message, stream_from_client)| {
454                    Self::validate_and_connect(first_message, client_factory).map(
455                        move |validation_result| {
456                            Self::handle_connection_validation_result(
457                                validation_result,
458                                destinations,
459                                stream_from_client,
460                            )
461                        },
462                    )
463                })
464                .then(move |stream| Self::process_response_stream(stream, server_frame_sink)),
465        )
466        .inspect(|_| info!("Client completing"))
467        .map_ok(|_| ()) // ignore the result from the forward.
468        .map_err(|_| StomperError::new("Unable to join response task"))
469    }
470
471    fn process_response_stream<S: ResultStream>(
472        response_stream: S,
473        server_frame_sink: Pin<
474            Box<dyn Sink<Vec<u8>, Error = StomperError> + Sync + Send + 'static>,
475        >,
476    ) -> impl Future<Output = Result<(), StomperError>> {
477        response_stream
478            .chain(futures::stream::once(async {
479                sleep(Duration::from_millis(LINGER_TIME)).await;
480                Err(StomperError::new("Closing stream"))
481            }))
482            .filter_map(Self::into_opt_ok_of_bytes)
483            .forward(server_frame_sink)
484    }
485
486    fn validate_and_connect<F: ClientFactory<T::Client> + 'static>(
487        first_message: Option<ClientEvent>,
488        client_factory: F,
489    ) -> BoxFuture<'static, Result<(HeartBeatIntervalls, T::Client), StomperError>> {
490        match first_message {
491            Some(ClientEvent::ClientFrame(Ok(ClientFrame::Connect(connect_frame)))) => {
492                if !connect_frame
493                    .accept_version()
494                    .value()
495                    .contains(&StompVersion::V1_2)
496                {
497                    ready(Err(StomperError::new("Only STOMP 1.2 is supported"))).boxed()
498                } else {
499                    let login: Option<String> = connect_frame
500                        .login()
501                        .map(|login_value| login_value.value().to_owned());
502                    let passcode: Option<String> = connect_frame
503                        .passcode()
504                        .map(|passcode_value| passcode_value.value().to_owned());
505                    let heartbeat = connect_frame.heartbeat().value().clone();
506
507                    client_factory
508                        .create(login, passcode.as_ref())
509                        .map_ok(move |client| (heartbeat, client))
510                        .boxed()
511                }
512            }
513            _ => ready(Err(StomperError::new(
514                "First message must be a CONNECT frame",
515            )))
516            .boxed(),
517        }
518    }
519
520    fn handle_connection_validation_result<S: ClientStream>(
521        first_result: Result<(HeartBeatIntervalls, T::Client), StomperError>,
522        destinations: T,
523        stream_from_client: S,
524    ) -> impl Stream<Item = Result<(ClientState, Option<Either<ServerFrame, Vec<u8>>>), StomperError>>
525           + Send {
526        if let Err(error) = first_result {
527            //todo!("Send error response")
528            once(ready(Ok((
529                ClientState::Dead,
530                Some(Either::Left(ServerFrame::Error(ErrorFrame::from_message(
531                    &error.message,
532                )))),
533            ))))
534            .left_stream()
535        } else {
536            let (tx, rx) = mpsc::unbounded_channel();
537            let client_proxy = AsyncStompClient::create(tx);
538
539            let (heartbeat_requested, client) = first_result.unwrap();
540
541            // This is the stream of events generated by processing on the server-side, rather than directly from the client
542            let stream_from_server = UnboundedReceiverStream::new(rx);
543
544            let (server_heartbeat_stream, server_heartbeat_resetter) = if heartbeat_requested
545                .expected
546                > 0
547            {
548                ResettableTimer::create(Duration::from_millis(heartbeat_requested.expected as u64))
549            } else {
550                ResettableTimer::default()
551            };
552
553            let all_events = once(ready(ClientEvent::Connected(HeartBeatIntervalls::new(
554                heartbeat_requested.expected,
555                heartbeat_requested.supplied,
556            ))))
557            .chain(select_all(vec![
558                stream_from_client.boxed(),
559                stream_from_server.boxed(),
560                server_heartbeat_stream
561                    .map(|_| ClientEvent::Heartbeat)
562                    .boxed(),
563            ]))
564            .inspect(|event| log::debug!("ClientEvent: {:?}", event));
565
566            let (client_heartbeat_stream, client_heartbeat_resetter) =
567                if heartbeat_requested.supplied > 0 {
568                    let heartbeat_with_buffer =
569                        heartbeat_requested.supplied * (HEARTBEAT_BUFFER_PERCENT + 100) / 100;
570                    ResettableTimer::create(Duration::from_millis(heartbeat_with_buffer as u64))
571                } else {
572                    ResettableTimer::default()
573                };
574
575            let event_handler = {
576                let mut client_session = ClientSession::new(
577                    destinations,
578                    client_proxy,
579                    server_heartbeat_resetter,
580                    client_heartbeat_resetter,
581                    client,
582                );
583
584                client_session.start_heartbeat_listener(client_heartbeat_stream);
585
586                move |event| client_session.handle_event(event)
587            };
588
589            all_events
590                .then(event_handler)
591                .inspect_ok(|(_, message)| {
592                    log::debug!("Message to client: {:?}", message);
593                })
594                .inspect_err(Self::log_error)
595                .take_while(Self::not_dead)
596                .right_stream()
597        }
598    }
599
600    fn start_heartbeat_listener(&mut self, mut timer: ResettableTimer) {
601        tokio::task::spawn({
602            let client = self.client_proxy.clone();
603
604            async move {
605                timer
606                    .next()
607                    .inspect(|_| {
608                        client.send_event(ClientEvent::Error("Missed heartbeat".to_owned()));
609                    })
610                    .await
611            }
612        });
613    }
614
615    fn handle(&mut self, frame: ClientFrame) -> ResultType {
616        match frame {
617            ClientFrame::Connect(_) => self.error("Already connected."),
618
619            ClientFrame::Subscribe(frame) => {
620                self.destinations.subscribe(
621                    DestinationId(frame.destination().value().to_owned()),
622                    Some(SubscriptionId::from(frame.id().value())),
623                    Box::new(self.client_proxy.clone()),
624                    &self.client,
625                );
626                ready(Ok((ClientState::Alive, None))).boxed()
627            }
628
629            ClientFrame::Send(frame) => {
630                self.destinations.send(
631                    DestinationId(frame.destination().value().to_owned()),
632                    InboundMessage {
633                        sender_message_id: None,
634                        body: frame.body().unwrap().to_owned(),
635                    },
636                    Box::new(self.client_proxy.clone()),
637                    &self.client,
638                );
639                ready(Ok((ClientState::Alive, None))).boxed()
640            }
641
642            ClientFrame::Disconnect(_frame) => {
643                info!("Client Disconnecting");
644                ready(Ok((ClientState::Dead, None))).boxed()
645            }
646            ClientFrame::Unsubscribe(frame) => {
647                self.unsubscribe(SubscriptionId(frame.id().value().to_owned()))
648            }
649
650            ClientFrame::Abort(_frame) => {
651                todo!()
652            }
653
654            ClientFrame::Ack(_frame) => {
655                todo!()
656            }
657
658            ClientFrame::Begin(_frame) => {
659                todo!()
660            }
661
662            ClientFrame::Commit(_frame) => {
663                todo!()
664            }
665
666            ClientFrame::Nack(_frame) => {
667                todo!()
668            }
669        }
670    }
671}
672
673fn is_heartbeat(bytes: &[u8]) -> bool {
674    matches!(bytes, b"\n" | b"\r\n")
675}
676
677#[cfg(test)]
678mod tests {
679    use super::{AsyncStompClient, ClientEvent};
680    use crate::destinations::{
681        DestinationId, MessageId, OutboundMessage, Subscriber, SubscriptionId,
682    };
683    use tokio::sync::mpsc;
684
685    #[tokio::test]
686    async fn it_calls_sender() {
687        let (tx, mut rx) = mpsc::unbounded_channel();
688
689        let client = AsyncStompClient::create(tx);
690
691        let result = client.send(
692            SubscriptionId::from("Arbitrary"),
693            Some(SubscriptionId::from("sub-1")),
694            OutboundMessage {
695                message_id: MessageId::from("1"),
696                destination: DestinationId::from("somedest"),
697                body: "Hello, World".as_bytes().to_owned(),
698            },
699        );
700
701        if result.is_err() {
702            panic!("Send failed");
703        }
704
705        if let Some(ClientEvent::ServerMessage(_, message)) = rx.recv().await {
706            assert_eq!("Hello, World", std::str::from_utf8(&message.body).unwrap());
707        } else {
708            panic!("No, or incorrect, message received");
709        }
710    }
711
712    #[tokio::test]
713    async fn returns_error_on_failure() {
714        let (tx, mut rx) = mpsc::unbounded_channel();
715
716        let client = AsyncStompClient::create(tx);
717
718        rx.close();
719
720        let result = client.send(
721            SubscriptionId::from("Arbitrary"),
722            Some(SubscriptionId::from("sub-1")),
723            OutboundMessage {
724                message_id: MessageId::from("1"),
725                destination: DestinationId::from("somedest"),
726                body: "Hello, World".as_bytes().to_owned(),
727            },
728        );
729
730        if let Err(error) = result {
731            assert_eq!(
732                "Unable to send message, client channel closed",
733                error.message
734            )
735        } else {
736            panic!("No, or incorrect, error message received");
737        }
738    }
739}