1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use log::warn;
6use pgwire::api::auth::{AuthSource, LoginInfo, Password};
7use pgwire::error::{PgWireError, PgWireResult};
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone)]
12pub struct User {
13 pub username: String,
14 pub password_hash: String,
15 pub roles: Vec<String>,
16 pub is_superuser: bool,
17 pub can_login: bool,
18 pub connection_limit: Option<i32>,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum Permission {
24 Select,
25 Insert,
26 Update,
27 Delete,
28 Create,
29 Drop,
30 Alter,
31 Index,
32 References,
33 Trigger,
34 Execute,
35 Usage,
36 Connect,
37 Temporary,
38 All,
39}
40
41impl Permission {
42 pub fn from_string(s: &str) -> Option<Permission> {
43 match s.to_uppercase().as_str() {
44 "SELECT" => Some(Permission::Select),
45 "INSERT" => Some(Permission::Insert),
46 "UPDATE" => Some(Permission::Update),
47 "DELETE" => Some(Permission::Delete),
48 "CREATE" => Some(Permission::Create),
49 "DROP" => Some(Permission::Drop),
50 "ALTER" => Some(Permission::Alter),
51 "INDEX" => Some(Permission::Index),
52 "REFERENCES" => Some(Permission::References),
53 "TRIGGER" => Some(Permission::Trigger),
54 "EXECUTE" => Some(Permission::Execute),
55 "USAGE" => Some(Permission::Usage),
56 "CONNECT" => Some(Permission::Connect),
57 "TEMPORARY" => Some(Permission::Temporary),
58 "ALL" => Some(Permission::All),
59 _ => None,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub enum ResourceType {
67 Table(String),
68 Schema(String),
69 Database(String),
70 Function(String),
71 Sequence(String),
72 All,
73}
74
75#[derive(Debug, Clone)]
77pub struct Grant {
78 pub permission: Permission,
79 pub resource: ResourceType,
80 pub granted_by: String,
81 pub with_grant_option: bool,
82}
83
84#[derive(Debug, Clone)]
86pub struct Role {
87 pub name: String,
88 pub is_superuser: bool,
89 pub can_login: bool,
90 pub can_create_db: bool,
91 pub can_create_role: bool,
92 pub can_create_user: bool,
93 pub can_replication: bool,
94 pub grants: Vec<Grant>,
95 pub inherited_roles: Vec<String>,
96}
97
98#[derive(Debug, Clone)]
100pub struct RoleConfig {
101 pub name: String,
102 pub is_superuser: bool,
103 pub can_login: bool,
104 pub can_create_db: bool,
105 pub can_create_role: bool,
106 pub can_create_user: bool,
107 pub can_replication: bool,
108}
109
110#[derive(Debug)]
112pub struct AuthManager {
113 users: Arc<RwLock<HashMap<String, User>>>,
114 roles: Arc<RwLock<HashMap<String, Role>>>,
115}
116
117impl Default for AuthManager {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl AuthManager {
124 pub fn new() -> Self {
125 let auth_manager = AuthManager {
126 users: Arc::new(RwLock::new(HashMap::new())),
127 roles: Arc::new(RwLock::new(HashMap::new())),
128 };
129
130 let postgres_user = User {
132 username: "postgres".to_string(),
133 password_hash: "".to_string(), roles: vec!["postgres".to_string()],
135 is_superuser: true,
136 can_login: true,
137 connection_limit: None,
138 };
139
140 let postgres_role = Role {
141 name: "postgres".to_string(),
142 is_superuser: true,
143 can_login: true,
144 can_create_db: true,
145 can_create_role: true,
146 can_create_user: true,
147 can_replication: true,
148 grants: vec![Grant {
149 permission: Permission::All,
150 resource: ResourceType::All,
151 granted_by: "system".to_string(),
152 with_grant_option: true,
153 }],
154 inherited_roles: vec![],
155 };
156
157 let auth_manager_clone = AuthManager {
159 users: auth_manager.users.clone(),
160 roles: auth_manager.roles.clone(),
161 };
162
163 tokio::spawn({
164 let users = auth_manager.users.clone();
165 let roles = auth_manager.roles.clone();
166 let auth_manager_spawn = auth_manager_clone;
167 async move {
168 users
169 .write()
170 .await
171 .insert("postgres".to_string(), postgres_user);
172 roles
173 .write()
174 .await
175 .insert("postgres".to_string(), postgres_role);
176
177 if let Err(e) = auth_manager_spawn.create_predefined_roles().await {
179 warn!("Failed to create predefined roles: {e:?}");
180 }
181 }
182 });
183
184 auth_manager
185 }
186
187 pub async fn add_user(&self, user: User) -> PgWireResult<()> {
189 let mut users = self.users.write().await;
190 users.insert(user.username.clone(), user);
191 Ok(())
192 }
193
194 pub async fn add_role(&self, role: Role) -> PgWireResult<()> {
196 let mut roles = self.roles.write().await;
197 roles.insert(role.name.clone(), role);
198 Ok(())
199 }
200
201 pub async fn authenticate(&self, username: &str, password: &str) -> PgWireResult<bool> {
203 let users = self.users.read().await;
204
205 if let Some(user) = users.get(username) {
206 if !user.can_login {
207 return Ok(false);
208 }
209
210 if user.password_hash.is_empty() || password == user.password_hash {
213 return Ok(true);
214 }
215 }
216
217 Ok(false)
220 }
221
222 pub async fn get_user(&self, username: &str) -> Option<User> {
224 let users = self.users.read().await;
225 users.get(username).cloned()
226 }
227
228 pub async fn get_role(&self, role_name: &str) -> Option<Role> {
230 let roles = self.roles.read().await;
231 roles.get(role_name).cloned()
232 }
233
234 pub async fn user_has_role(&self, username: &str, role_name: &str) -> bool {
236 if let Some(user) = self.get_user(username).await {
237 return user.roles.contains(&role_name.to_string()) || user.is_superuser;
238 }
239 false
240 }
241
242 pub async fn list_users(&self) -> Vec<String> {
244 let users = self.users.read().await;
245 users.keys().cloned().collect()
246 }
247
248 pub async fn list_roles(&self) -> Vec<String> {
250 let roles = self.roles.read().await;
251 roles.keys().cloned().collect()
252 }
253
254 pub async fn grant_permission(
256 &self,
257 role_name: &str,
258 permission: Permission,
259 resource: ResourceType,
260 granted_by: &str,
261 with_grant_option: bool,
262 ) -> PgWireResult<()> {
263 let mut roles = self.roles.write().await;
264
265 if let Some(role) = roles.get_mut(role_name) {
266 let grant = Grant {
267 permission,
268 resource,
269 granted_by: granted_by.to_string(),
270 with_grant_option,
271 };
272 role.grants.push(grant);
273 Ok(())
274 } else {
275 Err(PgWireError::UserError(Box::new(
276 pgwire::error::ErrorInfo::new(
277 "ERROR".to_string(),
278 "42704".to_string(), format!("role \"{role_name}\" does not exist"),
280 ),
281 )))
282 }
283 }
284
285 pub async fn revoke_permission(
287 &self,
288 role_name: &str,
289 permission: Permission,
290 resource: ResourceType,
291 ) -> PgWireResult<()> {
292 let mut roles = self.roles.write().await;
293
294 if let Some(role) = roles.get_mut(role_name) {
295 role.grants
296 .retain(|grant| !(grant.permission == permission && grant.resource == resource));
297 Ok(())
298 } else {
299 Err(PgWireError::UserError(Box::new(
300 pgwire::error::ErrorInfo::new(
301 "ERROR".to_string(),
302 "42704".to_string(), format!("role \"{role_name}\" does not exist"),
304 ),
305 )))
306 }
307 }
308
309 pub async fn check_permission(
311 &self,
312 username: &str,
313 permission: Permission,
314 resource: ResourceType,
315 ) -> bool {
316 if let Some(user) = self.get_user(username).await {
318 if user.is_superuser {
319 return true;
320 }
321
322 for role_name in &user.roles {
324 if let Some(role) = self.get_role(role_name).await {
325 if role.is_superuser {
327 return true;
328 }
329
330 for grant in &role.grants {
332 if self.permission_matches(&grant.permission, &permission)
333 && self.resource_matches(&grant.resource, &resource)
334 {
335 return true;
336 }
337 }
338
339 for inherited_role in &role.inherited_roles {
341 if self
342 .check_role_permission(inherited_role, &permission, &resource)
343 .await
344 {
345 return true;
346 }
347 }
348 }
349 }
350 }
351
352 false
353 }
354
355 fn check_role_permission<'a>(
357 &'a self,
358 role_name: &'a str,
359 permission: &'a Permission,
360 resource: &'a ResourceType,
361 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send + 'a>> {
362 Box::pin(async move {
363 if let Some(role) = self.get_role(role_name).await {
364 if role.is_superuser {
365 return true;
366 }
367
368 for grant in &role.grants {
370 if self.permission_matches(&grant.permission, permission)
371 && self.resource_matches(&grant.resource, resource)
372 {
373 return true;
374 }
375 }
376
377 for inherited_role in &role.inherited_roles {
379 if self
380 .check_role_permission(inherited_role, permission, resource)
381 .await
382 {
383 return true;
384 }
385 }
386 }
387
388 false
389 })
390 }
391
392 fn permission_matches(&self, grant_permission: &Permission, requested: &Permission) -> bool {
394 grant_permission == requested || matches!(grant_permission, Permission::All)
395 }
396
397 fn resource_matches(&self, grant_resource: &ResourceType, requested: &ResourceType) -> bool {
399 match (grant_resource, requested) {
400 (a, b) if a == b => true,
402 (ResourceType::All, _) => true,
404 (ResourceType::Schema(schema), ResourceType::Table(table)) => {
406 table.starts_with(&format!("{schema}."))
408 }
409 _ => false,
410 }
411 }
412
413 pub async fn add_role_inheritance(
415 &self,
416 child_role: &str,
417 parent_role: &str,
418 ) -> PgWireResult<()> {
419 let mut roles = self.roles.write().await;
420
421 if let Some(child) = roles.get_mut(child_role) {
422 if !child.inherited_roles.contains(&parent_role.to_string()) {
423 child.inherited_roles.push(parent_role.to_string());
424 }
425 Ok(())
426 } else {
427 Err(PgWireError::UserError(Box::new(
428 pgwire::error::ErrorInfo::new(
429 "ERROR".to_string(),
430 "42704".to_string(), format!("role \"{child_role}\" does not exist"),
432 ),
433 )))
434 }
435 }
436
437 pub async fn remove_role_inheritance(
439 &self,
440 child_role: &str,
441 parent_role: &str,
442 ) -> PgWireResult<()> {
443 let mut roles = self.roles.write().await;
444
445 if let Some(child) = roles.get_mut(child_role) {
446 child.inherited_roles.retain(|role| role != parent_role);
447 Ok(())
448 } else {
449 Err(PgWireError::UserError(Box::new(
450 pgwire::error::ErrorInfo::new(
451 "ERROR".to_string(),
452 "42704".to_string(), format!("role \"{child_role}\" does not exist"),
454 ),
455 )))
456 }
457 }
458
459 pub async fn create_role(&self, config: RoleConfig) -> PgWireResult<()> {
461 let role = Role {
462 name: config.name.clone(),
463 is_superuser: config.is_superuser,
464 can_login: config.can_login,
465 can_create_db: config.can_create_db,
466 can_create_role: config.can_create_role,
467 can_create_user: config.can_create_user,
468 can_replication: config.can_replication,
469 grants: vec![],
470 inherited_roles: vec![],
471 };
472
473 self.add_role(role).await
474 }
475
476 pub async fn create_predefined_roles(&self) -> PgWireResult<()> {
478 self.create_role(RoleConfig {
480 name: "readonly".to_string(),
481 is_superuser: false,
482 can_login: false,
483 can_create_db: false,
484 can_create_role: false,
485 can_create_user: false,
486 can_replication: false,
487 })
488 .await?;
489
490 self.grant_permission(
491 "readonly",
492 Permission::Select,
493 ResourceType::All,
494 "system",
495 false,
496 )
497 .await?;
498
499 self.create_role(RoleConfig {
501 name: "readwrite".to_string(),
502 is_superuser: false,
503 can_login: false,
504 can_create_db: false,
505 can_create_role: false,
506 can_create_user: false,
507 can_replication: false,
508 })
509 .await?;
510
511 self.grant_permission(
512 "readwrite",
513 Permission::Select,
514 ResourceType::All,
515 "system",
516 false,
517 )
518 .await?;
519
520 self.grant_permission(
521 "readwrite",
522 Permission::Insert,
523 ResourceType::All,
524 "system",
525 false,
526 )
527 .await?;
528
529 self.grant_permission(
530 "readwrite",
531 Permission::Update,
532 ResourceType::All,
533 "system",
534 false,
535 )
536 .await?;
537
538 self.grant_permission(
539 "readwrite",
540 Permission::Delete,
541 ResourceType::All,
542 "system",
543 false,
544 )
545 .await?;
546
547 self.create_role(RoleConfig {
549 name: "dbadmin".to_string(),
550 is_superuser: false,
551 can_login: true,
552 can_create_db: true,
553 can_create_role: false,
554 can_create_user: false,
555 can_replication: false,
556 })
557 .await?;
558
559 self.grant_permission(
560 "dbadmin",
561 Permission::All,
562 ResourceType::All,
563 "system",
564 true,
565 )
566 .await?;
567
568 Ok(())
569 }
570}
571
572#[derive(Clone)]
575pub struct DfAuthSource {
576 pub auth_manager: Arc<AuthManager>,
577}
578
579impl DfAuthSource {
580 pub fn new(auth_manager: Arc<AuthManager>) -> Self {
581 DfAuthSource { auth_manager }
582 }
583}
584
585#[async_trait]
586impl AuthSource for DfAuthSource {
587 async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
588 if let Some(username) = login.user() {
589 if let Some(user) = self.auth_manager.get_user(username).await {
591 if user.can_login {
592 Ok(Password::new(None, user.password_hash.into_bytes()))
596 } else {
597 Err(PgWireError::UserError(Box::new(
598 pgwire::error::ErrorInfo::new(
599 "FATAL".to_string(),
600 "28000".to_string(), format!("User \"{username}\" is not allowed to login"),
602 ),
603 )))
604 }
605 } else {
606 Err(PgWireError::UserError(Box::new(
607 pgwire::error::ErrorInfo::new(
608 "FATAL".to_string(),
609 "28P01".to_string(), format!("password authentication failed for user \"{username}\""),
611 ),
612 )))
613 }
614 } else {
615 Err(PgWireError::UserError(Box::new(
616 pgwire::error::ErrorInfo::new(
617 "FATAL".to_string(),
618 "28P01".to_string(), "No username provided in login request".to_string(),
620 ),
621 )))
622 }
623 }
624}
625
626pub struct SimpleAuthSource {
665 auth_manager: Arc<AuthManager>,
666}
667
668impl SimpleAuthSource {
669 pub fn new(auth_manager: Arc<AuthManager>) -> Self {
670 SimpleAuthSource { auth_manager }
671 }
672}
673
674#[async_trait]
675impl AuthSource for SimpleAuthSource {
676 async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
677 let username = login.user().unwrap_or("anonymous");
678
679 if let Some(user) = self.auth_manager.get_user(username).await {
681 if user.can_login {
682 return Ok(Password::new(None, vec![]));
684 }
685 }
686
687 if username == "postgres" {
689 return Ok(Password::new(None, vec![]));
690 }
691
692 Err(PgWireError::UserError(Box::new(
694 pgwire::error::ErrorInfo::new(
695 "FATAL".to_string(),
696 "28P01".to_string(), format!("password authentication failed for user \"{username}\""),
698 ),
699 )))
700 }
701}
702
703pub fn create_auth_source(auth_manager: Arc<AuthManager>) -> SimpleAuthSource {
705 SimpleAuthSource::new(auth_manager)
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711
712 #[tokio::test]
713 async fn test_auth_manager_creation() {
714 let auth_manager = AuthManager::new();
715
716 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
718
719 let users = auth_manager.list_users().await;
720 assert!(users.contains(&"postgres".to_string()));
721 }
722
723 #[tokio::test]
724 async fn test_user_authentication() {
725 let auth_manager = AuthManager::new();
726
727 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
729
730 assert!(auth_manager.authenticate("postgres", "").await.unwrap());
732 assert!(!auth_manager
733 .authenticate("nonexistent", "password")
734 .await
735 .unwrap());
736 }
737
738 #[tokio::test]
739 async fn test_role_management() {
740 let auth_manager = AuthManager::new();
741
742 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
744
745 assert!(auth_manager.user_has_role("postgres", "postgres").await);
747 assert!(auth_manager.user_has_role("postgres", "any_role").await); }
749}