Skip to main content

tansu_auth/
handshake.rs

1// Copyright ⓒ 2024-2026 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{Authentication, Error, Stage};
16use rama::{Context, Service};
17use rsasl::prelude::Mechname;
18use tansu_sans_io::{ApiKey, ErrorCode, SaslHandshakeRequest, SaslHandshakeResponse};
19use tracing::{debug, instrument};
20
21#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
22pub struct SaslHandshakeService;
23
24impl ApiKey for SaslHandshakeService {
25    const KEY: i16 = SaslHandshakeRequest::KEY;
26}
27
28impl<S> Service<S, SaslHandshakeRequest> for SaslHandshakeService
29where
30    S: Send + Sync + 'static,
31{
32    type Response = SaslHandshakeResponse;
33    type Error = Error;
34
35    #[instrument(skip(self, ctx), ret)]
36    async fn serve(
37        &self,
38        ctx: Context<S>,
39        req: SaslHandshakeRequest,
40    ) -> Result<Self::Response, Self::Error> {
41        if let Some(authentication) = ctx.get::<Authentication>().cloned() {
42            authentication.stage
43            .lock()
44            .map_err(Into::into)
45            .and_then(|mut guard| {
46                if let Some(Stage::Server(server)) = guard.take()
47                    && let Ok(mechanism) = Mechname::parse(req.mechanism.as_bytes())
48                {
49                    debug!(available = ?server.get_available().into_iter().map(|mechanism|mechanism.mechanism.as_str()).collect::<Vec<_>>());
50
51                    server
52                        .start_suggested(mechanism)
53                        .inspect_err(|err| debug!(?err, ?mechanism))
54                        .map_err(Into::into)
55                        .map(|session| {
56                            let mechanisms = [session.get_mechname().to_string()];
57
58                            _ = guard.replace(Stage::Session(session));
59
60                            SaslHandshakeResponse::default()
61                                .error_code(ErrorCode::None.into())
62                                .mechanisms(Some(mechanisms.into()))
63                        })
64                } else {
65                    Ok(SaslHandshakeResponse::default()
66                        .error_code(ErrorCode::UnsupportedSaslMechanism.into())
67                        .mechanisms(Some([req.mechanism].into())))
68                }
69            })
70        } else {
71            Ok(SaslHandshakeResponse::default()
72                .error_code(ErrorCode::UnsupportedSaslMechanism.into())
73                .mechanisms(Some([req.mechanism].into())))
74        }
75    }
76}