async_stomp/
client.rs

1use crate::frame;
2use crate::{FromServer, Message, Result, ToServer};
3use anyhow::{anyhow, bail};
4use bytes::{Buf, BytesMut};
5use futures::prelude::*;
6use futures::sink::SinkExt;
7use rustls::pki_types::ServerName;
8use std::fmt;
9use std::net::IpAddr;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tokio::net::TcpStream;
14use tokio_rustls::TlsConnector;
15use tokio_rustls::client::TlsStream;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17use typed_builder::TypedBuilder;
18use winnow::Partial;
19use winnow::error::ErrMode;
20use winnow::stream::Offset;
21
22/// The primary transport type used by STOMP clients
23///
24/// This is a `Framed` instance that handles encoding and decoding of STOMP frames
25/// over either a plain TCP connection or a TLS connection.
26pub type ClientTransport = Framed<TransportStream, ClientCodec>;
27
28/// Enum representing the transport stream, which can be either a plain TCP connection or a TLS connection
29///
30/// This type abstracts over the two possible connection types to provide a uniform interface
31/// for the rest of the library. It implements AsyncRead and AsyncWrite to handle all IO operations.
32#[allow(clippy::large_enum_variant)]
33pub enum TransportStream {
34    /// A plain, unencrypted TCP connection
35    Plain(TcpStream),
36    /// A secure TLS connection over TCP
37    Tls(TlsStream<TcpStream>),
38}
39
40// Implement AsyncRead for TransportStream to allow reading data from either connection type
41impl tokio::io::AsyncRead for TransportStream {
42    fn poll_read(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut tokio::io::ReadBuf<'_>,
46    ) -> Poll<std::io::Result<()>> {
47        // Delegate to the appropriate inner stream type
48        match self.get_mut() {
49            TransportStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
50            TransportStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
51        }
52    }
53}
54
55// Implement AsyncWrite for TransportStream to allow writing data to either connection type
56impl tokio::io::AsyncWrite for TransportStream {
57    fn poll_write(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &[u8],
61    ) -> Poll<std::io::Result<usize>> {
62        // Delegate to the appropriate inner stream type
63        match self.get_mut() {
64            TransportStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
65            TransportStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
66        }
67    }
68
69    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
70        // Delegate to the appropriate inner stream type
71        match self.get_mut() {
72            TransportStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
73            TransportStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
74        }
75    }
76
77    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
78        // Delegate to the appropriate inner stream type
79        match self.get_mut() {
80            TransportStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
81            TransportStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
82        }
83    }
84}
85
86// Debug implementation for TransportStream that provides a human-readable representation
87impl fmt::Debug for TransportStream {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            TransportStream::Plain(_) => write!(f, "Plain TCP connection"),
91            TransportStream::Tls(_) => write!(f, "TLS connection"),
92        }
93    }
94}
95
96/// A builder for creating and establishing STOMP connections to a server
97///
98/// This struct provides a builder pattern for configuring the connection
99/// parameters and then connecting to a STOMP server.
100///
101/// # Examples
102///
103/// ```rust,no_run
104/// use async_stomp::client::Connector;
105///
106///#[tokio::main]
107/// async fn main() {
108///   let connection = Connector::builder()
109///     .server("stomp.example.com")
110///     .virtualhost("stomp.example.com")
111///     .login("guest".to_string())
112///     .passcode("guest".to_string())
113///     .connect()
114///     .await;
115///}
116/// ```
117#[derive(TypedBuilder)]
118#[builder(build_method(vis="", name=__build))]
119pub struct Connector<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> {
120    /// The address to the stomp server
121    server: S,
122    /// Virtualhost, if no specific virtualhost is desired, it is recommended
123    /// to set this to the same as the host name that the socket
124    virtualhost: V,
125    /// Username to use for optional authentication to the server
126    #[builder(default, setter(strip_option))]
127    login: Option<String>,
128    /// Passcode to use for optional authentication to the server
129    #[builder(default, setter(strip_option))]
130    passcode: Option<String>,
131    /// Custom headers to be sent to the server
132    #[builder(default)]
133    headers: Vec<(String, String)>,
134    /// Whether to use TLS for this connection
135    #[builder(default = false)]
136    use_tls: bool,
137    /// Optional server name to verify in TLS certificate (defaults to hostname from server if not specified)
138    #[builder(default, setter(strip_option))]
139    tls_server_name: Option<String>,
140}
141
142/// Implementation of the builder connect method to allow the builder to directly connect
143#[allow(non_camel_case_types)]
144impl<
145    S: tokio::net::ToSocketAddrs + Clone,
146    V: Into<String> + Clone,
147    __login: ::typed_builder::Optional<Option<String>>,
148    __passcode: ::typed_builder::Optional<Option<String>>,
149    __headers: ::typed_builder::Optional<Vec<(String, String)>>,
150    __use_tls: ::typed_builder::Optional<bool>,
151    __tls_server_name: ::typed_builder::Optional<Option<String>>,
152>
153    ConnectorBuilder<
154        S,
155        V,
156        (
157            (S,),
158            (V,),
159            __login,
160            __passcode,
161            __headers,
162            __use_tls,
163            __tls_server_name,
164        ),
165    >
166{
167    /// Connect to the STOMP server using the configured parameters
168    ///
169    /// This method finalizes the builder and attempts to establish a connection
170    /// to the STOMP server. If successful, it returns a ClientTransport that can
171    /// be used to send and receive messages.
172    pub async fn connect(self) -> Result<ClientTransport> {
173        let connector = self.__build();
174        connector.connect().await
175    }
176
177    /// Create a Message for connection without actually connecting
178    ///
179    /// This can be used when you want to handle the connection process manually
180    /// or need access to the raw connection message.
181    pub fn msg(self) -> Message<ToServer> {
182        let connector = self.__build();
183        connector.msg()
184    }
185}
186
187impl<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> Connector<S, V> {
188    /// Creates a TLS connector with default trust anchors
189    ///
190    /// This method configures a TLS connector with the system's default trust anchors
191    /// for certificate verification.
192    async fn create_tls_connector(&self) -> Result<TlsConnector> {
193        // Create a root certificate store with webpki's built-in roots
194        let root_store = rustls::RootCertStore {
195            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
196        };
197
198        // Create a TLS client configuration with the root certificates
199        let config = rustls::ClientConfig::builder()
200            .with_root_certificates(root_store)
201            .with_no_client_auth();
202
203        Ok(TlsConnector::from(Arc::new(config)))
204    }
205
206    /// Connect to the STOMP server using the configured parameters
207    ///
208    /// This method establishes a connection to the STOMP server and performs
209    /// the STOMP protocol handshake. If successful, it returns a ClientTransport
210    /// that can be used to send and receive STOMP messages.
211    pub async fn connect(self) -> Result<ClientTransport> {
212        // First establish a TCP connection to the server
213        let tcp = TcpStream::connect(self.server.clone()).await?;
214
215        // Determine whether to use plain TCP or wrap with TLS
216        let transport_stream = if self.use_tls {
217            // Extract server name for TLS verification
218            let server_name = if let Some(name) = &self.tls_server_name {
219                name.clone()
220            } else {
221                // Extract the hostname from the server address
222                let server_addr = tcp.peer_addr()?;
223                let hostname = server_addr.ip().to_string();
224                if hostname.is_empty() {
225                    return Err(anyhow!(
226                        "Could not determine server hostname for TLS verification"
227                    ));
228                }
229                hostname
230            };
231
232            // Create TLS connector
233            let tls_connector = self.create_tls_connector().await?;
234
235            // Create a copy of server_name to avoid the borrow after move issue
236            let server_name_copy = server_name.clone();
237
238            // Try to parse the server name as an IP address first
239            let dns_name = if let Ok(ip_addr) = server_name_copy.parse::<IpAddr>() {
240                // Handle IP address
241                match ip_addr {
242                    IpAddr::V4(ipv4) => ServerName::IpAddress(ipv4.into()),
243                    IpAddr::V6(ipv6) => ServerName::IpAddress(ipv6.into()),
244                }
245            } else {
246                // Handle DNS name
247                ServerName::DnsName(
248                    server_name_copy
249                        .try_into()
250                        .map_err(|_| anyhow!("Invalid DNS name: {}", server_name))?,
251                )
252            };
253
254            // Connect with TLS
255            let tls_stream = tls_connector.connect(dns_name, tcp).await?;
256            TransportStream::Tls(tls_stream)
257        } else {
258            // Use plain TCP
259            TransportStream::Plain(tcp)
260        };
261
262        // Create a framed transport with the STOMP codec
263        let mut transport = ClientCodec.framed(transport_stream);
264
265        // Perform the STOMP protocol handshake
266        client_handshake(
267            &mut transport,
268            self.virtualhost.into(),
269            self.login,
270            self.passcode,
271            self.headers,
272        )
273        .await?;
274
275        Ok(transport)
276    }
277
278    /// Create a CONNECT message without actually connecting
279    ///
280    /// This method creates a STOMP CONNECT message using the configured parameters
281    /// which can be used to establish a connection manually.
282    pub fn msg(self) -> Message<ToServer> {
283        // Convert custom headers to the binary format expected by the protocol
284        let extra_headers = self
285            .headers
286            .into_iter()
287            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
288            .collect();
289
290        // Create the CONNECT message
291        Message {
292            content: ToServer::Connect {
293                accept_version: "1.2".into(),
294                host: self.virtualhost.into(),
295                login: self.login,
296                passcode: self.passcode,
297                heartbeat: None,
298            },
299            extra_headers,
300        }
301    }
302}
303
304/// Performs the STOMP protocol handshake with the server
305///
306/// This function sends a CONNECT frame to the server and waits for
307/// a CONNECTED response. If the server responds with anything else,
308/// the handshake is considered failed.
309async fn client_handshake(
310    transport: &mut ClientTransport,
311    virtualhost: String,
312    login: Option<String>,
313    passcode: Option<String>,
314    headers: Vec<(String, String)>,
315) -> Result<()> {
316    // Convert custom headers to the binary format expected by the protocol
317    let extra_headers = headers
318        .iter()
319        .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
320        .collect();
321
322    // Create the CONNECT message
323    let connect = Message {
324        content: ToServer::Connect {
325            accept_version: "1.2".into(),
326            host: virtualhost,
327            login,
328            passcode,
329            heartbeat: None,
330        },
331        extra_headers,
332    };
333
334    // Send the message to the server
335    transport.send(connect).await?;
336
337    // Receive and process the server's reply
338    let msg = transport.next().await.transpose()?;
339
340    // Check if the reply is a CONNECTED frame
341    if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
342        Ok(())
343    } else {
344        Err(anyhow!("unexpected reply: {:?}", msg))
345    }
346}
347
348/// Builder to create a Subscribe message with optional custom headers
349///
350/// This struct provides a builder pattern for configuring subscription parameters
351/// and creating a SUBSCRIBE message to send to a STOMP server.
352///
353/// # Examples
354///
355/// ```rust,no_run
356/// use futures::prelude::*;
357/// use async_stomp::client::Connector;
358/// use async_stomp::client::Subscriber;
359///
360///
361/// #[tokio::main]
362/// async fn main() -> Result<(), anyhow::Error> {
363///   let mut connection = Connector::builder()
364///     .server("stomp.example.com")
365///     .virtualhost("stomp.example.com")
366///     .login("guest".to_string())
367///     .passcode("guest".to_string())
368///     .headers(vec![("client-id".to_string(), "ClientTest".to_string())])
369///     .connect()
370///     .await.expect("Client connection");
371///   
372///   let subscribe_msg = Subscriber::builder()
373///     .destination("queue.test")
374///     .id("custom-subscriber-id")
375///     .subscribe();
376///
377///   connection.send(subscribe_msg).await?;
378///   Ok(())
379/// }
380/// ```
381#[derive(TypedBuilder)]
382#[builder(build_method(vis="", name=__build))]
383pub struct Subscriber<S: Into<String>, I: Into<String>> {
384    /// The destination to subscribe to (e.g., queue or topic name)
385    destination: S,
386    /// The subscription ID used to identify this subscription
387    id: I,
388    /// Custom headers to be included in the SUBSCRIBE frame
389    #[builder(default)]
390    headers: Vec<(String, String)>,
391}
392
393/// Implementation of the builder subscribe method to allow direct subscription creation
394#[allow(non_camel_case_types)]
395impl<S: Into<String>, I: Into<String>, __headers: ::typed_builder::Optional<Vec<(String, String)>>>
396    SubscriberBuilder<S, I, ((S,), (I,), __headers)>
397{
398    /// Creates a SUBSCRIBE message using the configured parameters
399    ///
400    /// This method finalizes the builder and returns a STOMP SUBSCRIBE message
401    /// that can be sent to a server to create a subscription.
402    pub fn subscribe(self) -> Message<ToServer> {
403        let subscriber = self.__build();
404        subscriber.subscribe()
405    }
406}
407
408impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
409    /// Creates a SUBSCRIBE message using the configured parameters
410    ///
411    /// This method returns a STOMP SUBSCRIBE message that can be sent to a server
412    /// to create a subscription with the configured destination, ID, and headers.
413    pub fn subscribe(self) -> Message<ToServer> {
414        // Create the basic Subscribe message
415        let mut msg: Message<ToServer> = ToServer::Subscribe {
416            destination: self.destination.into(),
417            id: self.id.into(),
418            ack: None,
419        }
420        .into();
421
422        // Add any custom headers
423        msg.extra_headers = self
424            .headers
425            .iter()
426            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
427            .collect();
428
429        msg
430    }
431}
432
433/// Codec for encoding/decoding STOMP protocol frames for client usage
434///
435/// This codec handles the conversion between STOMP protocol frames and Rust types,
436/// implementing the tokio_util::codec::Encoder and Decoder traits.
437pub struct ClientCodec;
438
439impl Decoder for ClientCodec {
440    type Item = Message<FromServer>;
441    type Error = anyhow::Error;
442
443    /// Decodes bytes from the server into STOMP messages
444    ///
445    /// This method attempts to parse a complete STOMP frame from the input buffer.
446    /// If a complete frame is available, it returns the parsed Message.
447    /// If more data is needed, it returns None.
448    /// If parsing fails, it returns an error.
449    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
450        // Create a partial view of the buffer for parsing
451        let buf = &mut Partial::new(src.chunk());
452
453        // Attempt to parse a frame from the buffer
454        let item = match frame::parse_frame(buf) {
455            Ok(frame) => Message::<FromServer>::from_frame(frame),
456            Err(ErrMode::Incomplete(_)) => return Ok(None), // Need more data
457            Err(e) => bail!("Parse failed: {:?}", e),       // Parsing error
458        };
459
460        // Calculate how many bytes were consumed
461        let len = buf.offset_from(&Partial::new(src.chunk()));
462
463        // Advance the buffer past the consumed bytes
464        src.advance(len);
465
466        // Return the parsed message (or error)
467        item.map(Some)
468    }
469}
470
471impl Encoder<Message<ToServer>> for ClientCodec {
472    type Error = anyhow::Error;
473
474    /// Encodes STOMP messages for sending to the server
475    ///
476    /// This method serializes a STOMP message into bytes to be sent over the network.
477    fn encode(
478        &mut self,
479        item: Message<ToServer>,
480        dst: &mut BytesMut,
481    ) -> std::result::Result<(), Self::Error> {
482        // Convert the message to a frame and serialize it into the buffer
483        item.to_frame().serialize(dst);
484        Ok(())
485    }
486}
487
488#[cfg(test)]
489mod tests {
490
491    use crate::{
492        Message, ToServer,
493        client::{Connector, Subscriber},
494    };
495    use bytes::BytesMut;
496
497    /// Tests the creation of a STOMP subscription message
498    ///
499    /// This test validates that a subscription message created using the Subscriber builder
500    /// contains the correct destination, ID, and custom headers. It verifies that the
501    /// subscription message serializes to the same byte sequence as a manually constructed
502    /// equivalent message.
503    ///
504    /// If this test fails, it means the Subscriber builder is not correctly constructing
505    /// STOMP SUBSCRIBE frames according to the protocol specification, which would cause
506    /// client subscriptions to fail or behave incorrectly when connecting to a STOMP server.
507    #[test]
508    fn subscription_message() {
509        let headers = vec![(
510            "activemq.subscriptionName".to_string(),
511            "ClientTest".to_string(),
512        )];
513        let subscribe_msg = Subscriber::builder()
514            .destination("queue.test")
515            .id("custom-subscriber-id")
516            .headers(headers.clone())
517            .subscribe();
518        let mut expected: Message<ToServer> = ToServer::Subscribe {
519            destination: "queue.test".to_string(),
520            id: "custom-subscriber-id".to_string(),
521            ack: None,
522        }
523        .into();
524        expected.extra_headers = headers
525            .into_iter()
526            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
527            .collect();
528
529        let mut expected_buffer = BytesMut::new();
530        expected.to_frame().serialize(&mut expected_buffer);
531        let mut actual_buffer = BytesMut::new();
532        subscribe_msg.to_frame().serialize(&mut actual_buffer);
533
534        assert_eq!(expected_buffer, actual_buffer);
535    }
536
537    /// Tests the creation of a STOMP connection message
538    ///
539    /// This test validates that a connection message created using the Connector builder
540    /// contains the correct server, virtualhost, login credentials, and custom headers.
541    /// It verifies that the connection message serializes to the same byte sequence as
542    /// a manually constructed equivalent message.
543    ///
544    /// If this test fails, it means the Connector builder is not correctly constructing
545    /// STOMP CONNECT frames according to the protocol specification, which would cause
546    /// client connections to fail when attempting to connect to a STOMP server.
547    #[test]
548    fn connection_message() {
549        let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
550        let connect_msg = Connector::builder()
551            .server("stomp.example.com")
552            .virtualhost("virtual.stomp.example.com")
553            .login("guest_login".to_string())
554            .passcode("guest_passcode".to_string())
555            .headers(headers.clone())
556            .msg();
557
558        let mut expected: Message<ToServer> = ToServer::Connect {
559            accept_version: "1.2".into(),
560            host: "virtual.stomp.example.com".into(),
561            login: Some("guest_login".to_string()),
562            passcode: Some("guest_passcode".to_string()),
563            heartbeat: None,
564        }
565        .into();
566        expected.extra_headers = headers
567            .into_iter()
568            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
569            .collect();
570
571        let mut expected_buffer = BytesMut::new();
572        expected.to_frame().serialize(&mut expected_buffer);
573        let mut actual_buffer = BytesMut::new();
574        connect_msg.to_frame().serialize(&mut actual_buffer);
575
576        assert_eq!(expected_buffer, actual_buffer);
577    }
578}