rustauth-plugins 0.2.0

Official RustAuth plugin modules.
Documentation
use std::future::Future;
use std::pin::Pin;
use std::task::Poll;

use rustauth_core::error::RustAuthError;
use rustauth_core::options::SecondaryStorage;
use time::OffsetDateTime;

use super::keys::{storage_key_by_hash, storage_key_by_id, storage_key_by_reference};
use super::listing::{list_from_secondary_storage, ListOptions, ListResult};
use super::ApiKeyStore;
use crate::api_key::models::ApiKeyRecord;

const STORAGE_CONCURRENCY: usize = 10;
type StorageFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, RustAuthError>> + Send + 'a>>;

const REF_INDEX_MAX_CAS_ATTEMPTS: usize = 32;

impl ApiKeyStore<'_> {
    pub(super) async fn list_secondary(
        &self,
        reference_id: &str,
        options: ListOptions,
    ) -> Result<ListResult, RustAuthError> {
        let Some(storage) = self.secondary_storage() else {
            return Ok(ListResult {
                api_keys: Vec::new(),
                total: 0,
            });
        };
        list_from_secondary_storage(&*storage, reference_id, &options).await
    }

    pub(super) async fn set_secondary(&self, api_key: &ApiKeyRecord) -> Result<(), RustAuthError> {
        let Some(storage) = self.secondary_storage() else {
            return Err(RustAuthError::Adapter(
                "secondary storage is required for API key secondary-storage mode".to_owned(),
            ));
        };
        set_secondary(&*storage, api_key, self.options.fallback_to_database).await
    }

    pub(super) async fn delete_secondary(
        &self,
        api_key: &ApiKeyRecord,
    ) -> Result<(), RustAuthError> {
        let Some(storage) = self.secondary_storage() else {
            return Err(RustAuthError::Adapter(
                "secondary storage is required for API key secondary-storage mode".to_owned(),
            ));
        };
        storage.delete(&storage_key_by_hash(&api_key.key)).await?;
        storage.delete(&storage_key_by_id(&api_key.id)).await?;
        let ref_key = storage_key_by_reference(&api_key.reference_id);
        if self.options.fallback_to_database {
            storage.delete(&ref_key).await?;
        } else {
            mutate_ref_index(storage.as_ref(), &ref_key, |ids| {
                ids.retain(|id| id != &api_key.id);
            })
            .await?;
        }
        Ok(())
    }
}

pub(super) async fn get_secondary_bounded(
    storage: &dyn SecondaryStorage,
    ids: Vec<String>,
) -> Result<Vec<ApiKeyRecord>, RustAuthError> {
    let mut ids = ids.into_iter();
    let mut in_flight: Vec<StorageFuture<'_, Option<ApiKeyRecord>>> = Vec::new();
    fill_secondary_gets(storage, &mut ids, &mut in_flight);
    let mut api_keys = Vec::new();
    while !in_flight.is_empty() {
        let result = poll_next_ready(&mut in_flight).await?;
        if let Some(api_key) = result {
            api_keys.push(api_key);
        }
        fill_secondary_gets(storage, &mut ids, &mut in_flight);
    }
    Ok(api_keys)
}

fn fill_secondary_gets<'a>(
    storage: &'a dyn SecondaryStorage,
    ids: &mut std::vec::IntoIter<String>,
    in_flight: &mut Vec<StorageFuture<'a, Option<ApiKeyRecord>>>,
) {
    while in_flight.len() < STORAGE_CONCURRENCY {
        let Some(id) = ids.next() else {
            break;
        };
        in_flight.push(Box::pin(async move {
            get_secondary(storage, &storage_key_by_id(&id)).await
        }));
    }
}

pub(super) async fn set_secondary_bounded(
    storage: &dyn SecondaryStorage,
    api_keys: &[ApiKeyRecord],
    fallback_to_database: bool,
) -> Result<(), RustAuthError> {
    let mut api_keys = api_keys.iter();
    let mut in_flight: Vec<StorageFuture<'_, ()>> = Vec::new();
    fill_secondary_sets(storage, &mut api_keys, fallback_to_database, &mut in_flight);
    while !in_flight.is_empty() {
        poll_next_ready(&mut in_flight).await?;
        fill_secondary_sets(storage, &mut api_keys, fallback_to_database, &mut in_flight);
    }
    Ok(())
}

