actix_amqp/
connection.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use actix_codec::{AsyncRead, AsyncWrite, Framed};
8use actix_utils::oneshot;
9use actix_utils::task::LocalWaker;
10use actix_utils::time::LowResTimeService;
11use futures::future::{err, Either};
12use futures::{future, Sink, Stream};
13use fxhash::FxHashMap;
14
15use amqp_codec::protocol::{Begin, Close, End, Error, Frame};
16use amqp_codec::{AmqpCodec, AmqpCodecError, AmqpFrame};
17
18use crate::cell::{Cell, WeakCell};
19use crate::errors::AmqpTransportError;
20use crate::hb::{Heartbeat, HeartbeatAction};
21use crate::session::{Session, SessionInner};
22use crate::Configuration;
23
24pub struct Connection<T: AsyncRead + AsyncWrite> {
25    inner: Cell<ConnectionInner>,
26    framed: Framed<T, AmqpCodec<AmqpFrame>>,
27    hb: Heartbeat,
28}
29
30pub(crate) enum ChannelState {
31    Opening(Option<oneshot::Sender<Session>>, WeakCell<ConnectionInner>),
32    Established(Cell<SessionInner>),
33    Closing(Option<oneshot::Sender<Result<(), AmqpTransportError>>>),
34}
35
36impl ChannelState {
37    fn is_opening(&self) -> bool {
38        match self {
39            ChannelState::Opening(_, _) => true,
40            _ => false,
41        }
42    }
43}
44
45pub(crate) struct ConnectionInner {
46    local: Configuration,
47    remote: Configuration,
48    write_queue: VecDeque<AmqpFrame>,
49    write_task: LocalWaker,
50    sessions: slab::Slab<ChannelState>,
51    sessions_map: FxHashMap<u16, usize>,
52    error: Option<AmqpTransportError>,
53    state: State,
54}
55
56#[derive(PartialEq)]
57enum State {
58    Normal,
59    Closing,
60    RemoteClose,
61    Drop,
62}
63
64impl<T: AsyncRead + AsyncWrite> Connection<T> {
65    pub fn new(
66        framed: Framed<T, AmqpCodec<AmqpFrame>>,
67        local: Configuration,
68        remote: Configuration,
69        time: Option<LowResTimeService>,
70    ) -> Connection<T> {
71        Connection {
72            framed,
73            hb: Heartbeat::new(
74                local.timeout().unwrap(),
75                remote.timeout(),
76                time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))),
77            ),
78            inner: Cell::new(ConnectionInner::new(local, remote)),
79        }
80    }
81
82    pub(crate) fn new_server(
83        framed: Framed<T, AmqpCodec<AmqpFrame>>,
84        inner: Cell<ConnectionInner>,
85        time: Option<LowResTimeService>,
86    ) -> Connection<T> {
87        let l_timeout = inner.get_ref().local.timeout().unwrap();
88        let r_timeout = inner.get_ref().remote.timeout();
89        Connection {
90            framed,
91            inner,
92            hb: Heartbeat::new(
93                l_timeout,
94                r_timeout,
95                time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))),
96            ),
97        }
98    }
99
100    /// Connection controller
101    pub fn controller(&self) -> ConnectionController {
102        ConnectionController(self.inner.clone())
103    }
104
105    /// Get remote configuration
106    pub fn remote_config(&self) -> &Configuration {
107        &self.inner.get_ref().remote
108    }
109
110    /// Gracefully close connection
111    pub fn close(&mut self) -> impl Future<Output = Result<(), AmqpTransportError>> {
112        future::ok(())
113    }
114
115    // TODO: implement
116    /// Close connection with error
117    pub fn close_with_error(
118        &mut self,
119        _err: Error,
120    ) -> impl Future<Output = Result<(), AmqpTransportError>> {
121        future::ok(())
122    }
123
124    /// Opens the session
125    pub fn open_session(&mut self) -> impl Future<Output = Result<Session, AmqpTransportError>> {
126        let cell = self.inner.downgrade();
127        let inner = self.inner.clone();
128
129        async move {
130            let inner = inner.get_mut();
131
132            if let Some(ref e) = inner.error {
133                Err(e.clone())
134            } else {
135                let (tx, rx) = oneshot::channel();
136
137                let entry = inner.sessions.vacant_entry();
138                let token = entry.key();
139
140                if token >= inner.local.channel_max {
141                    Err(AmqpTransportError::TooManyChannels)
142                } else {
143                    entry.insert(ChannelState::Opening(Some(tx), cell));
144
145                    let begin = Begin {
146                        remote_channel: None,
147                        next_outgoing_id: 1,
148                        incoming_window: std::u32::MAX,
149                        outgoing_window: std::u32::MAX,
150                        handle_max: std::u32::MAX,
151                        offered_capabilities: None,
152                        desired_capabilities: None,
153                        properties: None,
154                    };
155                    inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
156
157                    rx.await.map_err(|_| AmqpTransportError::Disconnected)
158                }
159            }
160        }
161    }
162
163    /// Get session by id. This method panics if session does not exists or in opening/closing state.
164    pub(crate) fn get_session(&self, id: usize) -> Cell<SessionInner> {
165        if let Some(channel) = self.inner.get_ref().sessions.get(id) {
166            if let ChannelState::Established(ref session) = channel {
167                return session.clone();
168            }
169        }
170        panic!("Session not found: {}", id);
171    }
172
173    pub(crate) fn register_remote_session(&mut self, channel_id: u16, begin: &Begin) {
174        trace!("remote session opened: {:?}", channel_id);
175
176        let cell = self.inner.clone();
177        let inner = self.inner.get_mut();
178        let entry = inner.sessions.vacant_entry();
179        let token = entry.key();
180
181        let session = Cell::new(SessionInner::new(
182            token,
183            false,
184            ConnectionController(cell),
185            token as u16,
186            begin.next_outgoing_id(),
187            begin.incoming_window(),
188            begin.outgoing_window(),
189        ));
190        entry.insert(ChannelState::Established(session));
191        inner.sessions_map.insert(channel_id, token);
192
193        let begin = Begin {
194            remote_channel: Some(channel_id),
195            next_outgoing_id: 1,
196            incoming_window: std::u32::MAX,
197            outgoing_window: begin.incoming_window(),
198            handle_max: std::u32::MAX,
199            offered_capabilities: None,
200            desired_capabilities: None,
201            properties: None,
202        };
203        inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
204    }
205
206    pub(crate) fn send_frame(&mut self, frame: AmqpFrame) {
207        self.inner.get_mut().post_frame(frame)
208    }
209
210    pub(crate) fn register_write_task(&self, cx: &mut Context) {
211        self.inner.write_task.register(cx.waker());
212    }
213
214    pub(crate) fn poll_outgoing(&mut self, cx: &mut Context) -> Poll<Result<(), AmqpCodecError>> {
215        let inner = self.inner.get_mut();
216        let mut update = false;
217        loop {
218            while !self.framed.is_write_buf_full() {
219                if let Some(frame) = inner.pop_next_frame() {
220                    trace!("outgoing: {:#?}", frame);
221                    update = true;
222                    if let Err(e) = self.framed.write(frame) {
223                        inner.set_error(e.clone().into());
224                        return Poll::Ready(Err(e));
225                    }
226                } else {
227                    break;
228                }
229            }
230
231            if !self.framed.is_write_buf_empty() {
232                match self.framed.flush(cx) {
233                    Poll::Pending => break,
234                    Poll::Ready(Err(e)) => {
235                        trace!("error sending data: {}", e);
236                        inner.set_error(e.clone().into());
237                        return Poll::Ready(Err(e));
238                    }
239                    Poll::Ready(_) => (),
240                }
241            } else {
242                break;
243            }
244        }
245        self.hb.update_remote(update);
246
247        if inner.state == State::Drop {
248            Poll::Ready(Ok(()))
249        } else if inner.state == State::RemoteClose
250            && inner.write_queue.is_empty()
251            && self.framed.is_write_buf_empty()
252        {
253            Poll::Ready(Ok(()))
254        } else {
255            Poll::Pending
256        }
257    }
258
259    pub(crate) fn poll_incoming(
260        &mut self,
261        cx: &mut Context,
262    ) -> Poll<Option<Result<AmqpFrame, AmqpCodecError>>> {
263        let inner = self.inner.get_mut();
264
265        let mut update = false;
266        loop {
267            match Pin::new(&mut self.framed).poll_next(cx) {
268                Poll::Ready(Some(Ok(frame))) => {
269                    trace!("incoming: {:#?}", frame);
270
271                    update = true;
272
273                    if let Frame::Empty = frame.performative() {
274                        self.hb.update_local(update);
275                        continue;
276                    }
277
278                    // handle connection close
279                    if let Frame::Close(ref close) = frame.performative() {
280                        inner.set_error(AmqpTransportError::Closed(close.error.clone()));
281
282                        if inner.state == State::Closing {
283                            inner.sessions.clear();
284                            return Poll::Ready(None);
285                        } else {
286                            let close = Close { error: None };
287                            inner.post_frame(AmqpFrame::new(0, close.into()));
288                            inner.state = State::RemoteClose;
289                        }
290                    }
291
292                    if inner.error.is_some() {
293                        error!("connection closed but new framed is received: {:?}", frame);
294                        return Poll::Ready(None);
295                    }
296
297                    // get local session id
298                    let channel_id =
299                        if let Some(token) = inner.sessions_map.get(&frame.channel_id()) {
300                            *token
301                        } else {
302                            // we dont have channel info, only Begin frame is allowed on new channel
303                            if let Frame::Begin(ref begin) = frame.performative() {
304                                if begin.remote_channel().is_some() {
305                                    inner.complete_session_creation(frame.channel_id(), begin);
306                                } else {
307                                    return Poll::Ready(Some(Ok(frame)));
308                                }
309                            } else {
310                                warn!("Unexpected frame: {:#?}", frame);
311                            }
312                            continue;
313                        };
314
315                    // handle session frames
316                    if let Some(channel) = inner.sessions.get_mut(channel_id) {
317                        match channel {
318                            ChannelState::Opening(_, _) => {
319                                error!("Unexpected opening state: {}", channel_id);
320                            }
321                            ChannelState::Established(ref mut session) => {
322                                match frame.performative() {
323                                    Frame::Attach(attach) => {
324                                        let cell = session.clone();
325                                        if !session.get_mut().handle_attach(attach, cell) {
326                                            return Poll::Ready(Some(Ok(frame)));
327                                        }
328                                    }
329                                    Frame::Flow(_) | Frame::Detach(_) => {
330                                        return Poll::Ready(Some(Ok(frame)));
331                                    }
332                                    Frame::End(remote_end) => {
333                                        trace!("Remote session end: {}", frame.channel_id());
334                                        let end = End { error: None };
335                                        session.get_mut().set_error(
336                                            AmqpTransportError::SessionEnded(
337                                                remote_end.error.clone(),
338                                            ),
339                                        );
340                                        let id = session.get_mut().id();
341                                        inner.post_frame(AmqpFrame::new(id, end.into()));
342                                        inner.sessions.remove(channel_id);
343                                        inner.sessions_map.remove(&frame.channel_id());
344                                    }
345                                    _ => session.get_mut().handle_frame(frame.into_parts().1),
346                                }
347                            }
348                            ChannelState::Closing(ref mut tx) => match frame.performative() {
349                                Frame::End(_) => {
350                                    if let Some(tx) = tx.take() {
351                                        let _ = tx.send(Ok(()));
352                                    }
353                                    inner.sessions.remove(channel_id);
354                                    inner.sessions_map.remove(&frame.channel_id());
355                                }
356                                frm => trace!("Got frame after initiated session end: {:?}", frm),
357                            },
358                        }
359                    } else {
360                        error!("Can not find channel: {}", channel_id);
361                        continue;
362                    }
363                }
364                Poll::Ready(None) => {
365                    inner.set_error(AmqpTransportError::Disconnected);
366                    return Poll::Ready(None);
367                }
368                Poll::Pending => {
369                    self.hb.update_local(update);
370                    break;
371                }
372                Poll::Ready(Some(Err(e))) => {
373                    trace!("error reading: {:?}", e);
374                    inner.set_error(e.clone().into());
375                    return Poll::Ready(Some(Err(e.into())));
376                }
377            }
378        }
379
380        Poll::Pending
381    }
382}
383
384impl<T: AsyncRead + AsyncWrite> Drop for Connection<T> {
385    fn drop(&mut self) {
386        self.inner
387            .get_mut()
388            .set_error(AmqpTransportError::Disconnected);
389    }
390}
391
392impl<T: AsyncRead + AsyncWrite> Future for Connection<T> {
393    type Output = Result<(), AmqpCodecError>;
394
395    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
396        // connection heartbeat
397        match self.hb.poll(cx) {
398            Ok(act) => match act {
399                HeartbeatAction::None => (),
400                HeartbeatAction::Close => {
401                    self.inner.get_mut().set_error(AmqpTransportError::Timeout);
402                    return Poll::Ready(Ok(()));
403                }
404                HeartbeatAction::Heartbeat => {
405                    self.inner
406                        .get_mut()
407                        .write_queue
408                        .push_back(AmqpFrame::new(0, Frame::Empty));
409                }
410            },
411            Err(e) => {
412                self.inner.get_mut().set_error(e);
413                return Poll::Ready(Ok(()));
414            }
415        }
416
417        loop {
418            match self.poll_incoming(cx) {
419                Poll::Ready(None) => return Poll::Ready(Ok(())),
420                Poll::Ready(Some(Ok(frame))) => {
421                    if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) {
422                        if let ChannelState::Established(ref session) = channel {
423                            session.get_mut().handle_frame(frame.into_parts().1);
424                            continue;
425                        }
426                    }
427                    warn!("Unexpected frame: {:?}", frame);
428                }
429                Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
430                Poll::Pending => break,
431            }
432        }
433        let _ = self.poll_outgoing(cx)?;
434        self.register_write_task(cx);
435
436        match self.poll_incoming(cx) {
437            Poll::Ready(None) => return Poll::Ready(Ok(())),
438            Poll::Ready(Some(Ok(frame))) => {
439                if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) {
440                    if let ChannelState::Established(ref session) = channel {
441                        session.get_mut().handle_frame(frame.into_parts().1);
442                        return Poll::Pending;
443                    }
444                }
445                warn!("Unexpected frame: {:?}", frame);
446            }
447            Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
448            Poll::Pending => (),
449        }
450
451        Poll::Pending
452    }
453}
454
455#[derive(Clone)]
456pub struct ConnectionController(pub(crate) Cell<ConnectionInner>);
457
458impl ConnectionController {
459    pub(crate) fn new(local: Configuration) -> ConnectionController {
460        ConnectionController(Cell::new(ConnectionInner {
461            local,
462            remote: Configuration::default(),
463            write_queue: VecDeque::new(),
464            write_task: LocalWaker::new(),
465            sessions: slab::Slab::with_capacity(8),
466            sessions_map: FxHashMap::default(),
467            error: None,
468            state: State::Normal,
469        }))
470    }
471
472    pub(crate) fn set_remote(&mut self, remote: Configuration) {
473        self.0.get_mut().remote = remote;
474    }
475
476    #[inline]
477    /// Get remote connection configuration
478    pub fn remote_config(&self) -> &Configuration {
479        &self.0.get_ref().remote
480    }
481
482    #[inline]
483    /// Drop connection
484    pub fn drop_connection(&mut self) {
485        let inner = self.0.get_mut();
486        inner.state = State::Drop;
487        inner.write_task.wake()
488    }
489
490    pub(crate) fn post_frame(&mut self, frame: AmqpFrame) {
491        self.0.get_mut().post_frame(frame)
492    }
493
494    pub(crate) fn drop_session_copy(&mut self, _id: usize) {}
495}
496
497impl ConnectionInner {
498    pub(crate) fn new(local: Configuration, remote: Configuration) -> ConnectionInner {
499        ConnectionInner {
500            local,
501            remote,
502            write_queue: VecDeque::new(),
503            write_task: LocalWaker::new(),
504            sessions: slab::Slab::with_capacity(8),
505            sessions_map: FxHashMap::default(),
506            error: None,
507            state: State::Normal,
508        }
509    }
510
511    fn set_error(&mut self, err: AmqpTransportError) {
512        for (_, channel) in self.sessions.iter_mut() {
513            match channel {
514                ChannelState::Opening(_, _) | ChannelState::Closing(_) => (),
515                ChannelState::Established(ref mut ses) => {
516                    ses.get_mut().set_error(err.clone());
517                }
518            }
519        }
520        self.sessions.clear();
521        self.sessions_map.clear();
522
523        self.error = Some(err);
524    }
525
526    fn pop_next_frame(&mut self) -> Option<AmqpFrame> {
527        self.write_queue.pop_front()
528    }
529
530    fn post_frame(&mut self, frame: AmqpFrame) {
531        // trace!("POST-FRAME: {:#?}", frame.performative());
532        self.write_queue.push_back(frame);
533        self.write_task.wake();
534    }
535
536    fn complete_session_creation(&mut self, channel_id: u16, begin: &Begin) {
537        trace!(
538            "session opened: {:?} {:?}",
539            channel_id,
540            begin.remote_channel()
541        );
542
543        let id = begin.remote_channel().unwrap() as usize;
544
545        if let Some(channel) = self.sessions.get_mut(id) {
546            if channel.is_opening() {
547                if let ChannelState::Opening(tx, cell) = channel {
548                    let cell = cell.upgrade().unwrap();
549                    let session = Cell::new(SessionInner::new(
550                        id,
551                        true,
552                        ConnectionController(cell),
553                        channel_id,
554                        begin.next_outgoing_id(),
555                        begin.incoming_window(),
556                        begin.outgoing_window(),
557                    ));
558                    self.sessions_map.insert(channel_id, id);
559
560                    if tx
561                        .take()
562                        .unwrap()
563                        .send(Session::new(session.clone()))
564                        .is_err()
565                    {
566                        // todo: send end session
567                    }
568                    *channel = ChannelState::Established(session)
569                }
570            } else {
571                // send error response
572            }
573        } else {
574            // todo: rogue begin right now - do nothing. in future might indicate incoming attach
575        }
576    }
577}