datafusion_postgres/
auth.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use pgwire::api::auth::{AuthSource, LoginInfo, Password};
6use pgwire::error::{PgWireError, PgWireResult};
7use tokio::sync::RwLock;
8
9use datafusion_pg_catalog::pg_catalog::context::*;
10
11/// Authentication manager that handles users and roles
12#[derive(Debug, Clone)]
13pub struct AuthManager {
14    users: Arc<RwLock<HashMap<String, User>>>,
15    roles: Arc<RwLock<HashMap<String, Role>>>,
16}
17
18impl Default for AuthManager {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl AuthManager {
25    pub fn new() -> Self {
26        let mut users = HashMap::new();
27        // Initialize with default postgres superuser
28        let postgres_user = User {
29            username: "postgres".to_string(),
30            password_hash: "".to_string(), // Empty password for now
31            roles: vec!["postgres".to_string()],
32            is_superuser: true,
33            can_login: true,
34            connection_limit: None,
35        };
36        users.insert(postgres_user.username.clone(), postgres_user);
37
38        let mut roles = HashMap::new();
39        let postgres_role = Role {
40            name: "postgres".to_string(),
41            is_superuser: true,
42            can_login: true,
43            can_create_db: true,
44            can_create_role: true,
45            can_create_user: true,
46            can_replication: true,
47            grants: vec![Grant {
48                permission: Permission::All,
49                resource: ResourceType::All,
50                granted_by: "system".to_string(),
51                with_grant_option: true,
52            }],
53            inherited_roles: vec![],
54        };
55        roles.insert(postgres_role.name.clone(), postgres_role);
56
57        AuthManager {
58            users: Arc::new(RwLock::new(users)),
59            roles: Arc::new(RwLock::new(roles)),
60        }
61    }
62
63    /// Add a new user to the system
64    pub async fn add_user(&self, user: User) -> PgWireResult<()> {
65        let mut users = self.users.write().await;
66        users.insert(user.username.clone(), user);
67        Ok(())
68    }
69
70    /// Add a new role to the system
71    pub async fn add_role(&self, role: Role) -> PgWireResult<()> {
72        let mut roles = self.roles.write().await;
73        roles.insert(role.name.clone(), role);
74        Ok(())
75    }
76
77    /// Authenticate a user with username and password
78    pub async fn authenticate(&self, username: &str, password: &str) -> PgWireResult<bool> {
79        let users = self.users.read().await;
80
81        if let Some(user) = users.get(username) {
82            if !user.can_login {
83                return Ok(false);
84            }
85
86            // For now, accept empty password or any password for existing users
87            // In production, this should use proper password hashing (bcrypt, etc.)
88            if user.password_hash.is_empty() || password == user.password_hash {
89                return Ok(true);
90            }
91        }
92
93        // If user doesn't exist, check if we should create them dynamically
94        // For now, only accept known users
95        Ok(false)
96    }
97
98    /// Get user information
99    pub async fn get_user(&self, username: &str) -> Option<User> {
100        let users = self.users.read().await;
101        users.get(username).cloned()
102    }
103
104    /// Get role information
105    pub async fn get_role(&self, role_name: &str) -> Option<Role> {
106        let roles = self.roles.read().await;
107        roles.get(role_name).cloned()
108    }
109
110    /// Check if user has a specific role
111    pub async fn user_has_role(&self, username: &str, role_name: &str) -> bool {
112        if let Some(user) = self.get_user(username).await {
113            return user.roles.contains(&role_name.to_string()) || user.is_superuser;
114        }
115        false
116    }
117
118    /// List all users (for administrative purposes)
119    pub async fn list_users(&self) -> Vec<String> {
120        let users = self.users.read().await;
121        users.keys().cloned().collect()
122    }
123
124    /// List all roles (for administrative purposes)
125    pub async fn list_roles(&self) -> Vec<String> {
126        let roles = self.roles.read().await;
127        roles.keys().cloned().collect()
128    }
129
130    /// Grant permission to a role
131    pub async fn grant_permission(
132        &self,
133        role_name: &str,
134        permission: Permission,
135        resource: ResourceType,
136        granted_by: &str,
137        with_grant_option: bool,
138    ) -> PgWireResult<()> {
139        let mut roles = self.roles.write().await;
140
141        if let Some(role) = roles.get_mut(role_name) {
142            let grant = Grant {
143                permission,
144                resource,
145                granted_by: granted_by.to_string(),
146                with_grant_option,
147            };
148            role.grants.push(grant);
149            Ok(())
150        } else {
151            Err(PgWireError::UserError(Box::new(
152                pgwire::error::ErrorInfo::new(
153                    "ERROR".to_string(),
154                    "42704".to_string(), // undefined_object
155                    format!("role \"{role_name}\" does not exist"),
156                ),
157            )))
158        }
159    }
160
161    /// Revoke permission from a role
162    pub async fn revoke_permission(
163        &self,
164        role_name: &str,
165        permission: Permission,
166        resource: ResourceType,
167    ) -> PgWireResult<()> {
168        let mut roles = self.roles.write().await;
169
170        if let Some(role) = roles.get_mut(role_name) {
171            role.grants
172                .retain(|grant| !(grant.permission == permission && grant.resource == resource));
173            Ok(())
174        } else {
175            Err(PgWireError::UserError(Box::new(
176                pgwire::error::ErrorInfo::new(
177                    "ERROR".to_string(),
178                    "42704".to_string(), // undefined_object
179                    format!("role \"{role_name}\" does not exist"),
180                ),
181            )))
182        }
183    }
184
185    /// Check if a user has a specific permission on a resource
186    pub async fn check_permission(
187        &self,
188        username: &str,
189        permission: Permission,
190        resource: ResourceType,
191    ) -> bool {
192        // Superusers have all permissions
193        if let Some(user) = self.get_user(username).await {
194            if user.is_superuser {
195                return true;
196            }
197
198            // Check permissions for each role the user has
199            for role_name in &user.roles {
200                if let Some(role) = self.get_role(role_name).await {
201                    // Superuser role has all permissions
202                    if role.is_superuser {
203                        return true;
204                    }
205
206                    // Check direct grants
207                    for grant in &role.grants {
208                        if self.permission_matches(&grant.permission, &permission)
209                            && self.resource_matches(&grant.resource, &resource)
210                        {
211                            return true;
212                        }
213                    }
214
215                    // Check inherited roles recursively
216                    for inherited_role in &role.inherited_roles {
217                        if self
218                            .check_role_permission(inherited_role, &permission, &resource)
219                            .await
220                        {
221                            return true;
222                        }
223                    }
224                }
225            }
226        }
227
228        false
229    }
230
231    /// Check if a role has a specific permission (helper for recursive checking)
232    fn check_role_permission<'a>(
233        &'a self,
234        role_name: &'a str,
235        permission: &'a Permission,
236        resource: &'a ResourceType,
237    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send + 'a>> {
238        Box::pin(async move {
239            if let Some(role) = self.get_role(role_name).await {
240                if role.is_superuser {
241                    return true;
242                }
243
244                // Check direct grants
245                for grant in &role.grants {
246                    if self.permission_matches(&grant.permission, permission)
247                        && self.resource_matches(&grant.resource, resource)
248                    {
249                        return true;
250                    }
251                }
252
253                // Check inherited roles
254                for inherited_role in &role.inherited_roles {
255                    if self
256                        .check_role_permission(inherited_role, permission, resource)
257                        .await
258                    {
259                        return true;
260                    }
261                }
262            }
263
264            false
265        })
266    }
267
268    /// Check if a permission grant matches the requested permission
269    fn permission_matches(&self, grant_permission: &Permission, requested: &Permission) -> bool {
270        grant_permission == requested || matches!(grant_permission, Permission::All)
271    }
272
273    /// Check if a resource grant matches the requested resource
274    fn resource_matches(&self, grant_resource: &ResourceType, requested: &ResourceType) -> bool {
275        match (grant_resource, requested) {
276            // Exact match
277            (a, b) if a == b => true,
278            // All resource type grants access to everything
279            (ResourceType::All, _) => true,
280            // Schema grants access to all tables in that schema
281            (ResourceType::Schema(schema), ResourceType::Table(table)) => {
282                // For simplicity, assume table names are schema.table format
283                table.starts_with(&format!("{schema}."))
284            }
285            _ => false,
286        }
287    }
288
289    /// Add role inheritance
290    pub async fn add_role_inheritance(
291        &self,
292        child_role: &str,
293        parent_role: &str,
294    ) -> PgWireResult<()> {
295        let mut roles = self.roles.write().await;
296
297        if let Some(child) = roles.get_mut(child_role) {
298            if !child.inherited_roles.contains(&parent_role.to_string()) {
299                child.inherited_roles.push(parent_role.to_string());
300            }
301            Ok(())
302        } else {
303            Err(PgWireError::UserError(Box::new(
304                pgwire::error::ErrorInfo::new(
305                    "ERROR".to_string(),
306                    "42704".to_string(), // undefined_object
307                    format!("role \"{child_role}\" does not exist"),
308                ),
309            )))
310        }
311    }
312
313    /// Remove role inheritance
314    pub async fn remove_role_inheritance(
315        &self,
316        child_role: &str,
317        parent_role: &str,
318    ) -> PgWireResult<()> {
319        let mut roles = self.roles.write().await;
320
321        if let Some(child) = roles.get_mut(child_role) {
322            child.inherited_roles.retain(|role| role != parent_role);
323            Ok(())
324        } else {
325            Err(PgWireError::UserError(Box::new(
326                pgwire::error::ErrorInfo::new(
327                    "ERROR".to_string(),
328                    "42704".to_string(), // undefined_object
329                    format!("role \"{child_role}\" does not exist"),
330                ),
331            )))
332        }
333    }
334
335    /// Create a new role with specific capabilities
336    pub async fn create_role(&self, config: RoleConfig) -> PgWireResult<()> {
337        let role = Role {
338            name: config.name.clone(),
339            is_superuser: config.is_superuser,
340            can_login: config.can_login,
341            can_create_db: config.can_create_db,
342            can_create_role: config.can_create_role,
343            can_create_user: config.can_create_user,
344            can_replication: config.can_replication,
345            grants: vec![],
346            inherited_roles: vec![],
347        };
348
349        self.add_role(role).await
350    }
351
352    /// Create common predefined roles
353    pub async fn create_predefined_roles(&self) -> PgWireResult<()> {
354        // Read-only role
355        self.create_role(RoleConfig {
356            name: "readonly".to_string(),
357            is_superuser: false,
358            can_login: false,
359            can_create_db: false,
360            can_create_role: false,
361            can_create_user: false,
362            can_replication: false,
363        })
364        .await?;
365
366        self.grant_permission(
367            "readonly",
368            Permission::Select,
369            ResourceType::All,
370            "system",
371            false,
372        )
373        .await?;
374
375        // Read-write role
376        self.create_role(RoleConfig {
377            name: "readwrite".to_string(),
378            is_superuser: false,
379            can_login: false,
380            can_create_db: false,
381            can_create_role: false,
382            can_create_user: false,
383            can_replication: false,
384        })
385        .await?;
386
387        self.grant_permission(
388            "readwrite",
389            Permission::Select,
390            ResourceType::All,
391            "system",
392            false,
393        )
394        .await?;
395
396        self.grant_permission(
397            "readwrite",
398            Permission::Insert,
399            ResourceType::All,
400            "system",
401            false,
402        )
403        .await?;
404
405        self.grant_permission(
406            "readwrite",
407            Permission::Update,
408            ResourceType::All,
409            "system",
410            false,
411        )
412        .await?;
413
414        self.grant_permission(
415            "readwrite",
416            Permission::Delete,
417            ResourceType::All,
418            "system",
419            false,
420        )
421        .await?;
422
423        // Database admin role
424        self.create_role(RoleConfig {
425            name: "dbadmin".to_string(),
426            is_superuser: false,
427            can_login: true,
428            can_create_db: true,
429            can_create_role: false,
430            can_create_user: false,
431            can_replication: false,
432        })
433        .await?;
434
435        self.grant_permission(
436            "dbadmin",
437            Permission::All,
438            ResourceType::All,
439            "system",
440            true,
441        )
442        .await?;
443
444        Ok(())
445    }
446}
447
448#[async_trait]
449impl PgCatalogContextProvider for AuthManager {
450    // retrieve all database role names
451    async fn roles(&self) -> Vec<String> {
452        self.list_roles().await
453    }
454
455    // retrieve database role information
456    async fn role(&self, name: &str) -> Option<Role> {
457        self.get_role(name).await
458    }
459}
460
461/// AuthSource implementation for integration with pgwire authentication
462/// Provides proper password-based authentication instead of custom startup handler
463#[derive(Clone)]
464pub struct DfAuthSource {
465    pub auth_manager: Arc<AuthManager>,
466}
467
468impl DfAuthSource {
469    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
470        DfAuthSource { auth_manager }
471    }
472}
473
474#[async_trait]
475impl AuthSource for DfAuthSource {
476    async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
477        if let Some(username) = login.user() {
478            // Check if user exists in our RBAC system
479            if let Some(user) = self.auth_manager.get_user(username).await {
480                if user.can_login {
481                    // Return the stored password hash for authentication
482                    // The pgwire authentication handlers (cleartext/md5/scram) will
483                    // handle the actual password verification process
484                    Ok(Password::new(None, user.password_hash.into_bytes()))
485                } else {
486                    Err(PgWireError::UserError(Box::new(
487                        pgwire::error::ErrorInfo::new(
488                            "FATAL".to_string(),
489                            "28000".to_string(), // invalid_authorization_specification
490                            format!("User \"{username}\" is not allowed to login"),
491                        ),
492                    )))
493                }
494            } else {
495                Err(PgWireError::UserError(Box::new(
496                    pgwire::error::ErrorInfo::new(
497                        "FATAL".to_string(),
498                        "28P01".to_string(), // invalid_password
499                        format!("password authentication failed for user \"{username}\""),
500                    ),
501                )))
502            }
503        } else {
504            Err(PgWireError::UserError(Box::new(
505                pgwire::error::ErrorInfo::new(
506                    "FATAL".to_string(),
507                    "28P01".to_string(), // invalid_password
508                    "No username provided in login request".to_string(),
509                ),
510            )))
511        }
512    }
513}
514
515// REMOVED: Custom startup handler approach
516//
517// Instead of implementing a custom StartupHandler, use the proper pgwire authentication:
518//
519// For cleartext authentication:
520// ```rust
521// use pgwire::api::auth::cleartext::CleartextStartupHandler;
522//
523// let auth_source = Arc::new(DfAuthSource::new(auth_manager));
524// let authenticator = CleartextStartupHandler::new(
525//     auth_source,
526//     Arc::new(DefaultServerParameterProvider::default())
527// );
528// ```
529//
530// For MD5 authentication:
531// ```rust
532// use pgwire::api::auth::md5::MD5StartupHandler;
533//
534// let auth_source = Arc::new(DfAuthSource::new(auth_manager));
535// let authenticator = MD5StartupHandler::new(
536//     auth_source,
537//     Arc::new(DefaultServerParameterProvider::default())
538// );
539// ```
540//
541// For SCRAM authentication (requires "server-api-scram" feature):
542// ```rust
543// use pgwire::api::auth::scram::SASLScramAuthStartupHandler;
544//
545// let auth_source = Arc::new(DfAuthSource::new(auth_manager));
546// let authenticator = SASLScramAuthStartupHandler::new(
547//     auth_source,
548//     Arc::new(DefaultServerParameterProvider::default())
549// );
550// ```
551
552/// Simple AuthSource implementation that accepts any user with empty password
553pub struct SimpleAuthSource {
554    auth_manager: Arc<AuthManager>,
555}
556
557impl SimpleAuthSource {
558    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
559        SimpleAuthSource { auth_manager }
560    }
561}
562
563#[async_trait]
564impl AuthSource for SimpleAuthSource {
565    async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
566        let username = login.user().unwrap_or("anonymous");
567
568        // Check if user exists and can login
569        if let Some(user) = self.auth_manager.get_user(username).await {
570            if user.can_login {
571                // Return empty password for now (no authentication required)
572                return Ok(Password::new(None, vec![]));
573            }
574        }
575
576        // For postgres user, always allow
577        if username == "postgres" {
578            return Ok(Password::new(None, vec![]));
579        }
580
581        // User not found or cannot login
582        Err(PgWireError::UserError(Box::new(
583            pgwire::error::ErrorInfo::new(
584                "FATAL".to_string(),
585                "28P01".to_string(), // invalid_password
586                format!("password authentication failed for user \"{username}\""),
587            ),
588        )))
589    }
590}
591
592/// Helper function to create auth source with auth manager
593pub fn create_auth_source(auth_manager: Arc<AuthManager>) -> SimpleAuthSource {
594    SimpleAuthSource::new(auth_manager)
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[tokio::test]
602    async fn test_auth_manager_creation() {
603        let auth_manager = AuthManager::new();
604
605        // Wait a bit for the default user to be added
606        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
607
608        let users = auth_manager.list_users().await;
609        assert!(users.contains(&"postgres".to_string()));
610    }
611
612    #[tokio::test]
613    async fn test_user_authentication() {
614        let auth_manager = AuthManager::new();
615
616        // Wait for initialization
617        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
618
619        // Test postgres user authentication
620        assert!(auth_manager.authenticate("postgres", "").await.unwrap());
621        assert!(!auth_manager
622            .authenticate("nonexistent", "password")
623            .await
624            .unwrap());
625    }
626
627    #[tokio::test]
628    async fn test_role_management() {
629        let auth_manager = AuthManager::new();
630
631        // Wait for initialization
632        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
633
634        // Test role checking
635        assert!(auth_manager.user_has_role("postgres", "postgres").await);
636        assert!(auth_manager.user_has_role("postgres", "any_role").await); // superuser
637    }
638}