1use crate::{
2 md5::md5_password,
3 scram::{
4 generate_nonce, generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError,
5 Sha256Out,
6 },
7 AuthType, CredentialData,
8};
9use tracing::error;
10
11#[derive(Debug)]
12pub enum ClientAuthResponse {
13 Initial(AuthType, Vec<u8>),
14 Continue(Vec<u8>),
15 Complete,
16 Waiting,
17 Error(ClientAuthError),
18}
19
20#[derive(Debug, thiserror::Error)]
21pub enum ClientAuthError {
22 #[error("SCRAM protocol error: {0}")]
23 ScramError(#[from] SCRAMError),
24 #[error("Invalid authentication state")]
25 InvalidState,
26 #[error("Invalid credentials")]
27 InvalidCredentials,
28 #[error("Unexpected message during authentication")]
29 UnexpectedMessage,
30}
31
32#[derive(Debug)]
33enum ClientAuthState {
34 Initial(String, CredentialData),
35 Complete,
36 Waiting,
37 Sasl(ClientTransaction, ClientEnvironmentImpl),
38}
39
40#[derive(Debug)]
41pub enum ClientAuthDrive<'a> {
42 Ok,
44 Plain,
46 Md5([u8; 4]),
48 Scram,
50 ScramResponse(&'a [u8]),
52}
53
54#[derive(Debug)]
55pub struct ClientAuth {
56 state: ClientAuthState,
57 auth_type: Option<AuthType>,
58}
59
60impl ClientAuth {
61 pub fn new(username: String, credentials: CredentialData) -> Self {
63 Self {
64 state: ClientAuthState::Initial(username, credentials),
65 auth_type: None,
66 }
67 }
68
69 pub fn is_complete(&self) -> bool {
70 matches!(self.state, ClientAuthState::Complete)
71 }
72
73 pub fn auth_type(&self) -> Option<AuthType> {
74 self.auth_type
75 }
76
77 pub fn drive(&mut self, drive: ClientAuthDrive) -> Result<ClientAuthResponse, ClientAuthError> {
78 match (&mut self.state, drive) {
79 (ClientAuthState::Initial(username, credentials), drive) => {
80 let username = std::mem::take(username);
81 let credentials = std::mem::replace(credentials, CredentialData::Deny);
82 self.handle_initial(username, credentials, drive)
83 }
84 (ClientAuthState::Sasl(tx, env), ClientAuthDrive::ScramResponse(message)) => {
86 let response = tx.process_message(&message, env)?;
87 match response {
88 Some(response) => Ok(ClientAuthResponse::Continue(response)),
89 None => {
90 self.state = ClientAuthState::Waiting;
91 Ok(ClientAuthResponse::Waiting)
92 }
93 }
94 }
95 (ClientAuthState::Waiting, ClientAuthDrive::Ok) => {
97 self.state = ClientAuthState::Complete;
98 Ok(ClientAuthResponse::Complete)
99 }
100 (_, drive) => {
102 error!("Received invalid drive {drive:?} in state {:?}", self.state);
103 Err(ClientAuthError::InvalidState)
104 }
105 }
106 }
107
108 fn handle_initial(
109 &mut self,
110 username: String,
111 credentials: CredentialData,
112 drive: ClientAuthDrive,
113 ) -> Result<ClientAuthResponse, ClientAuthError> {
114 let (auth_type, (state, response)) = match drive {
115 ClientAuthDrive::Ok => (
116 AuthType::Trust,
117 match credentials {
118 CredentialData::Deny => (
119 ClientAuthState::Complete,
120 ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
121 ),
122 _ => (ClientAuthState::Complete, ClientAuthResponse::Complete),
123 },
124 ),
125 ClientAuthDrive::Plain => (
126 AuthType::Plain,
127 match credentials {
128 CredentialData::Plain(credentials) => (
129 ClientAuthState::Waiting,
130 ClientAuthResponse::Initial(
131 AuthType::Plain,
132 credentials.clone().into_bytes(),
133 ),
134 ),
135 _ => (
136 ClientAuthState::Complete,
137 ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
138 ),
139 },
140 ),
141 ClientAuthDrive::Md5(salt) => (
142 AuthType::Md5,
143 match credentials {
144 CredentialData::Md5(credentials) => (
145 ClientAuthState::Waiting,
146 ClientAuthResponse::Initial(
147 AuthType::Md5,
148 credentials.salted(salt).into_bytes(),
149 ),
150 ),
151 CredentialData::Plain(credentials) => (
152 ClientAuthState::Waiting,
153 ClientAuthResponse::Initial(
154 AuthType::Md5,
155 md5_password(&credentials, &username, salt).into_bytes(),
156 ),
157 ),
158 _ => (
159 ClientAuthState::Complete,
160 ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
161 ),
162 },
163 ),
164 ClientAuthDrive::Scram => (
165 AuthType::ScramSha256,
166 match credentials {
167 CredentialData::Plain(credentials) => {
168 let env = ClientEnvironmentImpl {
169 password: credentials,
170 };
171 let mut tx = ClientTransaction::new(username.into());
172 let response = tx.process_message(&[], &env);
173 match response {
174 Ok(Some(response)) => (
175 ClientAuthState::Sasl(tx, env),
176 ClientAuthResponse::Initial(AuthType::ScramSha256, response),
177 ),
178 Ok(None) => (
179 ClientAuthState::Complete,
180 ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
181 ),
182 Err(e) => (
183 ClientAuthState::Complete,
184 ClientAuthResponse::Error(ClientAuthError::ScramError(e)),
185 ),
186 }
187 }
188 _ => (
189 ClientAuthState::Complete,
190 ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
191 ),
192 },
193 ),
194 _ => {
195 error!("Received invalid drive {drive:?} in state Initial");
196 return Err(ClientAuthError::InvalidState);
197 }
198 };
199
200 self.auth_type = Some(auth_type);
201 self.state = state;
202 Ok(response)
203 }
204}
205
206#[derive(Debug)]
207struct ClientEnvironmentImpl {
208 password: String,
209}
210
211impl ClientEnvironment for ClientEnvironmentImpl {
212 fn generate_nonce(&self) -> String {
213 generate_nonce()
214 }
215
216 fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
217 generate_salted_password(self.password.as_bytes(), salt, iterations)
218 }
219}