ntex-mqtt 0.2.0-beta.2

MQTT Client/Server framework for v5 and v3.1.1 protocols
Documentation
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, io, time};

use futures::future::{err, join, ok, LocalBoxFuture, Ready};
use futures::{ready, Future, FutureExt};
use ntex::codec::{AsyncRead, AsyncWrite, Framed};
use ntex::rt::time::{delay_for, Delay};
use ntex::service::{Service, ServiceFactory};

use crate::error::MqttError;
use crate::version::{ProtocolVersion, VersionCodec};
use crate::{v3, v5};

/// Mqtt Server
pub struct MqttServer<Io, V3, V5, Err, InitErr> {
    v3: V3,
    v5: V5,
    handshake_timeout: usize,
    _t: PhantomData<(Io, Err, InitErr)>,
}

impl<Io, Err, InitErr>
    MqttServer<
        Io,
        DefaultProtocolServer<Io, Err, InitErr, v3::codec::Codec>,
        DefaultProtocolServer<Io, Err, InitErr, v5::codec::Codec>,
        Err,
        InitErr,
    >
{
    /// Create mqtt protocol selector server
    pub fn new() -> Self {
        MqttServer {
            v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
            v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
            handshake_timeout: 0,
            _t: PhantomData,
        }
    }
}

impl<Io, V3, V5, Err, InitErr> MqttServer<Io, V3, V5, Err, InitErr> {
    /// Set handshake timeout in millis.
    ///
    /// Handshake includes `connect` packet.
    /// By default handshake timeuot is disabled.
    pub fn handshake_timeout(mut self, timeout: usize) -> Self {
        self.handshake_timeout = timeout;
        self
    }
}

