datafusion_postgres/
auth.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use log::warn;
6use pgwire::api::auth::{AuthSource, LoginInfo, Password};
7use pgwire::error::{PgWireError, PgWireResult};
8use tokio::sync::RwLock;
9
10/// User information stored in the authentication system
11#[derive(Debug, Clone)]
12pub struct User {
13    pub username: String,
14    pub password_hash: String,
15    pub roles: Vec<String>,
16    pub is_superuser: bool,
17    pub can_login: bool,
18    pub connection_limit: Option<i32>,
19}
20
21/// Permission types for granular access control
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum Permission {
24    Select,
25    Insert,
26    Update,
27    Delete,
28    Create,
29    Drop,
30    Alter,
31    Index,
32    References,
33    Trigger,
34    Execute,
35    Usage,
36    Connect,
37    Temporary,
38    All,
39}
40
41impl Permission {
42    pub fn from_string(s: &str) -> Option<Permission> {
43        match s.to_uppercase().as_str() {
44            "SELECT" => Some(Permission::Select),
45            "INSERT" => Some(Permission::Insert),
46            "UPDATE" => Some(Permission::Update),
47            "DELETE" => Some(Permission::Delete),
48            "CREATE" => Some(Permission::Create),
49            "DROP" => Some(Permission::Drop),
50            "ALTER" => Some(Permission::Alter),
51            "INDEX" => Some(Permission::Index),
52            "REFERENCES" => Some(Permission::References),
53            "TRIGGER" => Some(Permission::Trigger),
54            "EXECUTE" => Some(Permission::Execute),
55            "USAGE" => Some(Permission::Usage),
56            "CONNECT" => Some(Permission::Connect),
57            "TEMPORARY" => Some(Permission::Temporary),
58            "ALL" => Some(Permission::All),
59            _ => None,
60        }
61    }
62}
63
64/// Resource types for access control
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub enum ResourceType {
67    Table(String),
68    Schema(String),
69    Database(String),
70    Function(String),
71    Sequence(String),
72    All,
73}
74
75/// Grant entry for specific permissions on resources
76#[derive(Debug, Clone)]
77pub struct Grant {
78    pub permission: Permission,
79    pub resource: ResourceType,
80    pub granted_by: String,
81    pub with_grant_option: bool,
82}
83
84/// Role information for access control
85#[derive(Debug, Clone)]
86pub struct Role {
87    pub name: String,
88    pub is_superuser: bool,
89    pub can_login: bool,
90    pub can_create_db: bool,
91    pub can_create_role: bool,
92    pub can_create_user: bool,
93    pub can_replication: bool,
94    pub grants: Vec<Grant>,
95    pub inherited_roles: Vec<String>,
96}
97
98/// Role configuration for creation
99#[derive(Debug, Clone)]
100pub struct RoleConfig {
101    pub name: String,
102    pub is_superuser: bool,
103    pub can_login: bool,
104    pub can_create_db: bool,
105    pub can_create_role: bool,
106    pub can_create_user: bool,
107    pub can_replication: bool,
108}
109
110/// Authentication manager that handles users and roles
111#[derive(Debug)]
112pub struct AuthManager {
113    users: Arc<RwLock<HashMap<String, User>>>,
114    roles: Arc<RwLock<HashMap<String, Role>>>,
115}
116
117impl Default for AuthManager {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl AuthManager {
124    pub fn new() -> Self {
125        let auth_manager = AuthManager {
126            users: Arc::new(RwLock::new(HashMap::new())),
127            roles: Arc::new(RwLock::new(HashMap::new())),
128        };
129
130        // Initialize with default postgres superuser
131        let postgres_user = User {
132            username: "postgres".to_string(),
133            password_hash: "".to_string(), // Empty password for now
134            roles: vec!["postgres".to_string()],
135            is_superuser: true,
136            can_login: true,
137            connection_limit: None,
138        };
139
140        let postgres_role = Role {
141            name: "postgres".to_string(),
142            is_superuser: true,
143            can_login: true,
144            can_create_db: true,
145            can_create_role: true,
146            can_create_user: true,
147            can_replication: true,
148            grants: vec![Grant {
149                permission: Permission::All,
150                resource: ResourceType::All,
151                granted_by: "system".to_string(),
152                with_grant_option: true,
153            }],
154            inherited_roles: vec![],
155        };
156
157        // Add default users and roles
158        let auth_manager_clone = AuthManager {
159            users: auth_manager.users.clone(),
160            roles: auth_manager.roles.clone(),
161        };
162
163        tokio::spawn({
164            let users = auth_manager.users.clone();
165            let roles = auth_manager.roles.clone();
166            let auth_manager_spawn = auth_manager_clone;
167            async move {
168                users
169                    .write()
170                    .await
171                    .insert("postgres".to_string(), postgres_user);
172                roles
173                    .write()
174                    .await
175                    .insert("postgres".to_string(), postgres_role);
176
177                // Create predefined roles
178                if let Err(e) = auth_manager_spawn.create_predefined_roles().await {
179                    warn!("Failed to create predefined roles: {e:?}");
180                }
181            }
182        });
183
184        auth_manager
185    }
186
187    /// Add a new user to the system
188    pub async fn add_user(&self, user: User) -> PgWireResult<()> {
189        let mut users = self.users.write().await;
190        users.insert(user.username.clone(), user);
191        Ok(())
192    }
193
194    /// Add a new role to the system
195    pub async fn add_role(&self, role: Role) -> PgWireResult<()> {
196        let mut roles = self.roles.write().await;
197        roles.insert(role.name.clone(), role);
198        Ok(())
199    }
200
201    /// Authenticate a user with username and password
202    pub async fn authenticate(&self, username: &str, password: &str) -> PgWireResult<bool> {
203        let users = self.users.read().await;
204
205        if let Some(user) = users.get(username) {
206            if !user.can_login {
207                return Ok(false);
208            }
209
210            // For now, accept empty password or any password for existing users
211            // In production, this should use proper password hashing (bcrypt, etc.)
212            if user.password_hash.is_empty() || password == user.password_hash {
213                return Ok(true);
214            }
215        }
216
217        // If user doesn't exist, check if we should create them dynamically
218        // For now, only accept known users
219        Ok(false)
220    }
221
222    /// Get user information
223    pub async fn get_user(&self, username: &str) -> Option<User> {
224        let users = self.users.read().await;
225        users.get(username).cloned()
226    }
227
228    /// Get role information
229    pub async fn get_role(&self, role_name: &str) -> Option<Role> {
230        let roles = self.roles.read().await;
231        roles.get(role_name).cloned()
232    }
233
234    /// Check if user has a specific role
235    pub async fn user_has_role(&self, username: &str, role_name: &str) -> bool {
236        if let Some(user) = self.get_user(username).await {
237            return user.roles.contains(&role_name.to_string()) || user.is_superuser;
238        }
239        false
240    }
241
242    /// List all users (for administrative purposes)
243    pub async fn list_users(&self) -> Vec<String> {
244        let users = self.users.read().await;
245        users.keys().cloned().collect()
246    }
247
248    /// List all roles (for administrative purposes)
249    pub async fn list_roles(&self) -> Vec<String> {
250        let roles = self.roles.read().await;
251        roles.keys().cloned().collect()
252    }
253
254    /// Grant permission to a role
255    pub async fn grant_permission(
256        &self,
257        role_name: &str,
258        permission: Permission,
259        resource: ResourceType,
260        granted_by: &str,
261        with_grant_option: bool,
262    ) -> PgWireResult<()> {
263        let mut roles = self.roles.write().await;
264
265        if let Some(role) = roles.get_mut(role_name) {
266            let grant = Grant {
267                permission,
268                resource,
269                granted_by: granted_by.to_string(),
270                with_grant_option,
271            };
272            role.grants.push(grant);
273            Ok(())
274        } else {
275            Err(PgWireError::UserError(Box::new(
276                pgwire::error::ErrorInfo::new(
277                    "ERROR".to_string(),
278                    "42704".to_string(), // undefined_object
279                    format!("role \"{role_name}\" does not exist"),
280                ),
281            )))
282        }
283    }
284
285    /// Revoke permission from a role
286    pub async fn revoke_permission(
287        &self,
288        role_name: &str,
289        permission: Permission,
290        resource: ResourceType,
291    ) -> PgWireResult<()> {
292        let mut roles = self.roles.write().await;
293
294        if let Some(role) = roles.get_mut(role_name) {
295            role.grants
296                .retain(|grant| !(grant.permission == permission && grant.resource == resource));
297            Ok(())
298        } else {
299            Err(PgWireError::UserError(Box::new(
300                pgwire::error::ErrorInfo::new(
301                    "ERROR".to_string(),
302                    "42704".to_string(), // undefined_object
303                    format!("role \"{role_name}\" does not exist"),
304                ),
305            )))
306        }
307    }
308
309    /// Check if a user has a specific permission on a resource
310    pub async fn check_permission(
311        &self,
312        username: &str,
313        permission: Permission,
314        resource: ResourceType,
315    ) -> bool {
316        // Superusers have all permissions
317        if let Some(user) = self.get_user(username).await {
318            if user.is_superuser {
319                return true;
320            }
321
322            // Check permissions for each role the user has
323            for role_name in &user.roles {
324                if let Some(role) = self.get_role(role_name).await {
325                    // Superuser role has all permissions
326                    if role.is_superuser {
327                        return true;
328                    }
329
330                    // Check direct grants
331                    for grant in &role.grants {
332                        if self.permission_matches(&grant.permission, &permission)
333                            && self.resource_matches(&grant.resource, &resource)
334                        {
335                            return true;
336                        }
337                    }
338
339                    // Check inherited roles recursively
340                    for inherited_role in &role.inherited_roles {
341                        if self
342                            .check_role_permission(inherited_role, &permission, &resource)
343                            .await
344                        {
345                            return true;
346                        }
347                    }
348                }
349            }
350        }
351
352        false
353    }
354
355    /// Check if a role has a specific permission (helper for recursive checking)
356    fn check_role_permission<'a>(
357        &'a self,
358        role_name: &'a str,
359        permission: &'a Permission,
360        resource: &'a ResourceType,
361    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send + 'a>> {
362        Box::pin(async move {
363            if let Some(role) = self.get_role(role_name).await {
364                if role.is_superuser {
365                    return true;
366                }
367
368                // Check direct grants
369                for grant in &role.grants {
370                    if self.permission_matches(&grant.permission, permission)
371                        && self.resource_matches(&grant.resource, resource)
372                    {
373                        return true;
374                    }
375                }
376
377                // Check inherited roles
378                for inherited_role in &role.inherited_roles {
379                    if self
380                        .check_role_permission(inherited_role, permission, resource)
381                        .await
382                    {
383                        return true;
384                    }
385                }
386            }
387
388            false
389        })
390    }
391
392    /// Check if a permission grant matches the requested permission
393    fn permission_matches(&self, grant_permission: &Permission, requested: &Permission) -> bool {
394        grant_permission == requested || matches!(grant_permission, Permission::All)
395    }
396
397    /// Check if a resource grant matches the requested resource
398    fn resource_matches(&self, grant_resource: &ResourceType, requested: &ResourceType) -> bool {
399        match (grant_resource, requested) {
400            // Exact match
401            (a, b) if a == b => true,
402            // All resource type grants access to everything
403            (ResourceType::All, _) => true,
404            // Schema grants access to all tables in that schema
405            (ResourceType::Schema(schema), ResourceType::Table(table)) => {
406                // For simplicity, assume table names are schema.table format
407                table.starts_with(&format!("{schema}."))
408            }
409            _ => false,
410        }
411    }
412
413    /// Add role inheritance
414    pub async fn add_role_inheritance(
415        &self,
416        child_role: &str,
417        parent_role: &str,
418    ) -> PgWireResult<()> {
419        let mut roles = self.roles.write().await;
420
421        if let Some(child) = roles.get_mut(child_role) {
422            if !child.inherited_roles.contains(&parent_role.to_string()) {
423                child.inherited_roles.push(parent_role.to_string());
424            }
425            Ok(())
426        } else {
427            Err(PgWireError::UserError(Box::new(
428                pgwire::error::ErrorInfo::new(
429                    "ERROR".to_string(),
430                    "42704".to_string(), // undefined_object
431                    format!("role \"{child_role}\" does not exist"),
432                ),
433            )))
434        }
435    }
436
437    /// Remove role inheritance
438    pub async fn remove_role_inheritance(
439        &self,
440        child_role: &str,
441        parent_role: &str,
442    ) -> PgWireResult<()> {
443        let mut roles = self.roles.write().await;
444
445        if let Some(child) = roles.get_mut(child_role) {
446            child.inherited_roles.retain(|role| role != parent_role);
447            Ok(())
448        } else {
449            Err(PgWireError::UserError(Box::new(
450                pgwire::error::ErrorInfo::new(
451                    "ERROR".to_string(),
452                    "42704".to_string(), // undefined_object
453                    format!("role \"{child_role}\" does not exist"),
454                ),
455            )))
456        }
457    }
458
459    /// Create a new role with specific capabilities
460    pub async fn create_role(&self, config: RoleConfig) -> PgWireResult<()> {
461        let role = Role {
462            name: config.name.clone(),
463            is_superuser: config.is_superuser,
464            can_login: config.can_login,
465            can_create_db: config.can_create_db,
466            can_create_role: config.can_create_role,
467            can_create_user: config.can_create_user,
468            can_replication: config.can_replication,
469            grants: vec![],
470            inherited_roles: vec![],
471        };
472
473        self.add_role(role).await
474    }
475
476    /// Create common predefined roles
477    pub async fn create_predefined_roles(&self) -> PgWireResult<()> {
478        // Read-only role
479        self.create_role(RoleConfig {
480            name: "readonly".to_string(),
481            is_superuser: false,
482            can_login: false,
483            can_create_db: false,
484            can_create_role: false,
485            can_create_user: false,
486            can_replication: false,
487        })
488        .await?;
489
490        self.grant_permission(
491            "readonly",
492            Permission::Select,
493            ResourceType::All,
494            "system",
495            false,
496        )
497        .await?;
498
499        // Read-write role
500        self.create_role(RoleConfig {
501            name: "readwrite".to_string(),
502            is_superuser: false,
503            can_login: false,
504            can_create_db: false,
505            can_create_role: false,
506            can_create_user: false,
507            can_replication: false,
508        })
509        .await?;
510
511        self.grant_permission(
512            "readwrite",
513            Permission::Select,
514            ResourceType::All,
515            "system",
516            false,
517        )
518        .await?;
519
520        self.grant_permission(
521            "readwrite",
522            Permission::Insert,
523            ResourceType::All,
524            "system",
525            false,
526        )
527        .await?;
528
529        self.grant_permission(
530            "readwrite",
531            Permission::Update,
532            ResourceType::All,
533            "system",
534            false,
535        )
536        .await?;
537
538        self.grant_permission(
539            "readwrite",
540            Permission::Delete,
541            ResourceType::All,
542            "system",
543            false,
544        )
545        .await?;
546
547        // Database admin role
548        self.create_role(RoleConfig {
549            name: "dbadmin".to_string(),
550            is_superuser: false,
551            can_login: true,
552            can_create_db: true,
553            can_create_role: false,
554            can_create_user: false,
555            can_replication: false,
556        })
557        .await?;
558
559        self.grant_permission(
560            "dbadmin",
561            Permission::All,
562            ResourceType::All,
563            "system",
564            true,
565        )
566        .await?;
567
568        Ok(())
569    }
570}
571
572/// AuthSource implementation for integration with pgwire authentication
573/// Provides proper password-based authentication instead of custom startup handler
574#[derive(Clone)]
575pub struct DfAuthSource {
576    pub auth_manager: Arc<AuthManager>,
577}
578
579impl DfAuthSource {
580    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
581        DfAuthSource { auth_manager }
582    }
583}
584
585#[async_trait]
586impl AuthSource for DfAuthSource {
587    async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
588        if let Some(username) = login.user() {
589            // Check if user exists in our RBAC system
590            if let Some(user) = self.auth_manager.get_user(username).await {
591                if user.can_login {
592                    // Return the stored password hash for authentication
593                    // The pgwire authentication handlers (cleartext/md5/scram) will
594                    // handle the actual password verification process
595                    Ok(Password::new(None, user.password_hash.into_bytes()))
596                } else {
597                    Err(PgWireError::UserError(Box::new(
598                        pgwire::error::ErrorInfo::new(
599                            "FATAL".to_string(),
600                            "28000".to_string(), // invalid_authorization_specification
601                            format!("User \"{username}\" is not allowed to login"),
602                        ),
603                    )))
604                }
605            } else {
606                Err(PgWireError::UserError(Box::new(
607                    pgwire::error::ErrorInfo::new(
608                        "FATAL".to_string(),
609                        "28P01".to_string(), // invalid_password
610                        format!("password authentication failed for user \"{username}\""),
611                    ),
612                )))
613            }
614        } else {
615            Err(PgWireError::UserError(Box::new(
616                pgwire::error::ErrorInfo::new(
617                    "FATAL".to_string(),
618                    "28P01".to_string(), // invalid_password
619                    "No username provided in login request".to_string(),
620                ),
621            )))
622        }
623    }
624}
625
626// REMOVED: Custom startup handler approach
627//
628// Instead of implementing a custom StartupHandler, use the proper pgwire authentication:
629//
630// For cleartext authentication:
631// ```rust
632// use pgwire::api::auth::cleartext::CleartextStartupHandler;
633//
634// let auth_source = Arc::new(DfAuthSource::new(auth_manager));
635// let authenticator = CleartextStartupHandler::new(
636//     auth_source,
637//     Arc::new(DefaultServerParameterProvider::default())
638// );
639// ```
640//
641// For MD5 authentication:
642// ```rust
643// use pgwire::api::auth::md5::MD5StartupHandler;
644//
645// let auth_source = Arc::new(DfAuthSource::new(auth_manager));
646// let authenticator = MD5StartupHandler::new(
647//     auth_source,
648//     Arc::new(DefaultServerParameterProvider::default())
649// );
650// ```
651//
652// For SCRAM authentication (requires "server-api-scram" feature):
653// ```rust
654// use pgwire::api::auth::scram::SASLScramAuthStartupHandler;
655//
656// let auth_source = Arc::new(DfAuthSource::new(auth_manager));
657// let authenticator = SASLScramAuthStartupHandler::new(
658//     auth_source,
659//     Arc::new(DefaultServerParameterProvider::default())
660// );
661// ```
662
663/// Simple AuthSource implementation that accepts any user with empty password
664pub struct SimpleAuthSource {
665    auth_manager: Arc<AuthManager>,
666}
667
668impl SimpleAuthSource {
669    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
670        SimpleAuthSource { auth_manager }
671    }
672}
673
674#[async_trait]
675impl AuthSource for SimpleAuthSource {
676    async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
677        let username = login.user().unwrap_or("anonymous");
678
679        // Check if user exists and can login
680        if let Some(user) = self.auth_manager.get_user(username).await {
681            if user.can_login {
682                // Return empty password for now (no authentication required)
683                return Ok(Password::new(None, vec![]));
684            }
685        }
686
687        // For postgres user, always allow
688        if username == "postgres" {
689            return Ok(Password::new(None, vec![]));
690        }
691
692        // User not found or cannot login
693        Err(PgWireError::UserError(Box::new(
694            pgwire::error::ErrorInfo::new(
695                "FATAL".to_string(),
696                "28P01".to_string(), // invalid_password
697                format!("password authentication failed for user \"{username}\""),
698            ),
699        )))
700    }
701}
702
703/// Helper function to create auth source with auth manager
704pub fn create_auth_source(auth_manager: Arc<AuthManager>) -> SimpleAuthSource {
705    SimpleAuthSource::new(auth_manager)
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    #[tokio::test]
713    async fn test_auth_manager_creation() {
714        let auth_manager = AuthManager::new();
715
716        // Wait a bit for the default user to be added
717        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
718
719        let users = auth_manager.list_users().await;
720        assert!(users.contains(&"postgres".to_string()));
721    }
722
723    #[tokio::test]
724    async fn test_user_authentication() {
725        let auth_manager = AuthManager::new();
726
727        // Wait for initialization
728        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
729
730        // Test postgres user authentication
731        assert!(auth_manager.authenticate("postgres", "").await.unwrap());
732        assert!(!auth_manager
733            .authenticate("nonexistent", "password")
734            .await
735            .unwrap());
736    }
737
738    #[tokio::test]
739    async fn test_role_management() {
740        let auth_manager = AuthManager::new();
741
742        // Wait for initialization
743        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
744
745        // Test role checking
746        assert!(auth_manager.user_has_role("postgres", "postgres").await);
747        assert!(auth_manager.user_has_role("postgres", "any_role").await); // superuser
748    }
749}