1use std::collections::VecDeque;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use actix_codec::{AsyncRead, AsyncWrite, Framed};
8use actix_utils::oneshot;
9use actix_utils::task::LocalWaker;
10use actix_utils::time::LowResTimeService;
11use futures::future::{err, Either};
12use futures::{future, Sink, Stream};
13use fxhash::FxHashMap;
14
15use amqp_codec::protocol::{Begin, Close, End, Error, Frame};
16use amqp_codec::{AmqpCodec, AmqpCodecError, AmqpFrame};
17
18use crate::cell::{Cell, WeakCell};
19use crate::errors::AmqpTransportError;
20use crate::hb::{Heartbeat, HeartbeatAction};
21use crate::session::{Session, SessionInner};
22use crate::Configuration;
23
24pub struct Connection<T: AsyncRead + AsyncWrite> {
25 inner: Cell<ConnectionInner>,
26 framed: Framed<T, AmqpCodec<AmqpFrame>>,
27 hb: Heartbeat,
28}
29
30pub(crate) enum ChannelState {
31 Opening(Option<oneshot::Sender<Session>>, WeakCell<ConnectionInner>),
32 Established(Cell<SessionInner>),
33 Closing(Option<oneshot::Sender<Result<(), AmqpTransportError>>>),
34}
35
36impl ChannelState {
37 fn is_opening(&self) -> bool {
38 match self {
39 ChannelState::Opening(_, _) => true,
40 _ => false,
41 }
42 }
43}
44
45pub(crate) struct ConnectionInner {
46 local: Configuration,
47 remote: Configuration,
48 write_queue: VecDeque<AmqpFrame>,
49 write_task: LocalWaker,
50 sessions: slab::Slab<ChannelState>,
51 sessions_map: FxHashMap<u16, usize>,
52 error: Option<AmqpTransportError>,
53 state: State,
54}
55
56#[derive(PartialEq)]
57enum State {
58 Normal,
59 Closing,
60 RemoteClose,
61 Drop,
62}
63
64impl<T: AsyncRead + AsyncWrite> Connection<T> {
65 pub fn new(
66 framed: Framed<T, AmqpCodec<AmqpFrame>>,
67 local: Configuration,
68 remote: Configuration,
69 time: Option<LowResTimeService>,
70 ) -> Connection<T> {
71 Connection {
72 framed,
73 hb: Heartbeat::new(
74 local.timeout().unwrap(),
75 remote.timeout(),
76 time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))),
77 ),
78 inner: Cell::new(ConnectionInner::new(local, remote)),
79 }
80 }
81
82 pub(crate) fn new_server(
83 framed: Framed<T, AmqpCodec<AmqpFrame>>,
84 inner: Cell<ConnectionInner>,
85 time: Option<LowResTimeService>,
86 ) -> Connection<T> {
87 let l_timeout = inner.get_ref().local.timeout().unwrap();
88 let r_timeout = inner.get_ref().remote.timeout();
89 Connection {
90 framed,
91 inner,
92 hb: Heartbeat::new(
93 l_timeout,
94 r_timeout,
95 time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))),
96 ),
97 }
98 }
99
100 pub fn controller(&self) -> ConnectionController {
102 ConnectionController(self.inner.clone())
103 }
104
105 pub fn remote_config(&self) -> &Configuration {
107 &self.inner.get_ref().remote
108 }
109
110 pub fn close(&mut self) -> impl Future<Output = Result<(), AmqpTransportError>> {
112 future::ok(())
113 }
114
115 pub fn close_with_error(
118 &mut self,
119 _err: Error,
120 ) -> impl Future<Output = Result<(), AmqpTransportError>> {
121 future::ok(())
122 }
123
124 pub fn open_session(&mut self) -> impl Future<Output = Result<Session, AmqpTransportError>> {
126 let cell = self.inner.downgrade();
127 let inner = self.inner.clone();
128
129 async move {
130 let inner = inner.get_mut();
131
132 if let Some(ref e) = inner.error {
133 Err(e.clone())
134 } else {
135 let (tx, rx) = oneshot::channel();
136
137 let entry = inner.sessions.vacant_entry();
138 let token = entry.key();
139
140 if token >= inner.local.channel_max {
141 Err(AmqpTransportError::TooManyChannels)
142 } else {
143 entry.insert(ChannelState::Opening(Some(tx), cell));
144
145 let begin = Begin {
146 remote_channel: None,
147 next_outgoing_id: 1,
148 incoming_window: std::u32::MAX,
149 outgoing_window: std::u32::MAX,
150 handle_max: std::u32::MAX,
151 offered_capabilities: None,
152 desired_capabilities: None,
153 properties: None,
154 };
155 inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
156
157 rx.await.map_err(|_| AmqpTransportError::Disconnected)
158 }
159 }
160 }
161 }
162
163 pub(crate) fn get_session(&self, id: usize) -> Cell<SessionInner> {
165 if let Some(channel) = self.inner.get_ref().sessions.get(id) {
166 if let ChannelState::Established(ref session) = channel {
167 return session.clone();
168 }
169 }
170 panic!("Session not found: {}", id);
171 }
172
173 pub(crate) fn register_remote_session(&mut self, channel_id: u16, begin: &Begin) {
174 trace!("remote session opened: {:?}", channel_id);
175
176 let cell = self.inner.clone();
177 let inner = self.inner.get_mut();
178 let entry = inner.sessions.vacant_entry();
179 let token = entry.key();
180
181 let session = Cell::new(SessionInner::new(
182 token,
183 false,
184 ConnectionController(cell),
185 token as u16,
186 begin.next_outgoing_id(),
187 begin.incoming_window(),
188 begin.outgoing_window(),
189 ));
190 entry.insert(ChannelState::Established(session));
191 inner.sessions_map.insert(channel_id, token);
192
193 let begin = Begin {
194 remote_channel: Some(channel_id),
195 next_outgoing_id: 1,
196 incoming_window: std::u32::MAX,
197 outgoing_window: begin.incoming_window(),
198 handle_max: std::u32::MAX,
199 offered_capabilities: None,
200 desired_capabilities: None,
201 properties: None,
202 };
203 inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
204 }
205
206 pub(crate) fn send_frame(&mut self, frame: AmqpFrame) {
207 self.inner.get_mut().post_frame(frame)
208 }
209
210 pub(crate) fn register_write_task(&self, cx: &mut Context) {
211 self.inner.write_task.register(cx.waker());
212 }
213
214 pub(crate) fn poll_outgoing(&mut self, cx: &mut Context) -> Poll<Result<(), AmqpCodecError>> {
215 let inner = self.inner.get_mut();
216 let mut update = false;
217 loop {
218 while !self.framed.is_write_buf_full() {
219 if let Some(frame) = inner.pop_next_frame() {
220 trace!("outgoing: {:#?}", frame);
221 update = true;
222 if let Err(e) = self.framed.write(frame) {
223 inner.set_error(e.clone().into());
224 return Poll::Ready(Err(e));
225 }
226 } else {
227 break;
228 }
229 }
230
231 if !self.framed.is_write_buf_empty() {
232 match self.framed.flush(cx) {
233 Poll::Pending => break,
234 Poll::Ready(Err(e)) => {
235 trace!("error sending data: {}", e);
236 inner.set_error(e.clone().into());
237 return Poll::Ready(Err(e));
238 }
239 Poll::Ready(_) => (),
240 }
241 } else {
242 break;
243 }
244 }
245 self.hb.update_remote(update);
246
247 if inner.state == State::Drop {
248 Poll::Ready(Ok(()))
249 } else if inner.state == State::RemoteClose
250 && inner.write_queue.is_empty()
251 && self.framed.is_write_buf_empty()
252 {
253 Poll::Ready(Ok(()))
254 } else {
255 Poll::Pending
256 }
257 }
258
259 pub(crate) fn poll_incoming(
260 &mut self,
261 cx: &mut Context,
262 ) -> Poll<Option<Result<AmqpFrame, AmqpCodecError>>> {
263 let inner = self.inner.get_mut();
264
265 let mut update = false;
266 loop {
267 match Pin::new(&mut self.framed).poll_next(cx) {
268 Poll::Ready(Some(Ok(frame))) => {
269 trace!("incoming: {:#?}", frame);
270
271 update = true;
272
273 if let Frame::Empty = frame.performative() {
274 self.hb.update_local(update);
275 continue;
276 }
277
278 if let Frame::Close(ref close) = frame.performative() {
280 inner.set_error(AmqpTransportError::Closed(close.error.clone()));
281
282 if inner.state == State::Closing {
283 inner.sessions.clear();
284 return Poll::Ready(None);
285 } else {
286 let close = Close { error: None };
287 inner.post_frame(AmqpFrame::new(0, close.into()));
288 inner.state = State::RemoteClose;
289 }
290 }
291
292 if inner.error.is_some() {
293 error!("connection closed but new framed is received: {:?}", frame);
294 return Poll::Ready(None);
295 }
296
297 let channel_id =
299 if let Some(token) = inner.sessions_map.get(&frame.channel_id()) {
300 *token
301 } else {
302 if let Frame::Begin(ref begin) = frame.performative() {
304 if begin.remote_channel().is_some() {
305 inner.complete_session_creation(frame.channel_id(), begin);
306 } else {
307 return Poll::Ready(Some(Ok(frame)));
308 }
309 } else {
310 warn!("Unexpected frame: {:#?}", frame);
311 }
312 continue;
313 };
314
315 if let Some(channel) = inner.sessions.get_mut(channel_id) {
317 match channel {
318 ChannelState::Opening(_, _) => {
319 error!("Unexpected opening state: {}", channel_id);
320 }
321 ChannelState::Established(ref mut session) => {
322 match frame.performative() {
323 Frame::Attach(attach) => {
324 let cell = session.clone();
325 if !session.get_mut().handle_attach(attach, cell) {
326 return Poll::Ready(Some(Ok(frame)));
327 }
328 }
329 Frame::Flow(_) | Frame::Detach(_) => {
330 return Poll::Ready(Some(Ok(frame)));
331 }
332 Frame::End(remote_end) => {
333 trace!("Remote session end: {}", frame.channel_id());
334 let end = End { error: None };
335 session.get_mut().set_error(
336 AmqpTransportError::SessionEnded(
337 remote_end.error.clone(),
338 ),
339 );
340 let id = session.get_mut().id();
341 inner.post_frame(AmqpFrame::new(id, end.into()));
342 inner.sessions.remove(channel_id);
343 inner.sessions_map.remove(&frame.channel_id());
344 }
345 _ => session.get_mut().handle_frame(frame.into_parts().1),
346 }
347 }
348 ChannelState::Closing(ref mut tx) => match frame.performative() {
349 Frame::End(_) => {
350 if let Some(tx) = tx.take() {
351 let _ = tx.send(Ok(()));
352 }
353 inner.sessions.remove(channel_id);
354 inner.sessions_map.remove(&frame.channel_id());
355 }
356 frm => trace!("Got frame after initiated session end: {:?}", frm),
357 },
358 }
359 } else {
360 error!("Can not find channel: {}", channel_id);
361 continue;
362 }
363 }
364 Poll::Ready(None) => {
365 inner.set_error(AmqpTransportError::Disconnected);
366 return Poll::Ready(None);
367 }
368 Poll::Pending => {
369 self.hb.update_local(update);
370 break;
371 }
372 Poll::Ready(Some(Err(e))) => {
373 trace!("error reading: {:?}", e);
374 inner.set_error(e.clone().into());
375 return Poll::Ready(Some(Err(e.into())));
376 }
377 }
378 }
379
380 Poll::Pending
381 }
382}
383
384impl<T: AsyncRead + AsyncWrite> Drop for Connection<T> {
385 fn drop(&mut self) {
386 self.inner
387 .get_mut()
388 .set_error(AmqpTransportError::Disconnected);
389 }
390}
391
392impl<T: AsyncRead + AsyncWrite> Future for Connection<T> {
393 type Output = Result<(), AmqpCodecError>;
394
395 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
396 match self.hb.poll(cx) {
398 Ok(act) => match act {
399 HeartbeatAction::None => (),
400 HeartbeatAction::Close => {
401 self.inner.get_mut().set_error(AmqpTransportError::Timeout);
402 return Poll::Ready(Ok(()));
403 }
404 HeartbeatAction::Heartbeat => {
405 self.inner
406 .get_mut()
407 .write_queue
408 .push_back(AmqpFrame::new(0, Frame::Empty));
409 }
410 },
411 Err(e) => {
412 self.inner.get_mut().set_error(e);
413 return Poll::Ready(Ok(()));
414 }
415 }
416
417 loop {
418 match self.poll_incoming(cx) {
419 Poll::Ready(None) => return Poll::Ready(Ok(())),
420 Poll::Ready(Some(Ok(frame))) => {
421 if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) {
422 if let ChannelState::Established(ref session) = channel {
423 session.get_mut().handle_frame(frame.into_parts().1);
424 continue;
425 }
426 }
427 warn!("Unexpected frame: {:?}", frame);
428 }
429 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
430 Poll::Pending => break,
431 }
432 }
433 let _ = self.poll_outgoing(cx)?;
434 self.register_write_task(cx);
435
436 match self.poll_incoming(cx) {
437 Poll::Ready(None) => return Poll::Ready(Ok(())),
438 Poll::Ready(Some(Ok(frame))) => {
439 if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) {
440 if let ChannelState::Established(ref session) = channel {
441 session.get_mut().handle_frame(frame.into_parts().1);
442 return Poll::Pending;
443 }
444 }
445 warn!("Unexpected frame: {:?}", frame);
446 }
447 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
448 Poll::Pending => (),
449 }
450
451 Poll::Pending
452 }
453}
454
455#[derive(Clone)]
456pub struct ConnectionController(pub(crate) Cell<ConnectionInner>);
457
458impl ConnectionController {
459 pub(crate) fn new(local: Configuration) -> ConnectionController {
460 ConnectionController(Cell::new(ConnectionInner {
461 local,
462 remote: Configuration::default(),
463 write_queue: VecDeque::new(),
464 write_task: LocalWaker::new(),
465 sessions: slab::Slab::with_capacity(8),
466 sessions_map: FxHashMap::default(),
467 error: None,
468 state: State::Normal,
469 }))
470 }
471
472 pub(crate) fn set_remote(&mut self, remote: Configuration) {
473 self.0.get_mut().remote = remote;
474 }
475
476 #[inline]
477 pub fn remote_config(&self) -> &Configuration {
479 &self.0.get_ref().remote
480 }
481
482 #[inline]
483 pub fn drop_connection(&mut self) {
485 let inner = self.0.get_mut();
486 inner.state = State::Drop;
487 inner.write_task.wake()
488 }
489
490 pub(crate) fn post_frame(&mut self, frame: AmqpFrame) {
491 self.0.get_mut().post_frame(frame)
492 }
493
494 pub(crate) fn drop_session_copy(&mut self, _id: usize) {}
495}
496
497impl ConnectionInner {
498 pub(crate) fn new(local: Configuration, remote: Configuration) -> ConnectionInner {
499 ConnectionInner {
500 local,
501 remote,
502 write_queue: VecDeque::new(),
503 write_task: LocalWaker::new(),
504 sessions: slab::Slab::with_capacity(8),
505 sessions_map: FxHashMap::default(),
506 error: None,
507 state: State::Normal,
508 }
509 }
510
511 fn set_error(&mut self, err: AmqpTransportError) {
512 for (_, channel) in self.sessions.iter_mut() {
513 match channel {
514 ChannelState::Opening(_, _) | ChannelState::Closing(_) => (),
515 ChannelState::Established(ref mut ses) => {
516 ses.get_mut().set_error(err.clone());
517 }
518 }
519 }
520 self.sessions.clear();
521 self.sessions_map.clear();
522
523 self.error = Some(err);
524 }
525
526 fn pop_next_frame(&mut self) -> Option<AmqpFrame> {
527 self.write_queue.pop_front()
528 }
529
530 fn post_frame(&mut self, frame: AmqpFrame) {
531 self.write_queue.push_back(frame);
533 self.write_task.wake();
534 }
535
536 fn complete_session_creation(&mut self, channel_id: u16, begin: &Begin) {
537 trace!(
538 "session opened: {:?} {:?}",
539 channel_id,
540 begin.remote_channel()
541 );
542
543 let id = begin.remote_channel().unwrap() as usize;
544
545 if let Some(channel) = self.sessions.get_mut(id) {
546 if channel.is_opening() {
547 if let ChannelState::Opening(tx, cell) = channel {
548 let cell = cell.upgrade().unwrap();
549 let session = Cell::new(SessionInner::new(
550 id,
551 true,
552 ConnectionController(cell),
553 channel_id,
554 begin.next_outgoing_id(),
555 begin.incoming_window(),
556 begin.outgoing_window(),
557 ));
558 self.sessions_map.insert(channel_id, id);
559
560 if tx
561 .take()
562 .unwrap()
563 .send(Session::new(session.clone()))
564 .is_err()
565 {
566 }
568 *channel = ChannelState::Established(session)
569 }
570 } else {
571 }
573 } else {
574 }
576 }
577}