async_stomp/client.rs
1use crate::frame;
2use crate::{FromServer, Message, Result, ToServer};
3use anyhow::{anyhow, bail};
4use bytes::{Buf, BytesMut};
5use futures::prelude::*;
6use futures::sink::SinkExt;
7use rustls::pki_types::ServerName;
8use std::fmt;
9use std::net::IpAddr;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tokio::net::TcpStream;
14use tokio_rustls::TlsConnector;
15use tokio_rustls::client::TlsStream;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17use typed_builder::TypedBuilder;
18use winnow::Partial;
19use winnow::error::ErrMode;
20use winnow::stream::Offset;
21
22/// The primary transport type used by STOMP clients
23///
24/// This is a `Framed` instance that handles encoding and decoding of STOMP frames
25/// over either a plain TCP connection or a TLS connection.
26pub type ClientTransport = Framed<TransportStream, ClientCodec>;
27
28/// Enum representing the transport stream, which can be either a plain TCP connection or a TLS connection
29///
30/// This type abstracts over the two possible connection types to provide a uniform interface
31/// for the rest of the library. It implements AsyncRead and AsyncWrite to handle all IO operations.
32#[allow(clippy::large_enum_variant)]
33pub enum TransportStream {
34 /// A plain, unencrypted TCP connection
35 Plain(TcpStream),
36 /// A secure TLS connection over TCP
37 Tls(TlsStream<TcpStream>),
38}
39
40// Implement AsyncRead for TransportStream to allow reading data from either connection type
41impl tokio::io::AsyncRead for TransportStream {
42 fn poll_read(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut tokio::io::ReadBuf<'_>,
46 ) -> Poll<std::io::Result<()>> {
47 // Delegate to the appropriate inner stream type
48 match self.get_mut() {
49 TransportStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
50 TransportStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
51 }
52 }
53}
54
55// Implement AsyncWrite for TransportStream to allow writing data to either connection type
56impl tokio::io::AsyncWrite for TransportStream {
57 fn poll_write(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &[u8],
61 ) -> Poll<std::io::Result<usize>> {
62 // Delegate to the appropriate inner stream type
63 match self.get_mut() {
64 TransportStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
65 TransportStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
66 }
67 }
68
69 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
70 // Delegate to the appropriate inner stream type
71 match self.get_mut() {
72 TransportStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
73 TransportStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
74 }
75 }
76
77 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
78 // Delegate to the appropriate inner stream type
79 match self.get_mut() {
80 TransportStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
81 TransportStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
82 }
83 }
84}
85
86// Debug implementation for TransportStream that provides a human-readable representation
87impl fmt::Debug for TransportStream {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 TransportStream::Plain(_) => write!(f, "Plain TCP connection"),
91 TransportStream::Tls(_) => write!(f, "TLS connection"),
92 }
93 }
94}
95
96/// A builder for creating and establishing STOMP connections to a server
97///
98/// This struct provides a builder pattern for configuring the connection
99/// parameters and then connecting to a STOMP server.
100///
101/// # Examples
102///
103/// ```rust,no_run
104/// use async_stomp::client::Connector;
105///
106///#[tokio::main]
107/// async fn main() {
108/// let connection = Connector::builder()
109/// .server("stomp.example.com")
110/// .virtualhost("stomp.example.com")
111/// .login("guest".to_string())
112/// .passcode("guest".to_string())
113/// .connect()
114/// .await;
115///}
116/// ```
117#[derive(TypedBuilder)]
118#[builder(build_method(vis="", name=__build))]
119pub struct Connector<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> {
120 /// The address to the stomp server
121 server: S,
122 /// Virtualhost, if no specific virtualhost is desired, it is recommended
123 /// to set this to the same as the host name that the socket
124 virtualhost: V,
125 /// Username to use for optional authentication to the server
126 #[builder(default, setter(strip_option))]
127 login: Option<String>,
128 /// Passcode to use for optional authentication to the server
129 #[builder(default, setter(strip_option))]
130 passcode: Option<String>,
131 /// Custom headers to be sent to the server
132 #[builder(default)]
133 headers: Vec<(String, String)>,
134 /// Whether to use TLS for this connection
135 #[builder(default = false)]
136 use_tls: bool,
137 /// Optional server name to verify in TLS certificate (defaults to hostname from server if not specified)
138 #[builder(default, setter(strip_option))]
139 tls_server_name: Option<String>,
140}
141
142/// Implementation of the builder connect method to allow the builder to directly connect
143#[allow(non_camel_case_types)]
144impl<
145 S: tokio::net::ToSocketAddrs + Clone,
146 V: Into<String> + Clone,
147 __login: ::typed_builder::Optional<Option<String>>,
148 __passcode: ::typed_builder::Optional<Option<String>>,
149 __headers: ::typed_builder::Optional<Vec<(String, String)>>,
150 __use_tls: ::typed_builder::Optional<bool>,
151 __tls_server_name: ::typed_builder::Optional<Option<String>>,
152>
153 ConnectorBuilder<
154 S,
155 V,
156 (
157 (S,),
158 (V,),
159 __login,
160 __passcode,
161 __headers,
162 __use_tls,
163 __tls_server_name,
164 ),
165 >
166{
167 /// Connect to the STOMP server using the configured parameters
168 ///
169 /// This method finalizes the builder and attempts to establish a connection
170 /// to the STOMP server. If successful, it returns a ClientTransport that can
171 /// be used to send and receive messages.
172 pub async fn connect(self) -> Result<ClientTransport> {
173 let connector = self.__build();
174 connector.connect().await
175 }
176
177 /// Create a Message for connection without actually connecting
178 ///
179 /// This can be used when you want to handle the connection process manually
180 /// or need access to the raw connection message.
181 pub fn msg(self) -> Message<ToServer> {
182 let connector = self.__build();
183 connector.msg()
184 }
185}
186
187impl<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> Connector<S, V> {
188 /// Creates a TLS connector with default trust anchors
189 ///
190 /// This method configures a TLS connector with the system's default trust anchors
191 /// for certificate verification.
192 async fn create_tls_connector(&self) -> Result<TlsConnector> {
193 // Create a root certificate store with webpki's built-in roots
194 let root_store = rustls::RootCertStore {
195 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
196 };
197
198 // Create a TLS client configuration with the root certificates
199 let config = rustls::ClientConfig::builder()
200 .with_root_certificates(root_store)
201 .with_no_client_auth();
202
203 Ok(TlsConnector::from(Arc::new(config)))
204 }
205
206 /// Connect to the STOMP server using the configured parameters
207 ///
208 /// This method establishes a connection to the STOMP server and performs
209 /// the STOMP protocol handshake. If successful, it returns a ClientTransport
210 /// that can be used to send and receive STOMP messages.
211 pub async fn connect(self) -> Result<ClientTransport> {
212 // First establish a TCP connection to the server
213 let tcp = TcpStream::connect(self.server.clone()).await?;
214
215 // Determine whether to use plain TCP or wrap with TLS
216 let transport_stream = if self.use_tls {
217 // Extract server name for TLS verification
218 let server_name = if let Some(name) = &self.tls_server_name {
219 name.clone()
220 } else {
221 // Extract the hostname from the server address
222 let server_addr = tcp.peer_addr()?;
223 let hostname = server_addr.ip().to_string();
224 if hostname.is_empty() {
225 return Err(anyhow!(
226 "Could not determine server hostname for TLS verification"
227 ));
228 }
229 hostname
230 };
231
232 // Create TLS connector
233 let tls_connector = self.create_tls_connector().await?;
234
235 // Create a copy of server_name to avoid the borrow after move issue
236 let server_name_copy = server_name.clone();
237
238 // Try to parse the server name as an IP address first
239 let dns_name = if let Ok(ip_addr) = server_name_copy.parse::<IpAddr>() {
240 // Handle IP address
241 match ip_addr {
242 IpAddr::V4(ipv4) => ServerName::IpAddress(ipv4.into()),
243 IpAddr::V6(ipv6) => ServerName::IpAddress(ipv6.into()),
244 }
245 } else {
246 // Handle DNS name
247 ServerName::DnsName(
248 server_name_copy
249 .try_into()
250 .map_err(|_| anyhow!("Invalid DNS name: {}", server_name))?,
251 )
252 };
253
254 // Connect with TLS
255 let tls_stream = tls_connector.connect(dns_name, tcp).await?;
256 TransportStream::Tls(tls_stream)
257 } else {
258 // Use plain TCP
259 TransportStream::Plain(tcp)
260 };
261
262 // Create a framed transport with the STOMP codec
263 let mut transport = ClientCodec.framed(transport_stream);
264
265 // Perform the STOMP protocol handshake
266 client_handshake(
267 &mut transport,
268 self.virtualhost.into(),
269 self.login,
270 self.passcode,
271 self.headers,
272 )
273 .await?;
274
275 Ok(transport)
276 }
277
278 /// Create a CONNECT message without actually connecting
279 ///
280 /// This method creates a STOMP CONNECT message using the configured parameters
281 /// which can be used to establish a connection manually.
282 pub fn msg(self) -> Message<ToServer> {
283 // Convert custom headers to the binary format expected by the protocol
284 let extra_headers = self
285 .headers
286 .into_iter()
287 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
288 .collect();
289
290 // Create the CONNECT message
291 Message {
292 content: ToServer::Connect {
293 accept_version: "1.2".into(),
294 host: self.virtualhost.into(),
295 login: self.login,
296 passcode: self.passcode,
297 heartbeat: None,
298 },
299 extra_headers,
300 }
301 }
302}
303
304/// Performs the STOMP protocol handshake with the server
305///
306/// This function sends a CONNECT frame to the server and waits for
307/// a CONNECTED response. If the server responds with anything else,
308/// the handshake is considered failed.
309async fn client_handshake(
310 transport: &mut ClientTransport,
311 virtualhost: String,
312 login: Option<String>,
313 passcode: Option<String>,
314 headers: Vec<(String, String)>,
315) -> Result<()> {
316 // Convert custom headers to the binary format expected by the protocol
317 let extra_headers = headers
318 .iter()
319 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
320 .collect();
321
322 // Create the CONNECT message
323 let connect = Message {
324 content: ToServer::Connect {
325 accept_version: "1.2".into(),
326 host: virtualhost,
327 login,
328 passcode,
329 heartbeat: None,
330 },
331 extra_headers,
332 };
333
334 // Send the message to the server
335 transport.send(connect).await?;
336
337 // Receive and process the server's reply
338 let msg = transport.next().await.transpose()?;
339
340 // Check if the reply is a CONNECTED frame
341 if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
342 Ok(())
343 } else {
344 Err(anyhow!("unexpected reply: {:?}", msg))
345 }
346}
347
348/// Builder to create a Subscribe message with optional custom headers
349///
350/// This struct provides a builder pattern for configuring subscription parameters
351/// and creating a SUBSCRIBE message to send to a STOMP server.
352///
353/// # Examples
354///
355/// ```rust,no_run
356/// use futures::prelude::*;
357/// use async_stomp::client::Connector;
358/// use async_stomp::client::Subscriber;
359///
360///
361/// #[tokio::main]
362/// async fn main() -> Result<(), anyhow::Error> {
363/// let mut connection = Connector::builder()
364/// .server("stomp.example.com")
365/// .virtualhost("stomp.example.com")
366/// .login("guest".to_string())
367/// .passcode("guest".to_string())
368/// .headers(vec![("client-id".to_string(), "ClientTest".to_string())])
369/// .connect()
370/// .await.expect("Client connection");
371///
372/// let subscribe_msg = Subscriber::builder()
373/// .destination("queue.test")
374/// .id("custom-subscriber-id")
375/// .subscribe();
376///
377/// connection.send(subscribe_msg).await?;
378/// Ok(())
379/// }
380/// ```
381#[derive(TypedBuilder)]
382#[builder(build_method(vis="", name=__build))]
383pub struct Subscriber<S: Into<String>, I: Into<String>> {
384 /// The destination to subscribe to (e.g., queue or topic name)
385 destination: S,
386 /// The subscription ID used to identify this subscription
387 id: I,
388 /// Custom headers to be included in the SUBSCRIBE frame
389 #[builder(default)]
390 headers: Vec<(String, String)>,
391}
392
393/// Implementation of the builder subscribe method to allow direct subscription creation
394#[allow(non_camel_case_types)]
395impl<S: Into<String>, I: Into<String>, __headers: ::typed_builder::Optional<Vec<(String, String)>>>
396 SubscriberBuilder<S, I, ((S,), (I,), __headers)>
397{
398 /// Creates a SUBSCRIBE message using the configured parameters
399 ///
400 /// This method finalizes the builder and returns a STOMP SUBSCRIBE message
401 /// that can be sent to a server to create a subscription.
402 pub fn subscribe(self) -> Message<ToServer> {
403 let subscriber = self.__build();
404 subscriber.subscribe()
405 }
406}
407
408impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
409 /// Creates a SUBSCRIBE message using the configured parameters
410 ///
411 /// This method returns a STOMP SUBSCRIBE message that can be sent to a server
412 /// to create a subscription with the configured destination, ID, and headers.
413 pub fn subscribe(self) -> Message<ToServer> {
414 // Create the basic Subscribe message
415 let mut msg: Message<ToServer> = ToServer::Subscribe {
416 destination: self.destination.into(),
417 id: self.id.into(),
418 ack: None,
419 }
420 .into();
421
422 // Add any custom headers
423 msg.extra_headers = self
424 .headers
425 .iter()
426 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
427 .collect();
428
429 msg
430 }
431}
432
433/// Codec for encoding/decoding STOMP protocol frames for client usage
434///
435/// This codec handles the conversion between STOMP protocol frames and Rust types,
436/// implementing the tokio_util::codec::Encoder and Decoder traits.
437pub struct ClientCodec;
438
439impl Decoder for ClientCodec {
440 type Item = Message<FromServer>;
441 type Error = anyhow::Error;
442
443 /// Decodes bytes from the server into STOMP messages
444 ///
445 /// This method attempts to parse a complete STOMP frame from the input buffer.
446 /// If a complete frame is available, it returns the parsed Message.
447 /// If more data is needed, it returns None.
448 /// If parsing fails, it returns an error.
449 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
450 // Create a partial view of the buffer for parsing
451 let buf = &mut Partial::new(src.chunk());
452
453 // Attempt to parse a frame from the buffer
454 let item = match frame::parse_frame(buf) {
455 Ok(frame) => Message::<FromServer>::from_frame(frame),
456 Err(ErrMode::Incomplete(_)) => return Ok(None), // Need more data
457 Err(e) => bail!("Parse failed: {:?}", e), // Parsing error
458 };
459
460 // Calculate how many bytes were consumed
461 let len = buf.offset_from(&Partial::new(src.chunk()));
462
463 // Advance the buffer past the consumed bytes
464 src.advance(len);
465
466 // Return the parsed message (or error)
467 item.map(Some)
468 }
469}
470
471impl Encoder<Message<ToServer>> for ClientCodec {
472 type Error = anyhow::Error;
473
474 /// Encodes STOMP messages for sending to the server
475 ///
476 /// This method serializes a STOMP message into bytes to be sent over the network.
477 fn encode(
478 &mut self,
479 item: Message<ToServer>,
480 dst: &mut BytesMut,
481 ) -> std::result::Result<(), Self::Error> {
482 // Convert the message to a frame and serialize it into the buffer
483 item.to_frame().serialize(dst);
484 Ok(())
485 }
486}
487
488#[cfg(test)]
489mod tests {
490
491 use crate::{
492 Message, ToServer,
493 client::{Connector, Subscriber},
494 };
495 use bytes::BytesMut;
496
497 /// Tests the creation of a STOMP subscription message
498 ///
499 /// This test validates that a subscription message created using the Subscriber builder
500 /// contains the correct destination, ID, and custom headers. It verifies that the
501 /// subscription message serializes to the same byte sequence as a manually constructed
502 /// equivalent message.
503 ///
504 /// If this test fails, it means the Subscriber builder is not correctly constructing
505 /// STOMP SUBSCRIBE frames according to the protocol specification, which would cause
506 /// client subscriptions to fail or behave incorrectly when connecting to a STOMP server.
507 #[test]
508 fn subscription_message() {
509 let headers = vec![(
510 "activemq.subscriptionName".to_string(),
511 "ClientTest".to_string(),
512 )];
513 let subscribe_msg = Subscriber::builder()
514 .destination("queue.test")
515 .id("custom-subscriber-id")
516 .headers(headers.clone())
517 .subscribe();
518 let mut expected: Message<ToServer> = ToServer::Subscribe {
519 destination: "queue.test".to_string(),
520 id: "custom-subscriber-id".to_string(),
521 ack: None,
522 }
523 .into();
524 expected.extra_headers = headers
525 .into_iter()
526 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
527 .collect();
528
529 let mut expected_buffer = BytesMut::new();
530 expected.to_frame().serialize(&mut expected_buffer);
531 let mut actual_buffer = BytesMut::new();
532 subscribe_msg.to_frame().serialize(&mut actual_buffer);
533
534 assert_eq!(expected_buffer, actual_buffer);
535 }
536
537 /// Tests the creation of a STOMP connection message
538 ///
539 /// This test validates that a connection message created using the Connector builder
540 /// contains the correct server, virtualhost, login credentials, and custom headers.
541 /// It verifies that the connection message serializes to the same byte sequence as
542 /// a manually constructed equivalent message.
543 ///
544 /// If this test fails, it means the Connector builder is not correctly constructing
545 /// STOMP CONNECT frames according to the protocol specification, which would cause
546 /// client connections to fail when attempting to connect to a STOMP server.
547 #[test]
548 fn connection_message() {
549 let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
550 let connect_msg = Connector::builder()
551 .server("stomp.example.com")
552 .virtualhost("virtual.stomp.example.com")
553 .login("guest_login".to_string())
554 .passcode("guest_passcode".to_string())
555 .headers(headers.clone())
556 .msg();
557
558 let mut expected: Message<ToServer> = ToServer::Connect {
559 accept_version: "1.2".into(),
560 host: "virtual.stomp.example.com".into(),
561 login: Some("guest_login".to_string()),
562 passcode: Some("guest_passcode".to_string()),
563 heartbeat: None,
564 }
565 .into();
566 expected.extra_headers = headers
567 .into_iter()
568 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
569 .collect();
570
571 let mut expected_buffer = BytesMut::new();
572 expected.to_frame().serialize(&mut expected_buffer);
573 let mut actual_buffer = BytesMut::new();
574 connect_msg.to_frame().serialize(&mut actual_buffer);
575
576 assert_eq!(expected_buffer, actual_buffer);
577 }
578}