1use crate::{
2 ConnectionProperties, Error, ErrorKind, Event, Promise, Result,
3 channel::Channel,
4 channels::Channels,
5 configuration::Configuration,
6 connection_closer::ConnectionCloser,
7 connection_status::ConnectionStatus,
8 events::Events,
9 frames::Frames,
10 heartbeat::Heartbeat,
11 internal_rpc::{InternalRPC, InternalRPCHandle},
12 io_loop::IoLoop,
13 socket_state::SocketState,
14 tcp::{AMQPUriTcpExt, HandshakeResult, OwnedTLSConfig},
15 thread::ThreadHandle,
16 types::ReplyCode,
17 uri::AMQPUri,
18};
19use amq_protocol::frame::{AMQPFrame, ProtocolVersion};
20use async_trait::async_trait;
21use futures_core::Stream;
22use std::{fmt, io, sync::Arc};
23use tracing::{Level, level_enabled, trace};
24
25pub struct Connection {
37 configuration: Configuration,
38 status: ConnectionStatus,
39 channels: Channels,
40 events: Events,
41 io_loop: ThreadHandle,
42 closer: Arc<ConnectionCloser>,
43}
44
45impl Connection {
46 fn new(
47 configuration: Configuration,
48 status: ConnectionStatus,
49 channels: Channels,
50 internal_rpc: InternalRPCHandle,
51 events: Events,
52 ) -> Self {
53 let closer = Arc::new(ConnectionCloser::new(status.clone(), internal_rpc));
54 Self {
55 configuration,
56 status,
57 channels,
58 events,
59 io_loop: ThreadHandle::default(),
60 closer,
61 }
62 }
63
64 pub(crate) fn for_reconnect(
65 configuration: Configuration,
66 status: ConnectionStatus,
67 channels: Channels,
68 internal_rpc: InternalRPCHandle,
69 events: Events,
70 ) -> Self {
71 let conn = Self::new(configuration, status, channels, internal_rpc, events);
72 conn.closer.noop();
73 conn
74 }
75
76 pub async fn connect(uri: &str, options: ConnectionProperties) -> Result<Connection> {
87 Connect::connect(uri, options, OwnedTLSConfig::default()).await
88 }
89
90 pub async fn connect_with_config(
92 uri: &str,
93 options: ConnectionProperties,
94 config: OwnedTLSConfig,
95 ) -> Result<Connection> {
96 Connect::connect(uri, options, config).await
97 }
98
99 pub async fn connect_uri(uri: AMQPUri, options: ConnectionProperties) -> Result<Connection> {
101 Connect::connect(uri, options, OwnedTLSConfig::default()).await
102 }
103
104 pub async fn connect_uri_with_config(
106 uri: AMQPUri,
107 options: ConnectionProperties,
108 config: OwnedTLSConfig,
109 ) -> Result<Connection> {
110 Connect::connect(uri, options, config).await
111 }
112
113 pub async fn create_channel(&self) -> Result<Channel> {
121 if !self.status.connected() {
122 return Err(ErrorKind::InvalidConnectionState(self.status.state()).into());
123 }
124 let channel = self.channels.create(self.closer.clone())?;
125 channel.clone().channel_open(channel).await
127 }
128
129 pub fn events_listener(&self) -> impl Stream<Item = Event> + Send + 'static {
131 self.events.listener()
132 }
133
134 pub fn run(self) -> Result<()> {
138 let io_loop = self.io_loop.clone();
139 drop(self);
140 io_loop.wait("io loop")
141 }
142
143 #[deprecated(note = "Please use Connection::events_listener instead")]
144 pub fn on_error<E: FnMut(Error) + Send + 'static>(&self, handler: E) {
145 self.channels.set_error_handler(handler);
146 }
147
148 pub fn configuration(&self) -> &Configuration {
149 &self.configuration
150 }
151
152 pub fn status(&self) -> &ConnectionStatus {
153 &self.status
154 }
155
156 pub async fn close(&self, reply_code: ReplyCode, reply_text: &str) -> Result<()> {
163 if !self.status.connected() {
164 return Err(ErrorKind::InvalidConnectionState(self.status.state()).into());
165 }
166
167 self.channels.set_connection_closing();
168 self.channels
169 .channel0()
170 .connection_close(reply_code, reply_text, 0, 0)
171 .await
172 }
173
174 pub async fn block(&self, reason: &str) -> Result<()> {
176 self.channels.channel0().connection_blocked(reason).await
177 }
178
179 pub async fn unblock(&self) -> Result<()> {
181 self.channels.channel0().connection_unblocked().await
182 }
183
184 pub async fn update_secret(&self, new_secret: &str, reason: &str) -> Result<()> {
186 self.channels
187 .channel0()
188 .connection_update_secret(new_secret, reason)
189 .await
190 }
191
192 pub async fn connector(
193 uri: AMQPUri,
194 connect: Box<dyn Fn(&AMQPUri) -> HandshakeResult + Send + Sync>,
195 options: ConnectionProperties,
196 ) -> Result<Connection> {
197 let executor = options.executor()?;
198 let reactor = options.reactor()?;
199 let configuration = Configuration::new(&uri);
200 let status = ConnectionStatus::new(&uri);
201 let frames = Frames::default();
202 let socket_state = SocketState::default();
203 let internal_rpc = InternalRPC::new(executor.clone(), socket_state.handle());
204 let heartbeat = Heartbeat::new(status.clone(), executor.clone(), reactor.clone());
205 let events = Events::new();
206 let channels = Channels::new(
207 configuration.clone(),
208 status.clone(),
209 socket_state.handle(),
210 internal_rpc.handle(),
211 frames.clone(),
212 heartbeat.clone(),
213 executor,
214 uri.clone(),
215 options.clone(),
216 events.clone(),
217 );
218 let conn = Connection::new(
219 configuration,
220 status,
221 channels,
222 internal_rpc.handle(),
223 events,
224 );
225 let io_loop = IoLoop::new(
226 conn.status.clone(),
227 conn.configuration.clone(),
228 conn.channels.clone(),
229 internal_rpc.handle(),
230 frames,
231 socket_state,
232 connect.into(),
233 options.backoff,
234 uri.clone(),
235 heartbeat,
236 );
237
238 internal_rpc.start(conn.channels.clone());
239 conn.io_loop.register(io_loop.start(reactor)?);
240 conn.start(uri, options).await
241 }
242
243 pub(crate) async fn start(
244 self,
245 uri: AMQPUri,
246 options: ConnectionProperties,
247 ) -> Result<Connection> {
248 let (promise_out, resolver_out) = Promise::new();
249 let (promise_in, resolver_in) = Promise::new();
250 if level_enabled!(Level::TRACE) {
251 promise_out.set_marker("ProtocolHeader".into());
252 promise_in.set_marker("ProtocolHeader.Ok".into());
253 }
254 let channel0 = self.channels.channel0();
255
256 trace!("Set connection as connecting");
257 self.status.clone().set_connecting(
258 resolver_out.clone(),
259 resolver_in,
260 self,
261 uri.authority.userinfo.into(),
262 uri.query.auth_mechanism.unwrap_or_default(),
263 options,
264 )?;
265
266 trace!("Sending protocol header to server");
267 channel0.send_frame(
268 AMQPFrame::ProtocolHeader(ProtocolVersion::amqp_0_9_1()),
269 resolver_out,
270 None,
271 );
272
273 promise_out.await?;
274 trace!("Sent protocol header to server, waiting for connection flow");
275 promise_in.await
276 }
277}
278
279impl fmt::Debug for Connection {
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 f.debug_struct("Connection")
282 .field("configuration", &self.configuration)
283 .field("status", &self.status)
284 .field("channels", &self.channels)
285 .finish()
286 }
287}
288
289#[async_trait]
291pub trait Connect {
292 async fn connect(
294 self,
295 options: ConnectionProperties,
296 config: OwnedTLSConfig,
297 ) -> Result<Connection>;
298}
299
300#[async_trait]
301impl Connect for AMQPUri {
302 async fn connect(
303 self,
304 options: ConnectionProperties,
305 config: OwnedTLSConfig,
306 ) -> Result<Connection> {
307 Connection::connector(
308 self,
309 Box::new(move |uri| AMQPUriTcpExt::connect_with_config(uri, config.as_ref())),
310 options,
311 )
312 .await
313 }
314}
315
316#[async_trait]
317impl Connect for &str {
318 async fn connect(
319 self,
320 options: ConnectionProperties,
321 config: OwnedTLSConfig,
322 ) -> Result<Connection> {
323 match self.parse::<AMQPUri>() {
324 Ok(uri) => Connect::connect(uri, options, config).await,
325 Err(err) => Err(io::Error::other(err).into()),
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::BasicProperties;
334 use crate::channel_receiver_state::{ChannelReceiverState, DeliveryCause};
335 use crate::channel_status::ChannelState;
336 use crate::connection_status::ConnectionState;
337 use crate::options::BasicConsumeOptions;
338 use crate::types::{ChannelId, FieldTable, ShortString};
339 use amq_protocol::frame::AMQPContentHeader;
340 use amq_protocol::protocol::{AMQPClass, basic};
341 use executor_trait::FullExecutor;
342
343 fn create_connection(executor: Arc<dyn FullExecutor + Send + Sync>) -> Connection {
344 let uri = AMQPUri::default();
345 let reactor = Arc::new(async_reactor_trait::AsyncIo);
346 let configuration = Configuration::new(&uri);
347 let status = ConnectionStatus::new(&uri);
348 let frames = Frames::default();
349 let socket_state = SocketState::default();
350 let internal_rpc = InternalRPC::new(executor.clone(), socket_state.handle());
351 let heartbeat = Heartbeat::new(status.clone(), executor.clone(), reactor);
352 let events = Events::new();
353 let channels = Channels::new(
354 configuration.clone(),
355 status.clone(),
356 socket_state.handle(),
357 internal_rpc.handle(),
358 frames.clone(),
359 heartbeat.clone(),
360 executor,
361 uri.clone(),
362 ConnectionProperties::default(),
363 events.clone(),
364 );
365 let conn = Connection::new(
366 configuration,
367 status,
368 channels,
369 internal_rpc.handle(),
370 events,
371 );
372 conn.status.set_state(ConnectionState::Connected);
373 conn
374 }
375
376 #[test]
377 fn channel_limit() {
378 let _ = tracing_subscriber::fmt::try_init();
379
380 let executor = Arc::new(async_global_executor_trait::AsyncGlobalExecutor);
382 let conn = create_connection(executor.clone());
383 conn.configuration.set_channel_max(ChannelId::MAX);
384 for _ in 1..=ChannelId::MAX {
385 conn.channels.create(conn.closer.clone()).unwrap();
386 }
387
388 assert_eq!(
389 conn.channels.create(conn.closer.clone()),
390 Err(ErrorKind::ChannelsLimitReached.into())
391 );
392 }
393
394 #[test]
395 fn basic_consume_small_payload() {
396 let _ = tracing_subscriber::fmt::try_init();
397
398 use crate::consumer::Consumer;
399
400 let executor = Arc::new(async_global_executor_trait::AsyncGlobalExecutor);
402 let conn = create_connection(executor.clone());
403 conn.configuration.set_channel_max(2047);
404 let channel = conn.channels.create(conn.closer.clone()).unwrap();
405 channel.set_state(ChannelState::Connected);
406 let queue_name = ShortString::from("consumed");
407 let consumer_tag = ShortString::from("consumer-tag");
408 let consumer = Consumer::new(
409 consumer_tag.clone(),
410 executor,
411 None,
412 queue_name.clone(),
413 BasicConsumeOptions::default(),
414 FieldTable::default(),
415 );
416 if let Some(c) = conn.channels.get(channel.id()) {
417 c.register_consumer(consumer_tag.clone(), consumer);
418 c.register_queue(queue_name.clone(), Default::default(), Default::default());
419 }
420 {
422 let method = AMQPClass::Basic(basic::AMQPMethod::Deliver(basic::Deliver {
423 consumer_tag: consumer_tag.clone(),
424 delivery_tag: 1,
425 redelivered: false,
426 exchange: "".into(),
427 routing_key: queue_name,
428 }));
429 let class_id = method.get_amqp_class_id();
430 let deliver_frame = AMQPFrame::Method(channel.id(), method);
431 conn.channels.handle_frame(deliver_frame).unwrap();
432 let channel_state = channel.status().receiver_state();
433 let expected_state = ChannelReceiverState::WillReceiveContent(
434 class_id,
435 DeliveryCause::Consume(consumer_tag.clone()),
436 );
437 assert_eq!(channel_state, expected_state);
438 }
439 {
440 let header_frame = AMQPFrame::Header(
441 channel.id(),
442 60,
443 Box::new(AMQPContentHeader {
444 class_id: 60,
445 body_size: 2,
446 properties: BasicProperties::default(),
447 }),
448 );
449 conn.channels.handle_frame(header_frame).unwrap();
450 let channel_state = channel.status().receiver_state();
451 let expected_state =
452 ChannelReceiverState::ReceivingContent(DeliveryCause::Consume(consumer_tag), 2);
453 assert_eq!(channel_state, expected_state);
454 }
455 {
456 let body_frame = AMQPFrame::Body(channel.id(), b"{}".to_vec());
457 conn.channels.handle_frame(body_frame).unwrap();
458 let channel_state = channel.status().state();
459 let expected_state = ChannelState::Connected;
460 assert_eq!(channel_state, expected_state);
461 }
462 }
463
464 #[test]
465 fn basic_consume_empty_payload() {
466 let _ = tracing_subscriber::fmt::try_init();
467
468 use crate::consumer::Consumer;
469
470 let executor = Arc::new(async_global_executor_trait::AsyncGlobalExecutor);
472 let conn = create_connection(executor.clone());
473 conn.configuration.set_channel_max(2047);
474 let channel = conn.channels.create(conn.closer.clone()).unwrap();
475 channel.set_state(ChannelState::Connected);
476 let queue_name = ShortString::from("consumed");
477 let consumer_tag = ShortString::from("consumer-tag");
478 let consumer = Consumer::new(
479 consumer_tag.clone(),
480 executor,
481 None,
482 queue_name.clone(),
483 BasicConsumeOptions::default(),
484 FieldTable::default(),
485 );
486 if let Some(c) = conn.channels.get(channel.id()) {
487 c.register_consumer(consumer_tag.clone(), consumer);
488 c.register_queue(queue_name.clone(), Default::default(), Default::default());
489 }
490 {
492 let method = AMQPClass::Basic(basic::AMQPMethod::Deliver(basic::Deliver {
493 consumer_tag: consumer_tag.clone(),
494 delivery_tag: 1,
495 redelivered: false,
496 exchange: "".into(),
497 routing_key: queue_name,
498 }));
499 let class_id = method.get_amqp_class_id();
500 let deliver_frame = AMQPFrame::Method(channel.id(), method);
501 conn.channels.handle_frame(deliver_frame).unwrap();
502 let channel_state = channel.status().receiver_state();
503 let expected_state = ChannelReceiverState::WillReceiveContent(
504 class_id,
505 DeliveryCause::Consume(consumer_tag),
506 );
507 assert_eq!(channel_state, expected_state);
508 }
509 {
510 let header_frame = AMQPFrame::Header(
511 channel.id(),
512 60,
513 Box::new(AMQPContentHeader {
514 class_id: 60,
515 body_size: 0,
516 properties: BasicProperties::default(),
517 }),
518 );
519 conn.channels.handle_frame(header_frame).unwrap();
520 let channel_state = channel.status().state();
521 let expected_state = ChannelState::Connected;
522 assert_eq!(channel_state, expected_state);
523 }
524 }
525}