use std::future::Future;
use std::pin::Pin;
#[cfg(feature = "aes")]
use crate::Password;
#[cfg(feature = "aes")]
pub trait AsyncPasswordProvider: Send + Sync {
fn get_password(&self) -> Pin<Box<dyn Future<Output = Option<Password>> + Send + '_>>;
}
#[cfg(feature = "aes")]
#[derive(Debug, Clone)]
pub struct AsyncPassword {
password: Option<Password>,
}
#[cfg(feature = "aes")]
impl AsyncPassword {
pub fn new(password: impl Into<Password>) -> Self {
Self {
password: Some(password.into()),
}
}
pub fn none() -> Self {
Self { password: None }
}
pub fn from_option(password: Option<Password>) -> Self {
Self { password }
}
}
#[cfg(feature = "aes")]
impl AsyncPasswordProvider for AsyncPassword {
fn get_password(&self) -> Pin<Box<dyn Future<Output = Option<Password>> + Send + '_>> {
let password = self.password.clone();
Box::pin(async move { password })
}
}
#[cfg(feature = "aes")]
pub struct InteractivePasswordProvider {
receiver: tokio::sync::Mutex<Option<tokio::sync::oneshot::Receiver<Option<Password>>>>,
}
#[cfg(feature = "aes")]
impl InteractivePasswordProvider {
pub fn new() -> (tokio::sync::oneshot::Sender<Option<Password>>, Self) {
let (tx, rx) = tokio::sync::oneshot::channel();
let provider = Self {
receiver: tokio::sync::Mutex::new(Some(rx)),
};
(tx, provider)
}
pub fn from_receiver(receiver: tokio::sync::oneshot::Receiver<Option<Password>>) -> Self {
Self {
receiver: tokio::sync::Mutex::new(Some(receiver)),
}
}
}
#[cfg(feature = "aes")]
impl AsyncPasswordProvider for InteractivePasswordProvider {
fn get_password(&self) -> Pin<Box<dyn Future<Output = Option<Password>> + Send + '_>> {
Box::pin(async move {
let mut guard = self.receiver.lock().await;
if let Some(rx) = guard.take() {
rx.await.ok().flatten()
} else {
None
}
})
}
}
#[cfg(feature = "aes")]
impl std::fmt::Debug for InteractivePasswordProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InteractivePasswordProvider").finish()
}
}
#[cfg(feature = "aes")]
pub struct CallbackPasswordProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Option<Password>> + Send,
{
callback: F,
}
#[cfg(feature = "aes")]
impl<F, Fut> CallbackPasswordProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Option<Password>> + Send,
{
pub fn new(callback: F) -> Self {
Self { callback }
}
}
#[cfg(feature = "aes")]
impl<F, Fut> AsyncPasswordProvider for CallbackPasswordProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Option<Password>> + Send + 'static,
{
fn get_password(&self) -> Pin<Box<dyn Future<Output = Option<Password>> + Send + '_>> {
Box::pin((self.callback)())
}
}
#[cfg(feature = "aes")]
impl<F, Fut> std::fmt::Debug for CallbackPasswordProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Option<Password>> + Send,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CallbackPasswordProvider").finish()
}
}
#[cfg(all(test, feature = "aes"))]
mod tests {
use super::*;
#[tokio::test]
async fn test_async_password_with_value() {
let provider = AsyncPassword::new("test_password");
let password = provider.get_password().await;
assert!(password.is_some());
assert_eq!(password.unwrap().as_str(), "test_password");
}
#[tokio::test]
async fn test_async_password_none() {
let provider = AsyncPassword::none();
let password = provider.get_password().await;
assert!(password.is_none());
}
#[tokio::test]
async fn test_interactive_password_provider() {
let (tx, provider) = InteractivePasswordProvider::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
tx.send(Some(Password::new("interactive_password"))).ok();
});
let password = provider.get_password().await;
assert!(password.is_some());
assert_eq!(password.unwrap().as_str(), "interactive_password");
}
#[tokio::test]
async fn test_interactive_password_provider_cancelled() {
let (tx, provider) = InteractivePasswordProvider::new();
drop(tx);
let password = provider.get_password().await;
assert!(password.is_none());
}
#[tokio::test]
async fn test_callback_password_provider() {
let provider =
CallbackPasswordProvider::new(|| async { Some(Password::new("callback_password")) });
let password = provider.get_password().await;
assert!(password.is_some());
assert_eq!(password.unwrap().as_str(), "callback_password");
}
#[tokio::test]
async fn test_async_password_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AsyncPassword>();
}
}