1use bytes::{Buf, BytesMut};
2use futures::prelude::*;
3use futures::sink::SinkExt;
4
5use tokio::net::TcpStream;
6use tokio_util::codec::{Decoder, Encoder, Framed};
7use typed_builder::TypedBuilder;
8use winnow::error::ErrMode;
9use winnow::stream::Offset;
10use winnow::Partial;
11
12pub type ClientTransport = Framed<TcpStream, ClientCodec>;
13
14use crate::frame;
15use crate::{FromServer, Message, Result, ToServer};
16use anyhow::{anyhow, bail};
17
18#[derive(TypedBuilder)]
43#[builder(build_method(vis="", name=__build))]
44pub struct Connector<S: tokio::net::ToSocketAddrs, V: Into<String>> {
45 server: S,
47 virtualhost: V,
50 #[builder(default, setter(strip_option))]
52 login: Option<String>,
53 #[builder(default, setter(strip_option))]
55 passcode: Option<String>,
56 #[builder(default)]
58 headers: Vec<(String, String)>,
59}
60
61#[allow(non_camel_case_types)]
62impl<
63 S: tokio::net::ToSocketAddrs,
64 V: Into<String>,
65 __login: ::typed_builder::Optional<Option<String>>,
66 __passcode: ::typed_builder::Optional<Option<String>>,
67 __headers: ::typed_builder::Optional<Vec<(String, String)>>,
68 > ConnectorBuilder<S, V, ((S,), (V,), __login, __passcode, __headers)>
69{
70 pub async fn connect(self) -> Result<ClientTransport> {
71 let connector = self.__build();
72 connector.connect().await
73 }
74
75 pub fn msg(self) -> Message<ToServer> {
76 let connector = self.__build();
77 connector.msg()
78 }
79}
80
81impl<S: tokio::net::ToSocketAddrs, V: Into<String>> Connector<S, V> {
82 pub async fn connect(self) -> Result<ClientTransport> {
83 let tcp = TcpStream::connect(self.server).await?;
84 let mut transport = ClientCodec.framed(tcp);
85 client_handshake(
86 &mut transport,
87 self.virtualhost.into(),
88 self.login,
89 self.passcode,
90 self.headers,
91 )
92 .await?;
93 Ok(transport)
94 }
95
96 pub fn msg(self) -> Message<ToServer> {
97 let extra_headers = self
98 .headers
99 .into_iter()
100 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
101 .collect();
102 Message {
103 content: ToServer::Connect {
104 accept_version: "1.2".into(),
105 host: self.virtualhost.into(),
106 login: self.login,
107 passcode: self.passcode,
108 heartbeat: None,
109 },
110 extra_headers,
111 }
112 }
113}
114
115async fn client_handshake(
116 transport: &mut ClientTransport,
117 virtualhost: String,
118 login: Option<String>,
119 passcode: Option<String>,
120 headers: Vec<(String, String)>,
121) -> Result<()> {
122 let extra_headers = headers
123 .iter()
124 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
125 .collect();
126 let connect = Message {
127 content: ToServer::Connect {
128 accept_version: "1.2".into(),
129 host: virtualhost,
130 login,
131 passcode,
132 heartbeat: None,
133 },
134 extra_headers,
135 };
136 transport.send(connect).await?;
138 let msg = transport.next().await.transpose()?;
140 if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
141 Ok(())
142 } else {
143 Err(anyhow!("unexpected reply: {:?}", msg))
144 }
145}
146
147#[derive(TypedBuilder)]
178#[builder(build_method(vis="", name=__build))]
179pub struct Subscriber<S: Into<String>, I: Into<String>> {
180 destination: S,
181 id: I,
182 #[builder(default)]
183 headers: Vec<(String, String)>,
184}
185
186#[allow(non_camel_case_types)]
187impl<
188 S: Into<String>,
189 I: Into<String>,
190 __headers: ::typed_builder::Optional<Vec<(String, String)>>,
191 > SubscriberBuilder<S, I, ((S,), (I,), __headers)>
192{
193 pub fn subscribe(self) -> Message<ToServer> {
194 let subscriber = self.__build();
195 subscriber.subscribe()
196 }
197}
198
199impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
200 pub fn subscribe(self) -> Message<ToServer> {
201 let mut msg: Message<ToServer> = ToServer::Subscribe {
202 destination: self.destination.into(),
203 id: self.id.into(),
204 ack: None,
205 }
206 .into();
207 msg.extra_headers = self
208 .headers
209 .iter()
210 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
211 .collect();
212 msg
213 }
214}
215
216pub struct ClientCodec;
217
218impl Decoder for ClientCodec {
219 type Item = Message<FromServer>;
220 type Error = anyhow::Error;
221
222 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
223 let buf = &mut Partial::new(src.chunk());
224 let item = match frame::parse_frame(buf) {
225 Ok(frame) => Message::<FromServer>::from_frame(frame),
226 Err(ErrMode::Incomplete(_)) => return Ok(None),
227 Err(e) => bail!("Parse failed: {:?}", e),
228 };
229 let len = buf.offset_from(&Partial::new(src.chunk()));
230 src.advance(len);
231 item.map(Some)
232 }
233}
234
235impl Encoder<Message<ToServer>> for ClientCodec {
236 type Error = anyhow::Error;
237
238 fn encode(
239 &mut self,
240 item: Message<ToServer>,
241 dst: &mut BytesMut,
242 ) -> std::result::Result<(), Self::Error> {
243 item.to_frame().serialize(dst);
244 Ok(())
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use std::time::Duration;
251
252 use bytes::BytesMut;
253 use futures::{future::ok, SinkExt, StreamExt, TryStreamExt};
254
255 use crate::{
256 client::{Connector, Subscriber},
257 FromServer, Message, ToServer,
258 };
259
260 #[test]
261 fn subscription_message() {
262 let headers = vec![(
263 "activemq.subscriptionName".to_string(),
264 "ClientTest".to_string(),
265 )];
266 let subscribe_msg = Subscriber::builder()
267 .destination("queue.test")
268 .id("custom-subscriber-id")
269 .headers(headers.clone())
270 .subscribe();
271 let mut expected: Message<ToServer> = ToServer::Subscribe {
272 destination: "queue.test".to_string(),
273 id: "custom-subscriber-id".to_string(),
274 ack: None,
275 }
276 .into();
277 expected.extra_headers = headers
278 .into_iter()
279 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
280 .collect();
281
282 let mut expected_buffer = BytesMut::new();
283 expected.to_frame().serialize(&mut expected_buffer);
284 let mut actual_buffer = BytesMut::new();
285 subscribe_msg.to_frame().serialize(&mut actual_buffer);
286
287 assert_eq!(expected_buffer, actual_buffer);
288 }
289
290 #[test]
291 fn connection_message() {
292 let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
293 let connect_msg = Connector::builder()
294 .server("stomp.example.com")
295 .virtualhost("virtual.stomp.example.com")
296 .login("guest_login".to_string())
297 .passcode("guest_passcode".to_string())
298 .headers(headers.clone())
299 .msg();
300
301 let mut expected: Message<ToServer> = ToServer::Connect {
302 accept_version: "1.2".into(),
303 host: "virtual.stomp.example.com".into(),
304 login: Some("guest_login".to_string()),
305 passcode: Some("guest_passcode".to_string()),
306 heartbeat: None,
307 }
308 .into();
309 expected.extra_headers = headers
310 .into_iter()
311 .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
312 .collect();
313
314 let mut expected_buffer = BytesMut::new();
315 expected.to_frame().serialize(&mut expected_buffer);
316 let mut actual_buffer = BytesMut::new();
317 connect_msg.to_frame().serialize(&mut actual_buffer);
318
319 assert_eq!(expected_buffer, actual_buffer);
320 }
321
322 #[tokio::test]
324 #[ignore]
325 async fn test_session() {
326 let mut conn = Connector::builder()
327 .server("localhost:61613")
328 .virtualhost("/")
329 .login("artemis".to_string())
330 .passcode("artemis".to_string())
331 .connect()
332 .await
333 .expect("Default connection to localhost");
334
335 let msg = crate::Message {
336 content: ToServer::Send {
337 destination: "/test/a".to_string(),
338 transaction: None,
339 headers: Some(vec![("header-a".to_string(), "value-a".to_string())]),
340 body: Some("This is a test message".as_bytes().to_vec()),
341 },
342 extra_headers: vec![],
343 };
344 conn.send(msg).await.expect("Send a");
345 let msg = crate::Message {
346 content: ToServer::Send {
347 destination: "/test/b".to_string(),
348 transaction: None,
349 headers: Some(vec![("header-b".to_string(), "value-b".to_string())]),
350 body: Some("This is a another test message".as_bytes().to_vec()),
351 },
352 extra_headers: vec![],
353 };
354 conn.send(msg).await.expect("Send b");
355 }
356
357 #[tokio::test]
359 #[ignore]
360 async fn test_subscribe() {
361 let sub_msg = Subscriber::builder()
362 .destination("/test/a")
363 .id("tjo")
364 .subscribe();
365
366 let mut conn = Connector::builder()
367 .server("localhost:61613")
368 .virtualhost("/")
369 .login("artemis".to_string())
370 .passcode("artemis".to_string())
371 .connect()
372 .await
373 .expect("Default connection to localhost");
374
375 conn.send(sub_msg).await.expect("Send subscribe");
376 let (_sink, stream) = conn.split();
377
378 let mut cnt = 0;
379 let _ = stream
380 .try_for_each(|item| {
381 println!("==== {cnt}");
382 cnt += 1;
383 if let FromServer::Message { body, .. } = item.content {
384 println!(
385 "Message received: {:?}",
386 String::from_utf8_lossy(&body.unwrap())
387 );
388 } else {
389 println!("{:?}", item);
390 }
391 ok(())
392 })
393 .await;
394 }
395
396 #[tokio::test]
398 #[ignore]
399 async fn test_send_subscribe() {
400 let conn = Connector::builder()
401 .server("localhost:61613")
402 .virtualhost("/")
403 .login("artemis".to_string())
404 .passcode("artemis".to_string())
405 .connect()
406 .await
407 .expect("Default connection to localhost");
408
409 tokio::time::sleep(Duration::from_millis(200)).await;
410
411 let (mut sink, stream) = conn.split();
412
413 let fut1 = async move {
414 let subscribe = Subscriber::builder()
415 .destination("rusty")
416 .id("myid")
417 .subscribe();
418
419 sink.send(subscribe).await?;
420 println!("Subscribe sent");
421
422 tokio::time::sleep(Duration::from_millis(200)).await;
423
424 sink.send(
425 ToServer::Send {
426 destination: "rusty".into(),
427 transaction: None,
428 headers: None,
429 body: Some(b"Hello there rustaceans!".to_vec()),
430 }
431 .into(),
432 )
433 .await?;
434 println!("Message sent");
435
436 tokio::time::sleep(Duration::from_millis(200)).await;
437
438 sink.send(ToServer::Unsubscribe { id: "myid".into() }.into())
439 .await?;
440 println!("Unsubscribe sent");
441
442 tokio::time::sleep(Duration::from_millis(200)).await;
443
444 tokio::time::sleep(Duration::from_secs(1)).await;
445 sink.send(ToServer::Disconnect { receipt: None }.into())
446 .await?;
447 println!("Disconnect sent");
448
449 Ok(())
450 };
451
452 let fut2 = stream.try_for_each(|item| {
456 if let FromServer::Message { body, .. } = item.content {
457 println!(
458 "Message received: {:?}",
459 String::from_utf8_lossy(&body.unwrap())
460 );
461 } else {
462 println!("{:?}", item);
463 }
464 ok(())
465 });
466
467 futures::future::select(Box::pin(fut1), Box::pin(fut2))
468 .await
469 .factor_first()
470 .0
471 .expect("Select");
472 }
473}