gel_auth/handshake/
server_auth.rs1use crate::{
2 md5::StoredHash,
3 scram::{SCRAMError, ServerTransaction, StoredKey},
4 AuthType, CredentialData,
5};
6use tracing::error;
7
8#[derive(Debug)]
9pub enum ServerAuthResponse {
10 Initial(AuthType, Vec<u8>),
11 Continue(Vec<u8>),
12 Complete(Vec<u8>),
13 Error(ServerAuthError),
14}
15
16#[derive(Debug, derive_more::Error, derive_more::Display, derive_more::From)]
17pub enum ServerAuthError {
18 #[display("Invalid authorization specification")]
19 InvalidAuthorizationSpecification,
20 #[display("Invalid password")]
21 InvalidPassword,
22 #[display("Invalid SASL message ({_0})")]
23 InvalidSaslMessage(#[from] SCRAMError),
24 #[display("Unsupported authentication type")]
25 UnsupportedAuthType,
26 #[display("Invalid message type")]
27 InvalidMessageType,
28}
29
30#[derive(Debug)]
31enum ServerAuthState {
32 Initial,
33 Password(CredentialData),
34 MD5([u8; 4], CredentialData),
35 Sasl(ServerTransaction, StoredKey),
36 Complete,
37}
38
39#[derive(Debug)]
40pub enum ServerAuthDrive<'a> {
41 Initial,
42 Message(AuthType, &'a [u8]),
43}
44
45#[derive(Debug)]
46pub struct ServerAuth {
47 state: ServerAuthState,
48 username: String,
49 auth_type: AuthType,
50 credential_data: CredentialData,
51}
52
53impl ServerAuth {
54 pub fn new(username: String, auth_type: AuthType, credential_data: CredentialData) -> Self {
55 Self {
56 state: ServerAuthState::Initial,
57 username,
58 auth_type,
59 credential_data,
60 }
61 }
62
63 pub fn is_complete(&self) -> bool {
64 matches!(self.state, ServerAuthState::Complete)
65 }
66
67 pub fn is_initial_message(&self) -> bool {
68 match &self.state {
69 ServerAuthState::Initial => false,
70 ServerAuthState::Sasl(tx, _) => tx.initial(),
71 _ => true,
72 }
73 }
74
75 pub fn auth_type(&self) -> AuthType {
76 self.auth_type
77 }
78
79 pub fn drive(&mut self, drive: ServerAuthDrive) -> ServerAuthResponse {
80 match (&mut self.state, drive) {
81 (ServerAuthState::Initial, ServerAuthDrive::Initial) => self.handle_initial(),
82 (ServerAuthState::Password(data), ServerAuthDrive::Message(AuthType::Plain, input)) => {
83 let client_password = input;
84 let success = match data {
85 CredentialData::Deny => false,
86 CredentialData::Trust => true,
87 CredentialData::Plain(password) => client_password == password.as_bytes(),
88 CredentialData::Md5(md5) => {
89 let md5_1 = StoredHash::generate(client_password, &self.username);
90 md5_1 == *md5
91 }
92 CredentialData::Scram(scram) => {
93 let key =
94 StoredKey::generate(client_password, &scram.salt, scram.iterations);
95 key.stored_key == scram.stored_key
96 }
97 };
98 self.state = ServerAuthState::Complete;
99 if success {
100 ServerAuthResponse::Complete(Vec::new())
101 } else {
102 ServerAuthResponse::Error(ServerAuthError::InvalidPassword)
103 }
104 }
105 (ServerAuthState::MD5(salt, data), ServerAuthDrive::Message(AuthType::Md5, input)) => {
106 let success = match data {
107 CredentialData::Deny => false,
108 CredentialData::Trust => true,
109 CredentialData::Plain(password) => {
110 let server_md5 = StoredHash::generate(password.as_bytes(), &self.username);
111 server_md5.matches(input, *salt)
112 }
113 CredentialData::Md5(server_md5) => server_md5.matches(input, *salt),
114 CredentialData::Scram(_) => {
115 false
117 }
118 };
119
120 self.state = ServerAuthState::Complete;
121 if success {
122 ServerAuthResponse::Complete(Vec::new())
123 } else {
124 ServerAuthResponse::Error(ServerAuthError::InvalidPassword)
125 }
126 }
127 (
128 ServerAuthState::Sasl(tx, data),
129 ServerAuthDrive::Message(AuthType::ScramSha256, input),
130 ) => {
131 let initial = tx.initial();
132 match tx.process_message(input, data) {
133 Ok(final_message) => {
134 if initial {
135 ServerAuthResponse::Continue(final_message)
136 } else {
137 self.state = ServerAuthState::Complete;
138 ServerAuthResponse::Complete(final_message)
139 }
140 }
141 Err(e) => {
142 self.state = ServerAuthState::Complete;
143 ServerAuthResponse::Error(ServerAuthError::InvalidSaslMessage(e))
144 }
145 }
146 }
147 (_, drive) => {
148 self.state = ServerAuthState::Complete;
149 error!("Received invalid drive {drive:?} in state {:?}", self.state);
150 ServerAuthResponse::Error(ServerAuthError::InvalidMessageType)
151 }
152 }
153 }
154
155 fn handle_initial(&mut self) -> ServerAuthResponse {
156 match self.auth_type {
157 AuthType::Deny => {
158 self.state = ServerAuthState::Complete;
159 ServerAuthResponse::Error(ServerAuthError::InvalidAuthorizationSpecification)
160 }
161 AuthType::Trust => {
162 self.state = ServerAuthState::Complete;
163 ServerAuthResponse::Complete(Vec::new())
164 }
165 AuthType::Plain => {
166 self.state = ServerAuthState::Password(self.credential_data.clone());
167 ServerAuthResponse::Initial(AuthType::Plain, Vec::new())
168 }
169 AuthType::Md5 => {
170 let salt: [u8; 4] = rand::random();
171 match self.credential_data {
172 CredentialData::Scram(..) => {
173 ServerAuthResponse::Error(ServerAuthError::UnsupportedAuthType)
174 }
175 _ => {
176 self.state = ServerAuthState::MD5(salt, self.credential_data.clone());
177 ServerAuthResponse::Initial(AuthType::Md5, salt.into())
178 }
179 }
180 }
181 AuthType::ScramSha256 => {
182 let salt: [u8; 32] = rand::random();
183 let scram = match &self.credential_data {
184 CredentialData::Scram(scram) => scram.clone(),
185 CredentialData::Plain(password) => {
186 StoredKey::generate(password.as_bytes(), &salt, 4096)
187 }
188 CredentialData::Deny => StoredKey::generate(b"", &salt, 4096),
189 _ => {
190 return ServerAuthResponse::Error(ServerAuthError::UnsupportedAuthType);
191 }
192 };
193 let tx = ServerTransaction::default();
194 self.state = ServerAuthState::Sasl(tx, scram);
195 ServerAuthResponse::Initial(AuthType::ScramSha256, Vec::new())
196 }
197 }
198 }
199}