Skip to main content

drasi_lib/identity/
mod.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Identity providers for authentication credentials.
16
17use anyhow::Result;
18use async_trait::async_trait;
19
20/// Trait for identity providers that supply authentication credentials.
21///
22/// This is a plugin trait (Layer 3) — implementations return `anyhow::Result`
23/// and should use `.context()` for error chains. The framework wraps these
24/// into `DrasiError` at the public API boundary.
25#[async_trait]
26pub trait IdentityProvider: Send + Sync {
27    /// Fetch credentials for authentication.
28    async fn get_credentials(&self) -> Result<Credentials>;
29
30    /// Clone the provider into a boxed trait object.
31    fn clone_box(&self) -> Box<dyn IdentityProvider>;
32}
33
34impl Clone for Box<dyn IdentityProvider> {
35    fn clone(&self) -> Self {
36        self.clone_box()
37    }
38}
39
40/// Credentials returned by an identity provider.
41#[derive(Clone, PartialEq, Eq)]
42pub enum Credentials {
43    /// Traditional username and password authentication.
44    UsernamePassword { username: String, password: String },
45    /// Token-based authentication (Azure AD, AWS IAM, etc.).
46    Token { username: String, token: String },
47    /// Client certificate authentication (mTLS).
48    ///
49    /// Used for database connections that authenticate via TLS client certificates
50    /// instead of passwords or tokens.
51    Certificate {
52        /// PEM-encoded client certificate.
53        cert_pem: String,
54        /// PEM-encoded private key.
55        key_pem: String,
56        /// Optional username (some databases require it alongside certificates).
57        username: Option<String>,
58    },
59}
60
61// Manual Debug impl to redact sensitive fields (passwords, tokens, keys)
62impl std::fmt::Debug for Credentials {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Credentials::UsernamePassword { username, .. } => f
66                .debug_struct("UsernamePassword")
67                .field("username", username)
68                .field("password", &"[REDACTED]")
69                .finish(),
70            Credentials::Token { username, .. } => f
71                .debug_struct("Token")
72                .field("username", username)
73                .field("token", &"[REDACTED]")
74                .finish(),
75            Credentials::Certificate { username, .. } => f
76                .debug_struct("Certificate")
77                .field("cert_pem", &"[REDACTED]")
78                .field("key_pem", &"[REDACTED]")
79                .field("username", username)
80                .finish(),
81        }
82    }
83}
84
85impl Credentials {
86    /// Extract username and password/token for connection string building.
87    ///
88    /// Returns `Err(self)` if this is a `Certificate` variant.
89    pub fn try_into_auth_pair(self) -> std::result::Result<(String, String), Self> {
90        match self {
91            Credentials::UsernamePassword { username, password } => Ok((username, password)),
92            Credentials::Token { username, token } => Ok((username, token)),
93            other => Err(other),
94        }
95    }
96
97    /// Extract certificate and key for TLS client authentication.
98    ///
99    /// Returns `Ok((cert_pem, key_pem, optional_username))` for `Certificate` credentials,
100    /// or `Err(self)` for other variants.
101    pub fn try_into_certificate(
102        self,
103    ) -> std::result::Result<(String, String, Option<String>), Self> {
104        match self {
105            Credentials::Certificate {
106                cert_pem,
107                key_pem,
108                username,
109            } => Ok((cert_pem, key_pem, username)),
110            other => Err(other),
111        }
112    }
113
114    /// Extract username and password/token for connection string building.
115    ///
116    /// # Panics
117    /// Panics if called on `Certificate` credentials.
118    ///
119    /// # Deprecated
120    /// Use [`try_into_auth_pair`](Self::try_into_auth_pair) instead.
121    #[deprecated(note = "Use try_into_auth_pair() which returns Result instead of panicking")]
122    pub(crate) fn into_auth_pair(self) -> (String, String) {
123        self.try_into_auth_pair()
124            .unwrap_or_else(|_| panic!("Certificate credentials cannot be converted to an auth pair. Use try_into_auth_pair() or try_into_certificate() instead."))
125    }
126
127    /// Extract certificate and key for TLS client authentication.
128    ///
129    /// # Panics
130    /// Panics if called on non-Certificate credentials.
131    ///
132    /// # Deprecated
133    /// Use [`try_into_certificate`](Self::try_into_certificate) instead.
134    #[deprecated(note = "Use try_into_certificate() which returns Result instead of panicking")]
135    pub(crate) fn into_certificate(self) -> (String, String, Option<String>) {
136        self.try_into_certificate()
137            .unwrap_or_else(|_| panic!("Not certificate credentials. Use try_into_certificate() or try_into_auth_pair() instead."))
138    }
139
140    /// Returns `true` if this is a `Certificate` variant.
141    pub fn is_certificate(&self) -> bool {
142        matches!(self, Credentials::Certificate { .. })
143    }
144}
145
146mod password;
147pub use password::PasswordIdentityProvider;
148
149#[cfg(feature = "azure-identity")]
150mod azure;
151#[cfg(feature = "azure-identity")]
152pub use azure::AzureIdentityProvider;
153
154#[cfg(feature = "aws-identity")]
155mod aws;
156#[cfg(feature = "aws-identity")]
157pub use aws::AwsIdentityProvider;
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[tokio::test]
164    async fn test_password_provider() {
165        let provider = PasswordIdentityProvider::new("testuser", "testpass");
166        let credentials = provider.get_credentials().await.unwrap();
167
168        match credentials {
169            Credentials::UsernamePassword { username, password } => {
170                assert_eq!(username, "testuser");
171                assert_eq!(password, "testpass");
172            }
173            _ => panic!("Expected UsernamePassword credentials"),
174        }
175    }
176
177    #[tokio::test]
178    async fn test_provider_clone() {
179        let provider: Box<dyn IdentityProvider> =
180            Box::new(PasswordIdentityProvider::new("user", "pass"));
181        let cloned = provider.clone();
182
183        let credentials = cloned.get_credentials().await.unwrap();
184        assert!(matches!(credentials, Credentials::UsernamePassword { .. }));
185    }
186
187    #[test]
188    fn test_try_into_auth_pair_username_password() {
189        let creds = Credentials::UsernamePassword {
190            username: "user".into(),
191            password: "pass".into(),
192        };
193        let (u, p) = creds.try_into_auth_pair().unwrap();
194        assert_eq!(u, "user");
195        assert_eq!(p, "pass");
196    }
197
198    #[test]
199    fn test_try_into_auth_pair_token() {
200        let creds = Credentials::Token {
201            username: "user".into(),
202            token: "tok".into(),
203        };
204        let (u, t) = creds.try_into_auth_pair().unwrap();
205        assert_eq!(u, "user");
206        assert_eq!(t, "tok");
207    }
208
209    #[test]
210    fn test_try_into_auth_pair_rejects_certificate() {
211        let creds = Credentials::Certificate {
212            cert_pem: "cert".into(),
213            key_pem: "key".into(),
214            username: None,
215        };
216        let result = creds.try_into_auth_pair();
217        assert!(result.is_err());
218        // Verify the original credentials are returned in the Err
219        let returned = result.unwrap_err();
220        assert!(returned.is_certificate());
221    }
222
223    #[test]
224    fn test_try_into_certificate_success() {
225        let creds = Credentials::Certificate {
226            cert_pem: "cert".into(),
227            key_pem: "key".into(),
228            username: Some("user".into()),
229        };
230        let (c, k, u) = creds.try_into_certificate().unwrap();
231        assert_eq!(c, "cert");
232        assert_eq!(k, "key");
233        assert_eq!(u, Some("user".into()));
234    }
235
236    #[test]
237    fn test_try_into_certificate_rejects_password() {
238        let creds = Credentials::UsernamePassword {
239            username: "user".into(),
240            password: "pass".into(),
241        };
242        let result = creds.try_into_certificate();
243        assert!(result.is_err());
244    }
245}