drasi_lib/identity/
application.rs1use super::{CredentialContext, Credentials, IdentityProvider};
24use anyhow::Result;
25use async_trait::async_trait;
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29
30type AsyncCredentialCallback = dyn Fn(&CredentialContext) -> Pin<Box<dyn Future<Output = Result<Credentials>> + Send>>
32 + Send
33 + Sync;
34
35#[derive(Clone)]
60pub struct ApplicationIdentityProvider {
61 callback: Arc<AsyncCredentialCallback>,
62}
63
64impl ApplicationIdentityProvider {
65 pub fn new<F, Fut>(callback: F) -> Self
70 where
71 F: Fn(&CredentialContext) -> Fut + Send + Sync + 'static,
72 Fut: Future<Output = Result<Credentials>> + Send + 'static,
73 {
74 let cb: Arc<AsyncCredentialCallback> =
75 Arc::new(move |ctx| Box::pin(callback(ctx)) as Pin<Box<_>>);
76 Self { callback: cb }
77 }
78
79 pub fn new_sync<F>(callback: F) -> Self
86 where
87 F: Fn(&CredentialContext) -> Result<Credentials> + Send + Sync + 'static,
88 {
89 let cb: Arc<AsyncCredentialCallback> = Arc::new(move |ctx| {
90 let result = callback(ctx);
91 Box::pin(async move { result })
92 });
93 Self { callback: cb }
94 }
95}
96
97impl std::fmt::Debug for ApplicationIdentityProvider {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("ApplicationIdentityProvider")
100 .finish_non_exhaustive()
101 }
102}
103
104#[async_trait]
105impl IdentityProvider for ApplicationIdentityProvider {
106 async fn get_credentials(&self, context: &CredentialContext) -> Result<Credentials> {
107 (self.callback)(context).await
108 }
109
110 fn clone_box(&self) -> Box<dyn IdentityProvider> {
111 Box::new(self.clone())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use std::sync::atomic::{AtomicUsize, Ordering};
119
120 #[tokio::test]
121 async fn sync_closure_returns_username_password() {
122 let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
123 Ok(Credentials::UsernamePassword {
124 username: "alice".into(),
125 password: "s3cret".into(),
126 })
127 });
128
129 let creds = provider
130 .get_credentials(&CredentialContext::new())
131 .await
132 .unwrap();
133
134 assert_eq!(
135 creds,
136 Credentials::UsernamePassword {
137 username: "alice".into(),
138 password: "s3cret".into(),
139 }
140 );
141 }
142
143 #[tokio::test]
144 async fn async_closure_returns_token() {
145 let provider = ApplicationIdentityProvider::new(|_ctx| async {
146 Ok(Credentials::Token {
147 username: "svc".into(),
148 token: "abc.def.ghi".into(),
149 })
150 });
151
152 let creds = provider
153 .get_credentials(&CredentialContext::new())
154 .await
155 .unwrap();
156
157 assert_eq!(
158 creds,
159 Credentials::Token {
160 username: "svc".into(),
161 token: "abc.def.ghi".into(),
162 }
163 );
164 }
165
166 #[tokio::test]
167 async fn callback_observes_credential_context() {
168 let provider = ApplicationIdentityProvider::new_sync(|ctx| {
169 let host = ctx.get("hostname").unwrap_or("none").to_string();
170 let port = ctx.get("port").unwrap_or("0").to_string();
171 Ok(Credentials::UsernamePassword {
172 username: format!("user@{host}:{port}"),
173 password: "pw".into(),
174 })
175 });
176
177 let ctx = CredentialContext::new()
178 .with_property("hostname", "db.example.com")
179 .with_property("port", "5432");
180
181 let creds = provider.get_credentials(&ctx).await.unwrap();
182
183 assert_eq!(
184 creds,
185 Credentials::UsernamePassword {
186 username: "user@db.example.com:5432".into(),
187 password: "pw".into(),
188 }
189 );
190 }
191
192 #[tokio::test]
193 async fn clone_box_shares_underlying_callback() {
194 let calls = Arc::new(AtomicUsize::new(0));
195 let calls_for_cb = calls.clone();
196
197 let provider = ApplicationIdentityProvider::new_sync(move |_ctx| {
198 calls_for_cb.fetch_add(1, Ordering::SeqCst);
199 Ok(Credentials::UsernamePassword {
200 username: "u".into(),
201 password: "p".into(),
202 })
203 });
204
205 let cloned = provider.clone_box();
206
207 provider
208 .get_credentials(&CredentialContext::new())
209 .await
210 .unwrap();
211 cloned
212 .get_credentials(&CredentialContext::new())
213 .await
214 .unwrap();
215
216 assert_eq!(calls.load(Ordering::SeqCst), 2);
217 }
218
219 #[tokio::test]
220 async fn callback_error_is_propagated() {
221 let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
222 Err(anyhow::anyhow!("auth backend unavailable"))
223 });
224
225 let err = provider
226 .get_credentials(&CredentialContext::new())
227 .await
228 .unwrap_err();
229
230 assert!(err.to_string().contains("auth backend unavailable"));
231 }
232
233 #[tokio::test]
234 async fn sync_closure_returns_certificate() {
235 let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
236 Ok(Credentials::Certificate {
237 cert_pem: "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----".into(),
238 key_pem: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----".into(),
239 username: Some("cert-user".into()),
240 })
241 });
242
243 let creds = provider
244 .get_credentials(&CredentialContext::new())
245 .await
246 .unwrap();
247
248 assert!(creds.is_certificate());
249 assert_eq!(
250 creds,
251 Credentials::Certificate {
252 cert_pem: "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----".into(),
253 key_pem: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----".into(),
254 username: Some("cert-user".into()),
255 }
256 );
257 }
258
259 #[tokio::test]
260 async fn debug_impl_does_not_leak_callback_state() {
261 let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
262 Ok(Credentials::UsernamePassword {
263 username: "should-not-appear".into(),
264 password: "super-secret".into(),
265 })
266 });
267
268 let formatted = format!("{provider:?}");
269
270 assert!(formatted.contains("ApplicationIdentityProvider"));
271 assert!(!formatted.contains("super-secret"));
274 assert!(!formatted.contains("should-not-appear"));
275 }
276}