Skip to main content

async_stomp/
client.rs

1use crate::frame;
2use crate::{FromServer, Message, Result, ToServer};
3use anyhow::{anyhow, bail};
4use bytes::{Buf, BytesMut};
5use futures::prelude::*;
6use futures::sink::SinkExt;
7use rustls::pki_types::ServerName;
8use std::fmt;
9use std::net::IpAddr;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tokio::net::TcpStream;
14use tokio_rustls::TlsConnector;
15use tokio_rustls::client::TlsStream;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17use typed_builder::TypedBuilder;
18use winnow::Partial;
19use winnow::error::ErrMode;
20use winnow::stream::Offset;
21
22/// The primary transport type used by STOMP clients
23///
24/// This is a `Framed` instance that handles encoding and decoding of STOMP frames
25/// over either a plain TCP connection or a TLS connection.
26pub type ClientTransport = Framed<TransportStream, ClientCodec>;
27
28/// Enum representing the transport stream, which can be either a plain TCP connection or a TLS connection
29///
30/// This type abstracts over the two possible connection types to provide a uniform interface
31/// for the rest of the library. It implements AsyncRead and AsyncWrite to handle all IO operations.
32#[allow(clippy::large_enum_variant)]
33pub enum TransportStream {
34    /// A plain, unencrypted TCP connection
35    Plain(TcpStream),
36    /// A secure TLS connection over TCP
37    Tls(TlsStream<TcpStream>),
38}
39
40// Implement AsyncRead for TransportStream to allow reading data from either connection type
41impl tokio::io::AsyncRead for TransportStream {
42    fn poll_read(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut tokio::io::ReadBuf<'_>,
46    ) -> Poll<std::io::Result<()>> {
47        // Delegate to the appropriate inner stream type
48        match self.get_mut() {
49            TransportStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
50            TransportStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
51        }
52    }
53}
54
55// Implement AsyncWrite for TransportStream to allow writing data to either connection type
56impl tokio::io::AsyncWrite for TransportStream {
57    fn poll_write(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &[u8],
61    ) -> Poll<std::io::Result<usize>> {
62        // Delegate to the appropriate inner stream type
63        match self.get_mut() {
64            TransportStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
65            TransportStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
66        }
67    }
68
69    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
70        // Delegate to the appropriate inner stream type
71        match self.get_mut() {
72            TransportStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
73            TransportStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
74        }
75    }
76
77    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
78        // Delegate to the appropriate inner stream type
79        match self.get_mut() {
80            TransportStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
81            TransportStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
82        }
83    }
84}
85
86// Debug implementation for TransportStream that provides a human-readable representation
87impl fmt::Debug for TransportStream {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            TransportStream::Plain(_) => write!(f, "Plain TCP connection"),
91            TransportStream::Tls(_) => write!(f, "TLS connection"),
92        }
93    }
94}
95
96/// A builder for creating and establishing STOMP connections to a server
97///
98/// This struct provides a builder pattern for configuring the connection
99/// parameters and then connecting to a STOMP server.
100///
101/// # Examples
102///
103/// ```rust,no_run
104/// use async_stomp::client::Connector;
105///
106///#[tokio::main]
107/// async fn main() {
108///   let connection = Connector::builder()
109///     .server("stomp.example.com")
110///     .virtualhost("stomp.example.com")
111///     .login("guest".to_string())
112///     .passcode("guest".to_string())
113///     .connect()
114///     .await;
115///}
116/// ```
117#[derive(TypedBuilder)]
118#[builder(build_method(vis="", name=__build))]
119pub struct Connector<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> {
120    /// The address to the stomp server
121    server: S,
122    /// Virtualhost, if no specific virtualhost is desired, it is recommended
123    /// to set this to the same as the host name that the socket
124    virtualhost: V,
125    /// Username to use for optional authentication to the server
126    #[builder(default, setter(strip_option))]
127    login: Option<String>,
128    /// Passcode to use for optional authentication to the server
129    #[builder(default, setter(strip_option))]
130    passcode: Option<String>,
131    /// Custom headers to be sent to the server
132    #[builder(default)]
133    headers: Vec<(String, String)>,
134    /// Whether to use TLS for this connection
135    #[builder(default = false)]
136    use_tls: bool,
137    /// Optional server name to verify in TLS certificate (defaults to hostname from server if not specified)
138    #[builder(default, setter(strip_option))]
139    tls_server_name: Option<String>,
140}
141
142/// Implementation of the builder connect method to allow the builder to directly connect
143#[allow(non_camel_case_types)]
144impl<
145    S: tokio::net::ToSocketAddrs + Clone,
146    V: Into<String> + Clone,
147    __login,
148    __passcode,
149    __headers,
150    __use_tls,
151    __tls_server_name,
152>
153    ConnectorBuilder<
154        S,
155        V,
156        (
157            (S,),
158            (V,),
159            __login,
160            __passcode,
161            __headers,
162            __use_tls,
163            __tls_server_name,
164        ),
165    >
166where
167    Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
168            (
169                &'__typed_builder_lifetime_for_default S,
170                &'__typed_builder_lifetime_for_default V,
171                __login,
172            ),
173            Output = Option<String>,
174        >,
175    Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
176            (
177                &'__typed_builder_lifetime_for_default S,
178                &'__typed_builder_lifetime_for_default V,
179                &'__typed_builder_lifetime_for_default Option<String>,
180                __passcode,
181            ),
182            Output = Option<String>,
183        >,
184    Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
185            (
186                &'__typed_builder_lifetime_for_default S,
187                &'__typed_builder_lifetime_for_default V,
188                &'__typed_builder_lifetime_for_default Option<String>,
189                &'__typed_builder_lifetime_for_default Option<String>,
190                __headers,
191            ),
192            Output = Vec<(String, String)>,
193        >,
194    Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
195            (
196                &'__typed_builder_lifetime_for_default S,
197                &'__typed_builder_lifetime_for_default V,
198                &'__typed_builder_lifetime_for_default Option<String>,
199                &'__typed_builder_lifetime_for_default Option<String>,
200                &'__typed_builder_lifetime_for_default Vec<(String, String)>,
201                __use_tls,
202            ),
203            Output = bool,
204        >,
205    Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
206            (
207                &'__typed_builder_lifetime_for_default S,
208                &'__typed_builder_lifetime_for_default V,
209                &'__typed_builder_lifetime_for_default Option<String>,
210                &'__typed_builder_lifetime_for_default Option<String>,
211                &'__typed_builder_lifetime_for_default Vec<(String, String)>,
212                &'__typed_builder_lifetime_for_default bool,
213                __tls_server_name,
214            ),
215            Output = Option<String>,
216        >,
217{
218    /// Connect to the STOMP server using the configured parameters
219    ///
220    /// This method finalizes the builder and attempts to establish a connection
221    /// to the STOMP server. If successful, it returns a ClientTransport that can
222    /// be used to send and receive messages.
223    pub async fn connect(self) -> Result<ClientTransport> {
224        let connector: Connector<S, V> = self.__build();
225        connector.connect().await
226    }
227
228    /// Create a Message for connection without actually connecting
229    ///
230    /// This can be used when you want to handle the connection process manually
231    /// or need access to the raw connection message.
232    pub fn msg(self) -> Message<ToServer> {
233        let connector = self.__build();
234        connector.msg()
235    }
236}
237
238impl<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> Connector<S, V> {
239    /// Creates a TLS connector with default trust anchors
240    ///
241    /// This method configures a TLS connector with the system's default trust anchors
242    /// for certificate verification.
243    async fn create_tls_connector(&self) -> Result<TlsConnector> {
244        // Create a root certificate store with webpki's built-in roots
245        let root_store = rustls::RootCertStore {
246            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
247        };
248
249        // Create a TLS client configuration with the root certificates
250        let config = rustls::ClientConfig::builder()
251            .with_root_certificates(root_store)
252            .with_no_client_auth();
253
254        Ok(TlsConnector::from(Arc::new(config)))
255    }
256
257    /// Connect to the STOMP server using the configured parameters
258    ///
259    /// This method establishes a connection to the STOMP server and performs
260    /// the STOMP protocol handshake. If successful, it returns a ClientTransport
261    /// that can be used to send and receive STOMP messages.
262    pub async fn connect(self) -> Result<ClientTransport> {
263        // First establish a TCP connection to the server
264        let tcp = TcpStream::connect(self.server.clone()).await?;
265
266        // Determine whether to use plain TCP or wrap with TLS
267        let transport_stream = if self.use_tls {
268            // Extract server name for TLS verification
269            let server_name = if let Some(name) = &self.tls_server_name {
270                name.clone()
271            } else {
272                // Extract the hostname from the server address
273                let server_addr = tcp.peer_addr()?;
274                let hostname = server_addr.ip().to_string();
275                if hostname.is_empty() {
276                    return Err(anyhow!(
277                        "Could not determine server hostname for TLS verification"
278                    ));
279                }
280                hostname
281            };
282
283            // Create TLS connector
284            let tls_connector = self.create_tls_connector().await?;
285
286            // Create a copy of server_name to avoid the borrow after move issue
287            let server_name_copy = server_name.clone();
288
289            // Try to parse the server name as an IP address first
290            let dns_name = if let Ok(ip_addr) = server_name_copy.parse::<IpAddr>() {
291                // Handle IP address
292                match ip_addr {
293                    IpAddr::V4(ipv4) => ServerName::IpAddress(ipv4.into()),
294                    IpAddr::V6(ipv6) => ServerName::IpAddress(ipv6.into()),
295                }
296            } else {
297                // Handle DNS name
298                ServerName::DnsName(
299                    server_name_copy
300                        .try_into()
301                        .map_err(|_| anyhow!("Invalid DNS name: {}", server_name))?,
302                )
303            };
304
305            // Connect with TLS
306            let tls_stream = tls_connector.connect(dns_name, tcp).await?;
307            TransportStream::Tls(tls_stream)
308        } else {
309            // Use plain TCP
310            TransportStream::Plain(tcp)
311        };
312
313        // Create a framed transport with the STOMP codec
314        let mut transport = ClientCodec.framed(transport_stream);
315
316        // Perform the STOMP protocol handshake
317        client_handshake(
318            &mut transport,
319            self.virtualhost.into(),
320            self.login,
321            self.passcode,
322            self.headers,
323        )
324        .await?;
325
326        Ok(transport)
327    }
328
329    /// Create a CONNECT message without actually connecting
330    ///
331    /// This method creates a STOMP CONNECT message using the configured parameters
332    /// which can be used to establish a connection manually.
333    pub fn msg(self) -> Message<ToServer> {
334        // Convert custom headers to the binary format expected by the protocol
335        let extra_headers = self
336            .headers
337            .into_iter()
338            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
339            .collect();
340
341        // Create the CONNECT message
342        Message {
343            content: ToServer::Connect {
344                accept_version: "1.2".into(),
345                host: self.virtualhost.into(),
346                login: self.login,
347                passcode: self.passcode,
348                heartbeat: None,
349            },
350            extra_headers,
351        }
352    }
353}
354
355/// Performs the STOMP protocol handshake with the server
356///
357/// This function sends a CONNECT frame to the server and waits for
358/// a CONNECTED response. If the server responds with anything else,
359/// the handshake is considered failed.
360async fn client_handshake(
361    transport: &mut ClientTransport,
362    virtualhost: String,
363    login: Option<String>,
364    passcode: Option<String>,
365    headers: Vec<(String, String)>,
366) -> Result<()> {
367    // Convert custom headers to the binary format expected by the protocol
368    let extra_headers = headers
369        .iter()
370        .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
371        .collect();
372
373    // Create the CONNECT message
374    let connect = Message {
375        content: ToServer::Connect {
376            accept_version: "1.2".into(),
377            host: virtualhost,
378            login,
379            passcode,
380            heartbeat: None,
381        },
382        extra_headers,
383    };
384
385    // Send the message to the server
386    transport.send(connect).await?;
387
388    // Receive and process the server's reply
389    let msg = transport.next().await.transpose()?;
390
391    // Check if the reply is a CONNECTED frame
392    if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
393        Ok(())
394    } else {
395        Err(anyhow!("unexpected reply: {:?}", msg))
396    }
397}
398
399/// Builder to create a Subscribe message with optional custom headers
400///
401/// This struct provides a builder pattern for configuring subscription parameters
402/// and creating a SUBSCRIBE message to send to a STOMP server.
403///
404/// # Examples
405///
406/// ```rust,no_run
407/// use futures::prelude::*;
408/// use async_stomp::client::Connector;
409/// use async_stomp::client::Subscriber;
410///
411///
412/// #[tokio::main]
413/// async fn main() -> Result<(), anyhow::Error> {
414///   let mut connection = Connector::builder()
415///     .server("stomp.example.com")
416///     .virtualhost("stomp.example.com")
417///     .login("guest".to_string())
418///     .passcode("guest".to_string())
419///     .headers(vec![("client-id".to_string(), "ClientTest".to_string())])
420///     .connect()
421///     .await.expect("Client connection");
422///   
423///   let subscribe_msg = Subscriber::builder()
424///     .destination("queue.test")
425///     .id("custom-subscriber-id")
426///     .subscribe();
427///
428///   connection.send(subscribe_msg).await?;
429///   Ok(())
430/// }
431/// ```
432#[derive(TypedBuilder)]
433#[builder(build_method(vis="", name=__build))]
434pub struct Subscriber<S: Into<String>, I: Into<String>> {
435    /// The destination to subscribe to (e.g., queue or topic name)
436    destination: S,
437    /// The subscription ID used to identify this subscription
438    id: I,
439    /// Custom headers to be included in the SUBSCRIBE frame
440    #[builder(default)]
441    headers: Vec<(String, String)>,
442}
443
444/// Implementation of the builder subscribe method to allow direct subscription creation
445#[allow(non_camel_case_types)]
446impl<S: Into<String>, I: Into<String>, __headers> SubscriberBuilder<S, I, ((S,), (I,), __headers)>
447where
448    Subscriber<S, I>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
449            (
450                &'__typed_builder_lifetime_for_default S,
451                &'__typed_builder_lifetime_for_default I,
452                __headers,
453            ),
454            Output = Vec<(String, String)>,
455        >,
456{
457    /// Creates a SUBSCRIBE message using the configured parameters
458    ///
459    /// This method finalizes the builder and returns a STOMP SUBSCRIBE message
460    /// that can be sent to a server to create a subscription.
461    pub fn subscribe(self) -> Message<ToServer> {
462        let subscriber = self.__build();
463        subscriber.subscribe()
464    }
465}
466
467impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
468    /// Creates a SUBSCRIBE message using the configured parameters
469    ///
470    /// This method returns a STOMP SUBSCRIBE message that can be sent to a server
471    /// to create a subscription with the configured destination, ID, and headers.
472    pub fn subscribe(self) -> Message<ToServer> {
473        // Create the basic Subscribe message
474        let mut msg: Message<ToServer> = ToServer::Subscribe {
475            destination: self.destination.into(),
476            id: self.id.into(),
477            ack: None,
478        }
479        .into();
480
481        // Add any custom headers
482        msg.extra_headers = self
483            .headers
484            .iter()
485            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
486            .collect();
487
488        msg
489    }
490}
491
492/// Codec for encoding/decoding STOMP protocol frames for client usage
493///
494/// This codec handles the conversion between STOMP protocol frames and Rust types,
495/// implementing the tokio_util::codec::Encoder and Decoder traits.
496pub struct ClientCodec;
497
498impl Decoder for ClientCodec {
499    type Item = Message<FromServer>;
500    type Error = anyhow::Error;
501
502    /// Decodes bytes from the server into STOMP messages
503    ///
504    /// This method attempts to parse a complete STOMP frame from the input buffer.
505    /// If a complete frame is available, it returns the parsed Message.
506    /// If more data is needed, it returns None.
507    /// If parsing fails, it returns an error.
508    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
509        // Create a partial view of the buffer for parsing
510        let buf = &mut Partial::new(src.chunk());
511
512        // Attempt to parse a frame from the buffer
513        let item = match frame::parse_frame(buf) {
514            Ok(frame) => Message::<FromServer>::from_frame(frame),
515            Err(ErrMode::Incomplete(_)) => return Ok(None), // Need more data
516            Err(e) => bail!("Parse failed: {:?}", e),       // Parsing error
517        };
518
519        // Calculate how many bytes were consumed
520        let len = buf.offset_from(&Partial::new(src.chunk()));
521
522        // Advance the buffer past the consumed bytes
523        src.advance(len);
524
525        // Return the parsed message (or error)
526        item.map(Some)
527    }
528}
529
530impl Encoder<Message<ToServer>> for ClientCodec {
531    type Error = anyhow::Error;
532
533    /// Encodes STOMP messages for sending to the server
534    ///
535    /// This method serializes a STOMP message into bytes to be sent over the network.
536    fn encode(
537        &mut self,
538        item: Message<ToServer>,
539        dst: &mut BytesMut,
540    ) -> std::result::Result<(), Self::Error> {
541        // Convert the message to a frame and serialize it into the buffer
542        item.to_frame().serialize(dst);
543        Ok(())
544    }
545}
546
547#[cfg(test)]
548mod tests {
549
550    use crate::{
551        Message, ToServer,
552        client::{Connector, Subscriber},
553    };
554    use bytes::BytesMut;
555
556    /// Tests the creation of a STOMP subscription message
557    ///
558    /// This test validates that a subscription message created using the Subscriber builder
559    /// contains the correct destination, ID, and custom headers. It verifies that the
560    /// subscription message serializes to the same byte sequence as a manually constructed
561    /// equivalent message.
562    ///
563    /// If this test fails, it means the Subscriber builder is not correctly constructing
564    /// STOMP SUBSCRIBE frames according to the protocol specification, which would cause
565    /// client subscriptions to fail or behave incorrectly when connecting to a STOMP server.
566    #[test]
567    fn subscription_message() {
568        let headers = vec![(
569            "activemq.subscriptionName".to_string(),
570            "ClientTest".to_string(),
571        )];
572        let subscribe_msg = Subscriber::builder()
573            .destination("queue.test")
574            .id("custom-subscriber-id")
575            .headers(headers.clone())
576            .subscribe();
577        let mut expected: Message<ToServer> = ToServer::Subscribe {
578            destination: "queue.test".to_string(),
579            id: "custom-subscriber-id".to_string(),
580            ack: None,
581        }
582        .into();
583        expected.extra_headers = headers
584            .into_iter()
585            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
586            .collect();
587
588        let mut expected_buffer = BytesMut::new();
589        expected.to_frame().serialize(&mut expected_buffer);
590        let mut actual_buffer = BytesMut::new();
591        subscribe_msg.to_frame().serialize(&mut actual_buffer);
592
593        assert_eq!(expected_buffer, actual_buffer);
594    }
595
596    /// Tests the creation of a STOMP connection message
597    ///
598    /// This test validates that a connection message created using the Connector builder
599    /// contains the correct server, virtualhost, login credentials, and custom headers.
600    /// It verifies that the connection message serializes to the same byte sequence as
601    /// a manually constructed equivalent message.
602    ///
603    /// If this test fails, it means the Connector builder is not correctly constructing
604    /// STOMP CONNECT frames according to the protocol specification, which would cause
605    /// client connections to fail when attempting to connect to a STOMP server.
606    #[test]
607    fn connection_message() {
608        let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
609        let connect_msg = Connector::builder()
610            .server("stomp.example.com")
611            .virtualhost("virtual.stomp.example.com")
612            .login("guest_login".to_string())
613            .passcode("guest_passcode".to_string())
614            .headers(headers.clone())
615            .msg();
616
617        let mut expected: Message<ToServer> = ToServer::Connect {
618            accept_version: "1.2".into(),
619            host: "virtual.stomp.example.com".into(),
620            login: Some("guest_login".to_string()),
621            passcode: Some("guest_passcode".to_string()),
622            heartbeat: None,
623        }
624        .into();
625        expected.extra_headers = headers
626            .into_iter()
627            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
628            .collect();
629
630        let mut expected_buffer = BytesMut::new();
631        expected.to_frame().serialize(&mut expected_buffer);
632        let mut actual_buffer = BytesMut::new();
633        connect_msg.to_frame().serialize(&mut actual_buffer);
634
635        assert_eq!(expected_buffer, actual_buffer);
636    }
637}