entdb_server/server/
auth.rs1use async_trait::async_trait;
18use pgwire::api::auth::md5pass::hash_md5_password;
19use pgwire::api::auth::scram::gen_salted_password;
20use pgwire::api::auth::{AuthSource, LoginInfo, Password};
21use pgwire::error::PgWireResult;
22
23pub fn random_salt() -> [u8; 4] {
24 rand::random::<[u8; 4]>()
25}
26
27pub fn random_scram_salt() -> [u8; 16] {
28 rand::random::<[u8; 16]>()
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AuthMethod {
33 Md5,
34 ScramSha256,
35}
36
37impl AuthMethod {
38 pub fn as_str(self) -> &'static str {
39 match self {
40 AuthMethod::Md5 => "md5",
41 AuthMethod::ScramSha256 => "scram-sha-256",
42 }
43 }
44}
45
46#[derive(Clone)]
47pub struct EntAuthSource {
48 pub method: AuthMethod,
49 pub expected_user: String,
50 pub expected_password: String,
51 pub scram_iterations: usize,
52}
53
54#[async_trait]
55impl AuthSource for EntAuthSource {
56 async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult<Password> {
57 let user = login_info.user().unwrap_or_default().to_string();
58 let password = if user == self.expected_user {
59 self.expected_password.as_str()
60 } else {
61 "invalid-user"
63 };
64
65 match self.method {
66 AuthMethod::Md5 => {
67 let salt = random_salt();
68 let md5 = hash_md5_password(&user, password, &salt);
69 Ok(Password::new(Some(salt.to_vec()), md5.into_bytes()))
70 }
71 AuthMethod::ScramSha256 => {
72 let salt = random_scram_salt();
73 let salted = gen_salted_password(password, &salt, self.scram_iterations);
74 Ok(Password::new(Some(salt.to_vec()), salted))
75 }
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::{AuthMethod, EntAuthSource};
83 use pgwire::api::auth::{AuthSource, LoginInfo};
84
85 #[tokio::test]
86 async fn auth_source_returns_md5_password_bytes() {
87 let source = EntAuthSource {
88 method: AuthMethod::Md5,
89 expected_user: "entdb".to_string(),
90 expected_password: "entdb".to_string(),
91 scram_iterations: 4096,
92 };
93 let login = LoginInfo::new(Some("entdb"), None, "127.0.0.1".to_string());
94 let pass = source.get_password(&login).await.expect("password");
95 assert!(pass.salt().is_some());
96 let s = String::from_utf8(pass.password().to_vec()).expect("utf8");
97 assert!(s.starts_with("md5"));
98 }
99
100 #[tokio::test]
101 async fn auth_source_returns_scram_salted_password_bytes() {
102 let source = EntAuthSource {
103 method: AuthMethod::ScramSha256,
104 expected_user: "entdb".to_string(),
105 expected_password: "entdb".to_string(),
106 scram_iterations: 4096,
107 };
108 let login = LoginInfo::new(Some("entdb"), None, "127.0.0.1".to_string());
109 let pass = source.get_password(&login).await.expect("password");
110 assert_eq!(pass.salt().map(|s| s.len()), Some(16));
111 assert_eq!(pass.password().len(), 32);
113 }
114}