Skip to main content

entdb_server/server/
auth.rs

1/*
2 * Copyright 2026 EntDB Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use async_trait::async_trait;
18use pgwire::api::auth::md5pass::hash_md5_password;
19use pgwire::api::auth::scram::gen_salted_password;
20use pgwire::api::auth::{AuthSource, LoginInfo, Password};
21use pgwire::error::PgWireResult;
22
23pub fn random_salt() -> [u8; 4] {
24    rand::random::<[u8; 4]>()
25}
26
27pub fn random_scram_salt() -> [u8; 16] {
28    rand::random::<[u8; 16]>()
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AuthMethod {
33    Md5,
34    ScramSha256,
35}
36
37impl AuthMethod {
38    pub fn as_str(self) -> &'static str {
39        match self {
40            AuthMethod::Md5 => "md5",
41            AuthMethod::ScramSha256 => "scram-sha-256",
42        }
43    }
44}
45
46#[derive(Clone)]
47pub struct EntAuthSource {
48    pub method: AuthMethod,
49    pub expected_user: String,
50    pub expected_password: String,
51    pub scram_iterations: usize,
52}
53
54#[async_trait]
55impl AuthSource for EntAuthSource {
56    async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult<Password> {
57        let user = login_info.user().unwrap_or_default().to_string();
58        let password = if user == self.expected_user {
59            self.expected_password.as_str()
60        } else {
61            // Keep auth-time shape uniform for unknown users.
62            "invalid-user"
63        };
64
65        match self.method {
66            AuthMethod::Md5 => {
67                let salt = random_salt();
68                let md5 = hash_md5_password(&user, password, &salt);
69                Ok(Password::new(Some(salt.to_vec()), md5.into_bytes()))
70            }
71            AuthMethod::ScramSha256 => {
72                let salt = random_scram_salt();
73                let salted = gen_salted_password(password, &salt, self.scram_iterations);
74                Ok(Password::new(Some(salt.to_vec()), salted))
75            }
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::{AuthMethod, EntAuthSource};
83    use pgwire::api::auth::{AuthSource, LoginInfo};
84
85    #[tokio::test]
86    async fn auth_source_returns_md5_password_bytes() {
87        let source = EntAuthSource {
88            method: AuthMethod::Md5,
89            expected_user: "entdb".to_string(),
90            expected_password: "entdb".to_string(),
91            scram_iterations: 4096,
92        };
93        let login = LoginInfo::new(Some("entdb"), None, "127.0.0.1".to_string());
94        let pass = source.get_password(&login).await.expect("password");
95        assert!(pass.salt().is_some());
96        let s = String::from_utf8(pass.password().to_vec()).expect("utf8");
97        assert!(s.starts_with("md5"));
98    }
99
100    #[tokio::test]
101    async fn auth_source_returns_scram_salted_password_bytes() {
102        let source = EntAuthSource {
103            method: AuthMethod::ScramSha256,
104            expected_user: "entdb".to_string(),
105            expected_password: "entdb".to_string(),
106            scram_iterations: 4096,
107        };
108        let login = LoginInfo::new(Some("entdb"), None, "127.0.0.1".to_string());
109        let pass = source.get_password(&login).await.expect("password");
110        assert_eq!(pass.salt().map(|s| s.len()), Some(16));
111        // SHA-256 salted password bytes
112        assert_eq!(pass.password().len(), 32);
113    }
114}