1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use actix_codec::{AsyncRead, AsyncWrite, Framed};
7use actix_service::{Service, ServiceFactory};
8use amqp_codec::protocol::{
9 self, ProtocolId, SaslChallenge, SaslCode, SaslFrameBody, SaslMechanisms, SaslOutcome, Symbols,
10};
11use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, SaslFrame};
12use bytes::Bytes;
13use bytestring::ByteString;
14use futures::future::{err, ok, Either, Ready};
15use futures::{SinkExt, StreamExt};
16
17use super::connect::{ConnectAck, ConnectOpened};
18use super::errors::{AmqpError, ServerError};
19use crate::connection::ConnectionController;
20
21pub struct Sasl<Io> {
22 framed: Framed<Io, ProtocolIdCodec>,
23 mechanisms: Symbols,
24 controller: ConnectionController,
25}
26
27impl<Io> fmt::Debug for Sasl<Io> {
28 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
29 fmt.debug_struct("SaslAuth")
30 .field("mechanisms", &self.mechanisms)
31 .finish()
32 }
33}
34
35impl<Io> Sasl<Io> {
36 pub(crate) fn new(
37 framed: Framed<Io, ProtocolIdCodec>,
38 controller: ConnectionController,
39 ) -> Self {
40 Sasl {
41 framed,
42 controller,
43 mechanisms: Symbols::default(),
44 }
45 }
46}
47
48impl<Io> Sasl<Io>
49where
50 Io: AsyncRead + AsyncWrite,
51{
52 pub fn get_ref(&self) -> &Io {
54 self.framed.get_ref()
55 }
56
57 pub fn get_mut(&mut self) -> &mut Io {
59 self.framed.get_mut()
60 }
61
62 pub fn mechanism<U: Into<String>>(mut self, symbol: U) -> Self {
64 self.mechanisms.push(ByteString::from(symbol.into()).into());
65 self
66 }
67
68 pub async fn init(self) -> Result<Init<Io>, ServerError<()>> {
70 let Sasl {
71 framed,
72 mechanisms,
73 controller,
74 ..
75 } = self;
76
77 let mut framed = framed.into_framed(AmqpCodec::<SaslFrame>::new());
78 let frame = SaslMechanisms {
79 sasl_server_mechanisms: mechanisms,
80 }
81 .into();
82
83 framed.send(frame).await.map_err(ServerError::from)?;
84 let frame = framed
85 .next()
86 .await
87 .ok_or(ServerError::Disconnected)?
88 .map_err(ServerError::from)?;
89
90 match frame.body {
91 SaslFrameBody::SaslInit(frame) => Ok(Init {
92 frame,
93 framed,
94 controller,
95 }),
96 body => Err(ServerError::UnexpectedSaslBodyFrame(body)),
97 }
98 }
99}
100
101pub struct Init<Io> {
103 frame: protocol::SaslInit,
104 framed: Framed<Io, AmqpCodec<SaslFrame>>,
105 controller: ConnectionController,
106}
107
108impl<Io> fmt::Debug for Init<Io> {
109 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
110 fmt.debug_struct("SaslInit")
111 .field("frame", &self.frame)
112 .finish()
113 }
114}
115
116impl<Io> Init<Io>
117where
118 Io: AsyncRead + AsyncWrite,
119{
120 pub fn mechanism(&self) -> &str {
122 self.frame.mechanism.as_str()
123 }
124
125 pub fn initial_response(&self) -> Option<&[u8]> {
127 self.frame.initial_response.as_ref().map(|b| b.as_ref())
128 }
129
130 pub fn hostname(&self) -> Option<&str> {
132 self.frame.hostname.as_ref().map(|b| b.as_ref())
133 }
134
135 pub fn get_ref(&self) -> &Io {
137 self.framed.get_ref()
138 }
139
140 pub fn get_mut(&mut self) -> &mut Io {
142 self.framed.get_mut()
143 }
144
145 pub async fn challenge(self) -> Result<Response<Io>, ServerError<()>> {
147 self.challenge_with(Bytes::new()).await
148 }
149
150 pub async fn challenge_with(self, challenge: Bytes) -> Result<Response<Io>, ServerError<()>> {
152 let mut framed = self.framed;
153 let controller = self.controller;
154 let frame = SaslChallenge { challenge }.into();
155
156 framed.send(frame).await.map_err(ServerError::from)?;
157 let frame = framed
158 .next()
159 .await
160 .ok_or(ServerError::Disconnected)?
161 .map_err(ServerError::from)?;
162
163 match frame.body {
164 SaslFrameBody::SaslResponse(frame) => Ok(Response {
165 frame,
166 framed,
167 controller,
168 }),
169 body => Err(ServerError::UnexpectedSaslBodyFrame(body)),
170 }
171 }
172
173 pub async fn outcome(self, code: SaslCode) -> Result<Success<Io>, ServerError<()>> {
175 let mut framed = self.framed;
176 let controller = self.controller;
177
178 let frame = SaslOutcome {
179 code,
180 additional_data: None,
181 }
182 .into();
183 framed.send(frame).await.map_err(ServerError::from)?;
184
185 Ok(Success { framed, controller })
186 }
187}
188
189pub struct Response<Io> {
190 frame: protocol::SaslResponse,
191 framed: Framed<Io, AmqpCodec<SaslFrame>>,
192 controller: ConnectionController,
193}
194
195impl<Io> fmt::Debug for Response<Io> {
196 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
197 fmt.debug_struct("SaslResponse")
198 .field("frame", &self.frame)
199 .finish()
200 }
201}
202
203impl<Io> Response<Io>
204where
205 Io: AsyncRead + AsyncWrite,
206{
207 pub fn response(&self) -> &[u8] {
209 &self.frame.response[..]
210 }
211
212 pub async fn outcome(self, code: SaslCode) -> Result<Success<Io>, ServerError<()>> {
214 let mut framed = self.framed;
215 let controller = self.controller;
216 let frame = SaslOutcome {
217 code,
218 additional_data: None,
219 }
220 .into();
221
222 framed.send(frame).await.map_err(ServerError::from)?;
223 framed
224 .next()
225 .await
226 .ok_or(ServerError::Disconnected)?
227 .map_err(|res| ServerError::from(res))?;
228
229 Ok(Success { framed, controller })
230 }
231}
232
233pub struct Success<Io> {
234 framed: Framed<Io, AmqpCodec<SaslFrame>>,
235 controller: ConnectionController,
236}
237
238impl<Io> Success<Io>
239where
240 Io: AsyncRead + AsyncWrite,
241{
242 pub fn get_ref(&self) -> &Io {
244 self.framed.get_ref()
245 }
246
247 pub fn get_mut(&mut self) -> &mut Io {
249 self.framed.get_mut()
250 }
251
252 pub async fn open(self) -> Result<ConnectOpened<Io>, ServerError<()>> {
254 let mut framed = self.framed.into_framed(ProtocolIdCodec);
255 let mut controller = self.controller;
256
257 let protocol = framed
258 .next()
259 .await
260 .ok_or(ServerError::from(ProtocolIdError::Disconnected))?
261 .map_err(ServerError::from)?;
262
263 match protocol {
264 ProtocolId::Amqp => {
265 framed
267 .send(ProtocolId::Amqp)
268 .await
269 .map_err(ServerError::from)?;
270
271 let mut framed = framed.into_framed(AmqpCodec::<AmqpFrame>::new());
273 let frame = framed
274 .next()
275 .await
276 .ok_or(ServerError::Disconnected)?
277 .map_err(ServerError::from)?;
278
279 let frame = frame.into_parts().1;
280 match frame {
281 protocol::Frame::Open(frame) => {
282 trace!("Got open frame: {:?}", frame);
283 controller.set_remote((&frame).into());
284 Ok(ConnectOpened::new(frame, framed, controller))
285 }
286 frame => Err(ServerError::Unexpected(frame)),
287 }
288 }
289 proto => Err(ProtocolIdError::Unexpected {
290 exp: ProtocolId::Amqp,
291 got: proto,
292 }
293 .into()),
294 }
295 }
296}
297
298pub fn no_sasl<Io, St, E>() -> NoSaslService<Io, St, E> {
300 NoSaslService::default()
301}
302
303pub struct NoSaslService<Io, St, E>(std::marker::PhantomData<(Io, St, E)>);
304
305impl<Io, St, E> Default for NoSaslService<Io, St, E> {
306 fn default() -> Self {
307 NoSaslService(std::marker::PhantomData)
308 }
309}
310
311impl<Io, St, E> ServiceFactory for NoSaslService<Io, St, E> {
312 type Config = ();
313 type Request = Sasl<Io>;
314 type Response = ConnectAck<Io, St>;
315 type Error = AmqpError;
316 type InitError = E;
317 type Service = NoSaslService<Io, St, E>;
318 type Future = Ready<Result<Self::Service, Self::InitError>>;
319
320 fn new_service(&self, _: ()) -> Self::Future {
321 ok(NoSaslService(std::marker::PhantomData))
322 }
323}
324
325impl<Io, St, E> Service for NoSaslService<Io, St, E> {
326 type Request = Sasl<Io>;
327 type Response = ConnectAck<Io, St>;
328 type Error = AmqpError;
329 type Future = Ready<Result<Self::Response, Self::Error>>;
330
331 fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
332 Poll::Ready(Ok(()))
333 }
334
335 fn call(&mut self, _: Self::Request) -> Self::Future {
336 err(AmqpError::not_implemented())
337 }
338}