actix_amqp/server/
sasl.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use actix_codec::{AsyncRead, AsyncWrite, Framed};
7use actix_service::{Service, ServiceFactory};
8use amqp_codec::protocol::{
9    self, ProtocolId, SaslChallenge, SaslCode, SaslFrameBody, SaslMechanisms, SaslOutcome, Symbols,
10};
11use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, SaslFrame};
12use bytes::Bytes;
13use bytestring::ByteString;
14use futures::future::{err, ok, Either, Ready};
15use futures::{SinkExt, StreamExt};
16
17use super::connect::{ConnectAck, ConnectOpened};
18use super::errors::{AmqpError, ServerError};
19use crate::connection::ConnectionController;
20
21pub struct Sasl<Io> {
22    framed: Framed<Io, ProtocolIdCodec>,
23    mechanisms: Symbols,
24    controller: ConnectionController,
25}
26
27impl<Io> fmt::Debug for Sasl<Io> {
28    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
29        fmt.debug_struct("SaslAuth")
30            .field("mechanisms", &self.mechanisms)
31            .finish()
32    }
33}
34
35impl<Io> Sasl<Io> {
36    pub(crate) fn new(
37        framed: Framed<Io, ProtocolIdCodec>,
38        controller: ConnectionController,
39    ) -> Self {
40        Sasl {
41            framed,
42            controller,
43            mechanisms: Symbols::default(),
44        }
45    }
46}
47
48impl<Io> Sasl<Io>
49where
50    Io: AsyncRead + AsyncWrite,
51{
52    /// Returns reference to io object
53    pub fn get_ref(&self) -> &Io {
54        self.framed.get_ref()
55    }
56
57    /// Returns mutable reference to io object
58    pub fn get_mut(&mut self) -> &mut Io {
59        self.framed.get_mut()
60    }
61
62    /// Add supported sasl mechanism
63    pub fn mechanism<U: Into<String>>(mut self, symbol: U) -> Self {
64        self.mechanisms.push(ByteString::from(symbol.into()).into());
65        self
66    }
67
68    /// Initialize sasl auth procedure
69    pub async fn init(self) -> Result<Init<Io>, ServerError<()>> {
70        let Sasl {
71            framed,
72            mechanisms,
73            controller,
74            ..
75        } = self;
76
77        let mut framed = framed.into_framed(AmqpCodec::<SaslFrame>::new());
78        let frame = SaslMechanisms {
79            sasl_server_mechanisms: mechanisms,
80        }
81        .into();
82
83        framed.send(frame).await.map_err(ServerError::from)?;
84        let frame = framed
85            .next()
86            .await
87            .ok_or(ServerError::Disconnected)?
88            .map_err(ServerError::from)?;
89
90        match frame.body {
91            SaslFrameBody::SaslInit(frame) => Ok(Init {
92                frame,
93                framed,
94                controller,
95            }),
96            body => Err(ServerError::UnexpectedSaslBodyFrame(body)),
97        }
98    }
99}
100
101/// Initialization stage of sasl negotiation
102pub struct Init<Io> {
103    frame: protocol::SaslInit,
104    framed: Framed<Io, AmqpCodec<SaslFrame>>,
105    controller: ConnectionController,
106}
107
108impl<Io> fmt::Debug for Init<Io> {
109    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
110        fmt.debug_struct("SaslInit")
111            .field("frame", &self.frame)
112            .finish()
113    }
114}
115
116impl<Io> Init<Io>
117where
118    Io: AsyncRead + AsyncWrite,
119{
120    /// Sasl mechanism
121    pub fn mechanism(&self) -> &str {
122        self.frame.mechanism.as_str()
123    }
124
125    /// Sasl initial response
126    pub fn initial_response(&self) -> Option<&[u8]> {
127        self.frame.initial_response.as_ref().map(|b| b.as_ref())
128    }
129
130    /// Sasl initial response
131    pub fn hostname(&self) -> Option<&str> {
132        self.frame.hostname.as_ref().map(|b| b.as_ref())
133    }
134
135    /// Returns reference to io object
136    pub fn get_ref(&self) -> &Io {
137        self.framed.get_ref()
138    }
139
140    /// Returns mutable reference to io object
141    pub fn get_mut(&mut self) -> &mut Io {
142        self.framed.get_mut()
143    }
144
145    /// Initiate sasl challenge
146    pub async fn challenge(self) -> Result<Response<Io>, ServerError<()>> {
147        self.challenge_with(Bytes::new()).await
148    }
149
150    /// Initiate sasl challenge with challenge payload
151    pub async fn challenge_with(self, challenge: Bytes) -> Result<Response<Io>, ServerError<()>> {
152        let mut framed = self.framed;
153        let controller = self.controller;
154        let frame = SaslChallenge { challenge }.into();
155
156        framed.send(frame).await.map_err(ServerError::from)?;
157        let frame = framed
158            .next()
159            .await
160            .ok_or(ServerError::Disconnected)?
161            .map_err(ServerError::from)?;
162
163        match frame.body {
164            SaslFrameBody::SaslResponse(frame) => Ok(Response {
165                frame,
166                framed,
167                controller,
168            }),
169            body => Err(ServerError::UnexpectedSaslBodyFrame(body)),
170        }
171    }
172
173    /// Sasl challenge outcome
174    pub async fn outcome(self, code: SaslCode) -> Result<Success<Io>, ServerError<()>> {
175        let mut framed = self.framed;
176        let controller = self.controller;
177
178        let frame = SaslOutcome {
179            code,
180            additional_data: None,
181        }
182        .into();
183        framed.send(frame).await.map_err(ServerError::from)?;
184
185        Ok(Success { framed, controller })
186    }
187}
188
189pub struct Response<Io> {
190    frame: protocol::SaslResponse,
191    framed: Framed<Io, AmqpCodec<SaslFrame>>,
192    controller: ConnectionController,
193}
194
195impl<Io> fmt::Debug for Response<Io> {
196    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
197        fmt.debug_struct("SaslResponse")
198            .field("frame", &self.frame)
199            .finish()
200    }
201}
202
203impl<Io> Response<Io>
204where
205    Io: AsyncRead + AsyncWrite,
206{
207    /// Client response payload
208    pub fn response(&self) -> &[u8] {
209        &self.frame.response[..]
210    }
211
212    /// Sasl challenge outcome
213    pub async fn outcome(self, code: SaslCode) -> Result<Success<Io>, ServerError<()>> {
214        let mut framed = self.framed;
215        let controller = self.controller;
216        let frame = SaslOutcome {
217            code,
218            additional_data: None,
219        }
220        .into();
221
222        framed.send(frame).await.map_err(ServerError::from)?;
223        framed
224            .next()
225            .await
226            .ok_or(ServerError::Disconnected)?
227            .map_err(|res| ServerError::from(res))?;
228
229        Ok(Success { framed, controller })
230    }
231}
232
233pub struct Success<Io> {
234    framed: Framed<Io, AmqpCodec<SaslFrame>>,
235    controller: ConnectionController,
236}
237
238impl<Io> Success<Io>
239where
240    Io: AsyncRead + AsyncWrite,
241{
242    /// Returns reference to io object
243    pub fn get_ref(&self) -> &Io {
244        self.framed.get_ref()
245    }
246
247    /// Returns mutable reference to io object
248    pub fn get_mut(&mut self) -> &mut Io {
249        self.framed.get_mut()
250    }
251
252    /// Wait for connection open frame
253    pub async fn open(self) -> Result<ConnectOpened<Io>, ServerError<()>> {
254        let mut framed = self.framed.into_framed(ProtocolIdCodec);
255        let mut controller = self.controller;
256
257        let protocol = framed
258            .next()
259            .await
260            .ok_or(ServerError::from(ProtocolIdError::Disconnected))?
261            .map_err(ServerError::from)?;
262
263        match protocol {
264            ProtocolId::Amqp => {
265                // confirm protocol
266                framed
267                    .send(ProtocolId::Amqp)
268                    .await
269                    .map_err(ServerError::from)?;
270
271                // Wait for connection open frame
272                let mut framed = framed.into_framed(AmqpCodec::<AmqpFrame>::new());
273                let frame = framed
274                    .next()
275                    .await
276                    .ok_or(ServerError::Disconnected)?
277                    .map_err(ServerError::from)?;
278
279                let frame = frame.into_parts().1;
280                match frame {
281                    protocol::Frame::Open(frame) => {
282                        trace!("Got open frame: {:?}", frame);
283                        controller.set_remote((&frame).into());
284                        Ok(ConnectOpened::new(frame, framed, controller))
285                    }
286                    frame => Err(ServerError::Unexpected(frame)),
287                }
288            }
289            proto => Err(ProtocolIdError::Unexpected {
290                exp: ProtocolId::Amqp,
291                got: proto,
292            }
293            .into()),
294        }
295    }
296}
297
298/// Create service factory with disabled sasl support
299pub fn no_sasl<Io, St, E>() -> NoSaslService<Io, St, E> {
300    NoSaslService::default()
301}
302
303pub struct NoSaslService<Io, St, E>(std::marker::PhantomData<(Io, St, E)>);
304
305impl<Io, St, E> Default for NoSaslService<Io, St, E> {
306    fn default() -> Self {
307        NoSaslService(std::marker::PhantomData)
308    }
309}
310
311impl<Io, St, E> ServiceFactory for NoSaslService<Io, St, E> {
312    type Config = ();
313    type Request = Sasl<Io>;
314    type Response = ConnectAck<Io, St>;
315    type Error = AmqpError;
316    type InitError = E;
317    type Service = NoSaslService<Io, St, E>;
318    type Future = Ready<Result<Self::Service, Self::InitError>>;
319
320    fn new_service(&self, _: ()) -> Self::Future {
321        ok(NoSaslService(std::marker::PhantomData))
322    }
323}
324
325impl<Io, St, E> Service for NoSaslService<Io, St, E> {
326    type Request = Sasl<Io>;
327    type Response = ConnectAck<Io, St>;
328    type Error = AmqpError;
329    type Future = Ready<Result<Self::Response, Self::Error>>;
330
331    fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
332        Poll::Ready(Ok(()))
333    }
334
335    fn call(&mut self, _: Self::Request) -> Self::Future {
336        err(AmqpError::not_implemented())
337    }
338}