async_stomp/
client.rs

1use bytes::{Buf, BytesMut};
2use futures::prelude::*;
3use futures::sink::SinkExt;
4
5use tokio::net::TcpStream;
6use tokio_util::codec::{Decoder, Encoder, Framed};
7use typed_builder::TypedBuilder;
8use winnow::error::ErrMode;
9use winnow::stream::Offset;
10use winnow::Partial;
11
12pub type ClientTransport = Framed<TcpStream, ClientCodec>;
13
14use crate::frame;
15use crate::{FromServer, Message, Result, ToServer};
16use anyhow::{anyhow, bail};
17
18/// Create a connection to a STOMP server via TCP, including the connection handshake.
19/// If successful, returns a tuple of a message stream and a sender,
20/// which may be used to receive and send messages respectively.
21///
22/// `virtualhost` If no specific virtualhost is desired, it is recommended
23/// to set this to the same as the host name that the socket
24/// was established against (i.e, the same as the server address).
25///
26/// # Examples
27///
28/// ```rust,no_run
29/// use async_stomp::client::Connector;
30///
31///#[tokio::main]
32/// async fn main() {
33///   let connection = Connector::builder()
34///     .server("stomp.example.com")
35///     .virtualhost("stomp.example.com")
36///     .login("guest".to_string())
37///     .passcode("guest".to_string())
38///     .connect()
39///     .await;
40///}
41/// ```
42#[derive(TypedBuilder)]
43#[builder(build_method(vis="", name=__build))]
44pub struct Connector<S: tokio::net::ToSocketAddrs, V: Into<String>> {
45    /// The address to the stomp server
46    server: S,
47    /// Virtualhost, if no specific virtualhost is desired, it is recommended
48    /// to set this to the same as the host name that the socket
49    virtualhost: V,
50    /// Username to use for optional authentication to the server
51    #[builder(default, setter(strip_option))]
52    login: Option<String>,
53    /// Passcode to use for optional authentication to the server
54    #[builder(default, setter(strip_option))]
55    passcode: Option<String>,
56    /// Custom headers to be sent to the server
57    #[builder(default)]
58    headers: Vec<(String, String)>,
59}
60
61#[allow(non_camel_case_types)]
62impl<
63        S: tokio::net::ToSocketAddrs,
64        V: Into<String>,
65        __login: ::typed_builder::Optional<Option<String>>,
66        __passcode: ::typed_builder::Optional<Option<String>>,
67        __headers: ::typed_builder::Optional<Vec<(String, String)>>,
68    > ConnectorBuilder<S, V, ((S,), (V,), __login, __passcode, __headers)>
69{
70    pub async fn connect(self) -> Result<ClientTransport> {
71        let connector = self.__build();
72        connector.connect().await
73    }
74
75    pub fn msg(self) -> Message<ToServer> {
76        let connector = self.__build();
77        connector.msg()
78    }
79}
80
81impl<S: tokio::net::ToSocketAddrs, V: Into<String>> Connector<S, V> {
82    pub async fn connect(self) -> Result<ClientTransport> {
83        let tcp = TcpStream::connect(self.server).await?;
84        let mut transport = ClientCodec.framed(tcp);
85        client_handshake(
86            &mut transport,
87            self.virtualhost.into(),
88            self.login,
89            self.passcode,
90            self.headers,
91        )
92        .await?;
93        Ok(transport)
94    }
95
96    pub fn msg(self) -> Message<ToServer> {
97        let extra_headers = self
98            .headers
99            .into_iter()
100            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
101            .collect();
102        Message {
103            content: ToServer::Connect {
104                accept_version: "1.2".into(),
105                host: self.virtualhost.into(),
106                login: self.login,
107                passcode: self.passcode,
108                heartbeat: None,
109            },
110            extra_headers,
111        }
112    }
113}
114
115async fn client_handshake(
116    transport: &mut ClientTransport,
117    virtualhost: String,
118    login: Option<String>,
119    passcode: Option<String>,
120    headers: Vec<(String, String)>,
121) -> Result<()> {
122    let extra_headers = headers
123        .iter()
124        .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
125        .collect();
126    let connect = Message {
127        content: ToServer::Connect {
128            accept_version: "1.2".into(),
129            host: virtualhost,
130            login,
131            passcode,
132            heartbeat: None,
133        },
134        extra_headers,
135    };
136    // Send the message
137    transport.send(connect).await?;
138    // Receive reply
139    let msg = transport.next().await.transpose()?;
140    if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
141        Ok(())
142    } else {
143        Err(anyhow!("unexpected reply: {:?}", msg))
144    }
145}
146
147/// Builder to create a Subscribe message with optional custom headers
148///
149/// # Examples
150///
151/// ```rust,no_run
152/// use futures::prelude::*;
153/// use async_stomp::client::Connector;
154/// use async_stomp::client::Subscriber;
155///
156///
157/// #[tokio::main]
158/// async fn main() -> Result<(), anyhow::Error> {
159///   let mut connection = Connector::builder()
160///     .server("stomp.example.com")
161///     .virtualhost("stomp.example.com")
162///     .login("guest".to_string())
163///     .passcode("guest".to_string())
164///     .headers(vec![("client-id".to_string(), "ClientTest".to_string())])
165///     .connect()
166///     .await.expect("Client connection");
167///   
168///   let subscribe_msg = Subscriber::builder()
169///     .destination("queue.test")
170///     .id("custom-subscriber-id")
171///     .subscribe();
172///
173///   connection.send(subscribe_msg).await?;
174///   Ok(())
175/// }
176/// ```
177#[derive(TypedBuilder)]
178#[builder(build_method(vis="", name=__build))]
179pub struct Subscriber<S: Into<String>, I: Into<String>> {
180    destination: S,
181    id: I,
182    #[builder(default)]
183    headers: Vec<(String, String)>,
184}
185
186#[allow(non_camel_case_types)]
187impl<
188        S: Into<String>,
189        I: Into<String>,
190        __headers: ::typed_builder::Optional<Vec<(String, String)>>,
191    > SubscriberBuilder<S, I, ((S,), (I,), __headers)>
192{
193    pub fn subscribe(self) -> Message<ToServer> {
194        let subscriber = self.__build();
195        subscriber.subscribe()
196    }
197}
198
199impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
200    pub fn subscribe(self) -> Message<ToServer> {
201        let mut msg: Message<ToServer> = ToServer::Subscribe {
202            destination: self.destination.into(),
203            id: self.id.into(),
204            ack: None,
205        }
206        .into();
207        msg.extra_headers = self
208            .headers
209            .iter()
210            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
211            .collect();
212        msg
213    }
214}
215
216pub struct ClientCodec;
217
218impl Decoder for ClientCodec {
219    type Item = Message<FromServer>;
220    type Error = anyhow::Error;
221
222    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
223        let buf = &mut Partial::new(src.chunk());
224        let item = match frame::parse_frame(buf) {
225            Ok(frame) => Message::<FromServer>::from_frame(frame),
226            Err(ErrMode::Incomplete(_)) => return Ok(None),
227            Err(e) => bail!("Parse failed: {:?}", e),
228        };
229        let len = buf.offset_from(&Partial::new(src.chunk()));
230        src.advance(len);
231        item.map(Some)
232    }
233}
234
235impl Encoder<Message<ToServer>> for ClientCodec {
236    type Error = anyhow::Error;
237
238    fn encode(
239        &mut self,
240        item: Message<ToServer>,
241        dst: &mut BytesMut,
242    ) -> std::result::Result<(), Self::Error> {
243        item.to_frame().serialize(dst);
244        Ok(())
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use std::time::Duration;
251
252    use bytes::BytesMut;
253    use futures::{future::ok, SinkExt, StreamExt, TryStreamExt};
254
255    use crate::{
256        client::{Connector, Subscriber},
257        FromServer, Message, ToServer,
258    };
259
260    #[test]
261    fn subscription_message() {
262        let headers = vec![(
263            "activemq.subscriptionName".to_string(),
264            "ClientTest".to_string(),
265        )];
266        let subscribe_msg = Subscriber::builder()
267            .destination("queue.test")
268            .id("custom-subscriber-id")
269            .headers(headers.clone())
270            .subscribe();
271        let mut expected: Message<ToServer> = ToServer::Subscribe {
272            destination: "queue.test".to_string(),
273            id: "custom-subscriber-id".to_string(),
274            ack: None,
275        }
276        .into();
277        expected.extra_headers = headers
278            .into_iter()
279            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
280            .collect();
281
282        let mut expected_buffer = BytesMut::new();
283        expected.to_frame().serialize(&mut expected_buffer);
284        let mut actual_buffer = BytesMut::new();
285        subscribe_msg.to_frame().serialize(&mut actual_buffer);
286
287        assert_eq!(expected_buffer, actual_buffer);
288    }
289
290    #[test]
291    fn connection_message() {
292        let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
293        let connect_msg = Connector::builder()
294            .server("stomp.example.com")
295            .virtualhost("virtual.stomp.example.com")
296            .login("guest_login".to_string())
297            .passcode("guest_passcode".to_string())
298            .headers(headers.clone())
299            .msg();
300
301        let mut expected: Message<ToServer> = ToServer::Connect {
302            accept_version: "1.2".into(),
303            host: "virtual.stomp.example.com".into(),
304            login: Some("guest_login".to_string()),
305            passcode: Some("guest_passcode".to_string()),
306            heartbeat: None,
307        }
308        .into();
309        expected.extra_headers = headers
310            .into_iter()
311            .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
312            .collect();
313
314        let mut expected_buffer = BytesMut::new();
315        expected.to_frame().serialize(&mut expected_buffer);
316        let mut actual_buffer = BytesMut::new();
317        connect_msg.to_frame().serialize(&mut actual_buffer);
318
319        assert_eq!(expected_buffer, actual_buffer);
320    }
321
322    // Test to send a message
323    #[tokio::test]
324    #[ignore]
325    async fn test_session() {
326        let mut conn = Connector::builder()
327            .server("localhost:61613")
328            .virtualhost("/")
329            .login("artemis".to_string())
330            .passcode("artemis".to_string())
331            .connect()
332            .await
333            .expect("Default connection to localhost");
334
335        let msg = crate::Message {
336            content: ToServer::Send {
337                destination: "/test/a".to_string(),
338                transaction: None,
339                headers: Some(vec![("header-a".to_string(), "value-a".to_string())]),
340                body: Some("This is a test message".as_bytes().to_vec()),
341            },
342            extra_headers: vec![],
343        };
344        conn.send(msg).await.expect("Send a");
345        let msg = crate::Message {
346            content: ToServer::Send {
347                destination: "/test/b".to_string(),
348                transaction: None,
349                headers: Some(vec![("header-b".to_string(), "value-b".to_string())]),
350                body: Some("This is a another test message".as_bytes().to_vec()),
351            },
352            extra_headers: vec![],
353        };
354        conn.send(msg).await.expect("Send b");
355    }
356
357    // Test to recieve a message
358    #[tokio::test]
359    #[ignore]
360    async fn test_subscribe() {
361        let sub_msg = Subscriber::builder()
362            .destination("/test/a")
363            .id("tjo")
364            .subscribe();
365
366        let mut conn = Connector::builder()
367            .server("localhost:61613")
368            .virtualhost("/")
369            .login("artemis".to_string())
370            .passcode("artemis".to_string())
371            .connect()
372            .await
373            .expect("Default connection to localhost");
374
375        conn.send(sub_msg).await.expect("Send subscribe");
376        let (_sink, stream) = conn.split();
377
378        let mut cnt = 0;
379        let _ = stream
380            .try_for_each(|item| {
381                println!("==== {cnt}");
382                cnt += 1;
383                if let FromServer::Message { body, .. } = item.content {
384                    println!(
385                        "Message received: {:?}",
386                        String::from_utf8_lossy(&body.unwrap())
387                    );
388                } else {
389                    println!("{:?}", item);
390                }
391                ok(())
392            })
393            .await;
394    }
395
396    // Test to send and recieve message
397    #[tokio::test]
398    #[ignore]
399    async fn test_send_subscribe() {
400        let conn = Connector::builder()
401            .server("localhost:61613")
402            .virtualhost("/")
403            .login("artemis".to_string())
404            .passcode("artemis".to_string())
405            .connect()
406            .await
407            .expect("Default connection to localhost");
408
409        tokio::time::sleep(Duration::from_millis(200)).await;
410
411        let (mut sink, stream) = conn.split();
412
413        let fut1 = async move {
414            let subscribe = Subscriber::builder()
415                .destination("rusty")
416                .id("myid")
417                .subscribe();
418
419            sink.send(subscribe).await?;
420            println!("Subscribe sent");
421
422            tokio::time::sleep(Duration::from_millis(200)).await;
423
424            sink.send(
425                ToServer::Send {
426                    destination: "rusty".into(),
427                    transaction: None,
428                    headers: None,
429                    body: Some(b"Hello there rustaceans!".to_vec()),
430                }
431                .into(),
432            )
433            .await?;
434            println!("Message sent");
435
436            tokio::time::sleep(Duration::from_millis(200)).await;
437
438            sink.send(ToServer::Unsubscribe { id: "myid".into() }.into())
439                .await?;
440            println!("Unsubscribe sent");
441
442            tokio::time::sleep(Duration::from_millis(200)).await;
443
444            tokio::time::sleep(Duration::from_secs(1)).await;
445            sink.send(ToServer::Disconnect { receipt: None }.into())
446                .await?;
447            println!("Disconnect sent");
448
449            Ok(())
450        };
451
452        // Listen from the main thread. Once the Disconnect message is sent from
453        // the sender thread, the server will disconnect the client and the future
454        // will resolve, ending the program
455        let fut2 = stream.try_for_each(|item| {
456            if let FromServer::Message { body, .. } = item.content {
457                println!(
458                    "Message received: {:?}",
459                    String::from_utf8_lossy(&body.unwrap())
460                );
461            } else {
462                println!("{:?}", item);
463            }
464            ok(())
465        });
466
467        futures::future::select(Box::pin(fut1), Box::pin(fut2))
468            .await
469            .factor_first()
470            .0
471            .expect("Select");
472    }
473}