Skip to main content

ntex_amqp/
connection.rs

1use std::{fmt, future::Future, ops, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use ntex_io::{IoConfig, IoRef};
4use ntex_service::cfg::Cfg;
5use ntex_util::channel::{condition::Condition, condition::Waiter, oneshot};
6use ntex_util::{HashMap, future::Ready};
7
8use crate::codec::protocol::{self as codec, Begin, Close, End, Error, Frame, Role};
9use crate::codec::{AmqpCodec, AmqpFrame, types};
10use crate::control::ControlQueue;
11use crate::session::{INITIAL_NEXT_OUTGOING_ID, Session, SessionInner};
12use crate::sndlink::{SenderLink, SenderLinkInner};
13use crate::{
14    AmqpServiceConfig, RemoteServiceConfig, cell::Cell, error::AmqpProtocolError, types::Action,
15};
16
17pub struct Connection(ConnectionRef);
18
19#[derive(Clone)]
20pub struct ConnectionRef(pub(crate) Cell<ConnectionInner>);
21
22#[derive(Debug)]
23pub(crate) struct ConnectionInner {
24    io: IoRef,
25    state: ConnectionState,
26    codec: AmqpCodec<AmqpFrame>,
27    control_queue: Rc<ControlQueue>,
28    pub(crate) sessions: slab::Slab<SessionState>,
29    pub(crate) sessions_map: HashMap<u16, usize>,
30    pub(crate) on_close: Condition,
31    pub(crate) error: Option<AmqpProtocolError>,
32    channel_max: u16,
33    pub(crate) max_frame_size: u32,
34}
35
36#[derive(Debug)]
37pub(crate) enum SessionState {
38    Opening(Option<oneshot::Sender<Session>>, Cell<ConnectionInner>),
39    Established(Cell<SessionInner>),
40    Closing(Cell<SessionInner>),
41}
42
43impl SessionState {
44    fn is_opening(&self) -> bool {
45        matches!(self, SessionState::Opening(_, _))
46    }
47}
48
49#[derive(Clone, Copy, Debug, PartialEq)]
50pub(crate) enum ConnectionState {
51    Normal,
52    Closing,
53    RemoteClose,
54    Drop,
55}
56
57impl Connection {
58    pub(crate) fn new(
59        io: IoRef,
60        local_config: &Cfg<AmqpServiceConfig>,
61        remote_config: &RemoteServiceConfig,
62    ) -> Connection {
63        Connection(ConnectionRef(Cell::new(ConnectionInner {
64            io,
65            codec: AmqpCodec::new(),
66            state: ConnectionState::Normal,
67            sessions: slab::Slab::with_capacity(8),
68            sessions_map: HashMap::default(),
69            control_queue: Rc::default(),
70            error: None,
71            on_close: Condition::new(),
72            channel_max: local_config.channel_max,
73            max_frame_size: remote_config.max_frame_size,
74        })))
75    }
76
77    pub fn get_ref(&self) -> ConnectionRef {
78        self.0.clone()
79    }
80}
81
82impl AsRef<ConnectionRef> for Connection {
83    #[inline]
84    fn as_ref(&self) -> &ConnectionRef {
85        &self.0
86    }
87}
88
89impl ops::Deref for Connection {
90    type Target = ConnectionRef;
91
92    #[inline]
93    fn deref(&self) -> &Self::Target {
94        &self.0
95    }
96}
97
98impl Drop for Connection {
99    fn drop(&mut self) {
100        self.0.force_close();
101    }
102}
103
104impl ConnectionRef {
105    #[inline]
106    /// Get io tag for current connection
107    pub fn tag(&self) -> &'static str {
108        self.0.get_ref().io.tag()
109    }
110
111    #[inline]
112    /// Get io configuration for current connection
113    pub fn config(&self) -> &IoConfig {
114        self.0.get_ref().io.cfg()
115    }
116
117    #[inline]
118    /// Force close connection
119    pub fn force_close(&self) {
120        let inner = self.0.get_mut();
121        inner.state = ConnectionState::Drop;
122        inner.io.force_close();
123        inner.set_error(AmqpProtocolError::ConnectionDropped);
124    }
125
126    #[inline]
127    /// Check connection state
128    pub fn is_opened(&self) -> bool {
129        let inner = self.0.get_mut();
130        if inner.state != ConnectionState::Normal {
131            return false;
132        }
133        inner.error.is_none() && !inner.io.is_closed()
134    }
135
136    /// Get waiter for `on_close` event
137    pub fn on_close(&self) -> Waiter {
138        self.0.get_ref().on_close.wait()
139    }
140
141    /// Get connection error
142    pub fn get_error(&self) -> Option<AmqpProtocolError> {
143        self.0.get_ref().error.clone()
144    }
145
146    /// Get existing session by local channel id
147    pub fn get_session_by_local_id(&self, channel: u16) -> Option<Session> {
148        if let Some(SessionState::Established(inner)) =
149            self.0.get_ref().sessions.get(channel as usize)
150        {
151            Some(Session::new(inner.clone()))
152        } else {
153            None
154        }
155    }
156
157    /// Gracefully close connection
158    pub fn close(&self) -> impl Future<Output = Result<(), AmqpProtocolError>> {
159        let inner = self.0.get_mut();
160        inner.post_frame(AmqpFrame::new(0, Frame::Close(Close { error: None })));
161        inner.io.close();
162        Ready::Ok(())
163    }
164
165    /// Close connection with error
166    pub fn close_with_error<E>(&self, err: E) -> impl Future<Output = Result<(), AmqpProtocolError>>
167    where
168        Error: From<E>,
169    {
170        let inner = self.0.get_mut();
171        inner.post_frame(AmqpFrame::new(
172            0,
173            Frame::Close(Close {
174                error: Some(err.into()),
175            }),
176        ));
177        inner.io.close();
178        Ready::Ok(())
179    }
180
181    /// Opens the session
182    pub fn open_session(&self) -> OpenSession {
183        OpenSession::new(self.0.clone())
184    }
185
186    pub(crate) fn close_session(&self, id: usize) {
187        if let Some(state) = self.0.get_mut().sessions.get_mut(id)
188            && let SessionState::Established(inner) = state
189        {
190            *state = SessionState::Closing(inner.clone());
191        }
192    }
193
194    pub(crate) fn post_frame(&self, frame: AmqpFrame) {
195        let inner = self.0.get_mut();
196
197        #[cfg(feature = "frame-trace")]
198        log::trace!("{}: outgoing: {:#?}", inner.io.tag(), frame);
199
200        if let Err(e) = inner.io.encode(frame, &inner.codec) {
201            inner.set_error(e.into());
202        }
203    }
204
205    pub(crate) fn set_error(&self, err: AmqpProtocolError) {
206        self.0.get_mut().set_error(err);
207    }
208
209    pub(crate) fn get_control_queue(&self) -> &Rc<ControlQueue> {
210        &self.0.get_ref().control_queue
211    }
212
213    pub(crate) fn handle_frame(&self, frame: AmqpFrame) -> Result<Action, AmqpProtocolError> {
214        self.0.get_mut().handle_frame(frame, &self.0)
215    }
216}
217
218impl ConnectionInner {
219    pub(crate) fn set_error(&mut self, err: AmqpProtocolError) {
220        log::trace!("{}: Set connection error: {:?}", self.io.tag(), err);
221        for (_, channel) in &mut self.sessions {
222            match channel {
223                SessionState::Opening(_, _) | SessionState::Closing(_) => (),
224                SessionState::Established(ses) => {
225                    ses.get_mut().set_error(err.clone());
226                }
227            }
228        }
229        self.sessions.clear();
230        self.sessions_map.clear();
231
232        if self.error.is_none() {
233            self.error = Some(err);
234        }
235        self.on_close.notify_and_lock_readiness();
236    }
237
238    pub(crate) fn post_frame(&mut self, frame: AmqpFrame) {
239        #[cfg(feature = "frame-trace")]
240        log::trace!("{}: outgoing: {:#?}", self.io.tag(), frame);
241
242        if let Err(e) = self.io.encode(frame, &self.codec) {
243            self.set_error(e.into());
244        }
245    }
246
247    pub(crate) fn register_remote_session(
248        &mut self,
249        remote_channel_id: u16,
250        begin: Begin,
251        cell: &Cell<ConnectionInner>,
252    ) -> Result<(), AmqpProtocolError> {
253        log::trace!(
254            "{}: Remote session opened: {:?}",
255            self.io.tag(),
256            remote_channel_id
257        );
258
259        let entry = self.sessions.vacant_entry();
260        let local_token = entry.key();
261        let outgoing_window = begin.incoming_window();
262
263        let session = Cell::new(SessionInner::new(
264            local_token,
265            false,
266            ConnectionRef(cell.clone()),
267            remote_channel_id,
268            begin,
269        ));
270        entry.insert(SessionState::Established(session));
271        self.sessions_map.insert(remote_channel_id, local_token);
272
273        let begin = Begin(Box::new(codec::BeginInner {
274            outgoing_window,
275            remote_channel: Some(remote_channel_id),
276            next_outgoing_id: 1,
277            incoming_window: u32::MAX,
278            handle_max: u32::MAX,
279            offered_capabilities: None,
280            desired_capabilities: None,
281            properties: None,
282        }));
283
284        self.io
285            .encode(
286                AmqpFrame::new(local_token as u16, begin.into()),
287                &self.codec,
288            )
289            .map_err(AmqpProtocolError::Codec)
290    }
291
292    pub(crate) fn complete_session_creation(
293        &mut self,
294        local_channel_id: u16,
295        remote_channel_id: u16,
296        begin: Begin,
297    ) {
298        log::trace!(
299            "{}: Begin response received: local {:?} remote {:?}",
300            self.io.tag(),
301            local_channel_id,
302            remote_channel_id,
303        );
304
305        let local_token = local_channel_id as usize;
306
307        if let Some(channel) = self.sessions.get_mut(local_token) {
308            if channel.is_opening() {
309                if let SessionState::Opening(tx, cell) = channel {
310                    let session = Cell::new(SessionInner::new(
311                        local_token,
312                        true,
313                        ConnectionRef(cell.clone()),
314                        remote_channel_id,
315                        begin,
316                    ));
317                    self.sessions_map.insert(remote_channel_id, local_token);
318
319                    // TODO: send end session if `tx` is None
320                    tx.take()
321                        .and_then(|tx| tx.send(Session::new(session.clone())).err());
322                    *channel = SessionState::Established(session);
323
324                    log::trace!(
325                        "{}: Session established: local {:?} remote {:?}",
326                        self.io.tag(),
327                        local_channel_id,
328                        remote_channel_id,
329                    );
330                }
331            } else {
332                // TODO: send error response
333                log::warn!(
334                    "{}: Begin received for channel not in opening state. local channel: {} (remote channel: {})",
335                    self.io.tag(),
336                    local_channel_id,
337                    remote_channel_id
338                );
339            }
340        } else {
341            // TODO: rogue begin right now - do nothing. in future might indicate incoming attach
342            log::warn!(
343                "{}: Begin received for unknown local channel: {} (remote channel: {})",
344                self.io.tag(),
345                local_channel_id,
346                remote_channel_id
347            );
348        }
349    }
350
351    fn handle_frame(
352        &mut self,
353        frame: AmqpFrame,
354        inner: &Cell<ConnectionInner>,
355    ) -> Result<Action, AmqpProtocolError> {
356        let (channel_id, frame) = frame.into_parts();
357
358        match frame {
359            Frame::Empty => Ok(Action::None),
360            Frame::Close(close) => {
361                if self.state == ConnectionState::Closing {
362                    log::trace!("{}: Connection closed: {:?}", self.io.tag(), close);
363                    self.set_error(AmqpProtocolError::Disconnected);
364                    Ok(Action::None)
365                } else {
366                    log::trace!("{}: Connection closed remotely: {:?}", self.io.tag(), close);
367                    let err = AmqpProtocolError::Closed(close.error);
368                    self.set_error(err.clone());
369                    let close = Close { error: None };
370                    self.post_frame(AmqpFrame::new(0, close.into()));
371                    self.state = ConnectionState::RemoteClose;
372                    Ok(Action::RemoteClose(err))
373                }
374            }
375            Frame::Begin(begin) => {
376                // response Begin for open session
377                // the remote-channel property in the frame is the local channel id
378                // we previously sent to the remote
379                if let Some(local_channel_id) = begin.remote_channel() {
380                    self.complete_session_creation(local_channel_id, channel_id, begin);
381                } else {
382                    self.register_remote_session(channel_id, begin, inner)?;
383                }
384                Ok(Action::None)
385            }
386            _ => {
387                if self.error.is_some() {
388                    log::error!(
389                        "{}: Connection closed but new framed is received: {:?}",
390                        self.io.tag(),
391                        frame
392                    );
393                    return Ok(Action::None);
394                }
395
396                // get local session id
397                let state = if let Some(token) = self.sessions_map.get(&channel_id) {
398                    if let Some(state) = self.sessions.get_mut(*token) {
399                        state
400                    } else {
401                        log::error!("{}: Inconsistent internal state", self.io.tag());
402                        return Err(AmqpProtocolError::UnknownSession(frame));
403                    }
404                } else {
405                    return Err(AmqpProtocolError::UnknownSession(frame));
406                };
407
408                // handle session frames
409                match state {
410                    SessionState::Opening(_, _) => {
411                        log::error!(
412                            "{}: Unexpected opening state: {}",
413                            self.io.tag(),
414                            channel_id
415                        );
416                        Err(AmqpProtocolError::UnexpectedOpeningState(frame))
417                    }
418                    SessionState::Established(session) => match frame {
419                        Frame::Attach(attach) => {
420                            let cell = session.clone();
421                            if session.get_mut().handle_attach(&attach, cell) {
422                                Ok(Action::None)
423                            } else {
424                                match attach.0.role {
425                                    Role::Receiver => {
426                                        // remotly opened sender link
427                                        let (id, response) =
428                                            session.get_mut().new_remote_sender(&attach);
429                                        let link = SenderLink::new(Cell::new(
430                                            SenderLinkInner::with(id, &attach, session.clone()),
431                                        ));
432                                        Ok(Action::AttachSender(link, attach, response))
433                                    }
434                                    Role::Sender => {
435                                        // receiver link
436                                        let (response, link) = session
437                                            .get_mut()
438                                            .attach_remote_receiver_link(session.clone(), &attach);
439                                        Ok(Action::AttachReceiver(link, attach, response))
440                                    }
441                                }
442                            }
443                        }
444                        Frame::End(remote_end) => {
445                            log::trace!("{}: Remote session end: {}", self.io.tag(), channel_id);
446                            let id = session.get_mut().id();
447                            let action = session
448                                .get_mut()
449                                .end(AmqpProtocolError::SessionEnded(remote_end.error));
450                            if let Some(token) = self.sessions_map.remove(&channel_id) {
451                                self.sessions.remove(token);
452                            }
453                            self.post_frame(AmqpFrame::new(id, End { error: None }.into()));
454                            Ok(action)
455                        }
456                        _ => session.get_mut().handle_frame(frame),
457                    },
458                    SessionState::Closing(session) => match frame {
459                        Frame::End(frm) => {
460                            log::trace!("{}: Session end is confirmed: {:?}", self.io.tag(), frm);
461                            let _ = session
462                                .get_mut()
463                                .end(AmqpProtocolError::SessionEnded(frm.error));
464                            if let Some(token) = self.sessions_map.remove(&channel_id) {
465                                self.sessions.remove(token);
466                            }
467                            Ok(Action::None)
468                        }
469                        frm => {
470                            log::trace!(
471                                "{}: Got frame after initiated session end: {:?}",
472                                self.io.tag(),
473                                frm
474                            );
475                            Ok(Action::None)
476                        }
477                    },
478                }
479            }
480        }
481    }
482}
483
484impl fmt::Debug for ConnectionRef {
485    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
486        fmt.debug_struct("ConnectionRef").finish()
487    }
488}
489
490/// Open new session
491pub struct OpenSession {
492    con: Cell<ConnectionInner>,
493    fut: Option<Pin<Box<dyn Future<Output = Result<Session, AmqpProtocolError>>>>>,
494    props: Option<HashMap<types::Symbol, types::Variant>>,
495    offered_capabilities: Option<codec::Symbols>,
496    desired_capabilities: Option<codec::Symbols>,
497}
498
499impl OpenSession {
500    pub(crate) fn new(con: Cell<ConnectionInner>) -> Self {
501        Self {
502            con,
503            fut: None,
504            props: None,
505            offered_capabilities: None,
506            desired_capabilities: None,
507        }
508    }
509
510    #[must_use]
511    /// Set session offered capabilities
512    pub fn offered_capabilities(mut self, caps: codec::Symbols) -> Self {
513        self.offered_capabilities = Some(caps);
514        self
515    }
516
517    #[must_use]
518    /// Set session desired capabilities
519    pub fn desired_capabilities(mut self, caps: codec::Symbols) -> Self {
520        self.desired_capabilities = Some(caps);
521        self
522    }
523
524    #[must_use]
525    #[allow(clippy::missing_panics_doc)]
526    /// Set session property
527    pub fn property<K, V>(mut self, key: K, value: V) -> Self
528    where
529        K: Into<types::Symbol>,
530        V: Into<types::Variant>,
531    {
532        if self.props.is_none() {
533            self.props = Some(HashMap::default());
534        }
535        self.props
536            .as_mut()
537            .unwrap()
538            .insert(key.into(), value.into());
539        self
540    }
541
542    /// Attach session
543    pub async fn attach(self) -> Result<Session, AmqpProtocolError> {
544        open_session(
545            self.con,
546            self.offered_capabilities,
547            self.desired_capabilities,
548            self.props,
549        )
550        .await
551    }
552}
553
554impl Future for OpenSession {
555    type Output = Result<Session, AmqpProtocolError>;
556
557    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
558        let mut slf = self.as_mut();
559
560        if slf.fut.is_none() {
561            slf.fut = Some(Box::pin(open_session(
562                slf.con.clone(),
563                slf.offered_capabilities.take(),
564                slf.desired_capabilities.take(),
565                slf.props.take(),
566            )));
567        }
568
569        Pin::new(slf.fut.as_mut().unwrap()).poll(cx)
570    }
571}
572
573async fn open_session(
574    con: Cell<ConnectionInner>,
575    offered_capabilities: Option<codec::Symbols>,
576    desired_capabilities: Option<codec::Symbols>,
577    properties: Option<HashMap<types::Symbol, types::Variant>>,
578) -> Result<Session, AmqpProtocolError> {
579    let inner = con.get_mut();
580
581    if let Some(ref e) = inner.error {
582        log::error!("{}: Connection is in error state: {:?}", inner.io.tag(), e);
583        Err(e.clone())
584    } else {
585        let (tx, rx) = oneshot::channel();
586
587        let entry = inner.sessions.vacant_entry();
588        let token = entry.key();
589
590        if token >= inner.channel_max as usize {
591            log::trace!("{}: Too many channels: {:?}", inner.io.tag(), token);
592            Err(AmqpProtocolError::TooManyChannels)
593        } else {
594            entry.insert(SessionState::Opening(Some(tx), con.clone()));
595
596            let begin = Begin(Box::new(codec::BeginInner {
597                offered_capabilities,
598                desired_capabilities,
599                properties,
600                remote_channel: None,
601                next_outgoing_id: INITIAL_NEXT_OUTGOING_ID,
602                incoming_window: u32::MAX,
603                outgoing_window: u32::MAX,
604                handle_max: u32::MAX,
605            }));
606            inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
607            let _ = inner;
608
609            rx.await.map_err(|_| AmqpProtocolError::Disconnected)
610        }
611    }
612}