lapin/
connection.rs

1use crate::{
2    ConnectionProperties, Error, ErrorKind, Event, Promise, Result,
3    channel::Channel,
4    channels::Channels,
5    configuration::Configuration,
6    connection_closer::ConnectionCloser,
7    connection_status::ConnectionStatus,
8    events::Events,
9    frames::Frames,
10    heartbeat::Heartbeat,
11    internal_rpc::{InternalRPC, InternalRPCHandle},
12    io_loop::IoLoop,
13    socket_state::SocketState,
14    tcp::{AMQPUriTcpExt, HandshakeResult, OwnedTLSConfig},
15    thread::ThreadHandle,
16    types::ReplyCode,
17    uri::AMQPUri,
18};
19use amq_protocol::frame::{AMQPFrame, ProtocolVersion};
20use async_trait::async_trait;
21use futures_core::Stream;
22use std::{fmt, io, sync::Arc};
23use tracing::{Level, level_enabled, trace};
24
25/// A TCP connection to the AMQP server.
26///
27/// To connect to the server, one of the [`connect`] methods has to be called.
28///
29/// Afterwards, create a [`Channel`] by calling [`create_channel`].
30///
31/// Also see the RabbitMQ documentation on [connections](https://www.rabbitmq.com/connections.html).
32///
33/// [`connect`]: ./struct.Connection.html#method.connect
34/// [`Channel`]: ./struct.Channel.html
35/// [`create_channel`]: ./struct.Connection.html#method.create_channel
36pub struct Connection {
37    configuration: Configuration,
38    status: ConnectionStatus,
39    channels: Channels,
40    events: Events,
41    io_loop: ThreadHandle,
42    closer: Arc<ConnectionCloser>,
43}
44
45impl Connection {
46    fn new(
47        configuration: Configuration,
48        status: ConnectionStatus,
49        channels: Channels,
50        internal_rpc: InternalRPCHandle,
51        events: Events,
52    ) -> Self {
53        let closer = Arc::new(ConnectionCloser::new(status.clone(), internal_rpc));
54        Self {
55            configuration,
56            status,
57            channels,
58            events,
59            io_loop: ThreadHandle::default(),
60            closer,
61        }
62    }
63
64    pub(crate) fn for_reconnect(
65        configuration: Configuration,
66        status: ConnectionStatus,
67        channels: Channels,
68        internal_rpc: InternalRPCHandle,
69        events: Events,
70    ) -> Self {
71        let conn = Self::new(configuration, status, channels, internal_rpc, events);
72        conn.closer.noop();
73        conn
74    }
75
76    /// Connect to an AMQP Server.
77    ///
78    /// The URI must be in the following format:
79    ///
80    /// * `amqp://127.0.0.1:5672` will connect to the default virtual host `/`.
81    /// * `amqp://127.0.0.1:5672/` will connect to the virtual host `""` (empty string).
82    /// * `amqp://127.0.0.1:5672/%2f` will connect to the default virtual host `/`.
83    ///
84    /// Note that the virtual host has to be escaped with
85    /// [URL encoding](https://en.wikipedia.org/wiki/Percent-encoding).
86    pub async fn connect(uri: &str, options: ConnectionProperties) -> Result<Connection> {
87        Connect::connect(uri, options, OwnedTLSConfig::default()).await
88    }
89
90    /// Connect to an AMQP Server.
91    pub async fn connect_with_config(
92        uri: &str,
93        options: ConnectionProperties,
94        config: OwnedTLSConfig,
95    ) -> Result<Connection> {
96        Connect::connect(uri, options, config).await
97    }
98
99    /// Connect to an AMQP Server.
100    pub async fn connect_uri(uri: AMQPUri, options: ConnectionProperties) -> Result<Connection> {
101        Connect::connect(uri, options, OwnedTLSConfig::default()).await
102    }
103
104    /// Connect to an AMQP Server
105    pub async fn connect_uri_with_config(
106        uri: AMQPUri,
107        options: ConnectionProperties,
108        config: OwnedTLSConfig,
109    ) -> Result<Connection> {
110        Connect::connect(uri, options, config).await
111    }
112
113    /// Creates a new [`Channel`] on this connection.
114    ///
115    /// This method is only successful if the client is connected.
116    /// Otherwise, [`InvalidConnectionState`] error is returned.
117    ///
118    /// [`Channel`]: ./struct.Channel.html
119    /// [`InvalidConnectionState`]: ./enum.Error.html#variant.InvalidConnectionState
120    pub async fn create_channel(&self) -> Result<Channel> {
121        if !self.status.connected() {
122            return Err(ErrorKind::InvalidConnectionState(self.status.state()).into());
123        }
124        let channel = self.channels.create(self.closer.clone())?;
125        // FIXME: make sure we have a notifier on error+reconnect
126        channel.clone().channel_open(channel).await
127    }
128
129    /// Get a Stream of connection Events
130    pub fn events_listener(&self) -> impl Stream<Item = Event> + Send + 'static {
131        self.events.listener()
132    }
133
134    /// Block current thread while the connection is still active.
135    /// This is useful when you only have a consumer and nothing else keeping your application
136    /// "alive".
137    pub fn run(self) -> Result<()> {
138        let io_loop = self.io_loop.clone();
139        drop(self);
140        io_loop.wait("io loop")
141    }
142
143    #[deprecated(note = "Please use Connection::events_listener instead")]
144    pub fn on_error<E: FnMut(Error) + Send + 'static>(&self, handler: E) {
145        self.channels.set_error_handler(handler);
146    }
147
148    pub fn configuration(&self) -> &Configuration {
149        &self.configuration
150    }
151
152    pub fn status(&self) -> &ConnectionStatus {
153        &self.status
154    }
155
156    /// Request a connection close.
157    ///
158    /// This method is only successful if the connection is in the connected state,
159    /// otherwise an [`InvalidConnectionState`] error is returned.
160    ///
161    /// [`InvalidConnectionState`]: ./enum.Error.html#variant.InvalidConnectionState
162    pub async fn close(&self, reply_code: ReplyCode, reply_text: &str) -> Result<()> {
163        if !self.status.connected() {
164            return Err(ErrorKind::InvalidConnectionState(self.status.state()).into());
165        }
166
167        self.channels.set_connection_closing();
168        self.channels
169            .channel0()
170            .connection_close(reply_code, reply_text, 0, 0)
171            .await
172    }
173
174    /// Block all consumers and publishers on this connection
175    pub async fn block(&self, reason: &str) -> Result<()> {
176        self.channels.channel0().connection_blocked(reason).await
177    }
178
179    /// Unblock all consumers and publishers on this connection
180    pub async fn unblock(&self) -> Result<()> {
181        self.channels.channel0().connection_unblocked().await
182    }
183
184    /// Update the secret used by some authentication module such as OAuth2
185    pub async fn update_secret(&self, new_secret: &str, reason: &str) -> Result<()> {
186        self.channels
187            .channel0()
188            .connection_update_secret(new_secret, reason)
189            .await
190    }
191
192    pub async fn connector(
193        uri: AMQPUri,
194        connect: Box<dyn Fn(&AMQPUri) -> HandshakeResult + Send + Sync>,
195        options: ConnectionProperties,
196    ) -> Result<Connection> {
197        let executor = options.executor()?;
198        let reactor = options.reactor()?;
199        let configuration = Configuration::new(&uri);
200        let status = ConnectionStatus::new(&uri);
201        let frames = Frames::default();
202        let socket_state = SocketState::default();
203        let internal_rpc = InternalRPC::new(executor.clone(), socket_state.handle());
204        let heartbeat = Heartbeat::new(status.clone(), executor.clone(), reactor.clone());
205        let events = Events::new();
206        let channels = Channels::new(
207            configuration.clone(),
208            status.clone(),
209            socket_state.handle(),
210            internal_rpc.handle(),
211            frames.clone(),
212            heartbeat.clone(),
213            executor,
214            uri.clone(),
215            options.clone(),
216            events.clone(),
217        );
218        let conn = Connection::new(
219            configuration,
220            status,
221            channels,
222            internal_rpc.handle(),
223            events,
224        );
225        let io_loop = IoLoop::new(
226            conn.status.clone(),
227            conn.configuration.clone(),
228            conn.channels.clone(),
229            internal_rpc.handle(),
230            frames,
231            socket_state,
232            connect.into(),
233            options.backoff,
234            uri.clone(),
235            heartbeat,
236        );
237
238        internal_rpc.start(conn.channels.clone());
239        conn.io_loop.register(io_loop.start(reactor)?);
240        conn.start(uri, options).await
241    }
242
243    pub(crate) async fn start(
244        self,
245        uri: AMQPUri,
246        options: ConnectionProperties,
247    ) -> Result<Connection> {
248        let (promise_out, resolver_out) = Promise::new();
249        let (promise_in, resolver_in) = Promise::new();
250        if level_enabled!(Level::TRACE) {
251            promise_out.set_marker("ProtocolHeader".into());
252            promise_in.set_marker("ProtocolHeader.Ok".into());
253        }
254        let channel0 = self.channels.channel0();
255
256        trace!("Set connection as connecting");
257        self.status.clone().set_connecting(
258            resolver_out.clone(),
259            resolver_in,
260            self,
261            uri.authority.userinfo.into(),
262            uri.query.auth_mechanism.unwrap_or_default(),
263            options,
264        )?;
265
266        trace!("Sending protocol header to server");
267        channel0.send_frame(
268            AMQPFrame::ProtocolHeader(ProtocolVersion::amqp_0_9_1()),
269            resolver_out,
270            None,
271        );
272
273        promise_out.await?;
274        trace!("Sent protocol header to server, waiting for connection flow");
275        promise_in.await
276    }
277}
278
279impl fmt::Debug for Connection {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        f.debug_struct("Connection")
282            .field("configuration", &self.configuration)
283            .field("status", &self.status)
284            .field("channels", &self.channels)
285            .finish()
286    }
287}
288
289/// Trait providing a method to connect to an AMQP server
290#[async_trait]
291pub trait Connect {
292    /// connect to an AMQP server
293    async fn connect(
294        self,
295        options: ConnectionProperties,
296        config: OwnedTLSConfig,
297    ) -> Result<Connection>;
298}
299
300#[async_trait]
301impl Connect for AMQPUri {
302    async fn connect(
303        self,
304        options: ConnectionProperties,
305        config: OwnedTLSConfig,
306    ) -> Result<Connection> {
307        Connection::connector(
308            self,
309            Box::new(move |uri| AMQPUriTcpExt::connect_with_config(uri, config.as_ref())),
310            options,
311        )
312        .await
313    }
314}
315
316#[async_trait]
317impl Connect for &str {
318    async fn connect(
319        self,
320        options: ConnectionProperties,
321        config: OwnedTLSConfig,
322    ) -> Result<Connection> {
323        match self.parse::<AMQPUri>() {
324            Ok(uri) => Connect::connect(uri, options, config).await,
325            Err(err) => Err(io::Error::other(err).into()),
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::BasicProperties;
334    use crate::channel_receiver_state::{ChannelReceiverState, DeliveryCause};
335    use crate::channel_status::ChannelState;
336    use crate::connection_status::ConnectionState;
337    use crate::options::BasicConsumeOptions;
338    use crate::types::{ChannelId, FieldTable, ShortString};
339    use amq_protocol::frame::AMQPContentHeader;
340    use amq_protocol::protocol::{AMQPClass, basic};
341    use executor_trait::FullExecutor;
342
343    fn create_connection(executor: Arc<dyn FullExecutor + Send + Sync>) -> Connection {
344        let uri = AMQPUri::default();
345        let reactor = Arc::new(async_reactor_trait::AsyncIo);
346        let configuration = Configuration::new(&uri);
347        let status = ConnectionStatus::new(&uri);
348        let frames = Frames::default();
349        let socket_state = SocketState::default();
350        let internal_rpc = InternalRPC::new(executor.clone(), socket_state.handle());
351        let heartbeat = Heartbeat::new(status.clone(), executor.clone(), reactor);
352        let events = Events::new();
353        let channels = Channels::new(
354            configuration.clone(),
355            status.clone(),
356            socket_state.handle(),
357            internal_rpc.handle(),
358            frames.clone(),
359            heartbeat.clone(),
360            executor,
361            uri.clone(),
362            ConnectionProperties::default(),
363            events.clone(),
364        );
365        let conn = Connection::new(
366            configuration,
367            status,
368            channels,
369            internal_rpc.handle(),
370            events,
371        );
372        conn.status.set_state(ConnectionState::Connected);
373        conn
374    }
375
376    #[test]
377    fn channel_limit() {
378        let _ = tracing_subscriber::fmt::try_init();
379
380        // Bootstrap connection state to a consuming state
381        let executor = Arc::new(async_global_executor_trait::AsyncGlobalExecutor);
382        let conn = create_connection(executor.clone());
383        conn.configuration.set_channel_max(ChannelId::MAX);
384        for _ in 1..=ChannelId::MAX {
385            conn.channels.create(conn.closer.clone()).unwrap();
386        }
387
388        assert_eq!(
389            conn.channels.create(conn.closer.clone()),
390            Err(ErrorKind::ChannelsLimitReached.into())
391        );
392    }
393
394    #[test]
395    fn basic_consume_small_payload() {
396        let _ = tracing_subscriber::fmt::try_init();
397
398        use crate::consumer::Consumer;
399
400        // Bootstrap connection state to a consuming state
401        let executor = Arc::new(async_global_executor_trait::AsyncGlobalExecutor);
402        let conn = create_connection(executor.clone());
403        conn.configuration.set_channel_max(2047);
404        let channel = conn.channels.create(conn.closer.clone()).unwrap();
405        channel.set_state(ChannelState::Connected);
406        let queue_name = ShortString::from("consumed");
407        let consumer_tag = ShortString::from("consumer-tag");
408        let consumer = Consumer::new(
409            consumer_tag.clone(),
410            executor,
411            None,
412            queue_name.clone(),
413            BasicConsumeOptions::default(),
414            FieldTable::default(),
415        );
416        if let Some(c) = conn.channels.get(channel.id()) {
417            c.register_consumer(consumer_tag.clone(), consumer);
418            c.register_queue(queue_name.clone(), Default::default(), Default::default());
419        }
420        // Now test the state machine behaviour
421        {
422            let method = AMQPClass::Basic(basic::AMQPMethod::Deliver(basic::Deliver {
423                consumer_tag: consumer_tag.clone(),
424                delivery_tag: 1,
425                redelivered: false,
426                exchange: "".into(),
427                routing_key: queue_name,
428            }));
429            let class_id = method.get_amqp_class_id();
430            let deliver_frame = AMQPFrame::Method(channel.id(), method);
431            conn.channels.handle_frame(deliver_frame).unwrap();
432            let channel_state = channel.status().receiver_state();
433            let expected_state = ChannelReceiverState::WillReceiveContent(
434                class_id,
435                DeliveryCause::Consume(consumer_tag.clone()),
436            );
437            assert_eq!(channel_state, expected_state);
438        }
439        {
440            let header_frame = AMQPFrame::Header(
441                channel.id(),
442                60,
443                Box::new(AMQPContentHeader {
444                    class_id: 60,
445                    body_size: 2,
446                    properties: BasicProperties::default(),
447                }),
448            );
449            conn.channels.handle_frame(header_frame).unwrap();
450            let channel_state = channel.status().receiver_state();
451            let expected_state =
452                ChannelReceiverState::ReceivingContent(DeliveryCause::Consume(consumer_tag), 2);
453            assert_eq!(channel_state, expected_state);
454        }
455        {
456            let body_frame = AMQPFrame::Body(channel.id(), b"{}".to_vec());
457            conn.channels.handle_frame(body_frame).unwrap();
458            let channel_state = channel.status().state();
459            let expected_state = ChannelState::Connected;
460            assert_eq!(channel_state, expected_state);
461        }
462    }
463
464    #[test]
465    fn basic_consume_empty_payload() {
466        let _ = tracing_subscriber::fmt::try_init();
467
468        use crate::consumer::Consumer;
469
470        // Bootstrap connection state to a consuming state
471        let executor = Arc::new(async_global_executor_trait::AsyncGlobalExecutor);
472        let conn = create_connection(executor.clone());
473        conn.configuration.set_channel_max(2047);
474        let channel = conn.channels.create(conn.closer.clone()).unwrap();
475        channel.set_state(ChannelState::Connected);
476        let queue_name = ShortString::from("consumed");
477        let consumer_tag = ShortString::from("consumer-tag");
478        let consumer = Consumer::new(
479            consumer_tag.clone(),
480            executor,
481            None,
482            queue_name.clone(),
483            BasicConsumeOptions::default(),
484            FieldTable::default(),
485        );
486        if let Some(c) = conn.channels.get(channel.id()) {
487            c.register_consumer(consumer_tag.clone(), consumer);
488            c.register_queue(queue_name.clone(), Default::default(), Default::default());
489        }
490        // Now test the state machine behaviour
491        {
492            let method = AMQPClass::Basic(basic::AMQPMethod::Deliver(basic::Deliver {
493                consumer_tag: consumer_tag.clone(),
494                delivery_tag: 1,
495                redelivered: false,
496                exchange: "".into(),
497                routing_key: queue_name,
498            }));
499            let class_id = method.get_amqp_class_id();
500            let deliver_frame = AMQPFrame::Method(channel.id(), method);
501            conn.channels.handle_frame(deliver_frame).unwrap();
502            let channel_state = channel.status().receiver_state();
503            let expected_state = ChannelReceiverState::WillReceiveContent(
504                class_id,
505                DeliveryCause::Consume(consumer_tag),
506            );
507            assert_eq!(channel_state, expected_state);
508        }
509        {
510            let header_frame = AMQPFrame::Header(
511                channel.id(),
512                60,
513                Box::new(AMQPContentHeader {
514                    class_id: 60,
515                    body_size: 0,
516                    properties: BasicProperties::default(),
517                }),
518            );
519            conn.channels.handle_frame(header_frame).unwrap();
520            let channel_state = channel.status().state();
521            let expected_state = ChannelState::Connected;
522            assert_eq!(channel_state, expected_state);
523        }
524    }
525}