1use rsasl::{
16 callback::{Context, Request, SessionCallback, SessionData},
17 config::SASLConfig,
18 mechanisms::scram::properties::ScramStoredPassword,
19 prelude::{SASLError, SASLServer, Session, SessionError, Validation},
20 property::{AuthId, AuthzId, Password},
21 validate::{Validate, ValidationError},
22};
23use std::{
24 fmt::{self, Debug, Formatter},
25 str::FromStr,
26 sync::{Arc, Mutex, PoisonError},
27};
28use tansu_sans_io::ScramMechanism;
29use tansu_storage::Storage;
30use thiserror::Error;
31use tokio::task::JoinError;
32use tracing::{debug, instrument};
33
34mod authenticate;
35mod handshake;
36
37pub use authenticate::SaslAuthenticateService;
38pub use handshake::SaslHandshakeService;
39
40#[derive(Clone, Debug, Error)]
41pub enum Error {
42 Join(Arc<JoinError>),
43 Poison,
44 SansIo(#[from] tansu_sans_io::Error),
45 Sasl(Arc<SASLError>),
46 SaslSession(Arc<SessionError>),
47}
48
49impl fmt::Display for Error {
50 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51 write!(f, "{self:?}")
52 }
53}
54
55impl From<JoinError> for Error {
56 fn from(value: JoinError) -> Self {
57 Self::Join(Arc::new(value))
58 }
59}
60
61impl<T> From<PoisonError<T>> for Error {
62 fn from(_value: PoisonError<T>) -> Self {
63 Self::Poison
64 }
65}
66
67impl From<SASLError> for Error {
68 fn from(value: SASLError) -> Self {
69 Self::Sasl(Arc::new(value))
70 }
71}
72
73impl From<SessionError> for Error {
74 fn from(value: SessionError) -> Self {
75 Self::SaslSession(Arc::new(value))
76 }
77}
78
79#[derive(Clone, Default)]
80pub struct Authentication {
81 stage: Arc<Mutex<Option<Stage>>>,
82}
83
84pub enum Stage {
85 Server(SASLServer<Justification>),
86 Session(Session<Justification>),
87 Finished(Option<Success>),
88}
89
90impl Debug for Stage {
91 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
92 f.debug_struct(stringify!(Stage)).finish()
93 }
94}
95
96impl Authentication {
97 pub fn server(config: Arc<SASLConfig>) -> Self {
98 Self {
99 stage: Arc::new(Mutex::new(Some(Stage::Server(
100 SASLServer::<Justification>::new(config),
101 )))),
102 }
103 }
104
105 pub fn is_authenticated(&self) -> bool {
106 self.stage
107 .lock()
108 .map(|guard| matches!(guard.as_ref(), Some(Stage::Finished(_))))
109 .ok()
110 .unwrap_or_default()
111 }
112}
113
114impl Debug for Authentication {
115 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
116 f.debug_struct(stringify!(Authentication)).finish()
117 }
118}
119
120#[derive(Debug, Error)]
121pub enum AuthError {
122 Bad,
123 Io(tansu_sans_io::Error),
124 MissingProperty { mechanism: String, property: String },
125 NoSuchUser,
126 UnknownMechanism(String),
127}
128
129impl fmt::Display for AuthError {
130 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
131 write!(f, "{self:?}")
132 }
133}
134
135#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
136pub struct Success {
137 auth_id: String,
138}
139
140#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
141pub struct Justification;
142
143impl Validation for Justification {
144 type Value = Result<Success, AuthError>;
145}
146
147#[derive(Clone, Debug)]
148pub struct Callback<S> {
149 storage: S,
150}
151
152impl<S> Callback<S>
153where
154 S: Storage,
155{
156 pub fn new(storage: S) -> Self
157 where
158 S: Storage,
159 {
160 Self { storage }
161 }
162
163 #[instrument(skip_all)]
164 fn check(
165 &self,
166 session_data: &SessionData,
167 context: &Context<'_>,
168 ) -> Result<Result<Success, AuthError>, Error> {
169 debug!(mechanism = %session_data.mechanism().mechanism);
170
171 if session_data.mechanism().mechanism == "PLAIN" {
172 Ok(context
173 .get_ref::<Password>()
174 .ok_or(AuthError::MissingProperty {
175 mechanism: session_data.mechanism().mechanism.to_string(),
176 property: "Password".into(),
177 })
178 .and(
179 context
180 .get_ref::<AuthId>()
181 .inspect(|auth_id| {
182 debug!(mechanism = %session_data.mechanism().mechanism, auth_id)
183 })
184 .ok_or(AuthError::MissingProperty {
185 mechanism: session_data.mechanism().mechanism.to_string(),
186 property: "AuthId".into(),
187 }).map(ToString::to_string).map(|auth_id| {
188 Success { auth_id }
189 })
190 ))
191 } else if session_data.mechanism().mechanism.starts_with("SCRAM-") {
192 Ok(context
193 .get_ref::<AuthId>()
194 .inspect(|auth_id| debug!(mechanism = %session_data.mechanism().mechanism, auth_id))
195 .ok_or(AuthError::MissingProperty {
196 mechanism: session_data.mechanism().mechanism.to_string(),
197 property: "AuthId".into(),
198 })
199 .and_then(|auth_id| {
200 context
201 .get_ref::<AuthzId>()
202 .inspect(|authz_id| {
203 debug!(mechanism = %session_data.mechanism().mechanism, authz_id)
204 })
205 .map_or(Ok(Success{
206 auth_id:auth_id.to_string()
207 }), |authz_id| {
208 if authz_id == auth_id {
209 Ok(Success{
210 auth_id:auth_id.to_string()
211 })
212 } else {
213 Err(AuthError::Bad)
214 }
215 })
216 }))
217 } else {
218 Ok(Err(AuthError::UnknownMechanism(
219 session_data.mechanism().mechanism.to_string(),
220 )))
221 }
222 }
223}
224
225impl<S> SessionCallback for Callback<S>
226where
227 S: Storage,
228{
229 #[instrument(skip_all)]
230 fn callback(
231 &self,
232 session_data: &SessionData,
233 context: &Context<'_>,
234 request: &mut Request<'_>,
235 ) -> Result<(), SessionError> {
236 debug!(?session_data);
237
238 if session_data.mechanism().mechanism.starts_with("SCRAM-") {
239 let mechanism = ScramMechanism::from_str(session_data.mechanism().mechanism)
240 .map_err(|error| SessionError::Boxed(Box::new(error)))?;
241
242 let auth_id = context
243 .get_ref::<AuthId>()
244 .ok_or(SessionError::ValidationError(
245 ValidationError::MissingRequiredProperty,
246 ))?;
247
248 debug!(?auth_id, ?mechanism);
249
250 let rt = tokio::runtime::Builder::new_current_thread()
251 .enable_all()
252 .build()?;
253
254 let storage = self.storage.clone();
255
256 if let Ok(Some(credential)) = rt
257 .block_on(async { storage.user_scram_credential(auth_id, mechanism).await })
258 .inspect_err(|err| debug!(auth_id, ?mechanism, ?err))
259 {
260 _ = request
261 .satisfy::<ScramStoredPassword<'_>>(&ScramStoredPassword::new(
262 credential.iterations as u32,
263 &credential.salt[..],
264 &credential.stored_key[..],
265 &credential.server_key[..],
266 ))
267 .inspect_err(|err| debug!(auth_id, ?mechanism, ?err))?;
268 }
269 }
270
271 Ok(())
272 }
273
274 #[instrument(skip_all)]
275 fn validate(
276 &self,
277 session_data: &SessionData,
278 context: &Context<'_>,
279 validate: &mut Validate<'_>,
280 ) -> Result<(), ValidationError> {
281 debug!(?session_data);
282
283 _ = validate.with::<Justification, _>(|| {
284 self.check(session_data, context)
285 .map_err(|e| ValidationError::Boxed(Box::new(e)))
286 })?;
287
288 Ok(())
289 }
290}
291
292pub fn configuration<S>(storage: S) -> Result<Arc<SASLConfig>, Error>
293where
294 S: Storage,
295{
296 SASLConfig::builder()
297 .with_defaults()
298 .with_callback(Callback::new(storage))
299 .map_err(Into::into)
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 fn is_send<T: Send>() {}
307 fn is_sync<T: Sync>() {}
308
309 #[test]
310 fn authentication() {
311 is_send::<Authentication>();
312 is_sync::<Authentication>();
313 }
314}