1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use pgwire::api::auth::{AuthSource, LoginInfo, Password};
6use pgwire::error::{PgWireError, PgWireResult};
7use tokio::sync::RwLock;
8
9use datafusion_pg_catalog::pg_catalog::context::*;
10
11#[derive(Debug, Clone)]
13pub struct AuthManager {
14 users: Arc<RwLock<HashMap<String, User>>>,
15 roles: Arc<RwLock<HashMap<String, Role>>>,
16}
17
18impl Default for AuthManager {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl AuthManager {
25 pub fn new() -> Self {
26 let mut users = HashMap::new();
27 let postgres_user = User {
29 username: "postgres".to_string(),
30 password_hash: "".to_string(), roles: vec!["postgres".to_string()],
32 is_superuser: true,
33 can_login: true,
34 connection_limit: None,
35 };
36 users.insert(postgres_user.username.clone(), postgres_user);
37
38 let mut roles = HashMap::new();
39 let postgres_role = Role {
40 name: "postgres".to_string(),
41 is_superuser: true,
42 can_login: true,
43 can_create_db: true,
44 can_create_role: true,
45 can_create_user: true,
46 can_replication: true,
47 grants: vec![Grant {
48 permission: Permission::All,
49 resource: ResourceType::All,
50 granted_by: "system".to_string(),
51 with_grant_option: true,
52 }],
53 inherited_roles: vec![],
54 };
55 roles.insert(postgres_role.name.clone(), postgres_role);
56
57 AuthManager {
58 users: Arc::new(RwLock::new(users)),
59 roles: Arc::new(RwLock::new(roles)),
60 }
61 }
62
63 pub async fn add_user(&self, user: User) -> PgWireResult<()> {
65 let mut users = self.users.write().await;
66 users.insert(user.username.clone(), user);
67 Ok(())
68 }
69
70 pub async fn add_role(&self, role: Role) -> PgWireResult<()> {
72 let mut roles = self.roles.write().await;
73 roles.insert(role.name.clone(), role);
74 Ok(())
75 }
76
77 pub async fn authenticate(&self, username: &str, password: &str) -> PgWireResult<bool> {
79 let users = self.users.read().await;
80
81 if let Some(user) = users.get(username) {
82 if !user.can_login {
83 return Ok(false);
84 }
85
86 if user.password_hash.is_empty() || password == user.password_hash {
89 return Ok(true);
90 }
91 }
92
93 Ok(false)
96 }
97
98 pub async fn get_user(&self, username: &str) -> Option<User> {
100 let users = self.users.read().await;
101 users.get(username).cloned()
102 }
103
104 pub async fn get_role(&self, role_name: &str) -> Option<Role> {
106 let roles = self.roles.read().await;
107 roles.get(role_name).cloned()
108 }
109
110 pub async fn user_has_role(&self, username: &str, role_name: &str) -> bool {
112 if let Some(user) = self.get_user(username).await {
113 return user.roles.contains(&role_name.to_string()) || user.is_superuser;
114 }
115 false
116 }
117
118 pub async fn list_users(&self) -> Vec<String> {
120 let users = self.users.read().await;
121 users.keys().cloned().collect()
122 }
123
124 pub async fn list_roles(&self) -> Vec<String> {
126 let roles = self.roles.read().await;
127 roles.keys().cloned().collect()
128 }
129
130 pub async fn grant_permission(
132 &self,
133 role_name: &str,
134 permission: Permission,
135 resource: ResourceType,
136 granted_by: &str,
137 with_grant_option: bool,
138 ) -> PgWireResult<()> {
139 let mut roles = self.roles.write().await;
140
141 if let Some(role) = roles.get_mut(role_name) {
142 let grant = Grant {
143 permission,
144 resource,
145 granted_by: granted_by.to_string(),
146 with_grant_option,
147 };
148 role.grants.push(grant);
149 Ok(())
150 } else {
151 Err(PgWireError::UserError(Box::new(
152 pgwire::error::ErrorInfo::new(
153 "ERROR".to_string(),
154 "42704".to_string(), format!("role \"{role_name}\" does not exist"),
156 ),
157 )))
158 }
159 }
160
161 pub async fn revoke_permission(
163 &self,
164 role_name: &str,
165 permission: Permission,
166 resource: ResourceType,
167 ) -> PgWireResult<()> {
168 let mut roles = self.roles.write().await;
169
170 if let Some(role) = roles.get_mut(role_name) {
171 role.grants
172 .retain(|grant| !(grant.permission == permission && grant.resource == resource));
173 Ok(())
174 } else {
175 Err(PgWireError::UserError(Box::new(
176 pgwire::error::ErrorInfo::new(
177 "ERROR".to_string(),
178 "42704".to_string(), format!("role \"{role_name}\" does not exist"),
180 ),
181 )))
182 }
183 }
184
185 pub async fn check_permission(
187 &self,
188 username: &str,
189 permission: Permission,
190 resource: ResourceType,
191 ) -> bool {
192 if let Some(user) = self.get_user(username).await {
194 if user.is_superuser {
195 return true;
196 }
197
198 for role_name in &user.roles {
200 if let Some(role) = self.get_role(role_name).await {
201 if role.is_superuser {
203 return true;
204 }
205
206 for grant in &role.grants {
208 if self.permission_matches(&grant.permission, &permission)
209 && self.resource_matches(&grant.resource, &resource)
210 {
211 return true;
212 }
213 }
214
215 for inherited_role in &role.inherited_roles {
217 if self
218 .check_role_permission(inherited_role, &permission, &resource)
219 .await
220 {
221 return true;
222 }
223 }
224 }
225 }
226 }
227
228 false
229 }
230
231 fn check_role_permission<'a>(
233 &'a self,
234 role_name: &'a str,
235 permission: &'a Permission,
236 resource: &'a ResourceType,
237 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send + 'a>> {
238 Box::pin(async move {
239 if let Some(role) = self.get_role(role_name).await {
240 if role.is_superuser {
241 return true;
242 }
243
244 for grant in &role.grants {
246 if self.permission_matches(&grant.permission, permission)
247 && self.resource_matches(&grant.resource, resource)
248 {
249 return true;
250 }
251 }
252
253 for inherited_role in &role.inherited_roles {
255 if self
256 .check_role_permission(inherited_role, permission, resource)
257 .await
258 {
259 return true;
260 }
261 }
262 }
263
264 false
265 })
266 }
267
268 fn permission_matches(&self, grant_permission: &Permission, requested: &Permission) -> bool {
270 grant_permission == requested || matches!(grant_permission, Permission::All)
271 }
272
273 fn resource_matches(&self, grant_resource: &ResourceType, requested: &ResourceType) -> bool {
275 match (grant_resource, requested) {
276 (a, b) if a == b => true,
278 (ResourceType::All, _) => true,
280 (ResourceType::Schema(schema), ResourceType::Table(table)) => {
282 table.starts_with(&format!("{schema}."))
284 }
285 _ => false,
286 }
287 }
288
289 pub async fn add_role_inheritance(
291 &self,
292 child_role: &str,
293 parent_role: &str,
294 ) -> PgWireResult<()> {
295 let mut roles = self.roles.write().await;
296
297 if let Some(child) = roles.get_mut(child_role) {
298 if !child.inherited_roles.contains(&parent_role.to_string()) {
299 child.inherited_roles.push(parent_role.to_string());
300 }
301 Ok(())
302 } else {
303 Err(PgWireError::UserError(Box::new(
304 pgwire::error::ErrorInfo::new(
305 "ERROR".to_string(),
306 "42704".to_string(), format!("role \"{child_role}\" does not exist"),
308 ),
309 )))
310 }
311 }
312
313 pub async fn remove_role_inheritance(
315 &self,
316 child_role: &str,
317 parent_role: &str,
318 ) -> PgWireResult<()> {
319 let mut roles = self.roles.write().await;
320
321 if let Some(child) = roles.get_mut(child_role) {
322 child.inherited_roles.retain(|role| role != parent_role);
323 Ok(())
324 } else {
325 Err(PgWireError::UserError(Box::new(
326 pgwire::error::ErrorInfo::new(
327 "ERROR".to_string(),
328 "42704".to_string(), format!("role \"{child_role}\" does not exist"),
330 ),
331 )))
332 }
333 }
334
335 pub async fn create_role(&self, config: RoleConfig) -> PgWireResult<()> {
337 let role = Role {
338 name: config.name.clone(),
339 is_superuser: config.is_superuser,
340 can_login: config.can_login,
341 can_create_db: config.can_create_db,
342 can_create_role: config.can_create_role,
343 can_create_user: config.can_create_user,
344 can_replication: config.can_replication,
345 grants: vec![],
346 inherited_roles: vec![],
347 };
348
349 self.add_role(role).await
350 }
351
352 pub async fn create_predefined_roles(&self) -> PgWireResult<()> {
354 self.create_role(RoleConfig {
356 name: "readonly".to_string(),
357 is_superuser: false,
358 can_login: false,
359 can_create_db: false,
360 can_create_role: false,
361 can_create_user: false,
362 can_replication: false,
363 })
364 .await?;
365
366 self.grant_permission(
367 "readonly",
368 Permission::Select,
369 ResourceType::All,
370 "system",
371 false,
372 )
373 .await?;
374
375 self.create_role(RoleConfig {
377 name: "readwrite".to_string(),
378 is_superuser: false,
379 can_login: false,
380 can_create_db: false,
381 can_create_role: false,
382 can_create_user: false,
383 can_replication: false,
384 })
385 .await?;
386
387 self.grant_permission(
388 "readwrite",
389 Permission::Select,
390 ResourceType::All,
391 "system",
392 false,
393 )
394 .await?;
395
396 self.grant_permission(
397 "readwrite",
398 Permission::Insert,
399 ResourceType::All,
400 "system",
401 false,
402 )
403 .await?;
404
405 self.grant_permission(
406 "readwrite",
407 Permission::Update,
408 ResourceType::All,
409 "system",
410 false,
411 )
412 .await?;
413
414 self.grant_permission(
415 "readwrite",
416 Permission::Delete,
417 ResourceType::All,
418 "system",
419 false,
420 )
421 .await?;
422
423 self.create_role(RoleConfig {
425 name: "dbadmin".to_string(),
426 is_superuser: false,
427 can_login: true,
428 can_create_db: true,
429 can_create_role: false,
430 can_create_user: false,
431 can_replication: false,
432 })
433 .await?;
434
435 self.grant_permission(
436 "dbadmin",
437 Permission::All,
438 ResourceType::All,
439 "system",
440 true,
441 )
442 .await?;
443
444 Ok(())
445 }
446}
447
448#[async_trait]
449impl PgCatalogContextProvider for AuthManager {
450 async fn roles(&self) -> Vec<String> {
452 self.list_roles().await
453 }
454
455 async fn role(&self, name: &str) -> Option<Role> {
457 self.get_role(name).await
458 }
459}
460
461#[derive(Clone)]
464pub struct DfAuthSource {
465 pub auth_manager: Arc<AuthManager>,
466}
467
468impl DfAuthSource {
469 pub fn new(auth_manager: Arc<AuthManager>) -> Self {
470 DfAuthSource { auth_manager }
471 }
472}
473
474#[async_trait]
475impl AuthSource for DfAuthSource {
476 async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
477 if let Some(username) = login.user() {
478 if let Some(user) = self.auth_manager.get_user(username).await {
480 if user.can_login {
481 Ok(Password::new(None, user.password_hash.into_bytes()))
485 } else {
486 Err(PgWireError::UserError(Box::new(
487 pgwire::error::ErrorInfo::new(
488 "FATAL".to_string(),
489 "28000".to_string(), format!("User \"{username}\" is not allowed to login"),
491 ),
492 )))
493 }
494 } else {
495 Err(PgWireError::UserError(Box::new(
496 pgwire::error::ErrorInfo::new(
497 "FATAL".to_string(),
498 "28P01".to_string(), format!("password authentication failed for user \"{username}\""),
500 ),
501 )))
502 }
503 } else {
504 Err(PgWireError::UserError(Box::new(
505 pgwire::error::ErrorInfo::new(
506 "FATAL".to_string(),
507 "28P01".to_string(), "No username provided in login request".to_string(),
509 ),
510 )))
511 }
512 }
513}
514
515pub struct SimpleAuthSource {
554 auth_manager: Arc<AuthManager>,
555}
556
557impl SimpleAuthSource {
558 pub fn new(auth_manager: Arc<AuthManager>) -> Self {
559 SimpleAuthSource { auth_manager }
560 }
561}
562
563#[async_trait]
564impl AuthSource for SimpleAuthSource {
565 async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
566 let username = login.user().unwrap_or("anonymous");
567
568 if let Some(user) = self.auth_manager.get_user(username).await {
570 if user.can_login {
571 return Ok(Password::new(None, vec![]));
573 }
574 }
575
576 if username == "postgres" {
578 return Ok(Password::new(None, vec![]));
579 }
580
581 Err(PgWireError::UserError(Box::new(
583 pgwire::error::ErrorInfo::new(
584 "FATAL".to_string(),
585 "28P01".to_string(), format!("password authentication failed for user \"{username}\""),
587 ),
588 )))
589 }
590}
591
592pub fn create_auth_source(auth_manager: Arc<AuthManager>) -> SimpleAuthSource {
594 SimpleAuthSource::new(auth_manager)
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[tokio::test]
602 async fn test_auth_manager_creation() {
603 let auth_manager = AuthManager::new();
604
605 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
607
608 let users = auth_manager.list_users().await;
609 assert!(users.contains(&"postgres".to_string()));
610 }
611
612 #[tokio::test]
613 async fn test_user_authentication() {
614 let auth_manager = AuthManager::new();
615
616 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
618
619 assert!(auth_manager.authenticate("postgres", "").await.unwrap());
621 assert!(!auth_manager
622 .authenticate("nonexistent", "password")
623 .await
624 .unwrap());
625 }
626
627 #[tokio::test]
628 async fn test_role_management() {
629 let auth_manager = AuthManager::new();
630
631 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
633
634 assert!(auth_manager.user_has_role("postgres", "postgres").await);
636 assert!(auth_manager.user_has_role("postgres", "any_role").await); }
638}