fn fill_secondary_sets<'a>(
    storage: &'a dyn SecondaryStorage,
    api_keys: &mut std::slice::Iter<'a, ApiKeyRecord>,
    fallback_to_database: bool,
    in_flight: &mut Vec<StorageFuture<'a, ()>>,
) {
    while in_flight.len() < STORAGE_CONCURRENCY {
        let Some(api_key) = api_keys.next() else {
            break;
        };
        in_flight.push(Box::pin(async move {
            set_secondary(storage, api_key, fallback_to_database).await
        }));
    }
}

async fn poll_next_ready<'a, T>(
    in_flight: &mut Vec<StorageFuture<'a, T>>,
) -> Result<T, RustAuthError> {
    std::future::poll_fn(|context| {
        let mut index = 0;
        while index < in_flight.len() {
            if let Poll::Ready(result) = in_flight[index].as_mut().poll(context) {
                drop(in_flight.swap_remove(index));
                return Poll::Ready(result);
            }
            index += 1;
        }
        Poll::Pending
    })
    .await
}

async fn set_secondary(
    storage: &dyn SecondaryStorage,
    api_key: &ApiKeyRecord,
    fallback_to_database: bool,
) -> Result<(), RustAuthError> {
    let ttl = ttl_seconds(api_key);
    let serialized = serde_json::to_string(api_key)
        .map_err(|error| RustAuthError::Adapter(error.to_string()))?;
    storage
        .set(&storage_key_by_hash(&api_key.key), serialized.clone(), ttl)
        .await?;
    storage
        .set(&storage_key_by_id(&api_key.id), serialized, ttl)
        .await?;
    let ref_key = storage_key_by_reference(&api_key.reference_id);
    if fallback_to_database {
        storage.delete(&ref_key).await?;
        return Ok(());
    }
    mutate_ref_index(storage, &ref_key, |ids| {
        if !ids.iter().any(|id| id == &api_key.id) {
            ids.push(api_key.id.clone());
        }
    })
    .await
}

pub(super) async fn get_secondary(
    storage: &dyn SecondaryStorage,
    key: &str,
) -> Result<Option<ApiKeyRecord>, RustAuthError> {
    let Some(raw) = storage.get(key).await? else {
        return Ok(None);
    };
    Ok(serde_json::from_str::<ApiKeyRecord>(&raw).ok())
}

fn ttl_seconds(api_key: &ApiKeyRecord) -> Option<u64> {
    let expires_at = api_key.expires_at?;
    let seconds = (expires_at - OffsetDateTime::now_utc()).whole_seconds();
    u64::try_from(seconds).ok().filter(|seconds| *seconds > 0)
}

pub(super) async fn mutate_ref_index<F>(
    storage: &dyn SecondaryStorage,
    ref_key: &str,
    mutate: F,
) -> Result<(), RustAuthError>
where
    F: Fn(&mut Vec<String>),
{
    for _ in 0..REF_INDEX_MAX_CAS_ATTEMPTS {
        let expected = storage.get(ref_key).await?;
        let mut ids = expected
            .as_deref()
            .and_then(|raw| serde_json::from_str::<Vec<String>>(raw).ok())
            .unwrap_or_default();
        mutate(&mut ids);
        if ids.is_empty() {
            if storage.delete_if_value(ref_key, expected).await? {
                return Ok(());
            }
        } else {
            let next = serde_json::to_string(&ids)
                .map_err(|error| RustAuthError::Adapter(error.to_string()))?;
            if storage
                .compare_and_set(ref_key, expected, next, None)
                .await?
            {
                return Ok(());
            }
        }
        tokio::task::yield_now().await;
    }
    Err(RustAuthError::Adapter(
        "api key secondary-storage reference index update conflicted too many times".to_owned(),
    ))
}