impl<Io, V3, V5, Err, InitErr> MqttServer<Io, V3, V5, Err, InitErr>
where
    Io: AsyncRead + AsyncWrite + Unpin + 'static,
    V3: ServiceFactory<
        Config = (),
        Request = (Framed<Io, v3::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
        InitError = InitErr,
    >,
    V5: ServiceFactory<
        Config = (),
        Request = (Framed<Io, v5::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
        InitError = InitErr,
    >,
{
    /// Service to handle v3 protocol
    pub fn v3<St, C, Cn, P>(
        self,
        service: v3::MqttServer<Io, St, C, Cn, P>,
    ) -> MqttServer<
        Io,
        impl ServiceFactory<
            Config = (),
            Request = (Framed<Io, v3::codec::Codec>, Option<Delay>),
            Response = (),
            Error = MqttError<Err>,
            InitError = InitErr,
        >,
        V5,
        Err,
        InitErr,
    >
    where
        St: 'static,
        C: ServiceFactory<
                Config = (),
                Request = v3::Connect<Io>,
                Response = v3::ConnectAck<Io, St>,
                Error = Err,
                InitError = InitErr,
            > + 'static,
        Cn: ServiceFactory<
                Config = v3::Session<St>,
                Request = v3::ControlPacket,
                Response = v3::ControlResult,
            > + 'static,
        P: ServiceFactory<Config = v3::Session<St>, Request = v3::Publish, Response = ()>
            + 'static,
        C::Error: From<Cn::Error>
            + From<Cn::InitError>
            + From<P::Error>
            + From<P::InitError>
            + fmt::Debug,
    {
        MqttServer {
            v3: service.inner_finish(),
            v5: self.v5,
            handshake_timeout: self.handshake_timeout,
            _t: PhantomData,
        }
    }

    /// Service to handle v5 protocol
    pub fn v5<St, C, Cn, P>(
        self,
        service: v5::MqttServer<Io, St, C, Cn, P>,
    ) -> MqttServer<
        Io,
        V3,
        impl ServiceFactory<
            Config = (),
            Request = (Framed<Io, v5::codec::Codec>, Option<Delay>),
            Response = (),
            Error = MqttError<Err>,
            InitError = InitErr,
        >,
        Err,
        InitErr,
    >
    where
        St: 'static,
        C: ServiceFactory<
                Config = (),
                Request = v5::Connect<Io>,
                Response = v5::ConnectAck<Io, St>,
                Error = Err,
                InitError = InitErr,
            > + 'static,
        Cn: ServiceFactory<
                Config = v5::Session<St>,
                Request = v5::ControlPacket,
                Response = v5::ControlResult,
            > + 'static,
        P: ServiceFactory<
                Config = v5::Session<St>,
                Request = v5::Publish,
                Response = v5::PublishAck,
            > + 'static,
        C::Error: From<Cn::Error>
            + From<Cn::InitError>
            + From<P::Error>
            + From<P::InitError>
            + fmt::Debug,
    {
        MqttServer {
            v3: self.v3,
            v5: service.inner_finish(),
            handshake_timeout: self.handshake_timeout,
            _t: PhantomData,
        }
    }
}

impl<Io, V3, V5, Err, InitErr> ServiceFactory for MqttServer<Io, V3, V5, Err, InitErr>
where
    Io: AsyncRead + AsyncWrite + Unpin + 'static,
    V3: ServiceFactory<
        Config = (),
        Request = (Framed<Io, v3::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
        InitError = InitErr,
    >,
    V5: ServiceFactory<
        Config = (),
        Request = (Framed<Io, v5::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
        InitError = InitErr,
    >,
    V3::Future: 'static,
    V5::Future: 'static,
{
    type Config = ();
    type Request = Io;
    type Response = ();
    type Error = MqttError<Err>;
    type Service = MqttServerImpl<Io, V3::Service, V5::Service, Err>;
    type InitError = InitErr;
    type Future = LocalBoxFuture<
        'static,
        Result<MqttServerImpl<Io, V3::Service, V5::Service, Err>, InitErr>,
    >;

    fn new_service(&self, _: ()) -> Self::Future {
        let handshake_timeout = self.handshake_timeout;

        join(self.v3.new_service(()), self.v5.new_service(()))
            .map(move |(v3, v5)| {
                let v3 = v3?;
                let v5 = v5?;
                Ok(MqttServerImpl {
                    handlers: Rc::new((v3, v5)),
                    handshake_timeout,
                    _t: PhantomData,
                })
            })
            .boxed_local()
    }
}

/// Mqtt Server
pub struct MqttServerImpl<Io, V3, V5, Err> {
    handlers: Rc<(V3, V5)>,
    handshake_timeout: usize,
    _t: PhantomData<(Io, Err)>,
}

impl<Io, V3, V5, Err> Service for MqttServerImpl<Io, V3, V5, Err>
where
    Io: AsyncRead + AsyncWrite + Unpin + 'static,
    V3: Service<
        Request = (Framed<Io, v3::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
    >,
    V5: Service<
        Request = (Framed<Io, v5::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
    >,
{
    type Request = Io;
    type Response = ();
    type Error = MqttError<Err>;
    type Future = MqttServerImplResponse<Io, V3, V5, Err>;

    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        let ready1 = self.handlers.0.poll_ready(cx)?.is_ready();
        let ready2 = self.handlers.1.poll_ready(cx)?.is_ready();

        if ready1 && ready2 {
            Poll::Ready(Ok(()))
        } else {
            Poll::Pending
        }
    }

    fn poll_shutdown(&self, cx: &mut Context<'_>, is_error: bool) -> Poll<()> {
        let ready1 = self.handlers.0.poll_shutdown(cx, is_error).is_ready();
        let ready2 = self.handlers.1.poll_shutdown(cx, is_error).is_ready();

        if ready1 && ready2 {
            Poll::Ready(())
        } else {
            Poll::Pending
        }
    }

    fn call(&self, req: Io) -> Self::Future {
        let delay = if self.handshake_timeout > 0 {
            Some(delay_for(time::Duration::from_secs(
                self.handshake_timeout as u64,
            )))
        } else {
            None
        };

        MqttServerImplResponse {
            state: MqttServerImplState::Version(Some((
                Framed::new(req, VersionCodec),
                self.handlers.clone(),
                delay,
            ))),
        }
    }
}

#[pin_project::pin_project]
pub struct MqttServerImplResponse<Io, V3, V5, Err>
where
    Io: AsyncRead + AsyncWrite + Unpin + 'static,
    V3: Service<
        Request = (Framed<Io, v3::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
    >,
    V5: Service<
        Request = (Framed<Io, v5::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
    >,
{
    #[pin]
    state: MqttServerImplState<Io, V3, V5>,
}

#[pin_project::pin_project(project = MqttServerImplStateProject)]
pub(crate) enum MqttServerImplState<Io, V3: Service, V5: Service> {
    V3(#[pin] V3::Future),
    V5(#[pin] V5::Future),
    Version(Option<(Framed<Io, VersionCodec>, Rc<(V3, V5)>, Option<Delay>)>),
}

impl<Io, V3, V5, Err> Future for MqttServerImplResponse<Io, V3, V5, Err>
where
    Io: AsyncRead + AsyncWrite + Unpin + 'static,
    V3: Service<
        Request = (Framed<Io, v3::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
    >,
    V5: Service<
        Request = (Framed<Io, v5::codec::Codec>, Option<Delay>),
        Response = (),
        Error = MqttError<Err>,
    >,
{
    type Output = Result<(), MqttError<Err>>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.as_mut().project();

        loop {
            match this.state.project() {
                MqttServerImplStateProject::V3(fut) => return fut.poll(cx),
                MqttServerImplStateProject::V5(fut) => return fut.poll(cx),
                MqttServerImplStateProject::Version(ref mut item) => {
                    if let Some(ref mut delay) = item.as_mut().unwrap().2 {
                        match Pin::new(delay).poll(cx) {
                            Poll::Pending => (),
                            Poll::Ready(_) => {
                                return Poll::Ready(Err(MqttError::HandshakeTimeout))
                            }
                        }
                    };

                    if let Some(ver) = ready!(item.as_mut().unwrap().0.next_item(cx)) {
                        let (framed, handlers, delay) = item.take().unwrap();
                        this = self.as_mut().project();
                        match ver? {
                            ProtocolVersion::MQTT3 => {
                                let framed = framed.map_codec(|_| v3::codec::Codec::new());
                                this.state.set(MqttServerImplState::V3(
                                    handlers.0.call((framed, delay)),
                                ))
                            }
                            ProtocolVersion::MQTT5 => {
                                let framed = framed.map_codec(|_| v5::codec::Codec::new());
                                this.state.set(MqttServerImplState::V5(
                                    handlers.1.call((framed, delay)),
                                ))
                            }
                        }
                        continue;
                    } else {
                        return Poll::Ready(Err(MqttError::Disconnected));
                    }
                }
            }
        }
    }
}

pub struct DefaultProtocolServer<Io, Err, InitErr, Codec> {
    ver: ProtocolVersion,
    _t: PhantomData<(Io, Err, InitErr, Codec)>,
}

impl<Io, Err, InitErr, Codec> DefaultProtocolServer<Io, Err, InitErr, Codec> {
    fn new(ver: ProtocolVersion) -> Self {
        Self {
            ver,
            _t: PhantomData,
        }
    }
}

impl<Io, Err, InitErr, Codec> ServiceFactory
    for DefaultProtocolServer<Io, Err, InitErr, Codec>
{
    type Config = ();
    type Request = (Framed<Io, Codec>, Option<Delay>);
    type Response = ();
    type Error = MqttError<Err>;
    type Service = DefaultProtocolServer<Io, Err, InitErr, Codec>;
    type InitError = InitErr;
    type Future = Ready<Result<Self::Service, Self::InitError>>;

    fn new_service(&self, _: ()) -> Self::Future {
        ok(DefaultProtocolServer {
            ver: self.ver,
            _t: PhantomData,
        })
    }
}

impl<Io, Err, InitErr, Codec> Service for DefaultProtocolServer<Io, Err, InitErr, Codec> {
    type Request = (Framed<Io, Codec>, Option<Delay>);
    type Response = ();
    type Error = MqttError<Err>;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&self, _: Self::Request) -> Self::Future {
        err(MqttError::Io(io::Error::new(
            io::ErrorKind::Other,
            format!("Protocol is not supported: {:?}", self.ver),
        )))
    }
}