1use std::{fmt, future::Future, ops, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use ntex_io::{IoConfig, IoRef};
4use ntex_service::cfg::Cfg;
5use ntex_util::channel::{condition::Condition, condition::Waiter, oneshot};
6use ntex_util::{HashMap, future::Ready};
7
8use crate::codec::protocol::{self as codec, Begin, Close, End, Error, Frame, Role};
9use crate::codec::{AmqpCodec, AmqpFrame, types};
10use crate::control::ControlQueue;
11use crate::session::{INITIAL_NEXT_OUTGOING_ID, Session, SessionInner};
12use crate::sndlink::{SenderLink, SenderLinkInner};
13use crate::{
14 AmqpServiceConfig, RemoteServiceConfig, cell::Cell, error::AmqpProtocolError, types::Action,
15};
16
17pub struct Connection(ConnectionRef);
18
19#[derive(Clone)]
20pub struct ConnectionRef(pub(crate) Cell<ConnectionInner>);
21
22#[derive(Debug)]
23pub(crate) struct ConnectionInner {
24 io: IoRef,
25 state: ConnectionState,
26 codec: AmqpCodec<AmqpFrame>,
27 control_queue: Rc<ControlQueue>,
28 pub(crate) sessions: slab::Slab<SessionState>,
29 pub(crate) sessions_map: HashMap<u16, usize>,
30 pub(crate) on_close: Condition,
31 pub(crate) error: Option<AmqpProtocolError>,
32 channel_max: u16,
33 pub(crate) max_frame_size: u32,
34}
35
36#[derive(Debug)]
37pub(crate) enum SessionState {
38 Opening(Option<oneshot::Sender<Session>>, Cell<ConnectionInner>),
39 Established(Cell<SessionInner>),
40 Closing(Cell<SessionInner>),
41}
42
43impl SessionState {
44 fn is_opening(&self) -> bool {
45 matches!(self, SessionState::Opening(_, _))
46 }
47}
48
49#[derive(Clone, Copy, Debug, PartialEq)]
50pub(crate) enum ConnectionState {
51 Normal,
52 Closing,
53 RemoteClose,
54 Drop,
55}
56
57impl Connection {
58 pub(crate) fn new(
59 io: IoRef,
60 local_config: &Cfg<AmqpServiceConfig>,
61 remote_config: &RemoteServiceConfig,
62 ) -> Connection {
63 Connection(ConnectionRef(Cell::new(ConnectionInner {
64 io,
65 codec: AmqpCodec::new(),
66 state: ConnectionState::Normal,
67 sessions: slab::Slab::with_capacity(8),
68 sessions_map: HashMap::default(),
69 control_queue: Rc::default(),
70 error: None,
71 on_close: Condition::new(),
72 channel_max: local_config.channel_max,
73 max_frame_size: remote_config.max_frame_size,
74 })))
75 }
76
77 pub fn get_ref(&self) -> ConnectionRef {
78 self.0.clone()
79 }
80}
81
82impl AsRef<ConnectionRef> for Connection {
83 #[inline]
84 fn as_ref(&self) -> &ConnectionRef {
85 &self.0
86 }
87}
88
89impl ops::Deref for Connection {
90 type Target = ConnectionRef;
91
92 #[inline]
93 fn deref(&self) -> &Self::Target {
94 &self.0
95 }
96}
97
98impl Drop for Connection {
99 fn drop(&mut self) {
100 self.0.force_close();
101 }
102}
103
104impl ConnectionRef {
105 #[inline]
106 pub fn tag(&self) -> &'static str {
108 self.0.get_ref().io.tag()
109 }
110
111 #[inline]
112 pub fn config(&self) -> &IoConfig {
114 self.0.get_ref().io.cfg()
115 }
116
117 #[inline]
118 pub fn force_close(&self) {
120 let inner = self.0.get_mut();
121 inner.state = ConnectionState::Drop;
122 inner.io.force_close();
123 inner.set_error(AmqpProtocolError::ConnectionDropped);
124 }
125
126 #[inline]
127 pub fn is_opened(&self) -> bool {
129 let inner = self.0.get_mut();
130 if inner.state != ConnectionState::Normal {
131 return false;
132 }
133 inner.error.is_none() && !inner.io.is_closed()
134 }
135
136 pub fn on_close(&self) -> Waiter {
138 self.0.get_ref().on_close.wait()
139 }
140
141 pub fn get_error(&self) -> Option<AmqpProtocolError> {
143 self.0.get_ref().error.clone()
144 }
145
146 pub fn get_session_by_local_id(&self, channel: u16) -> Option<Session> {
148 if let Some(SessionState::Established(inner)) =
149 self.0.get_ref().sessions.get(channel as usize)
150 {
151 Some(Session::new(inner.clone()))
152 } else {
153 None
154 }
155 }
156
157 pub fn close(&self) -> impl Future<Output = Result<(), AmqpProtocolError>> {
159 let inner = self.0.get_mut();
160 inner.post_frame(AmqpFrame::new(0, Frame::Close(Close { error: None })));
161 inner.io.close();
162 Ready::Ok(())
163 }
164
165 pub fn close_with_error<E>(&self, err: E) -> impl Future<Output = Result<(), AmqpProtocolError>>
167 where
168 Error: From<E>,
169 {
170 let inner = self.0.get_mut();
171 inner.post_frame(AmqpFrame::new(
172 0,
173 Frame::Close(Close {
174 error: Some(err.into()),
175 }),
176 ));
177 inner.io.close();
178 Ready::Ok(())
179 }
180
181 pub fn open_session(&self) -> OpenSession {
183 OpenSession::new(self.0.clone())
184 }
185
186 pub(crate) fn close_session(&self, id: usize) {
187 if let Some(state) = self.0.get_mut().sessions.get_mut(id)
188 && let SessionState::Established(inner) = state
189 {
190 *state = SessionState::Closing(inner.clone());
191 }
192 }
193
194 pub(crate) fn post_frame(&self, frame: AmqpFrame) {
195 let inner = self.0.get_mut();
196
197 #[cfg(feature = "frame-trace")]
198 log::trace!("{}: outgoing: {:#?}", inner.io.tag(), frame);
199
200 if let Err(e) = inner.io.encode(frame, &inner.codec) {
201 inner.set_error(e.into());
202 }
203 }
204
205 pub(crate) fn set_error(&self, err: AmqpProtocolError) {
206 self.0.get_mut().set_error(err);
207 }
208
209 pub(crate) fn get_control_queue(&self) -> &Rc<ControlQueue> {
210 &self.0.get_ref().control_queue
211 }
212
213 pub(crate) fn handle_frame(&self, frame: AmqpFrame) -> Result<Action, AmqpProtocolError> {
214 self.0.get_mut().handle_frame(frame, &self.0)
215 }
216}
217
218impl ConnectionInner {
219 pub(crate) fn set_error(&mut self, err: AmqpProtocolError) {
220 log::trace!("{}: Set connection error: {:?}", self.io.tag(), err);
221 for (_, channel) in &mut self.sessions {
222 match channel {
223 SessionState::Opening(_, _) | SessionState::Closing(_) => (),
224 SessionState::Established(ses) => {
225 ses.get_mut().set_error(err.clone());
226 }
227 }
228 }
229 self.sessions.clear();
230 self.sessions_map.clear();
231
232 if self.error.is_none() {
233 self.error = Some(err);
234 }
235 self.on_close.notify_and_lock_readiness();
236 }
237
238 pub(crate) fn post_frame(&mut self, frame: AmqpFrame) {
239 #[cfg(feature = "frame-trace")]
240 log::trace!("{}: outgoing: {:#?}", self.io.tag(), frame);
241
242 if let Err(e) = self.io.encode(frame, &self.codec) {
243 self.set_error(e.into());
244 }
245 }
246
247 pub(crate) fn register_remote_session(
248 &mut self,
249 remote_channel_id: u16,
250 begin: Begin,
251 cell: &Cell<ConnectionInner>,
252 ) -> Result<(), AmqpProtocolError> {
253 log::trace!(
254 "{}: Remote session opened: {:?}",
255 self.io.tag(),
256 remote_channel_id
257 );
258
259 let entry = self.sessions.vacant_entry();
260 let local_token = entry.key();
261 let outgoing_window = begin.incoming_window();
262
263 let session = Cell::new(SessionInner::new(
264 local_token,
265 false,
266 ConnectionRef(cell.clone()),
267 remote_channel_id,
268 begin,
269 ));
270 entry.insert(SessionState::Established(session));
271 self.sessions_map.insert(remote_channel_id, local_token);
272
273 let begin = Begin(Box::new(codec::BeginInner {
274 outgoing_window,
275 remote_channel: Some(remote_channel_id),
276 next_outgoing_id: 1,
277 incoming_window: u32::MAX,
278 handle_max: u32::MAX,
279 offered_capabilities: None,
280 desired_capabilities: None,
281 properties: None,
282 }));
283
284 self.io
285 .encode(
286 AmqpFrame::new(local_token as u16, begin.into()),
287 &self.codec,
288 )
289 .map_err(AmqpProtocolError::Codec)
290 }
291
292 pub(crate) fn complete_session_creation(
293 &mut self,
294 local_channel_id: u16,
295 remote_channel_id: u16,
296 begin: Begin,
297 ) {
298 log::trace!(
299 "{}: Begin response received: local {:?} remote {:?}",
300 self.io.tag(),
301 local_channel_id,
302 remote_channel_id,
303 );
304
305 let local_token = local_channel_id as usize;
306
307 if let Some(channel) = self.sessions.get_mut(local_token) {
308 if channel.is_opening() {
309 if let SessionState::Opening(tx, cell) = channel {
310 let session = Cell::new(SessionInner::new(
311 local_token,
312 true,
313 ConnectionRef(cell.clone()),
314 remote_channel_id,
315 begin,
316 ));
317 self.sessions_map.insert(remote_channel_id, local_token);
318
319 tx.take()
321 .and_then(|tx| tx.send(Session::new(session.clone())).err());
322 *channel = SessionState::Established(session);
323
324 log::trace!(
325 "{}: Session established: local {:?} remote {:?}",
326 self.io.tag(),
327 local_channel_id,
328 remote_channel_id,
329 );
330 }
331 } else {
332 log::warn!(
334 "{}: Begin received for channel not in opening state. local channel: {} (remote channel: {})",
335 self.io.tag(),
336 local_channel_id,
337 remote_channel_id
338 );
339 }
340 } else {
341 log::warn!(
343 "{}: Begin received for unknown local channel: {} (remote channel: {})",
344 self.io.tag(),
345 local_channel_id,
346 remote_channel_id
347 );
348 }
349 }
350
351 fn handle_frame(
352 &mut self,
353 frame: AmqpFrame,
354 inner: &Cell<ConnectionInner>,
355 ) -> Result<Action, AmqpProtocolError> {
356 let (channel_id, frame) = frame.into_parts();
357
358 match frame {
359 Frame::Empty => Ok(Action::None),
360 Frame::Close(close) => {
361 if self.state == ConnectionState::Closing {
362 log::trace!("{}: Connection closed: {:?}", self.io.tag(), close);
363 self.set_error(AmqpProtocolError::Disconnected);
364 Ok(Action::None)
365 } else {
366 log::trace!("{}: Connection closed remotely: {:?}", self.io.tag(), close);
367 let err = AmqpProtocolError::Closed(close.error);
368 self.set_error(err.clone());
369 let close = Close { error: None };
370 self.post_frame(AmqpFrame::new(0, close.into()));
371 self.state = ConnectionState::RemoteClose;
372 Ok(Action::RemoteClose(err))
373 }
374 }
375 Frame::Begin(begin) => {
376 if let Some(local_channel_id) = begin.remote_channel() {
380 self.complete_session_creation(local_channel_id, channel_id, begin);
381 } else {
382 self.register_remote_session(channel_id, begin, inner)?;
383 }
384 Ok(Action::None)
385 }
386 _ => {
387 if self.error.is_some() {
388 log::error!(
389 "{}: Connection closed but new framed is received: {:?}",
390 self.io.tag(),
391 frame
392 );
393 return Ok(Action::None);
394 }
395
396 let state = if let Some(token) = self.sessions_map.get(&channel_id) {
398 if let Some(state) = self.sessions.get_mut(*token) {
399 state
400 } else {
401 log::error!("{}: Inconsistent internal state", self.io.tag());
402 return Err(AmqpProtocolError::UnknownSession(frame));
403 }
404 } else {
405 return Err(AmqpProtocolError::UnknownSession(frame));
406 };
407
408 match state {
410 SessionState::Opening(_, _) => {
411 log::error!(
412 "{}: Unexpected opening state: {}",
413 self.io.tag(),
414 channel_id
415 );
416 Err(AmqpProtocolError::UnexpectedOpeningState(frame))
417 }
418 SessionState::Established(session) => match frame {
419 Frame::Attach(attach) => {
420 let cell = session.clone();
421 if session.get_mut().handle_attach(&attach, cell) {
422 Ok(Action::None)
423 } else {
424 match attach.0.role {
425 Role::Receiver => {
426 let (id, response) =
428 session.get_mut().new_remote_sender(&attach);
429 let link = SenderLink::new(Cell::new(
430 SenderLinkInner::with(id, &attach, session.clone()),
431 ));
432 Ok(Action::AttachSender(link, attach, response))
433 }
434 Role::Sender => {
435 let (response, link) = session
437 .get_mut()
438 .attach_remote_receiver_link(session.clone(), &attach);
439 Ok(Action::AttachReceiver(link, attach, response))
440 }
441 }
442 }
443 }
444 Frame::End(remote_end) => {
445 log::trace!("{}: Remote session end: {}", self.io.tag(), channel_id);
446 let id = session.get_mut().id();
447 let action = session
448 .get_mut()
449 .end(AmqpProtocolError::SessionEnded(remote_end.error));
450 if let Some(token) = self.sessions_map.remove(&channel_id) {
451 self.sessions.remove(token);
452 }
453 self.post_frame(AmqpFrame::new(id, End { error: None }.into()));
454 Ok(action)
455 }
456 _ => session.get_mut().handle_frame(frame),
457 },
458 SessionState::Closing(session) => match frame {
459 Frame::End(frm) => {
460 log::trace!("{}: Session end is confirmed: {:?}", self.io.tag(), frm);
461 let _ = session
462 .get_mut()
463 .end(AmqpProtocolError::SessionEnded(frm.error));
464 if let Some(token) = self.sessions_map.remove(&channel_id) {
465 self.sessions.remove(token);
466 }
467 Ok(Action::None)
468 }
469 frm => {
470 log::trace!(
471 "{}: Got frame after initiated session end: {:?}",
472 self.io.tag(),
473 frm
474 );
475 Ok(Action::None)
476 }
477 },
478 }
479 }
480 }
481 }
482}
483
484impl fmt::Debug for ConnectionRef {
485 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
486 fmt.debug_struct("ConnectionRef").finish()
487 }
488}
489
490pub struct OpenSession {
492 con: Cell<ConnectionInner>,
493 fut: Option<Pin<Box<dyn Future<Output = Result<Session, AmqpProtocolError>>>>>,
494 props: Option<HashMap<types::Symbol, types::Variant>>,
495 offered_capabilities: Option<codec::Symbols>,
496 desired_capabilities: Option<codec::Symbols>,
497}
498
499impl OpenSession {
500 pub(crate) fn new(con: Cell<ConnectionInner>) -> Self {
501 Self {
502 con,
503 fut: None,
504 props: None,
505 offered_capabilities: None,
506 desired_capabilities: None,
507 }
508 }
509
510 #[must_use]
511 pub fn offered_capabilities(mut self, caps: codec::Symbols) -> Self {
513 self.offered_capabilities = Some(caps);
514 self
515 }
516
517 #[must_use]
518 pub fn desired_capabilities(mut self, caps: codec::Symbols) -> Self {
520 self.desired_capabilities = Some(caps);
521 self
522 }
523
524 #[must_use]
525 #[allow(clippy::missing_panics_doc)]
526 pub fn property<K, V>(mut self, key: K, value: V) -> Self
528 where
529 K: Into<types::Symbol>,
530 V: Into<types::Variant>,
531 {
532 if self.props.is_none() {
533 self.props = Some(HashMap::default());
534 }
535 self.props
536 .as_mut()
537 .unwrap()
538 .insert(key.into(), value.into());
539 self
540 }
541
542 pub async fn attach(self) -> Result<Session, AmqpProtocolError> {
544 open_session(
545 self.con,
546 self.offered_capabilities,
547 self.desired_capabilities,
548 self.props,
549 )
550 .await
551 }
552}
553
554impl Future for OpenSession {
555 type Output = Result<Session, AmqpProtocolError>;
556
557 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
558 let mut slf = self.as_mut();
559
560 if slf.fut.is_none() {
561 slf.fut = Some(Box::pin(open_session(
562 slf.con.clone(),
563 slf.offered_capabilities.take(),
564 slf.desired_capabilities.take(),
565 slf.props.take(),
566 )));
567 }
568
569 Pin::new(slf.fut.as_mut().unwrap()).poll(cx)
570 }
571}
572
573async fn open_session(
574 con: Cell<ConnectionInner>,
575 offered_capabilities: Option<codec::Symbols>,
576 desired_capabilities: Option<codec::Symbols>,
577 properties: Option<HashMap<types::Symbol, types::Variant>>,
578) -> Result<Session, AmqpProtocolError> {
579 let inner = con.get_mut();
580
581 if let Some(ref e) = inner.error {
582 log::error!("{}: Connection is in error state: {:?}", inner.io.tag(), e);
583 Err(e.clone())
584 } else {
585 let (tx, rx) = oneshot::channel();
586
587 let entry = inner.sessions.vacant_entry();
588 let token = entry.key();
589
590 if token >= inner.channel_max as usize {
591 log::trace!("{}: Too many channels: {:?}", inner.io.tag(), token);
592 Err(AmqpProtocolError::TooManyChannels)
593 } else {
594 entry.insert(SessionState::Opening(Some(tx), con.clone()));
595
596 let begin = Begin(Box::new(codec::BeginInner {
597 offered_capabilities,
598 desired_capabilities,
599 properties,
600 remote_channel: None,
601 next_outgoing_id: INITIAL_NEXT_OUTGOING_ID,
602 incoming_window: u32::MAX,
603 outgoing_window: u32::MAX,
604 handle_max: u32::MAX,
605 }));
606 inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
607 let _ = inner;
608
609 rx.await.map_err(|_| AmqpProtocolError::Disconnected)
610 }
611 }
612}