Skip to main content

drasi_lib/identity/
application.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! In-process identity provider that delegates credential acquisition to a
16//! host-supplied closure.
17//!
18//! This is the identity-provider counterpart to `ApplicationSource`,
19//! `ApplicationReaction`, and `ApplicationBootstrapProvider`: it lets a host
20//! application reuse its existing authentication code (Azure AD, AWS, Vault,
21//! etc.) instead of configuring a separate Drasi identity-provider plugin.
22
23use super::{CredentialContext, Credentials, IdentityProvider};
24use anyhow::Result;
25use async_trait::async_trait;
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29
30/// Boxed async credential callback that a user-supplied closure is stored as.
31type AsyncCredentialCallback = dyn Fn(&CredentialContext) -> Pin<Box<dyn Future<Output = Result<Credentials>> + Send>>
32    + Send
33    + Sync;
34
35/// Identity provider whose `get_credentials` calls a host-supplied closure.
36///
37/// Use [`ApplicationIdentityProvider::new`] for an async closure (typical for
38/// real token-acquisition flows) or [`ApplicationIdentityProvider::new_sync`]
39/// for a synchronous closure (handy for tests and static credentials).
40///
41/// # Example
42///
43/// ```no_run
44/// use std::sync::Arc;
45/// use drasi_lib::identity::{ApplicationIdentityProvider, Credentials};
46///
47/// let provider = ApplicationIdentityProvider::new(|ctx| {
48///     let host = ctx.get("hostname").unwrap_or("default").to_string();
49///     async move {
50///         // call into your existing auth code here
51///         Ok(Credentials::UsernamePassword {
52///             username: format!("user@{host}"),
53///             password: "secret".into(),
54///         })
55///     }
56/// });
57/// let _provider: Arc<dyn drasi_lib::identity::IdentityProvider> = Arc::new(provider);
58/// ```
59#[derive(Clone)]
60pub struct ApplicationIdentityProvider {
61    callback: Arc<AsyncCredentialCallback>,
62}
63
64impl ApplicationIdentityProvider {
65    /// Create a provider backed by an async closure.
66    ///
67    /// See also [`ApplicationIdentityProvider::new_sync`] for synchronous
68    /// callbacks.
69    pub fn new<F, Fut>(callback: F) -> Self
70    where
71        F: Fn(&CredentialContext) -> Fut + Send + Sync + 'static,
72        Fut: Future<Output = Result<Credentials>> + Send + 'static,
73    {
74        let cb: Arc<AsyncCredentialCallback> =
75            Arc::new(move |ctx| Box::pin(callback(ctx)) as Pin<Box<_>>);
76        Self { callback: cb }
77    }
78
79    /// Create a provider backed by a synchronous closure.
80    ///
81    /// Convenience wrapper for callbacks that don't need to await — the
82    /// closure result is wrapped in a ready future.
83    ///
84    /// See also [`ApplicationIdentityProvider::new`] for async callbacks.
85    pub fn new_sync<F>(callback: F) -> Self
86    where
87        F: Fn(&CredentialContext) -> Result<Credentials> + Send + Sync + 'static,
88    {
89        let cb: Arc<AsyncCredentialCallback> = Arc::new(move |ctx| {
90            let result = callback(ctx);
91            Box::pin(async move { result })
92        });
93        Self { callback: cb }
94    }
95}
96
97impl std::fmt::Debug for ApplicationIdentityProvider {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("ApplicationIdentityProvider")
100            .finish_non_exhaustive()
101    }
102}
103
104#[async_trait]
105impl IdentityProvider for ApplicationIdentityProvider {
106    async fn get_credentials(&self, context: &CredentialContext) -> Result<Credentials> {
107        (self.callback)(context).await
108    }
109
110    fn clone_box(&self) -> Box<dyn IdentityProvider> {
111        Box::new(self.clone())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use std::sync::atomic::{AtomicUsize, Ordering};
119
120    #[tokio::test]
121    async fn sync_closure_returns_username_password() {
122        let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
123            Ok(Credentials::UsernamePassword {
124                username: "alice".into(),
125                password: "s3cret".into(),
126            })
127        });
128
129        let creds = provider
130            .get_credentials(&CredentialContext::new())
131            .await
132            .unwrap();
133
134        assert_eq!(
135            creds,
136            Credentials::UsernamePassword {
137                username: "alice".into(),
138                password: "s3cret".into(),
139            }
140        );
141    }
142
143    #[tokio::test]
144    async fn async_closure_returns_token() {
145        let provider = ApplicationIdentityProvider::new(|_ctx| async {
146            Ok(Credentials::Token {
147                username: "svc".into(),
148                token: "abc.def.ghi".into(),
149            })
150        });
151
152        let creds = provider
153            .get_credentials(&CredentialContext::new())
154            .await
155            .unwrap();
156
157        assert_eq!(
158            creds,
159            Credentials::Token {
160                username: "svc".into(),
161                token: "abc.def.ghi".into(),
162            }
163        );
164    }
165
166    #[tokio::test]
167    async fn callback_observes_credential_context() {
168        let provider = ApplicationIdentityProvider::new_sync(|ctx| {
169            let host = ctx.get("hostname").unwrap_or("none").to_string();
170            let port = ctx.get("port").unwrap_or("0").to_string();
171            Ok(Credentials::UsernamePassword {
172                username: format!("user@{host}:{port}"),
173                password: "pw".into(),
174            })
175        });
176
177        let ctx = CredentialContext::new()
178            .with_property("hostname", "db.example.com")
179            .with_property("port", "5432");
180
181        let creds = provider.get_credentials(&ctx).await.unwrap();
182
183        assert_eq!(
184            creds,
185            Credentials::UsernamePassword {
186                username: "user@db.example.com:5432".into(),
187                password: "pw".into(),
188            }
189        );
190    }
191
192    #[tokio::test]
193    async fn clone_box_shares_underlying_callback() {
194        let calls = Arc::new(AtomicUsize::new(0));
195        let calls_for_cb = calls.clone();
196
197        let provider = ApplicationIdentityProvider::new_sync(move |_ctx| {
198            calls_for_cb.fetch_add(1, Ordering::SeqCst);
199            Ok(Credentials::UsernamePassword {
200                username: "u".into(),
201                password: "p".into(),
202            })
203        });
204
205        let cloned = provider.clone_box();
206
207        provider
208            .get_credentials(&CredentialContext::new())
209            .await
210            .unwrap();
211        cloned
212            .get_credentials(&CredentialContext::new())
213            .await
214            .unwrap();
215
216        assert_eq!(calls.load(Ordering::SeqCst), 2);
217    }
218
219    #[tokio::test]
220    async fn callback_error_is_propagated() {
221        let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
222            Err(anyhow::anyhow!("auth backend unavailable"))
223        });
224
225        let err = provider
226            .get_credentials(&CredentialContext::new())
227            .await
228            .unwrap_err();
229
230        assert!(err.to_string().contains("auth backend unavailable"));
231    }
232
233    #[tokio::test]
234    async fn sync_closure_returns_certificate() {
235        let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
236            Ok(Credentials::Certificate {
237                cert_pem: "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----".into(),
238                key_pem: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----".into(),
239                username: Some("cert-user".into()),
240            })
241        });
242
243        let creds = provider
244            .get_credentials(&CredentialContext::new())
245            .await
246            .unwrap();
247
248        assert!(creds.is_certificate());
249        assert_eq!(
250            creds,
251            Credentials::Certificate {
252                cert_pem: "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----".into(),
253                key_pem: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----".into(),
254                username: Some("cert-user".into()),
255            }
256        );
257    }
258
259    #[tokio::test]
260    async fn debug_impl_does_not_leak_callback_state() {
261        let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
262            Ok(Credentials::UsernamePassword {
263                username: "should-not-appear".into(),
264                password: "super-secret".into(),
265            })
266        });
267
268        let formatted = format!("{provider:?}");
269
270        assert!(formatted.contains("ApplicationIdentityProvider"));
271        // The Debug impl is intentionally opaque — it must not surface
272        // closure-captured state or anything the closure might return.
273        assert!(!formatted.contains("super-secret"));
274        assert!(!formatted.contains("should-not-appear"));
275    }
276}