1use crate::frame;
2use crate::{FromServer, Message, Result, ToServer};
3use anyhow::{anyhow, bail};
4use bytes::{Buf, BytesMut};
5use futures::prelude::*;
6use futures::sink::SinkExt;
7use rustls::pki_types::ServerName;
8use std::fmt;
9use std::net::IpAddr;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tokio::net::TcpStream;
14use tokio_rustls::TlsConnector;
15use tokio_rustls::client::TlsStream;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17use typed_builder::TypedBuilder;
18use winnow::Partial;
19use winnow::error::ErrMode;
20use winnow::stream::Offset;
21
22pub type ClientTransport = Framed<TransportStream, ClientCodec>;
27
28#[allow(clippy::large_enum_variant)]
33pub enum TransportStream {
34 Plain(TcpStream),
36 Tls(TlsStream<TcpStream>),
38}
39
40impl tokio::io::AsyncRead for TransportStream {
42 fn poll_read(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut tokio::io::ReadBuf<'_>,
46 ) -> Poll<std::io::Result<()>> {
47 match self.get_mut() {
49 TransportStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
50 TransportStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
51 }
52 }
53}
54
55impl tokio::io::AsyncWrite for TransportStream {
57 fn poll_write(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &[u8],
61 ) -> Poll<std::io::Result<usize>> {
62 match self.get_mut() {
64 TransportStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
65 TransportStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
66 }
67 }
68
69 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
70 match self.get_mut() {
72 TransportStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
73 TransportStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
74 }
75 }
76
77 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
78 match self.get_mut() {
80 TransportStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
81 TransportStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
82 }
83 }
84}
85
86impl fmt::Debug for TransportStream {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 TransportStream::Plain(_) => write!(f, "Plain TCP connection"),
91 TransportStream::Tls(_) => write!(f, "TLS connection"),
92 }
93 }
94}
95
96#[derive(TypedBuilder)]
118#[builder(build_method(vis="", name=__build))]
119pub struct Connector<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> {
120 server: S,
122 virtualhost: V,
125 #[builder(default, setter(strip_option))]
127 login: Option<String>,
128 #[builder(default, setter(strip_option))]
130 passcode: Option<String>,
131 #[builder(default)]
133 headers: Vec<(String, String)>,
134 #[builder(default = false)]
136 use_tls: bool,
137 #[builder(default, setter(strip_option))]
139 tls_server_name: Option<String>,
140}
141
142#[allow(non_camel_case_types)]
144impl<
145 S: tokio::net::ToSocketAddrs + Clone,
146 V: Into<String> + Clone,
147 __login,
148 __passcode,
149 __headers,
150 __use_tls,
151 __tls_server_name,
152>
153 ConnectorBuilder<
154 S,
155 V,
156 (
157 (S,),
158 (V,),
159 __login,
160 __passcode,
161 __headers,
162 __use_tls,
163 __tls_server_name,
164 ),
165 >
166where
167 Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
168 (
169 &'__typed_builder_lifetime_for_default S,
170 &'__typed_builder_lifetime_for_default V,
171 __login,
172 ),
173 Output = Option<String>,
174 >,
175 Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
176 (
177 &'__typed_builder_lifetime_for_default S,
178 &'__typed_builder_lifetime_for_default V,
179 &'__typed_builder_lifetime_for_default Option<String>,
180 __passcode,
181 ),
182 Output = Option<String>,
183 >,
184 Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
185 (
186 &'__typed_builder_lifetime_for_default S,
187 &'__typed_builder_lifetime_for_default V,
188 &'__typed_builder_lifetime_for_default Option<String>,
189 &'__typed_builder_lifetime_for_default Option<String>,
190 __headers,
191 ),
192 Output = Vec<(String, String)>,
193 >,
194 Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
195 (
196 &'__typed_builder_lifetime_for_default S,
197 &'__typed_builder_lifetime_for_default V,
198 &'__typed_builder_lifetime_for_default Option<String>,
199 &'__typed_builder_lifetime_for_default Option<String>,
200 &'__typed_builder_lifetime_for_default Vec<(String, String)>,
201 __use_tls,
202 ),
203 Output = bool,
204 >,
205 Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
206 (
207 &'__typed_builder_lifetime_for_default S,
208 &'__typed_builder_lifetime_for_default V,
209 &'__typed_builder_lifetime_for_default Option<String>,
210 &'__typed_builder_lifetime_for_default Option<String>,
211 &'__typed_builder_lifetime_for_default Vec<(String, String)>,
212 &'__typed_builder_lifetime_for_default bool,
213 __tls_server_name,
214 ),
215 Output = Option<String>,
216 >,
217{
218 pub async fn connect(self) -> Result<ClientTransport> {
224 let connector: Connector<S, V> = self.__build();
225 connector.connect().await
226 }
227
228 pub fn msg(self) -> Message<ToServer> {
233 let connector = self.__build();
234 connector.msg()
235 }
236}
237
238impl<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> Connector<S, V> {
239 async fn create_tls_connector(&self) -> Result<TlsConnector> {
244 let root_store = rustls::RootCertStore {
246 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
247 };
248
249 let config = rustls::ClientConfig::builder()
251 .with_root_certificates(root_store)
252 .with_no_client_auth();
253
254 Ok(TlsConnector::from(Arc::new(config)))
255 }
256
257 pub async fn connect(self) -> Result<ClientTransport> {
263 let tcp = TcpStream::connect(self.server.clone()).await?;
265
266 let transport_stream = if self.use_tls {
268 let server_name = if let Some(name) = &self.tls_server_name {
270 name.clone()
271 } else {
272 let server_addr = tcp.peer_addr()?;
274 let hostname = server_addr.ip().to_string();
275 if hostname.is_empty() {
276 return Err(anyhow!(
277 "Could not determine server hostname for TLS verification"
278 ));
279 }
280 hostname
281 };
282
283 let tls_connector = self.create_tls_connector().await?;
285
286 let server_name_copy = server_name.clone();
288
289 let dns_name = if let Ok(ip_addr) = server_name_copy.parse::<IpAddr>() {
291 match ip_addr {
293 IpAddr::V4(ipv4) => ServerName::IpAddress(ipv4.into()),
294 IpAddr::V6(ipv6) => ServerName::IpAddress(ipv6.into()),
295 }
296 } else {
297 ServerName::DnsName(
299 server_name_copy
300 .try_into()
301 .map_err(|_| anyhow!("Invalid DNS name: {}", server_name))?,
302 )
303 };
304
305 let tls_stream = tls_connector.connect(dns_name, tcp).await?;
307 TransportStream::Tls(tls_stream)
308 } else {
309 TransportStream::Plain(tcp)
311 };
312
313 let mut transport = ClientCodec.framed(transport_stream);
315
316 client_handshake(
318 &mut transport,
319 self.virtualhost.into(),
320 self.login,
321 self.passcode,
322 self.headers,
323 )
324 .await?;
325
326 Ok(transport)
327 }
328
329 pub fn msg(self) -> Message<ToServer> {
334 let extra_headers = self
336 .headers
337 .into_iter()
338 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
339 .collect();
340
341 Message {
343 content: ToServer::Connect {
344 accept_version: "1.2".into(),
345 host: self.virtualhost.into(),
346 login: self.login,
347 passcode: self.passcode,
348 heartbeat: None,
349 },
350 extra_headers,
351 }
352 }
353}
354
355async fn client_handshake(
361 transport: &mut ClientTransport,
362 virtualhost: String,
363 login: Option<String>,
364 passcode: Option<String>,
365 headers: Vec<(String, String)>,
366) -> Result<()> {
367 let extra_headers = headers
369 .iter()
370 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
371 .collect();
372
373 let connect = Message {
375 content: ToServer::Connect {
376 accept_version: "1.2".into(),
377 host: virtualhost,
378 login,
379 passcode,
380 heartbeat: None,
381 },
382 extra_headers,
383 };
384
385 transport.send(connect).await?;
387
388 let msg = transport.next().await.transpose()?;
390
391 if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
393 Ok(())
394 } else {
395 Err(anyhow!("unexpected reply: {:?}", msg))
396 }
397}
398
399#[derive(TypedBuilder)]
433#[builder(build_method(vis="", name=__build))]
434pub struct Subscriber<S: Into<String>, I: Into<String>> {
435 destination: S,
437 id: I,
439 #[builder(default)]
441 headers: Vec<(String, String)>,
442}
443
444#[allow(non_camel_case_types)]
446impl<S: Into<String>, I: Into<String>, __headers> SubscriberBuilder<S, I, ((S,), (I,), __headers)>
447where
448 Subscriber<S, I>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
449 (
450 &'__typed_builder_lifetime_for_default S,
451 &'__typed_builder_lifetime_for_default I,
452 __headers,
453 ),
454 Output = Vec<(String, String)>,
455 >,
456{
457 pub fn subscribe(self) -> Message<ToServer> {
462 let subscriber = self.__build();
463 subscriber.subscribe()
464 }
465}
466
467impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
468 pub fn subscribe(self) -> Message<ToServer> {
473 let mut msg: Message<ToServer> = ToServer::Subscribe {
475 destination: self.destination.into(),
476 id: self.id.into(),
477 ack: None,
478 }
479 .into();
480
481 msg.extra_headers = self
483 .headers
484 .iter()
485 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
486 .collect();
487
488 msg
489 }
490}
491
492pub struct ClientCodec;
497
498impl Decoder for ClientCodec {
499 type Item = Message<FromServer>;
500 type Error = anyhow::Error;
501
502 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
509 let buf = &mut Partial::new(src.chunk());
511
512 let item = match frame::parse_frame(buf) {
514 Ok(frame) => Message::<FromServer>::from_frame(frame),
515 Err(ErrMode::Incomplete(_)) => return Ok(None), Err(e) => bail!("Parse failed: {:?}", e), };
518
519 let len = buf.offset_from(&Partial::new(src.chunk()));
521
522 src.advance(len);
524
525 item.map(Some)
527 }
528}
529
530impl Encoder<Message<ToServer>> for ClientCodec {
531 type Error = anyhow::Error;
532
533 fn encode(
537 &mut self,
538 item: Message<ToServer>,
539 dst: &mut BytesMut,
540 ) -> std::result::Result<(), Self::Error> {
541 item.to_frame().serialize(dst);
543 Ok(())
544 }
545}
546
547#[cfg(test)]
548mod tests {
549
550 use crate::{
551 Message, ToServer,
552 client::{Connector, Subscriber},
553 };
554 use bytes::BytesMut;
555
556 #[test]
567 fn subscription_message() {
568 let headers = vec![(
569 "activemq.subscriptionName".to_string(),
570 "ClientTest".to_string(),
571 )];
572 let subscribe_msg = Subscriber::builder()
573 .destination("queue.test")
574 .id("custom-subscriber-id")
575 .headers(headers.clone())
576 .subscribe();
577 let mut expected: Message<ToServer> = ToServer::Subscribe {
578 destination: "queue.test".to_string(),
579 id: "custom-subscriber-id".to_string(),
580 ack: None,
581 }
582 .into();
583 expected.extra_headers = headers
584 .into_iter()
585 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
586 .collect();
587
588 let mut expected_buffer = BytesMut::new();
589 expected.to_frame().serialize(&mut expected_buffer);
590 let mut actual_buffer = BytesMut::new();
591 subscribe_msg.to_frame().serialize(&mut actual_buffer);
592
593 assert_eq!(expected_buffer, actual_buffer);
594 }
595
596 #[test]
607 fn connection_message() {
608 let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
609 let connect_msg = Connector::builder()
610 .server("stomp.example.com")
611 .virtualhost("virtual.stomp.example.com")
612 .login("guest_login".to_string())
613 .passcode("guest_passcode".to_string())
614 .headers(headers.clone())
615 .msg();
616
617 let mut expected: Message<ToServer> = ToServer::Connect {
618 accept_version: "1.2".into(),
619 host: "virtual.stomp.example.com".into(),
620 login: Some("guest_login".to_string()),
621 passcode: Some("guest_passcode".to_string()),
622 heartbeat: None,
623 }
624 .into();
625 expected.extra_headers = headers
626 .into_iter()
627 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
628 .collect();
629
630 let mut expected_buffer = BytesMut::new();
631 expected.to_frame().serialize(&mut expected_buffer);
632 let mut actual_buffer = BytesMut::new();
633 connect_msg.to_frame().serialize(&mut actual_buffer);
634
635 assert_eq!(expected_buffer, actual_buffer);
636 }
637}