use std::{collections::VecDeque, future::Future, hash, pin::Pin, task::Context, task::Poll};
use ntex::util::{poll_fn, ByteString, BytesMut, PoolRef, Stream};
use ntex::{channel::oneshot, task::LocalWaker};
use ntex_amqp_codec::protocol::{
self as codec, Attach, DeliveryNumber, Disposition, Error, Handle, LinkError,
ReceiverSettleMode, Role, SenderSettleMode, Source, TerminusDurability, TerminusExpiryPolicy,
Transfer, TransferBody,
};
use ntex_amqp_codec::types::{Symbol, Variant};
use ntex_amqp_codec::Encode;
use crate::session::{Session, SessionInner};
use crate::{cell::Cell, error::AmqpProtocolError, types::Action};
#[derive(Clone, Debug)]
pub struct ReceiverLink {
pub(crate) inner: Cell<ReceiverLinkInner>,
}
impl Eq for ReceiverLink {}
impl PartialEq<ReceiverLink> for ReceiverLink {
fn eq(&self, other: &ReceiverLink) -> bool {
std::ptr::eq(self.inner.get_ref(), other.inner.get_ref())
}
}
impl hash::Hash for ReceiverLink {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
(self.inner.get_ref() as *const _ as usize).hash(state);
}
}
impl ReceiverLink {
pub(crate) fn new(inner: Cell<ReceiverLinkInner>) -> ReceiverLink {
ReceiverLink { inner }
}
pub fn name(&self) -> &ByteString {
&self.inner.get_ref().name
}
pub fn handle(&self) -> Handle {
self.inner.get_ref().handle as Handle
}
pub fn remote_handle(&self) -> Handle {
self.inner.get_ref().remote_handle as Handle
}
pub fn credit(&self) -> u32 {
self.inner.get_ref().credit
}
pub fn session(&self) -> &Session {
&self.inner.get_ref().session
}
pub fn is_closed(&self) -> bool {
self.inner.get_ref().closed
}
pub fn error(&self) -> Option<&Error> {
self.inner.get_ref().error.as_ref()
}
pub(crate) fn confirm_receiver_link(&self, frm: &Attach) {
let inner = self.inner.get_mut();
let size = self.inner.get_ref().max_message_size;
let size = if size != 0 { Some(size) } else { None };
inner
.session
.inner
.get_mut()
.confirm_receiver_link(inner.handle, frm, size);
}
pub fn set_link_credit(&self, credit: u32) {
self.inner.get_mut().set_link_credit(credit);
}
pub fn set_max_message_size(&self, size: u64) {
self.inner.get_mut().max_message_size = size;
}
pub fn set_max_partial_transfer_size(&self, size: usize) {
self.inner.get_mut().set_max_partial_transfer(size);
}
pub fn has_transfers(&self) -> bool {
!self.inner.get_mut().queue.is_empty()
}
pub fn get_transfer(&self) -> Option<Transfer> {
self.inner.get_mut().queue.pop_front()
}
pub fn send_disposition(&self, disp: Disposition) {
self.inner
.get_mut()
.session
.inner
.get_mut()
.post_frame(disp.into());
}
pub fn wait_disposition(
&self,
id: DeliveryNumber,
) -> impl Future<Output = Result<Disposition, AmqpProtocolError>> {
self.inner.get_mut().session.wait_disposition(id)
}
pub fn close(&self) -> impl Future<Output = Result<(), AmqpProtocolError>> {
self.inner.get_mut().close(None)
}
pub fn close_with_error<E>(
&self,
error: E,
) -> impl Future<Output = Result<(), AmqpProtocolError>>
where
Error: From<E>,
{
self.inner.get_mut().close(Some(error.into()))
}
pub(crate) fn remote_closed(&self, error: Option<Error>) {
let inner = self.inner.get_mut();
trace!(
"Receiver link has been closed remotely handle: {:?} name: {:?}",
inner.remote_handle,
inner.name
);
inner.closed = true;
inner.error = error;
inner.wake();
}
pub async fn recv(&self) -> Option<Result<Transfer, AmqpProtocolError>> {
poll_fn(|cx| self.poll_recv(cx)).await
}
pub fn poll_recv(
&self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Transfer, AmqpProtocolError>>> {
let inner = self.inner.get_mut();
if inner.partial_body.is_some() && inner.queue.len() == 1 {
if inner.closed {
if let Some(err) = inner.error.take() {
Poll::Ready(Some(Err(AmqpProtocolError::LinkDetached(Some(err)))))
} else {
Poll::Ready(None)
}
} else {
inner.reader_task.register(cx.waker());
Poll::Pending
}
} else if let Some(tr) = inner.queue.pop_front() {
Poll::Ready(Some(Ok(tr)))
} else if inner.closed {
if let Some(err) = inner.error.take() {
Poll::Ready(Some(Err(AmqpProtocolError::LinkDetached(Some(err)))))
} else {
Poll::Ready(None)
}
} else {
inner.reader_task.register(cx.waker());
Poll::Pending
}
}
}
impl Stream for ReceiverLink {
type Item = Result<Transfer, AmqpProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_recv(cx)
}
}
#[derive(Debug)]
pub(crate) struct ReceiverLinkInner {
name: ByteString,
handle: Handle,
remote_handle: Handle,
session: Session,
closed: bool,
reader_task: LocalWaker,
queue: VecDeque<Transfer>,
credit: u32,
delivery_count: u32,
error: Option<Error>,
partial_body: Option<BytesMut>,
partial_body_max: usize,
max_message_size: u64,
pool: PoolRef,
}
impl ReceiverLinkInner {
pub(crate) fn new(
session: Cell<SessionInner>,
handle: Handle,
remote_handle: Handle,
frame: &Attach,
) -> ReceiverLinkInner {
let pool = session.get_ref().memory_pool();
let mut name = frame.name().clone();
name.trimdown();
ReceiverLinkInner {
name,
handle,
remote_handle,
pool,
session: Session::new(session),
closed: false,
queue: VecDeque::with_capacity(4),
credit: 0,
error: None,
partial_body: None,
partial_body_max: 262_144,
delivery_count: frame.initial_delivery_count().unwrap_or(0),
max_message_size: frame.max_message_size().unwrap_or(0),
reader_task: LocalWaker::new(),
}
}
fn wake(&self) {
self.reader_task.wake();
}
pub(crate) fn name(&self) -> &ByteString {
&self.name
}
pub(crate) fn detached(&mut self) {
self.queue.clear();
self.closed = true;
}
pub(crate) fn close(
&mut self,
error: Option<Error>,
) -> impl Future<Output = Result<(), AmqpProtocolError>> {
let (tx, rx) = oneshot::channel();
if self.closed {
let _ = tx.send(Ok(()));
} else {
self.session
.inner
.get_mut()
.detach_receiver_link(self.handle, true, error, tx);
}
self.wake();
async move {
match rx.await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(AmqpProtocolError::Disconnected),
}
}
}
fn set_max_partial_transfer(&mut self, size: usize) {
self.partial_body_max = size;
}
pub(crate) fn set_link_credit(&mut self, credit: u32) {
self.credit += credit;
self.session
.inner
.get_mut()
.rcv_link_flow(self.handle, self.delivery_count, credit);
}
#[allow(clippy::unnecessary_unwrap)]
pub(crate) fn handle_transfer(
&mut self,
mut transfer: Transfer,
inner: &Cell<ReceiverLinkInner>,
) -> Result<Action, AmqpProtocolError> {
if self.credit == 0 {
let err = Error(Box::new(codec::ErrorInner {
condition: LinkError::TransferLimitExceeded.into(),
description: None,
info: None,
}));
let _ = self.close(Some(err));
Ok(Action::None)
} else {
self.credit -= 1;
if let Some(ref mut body) = self.partial_body {
if transfer.0.delivery_id.is_some() {
if self
.queue
.back()
.map_or(true, |back| back.0.delivery_id != transfer.0.delivery_id)
{
let err = Error(Box::new(codec::ErrorInner {
condition: LinkError::DetachForced.into(),
description: Some(ByteString::from_static("delivery_id is wrong")),
info: None,
}));
let _ = self.close(Some(err));
return Ok(Action::None);
}
}
if let Some(transfer_body) = transfer.0.body.take() {
if body.len() + transfer_body.len() > self.partial_body_max {
let err = Error(Box::new(codec::ErrorInner {
condition: LinkError::MessageSizeExceeded.into(),
description: None,
info: None,
}));
let _ = self.close(Some(err));
return Ok(Action::None);
}
transfer_body.encode(body);
}
if transfer.0.more {
Ok(Action::None)
} else {
self.delivery_count += 1;
let partial_body = self.partial_body.take();
if partial_body.is_some() && !self.queue.is_empty() {
self.queue.back_mut().unwrap().0.body =
Some(TransferBody::Data(partial_body.unwrap().freeze()));
if self.queue.len() == 1 {
self.wake();
}
Ok(Action::Transfer(ReceiverLink {
inner: inner.clone(),
}))
} else {
log::error!("Inconsistent state, bug");
let err = Error(Box::new(codec::ErrorInner {
condition: LinkError::DetachForced.into(),
description: Some(ByteString::from_static("Internal error")),
info: None,
}));
let _ = self.close(Some(err));
Ok(Action::None)
}
}
} else if transfer.more() {
if transfer.delivery_id().is_none() {
let err = Error(Box::new(codec::ErrorInner {
condition: LinkError::DetachForced.into(),
description: Some(ByteString::from_static("delivery_id is required")),
info: None,
}));
let _ = self.close(Some(err));
Ok(Action::None)
} else {
let body = if let Some(body) = transfer.0.body.take() {
match body {
TransferBody::Data(data) => BytesMut::copy_from_slice(&data),
TransferBody::Message(msg) => {
let mut buf = self.pool.buf_with_capacity(msg.encoded_size());
msg.encode(&mut buf);
buf
}
}
} else {
self.pool.buf_with_capacity(16)
};
self.partial_body = Some(body);
self.queue.push_back(transfer);
Ok(Action::None)
}
} else {
self.delivery_count += 1;
self.queue.push_back(transfer);
if self.queue.len() == 1 {
self.wake();
}
Ok(Action::Transfer(ReceiverLink {
inner: inner.clone(),
}))
}
}
}
}
pub struct ReceiverLinkBuilder {
frame: Attach,
session: Cell<SessionInner>,
}
impl ReceiverLinkBuilder {
pub(crate) fn new(name: ByteString, address: ByteString, session: Cell<SessionInner>) -> Self {
let source = Source {
address: Some(address),
durable: TerminusDurability::None,
expiry_policy: TerminusExpiryPolicy::SessionEnd,
timeout: 0,
dynamic: false,
dynamic_node_properties: None,
distribution_mode: None,
filter: None,
default_outcome: None,
outcomes: None,
capabilities: None,
};
let frame = Attach(Box::new(codec::AttachInner {
name,
handle: 0_u32,
role: Role::Receiver,
snd_settle_mode: SenderSettleMode::Mixed,
rcv_settle_mode: ReceiverSettleMode::First,
source: Some(source),
target: None,
unsettled: None,
incomplete_unsettled: false,
initial_delivery_count: None,
max_message_size: Some(65536 * 4),
offered_capabilities: None,
desired_capabilities: None,
properties: None,
}));
ReceiverLinkBuilder { frame, session }
}
pub fn max_message_size(mut self, size: u64) -> Self {
self.frame.0.max_message_size = Some(size);
self
}
pub fn property<K, V>(mut self, key: K, value: Option<V>) -> Self
where
Symbol: From<K>,
Variant: From<V>,
{
let key = key.into();
let props = self.frame.get_properties_mut();
match value {
Some(value) => props.insert(key, value.into()),
None => props.remove(&key),
};
self
}
pub async fn attach(self) -> Result<ReceiverLink, AmqpProtocolError> {
let cell = self.session.clone();
let res = self
.session
.get_mut()
.attach_local_receiver_link(cell, self.frame)
.await;
match res {
Ok(Ok(res)) => Ok(res),
Ok(Err(err)) => Err(err),
Err(_) => Err(AmqpProtocolError::Disconnected),
}
}
}