use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FeatureAccess {
Toggle(bool),
Limit(u64),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TierInfo {
pub name: String,
pub features: HashMap<String, FeatureAccess>,
}
pub trait TierBackend: Send + Sync {
fn resolve(
&self,
owner_id: &str,
) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>>;
}
#[derive(Clone)]
pub struct TierResolver(Arc<dyn TierBackend>);
impl TierInfo {
pub fn has_feature(&self, name: &str) -> bool {
match self.features.get(name) {
Some(FeatureAccess::Toggle(v)) => *v,
Some(FeatureAccess::Limit(v)) => *v > 0,
None => false,
}
}
pub fn is_enabled(&self, name: &str) -> bool {
matches!(self.features.get(name), Some(FeatureAccess::Toggle(true)))
}
pub fn limit(&self, name: &str) -> Option<u64> {
match self.features.get(name) {
Some(FeatureAccess::Limit(v)) => Some(*v),
_ => None,
}
}
pub fn limit_ceiling(&self, name: &str) -> Result<u64> {
match self.features.get(name) {
Some(FeatureAccess::Limit(v)) => Ok(*v),
Some(FeatureAccess::Toggle(_)) => {
Err(Error::internal(format!("Feature '{name}' is not a limit")))
}
None => Err(Error::forbidden(format!(
"Feature '{name}' is not available on your current plan"
))),
}
}
pub fn check_limit(&self, name: &str, current: u64) -> Result<()> {
let ceiling = self.limit_ceiling(name)?;
if current >= ceiling {
Err(Error::forbidden(format!(
"Limit exceeded for '{name}': {current}/{ceiling}"
)))
} else {
Ok(())
}
}
}
impl TierResolver {
pub fn from_backend(backend: Arc<dyn TierBackend>) -> Self {
Self(backend)
}
pub async fn resolve(&self, owner_id: &str) -> Result<TierInfo> {
self.0.resolve(owner_id).await
}
}
#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
pub mod test_support {
use super::*;
pub struct StaticTierBackend {
tier: TierInfo,
}
impl StaticTierBackend {
pub fn new(tier: TierInfo) -> Self {
Self { tier }
}
}
impl TierBackend for StaticTierBackend {
fn resolve(
&self,
_owner_id: &str,
) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
Box::pin(async { Ok(self.tier.clone()) })
}
}
pub struct FailingTierBackend;
impl TierBackend for FailingTierBackend {
fn resolve(
&self,
_owner_id: &str,
) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
Box::pin(async { Err(Error::internal("test: backend failure")) })
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn free_tier() -> TierInfo {
TierInfo {
name: "free".into(),
features: HashMap::from([
("basic_export".into(), FeatureAccess::Toggle(true)),
("sso".into(), FeatureAccess::Toggle(false)),
("api_calls".into(), FeatureAccess::Limit(1_000)),
("storage_mb".into(), FeatureAccess::Limit(0)),
]),
}
}
fn pro_tier() -> TierInfo {
TierInfo {
name: "pro".into(),
features: HashMap::from([
("basic_export".into(), FeatureAccess::Toggle(true)),
("sso".into(), FeatureAccess::Toggle(true)),
("api_calls".into(), FeatureAccess::Limit(100_000)),
]),
}
}
#[test]
fn has_feature_toggle_true() {
assert!(free_tier().has_feature("basic_export"));
}
#[test]
fn has_feature_toggle_false() {
assert!(!free_tier().has_feature("sso"));
}
#[test]
fn has_feature_limit_positive() {
assert!(free_tier().has_feature("api_calls"));
}
#[test]
fn has_feature_limit_zero() {
assert!(!free_tier().has_feature("storage_mb"));
}
#[test]
fn has_feature_missing() {
assert!(!free_tier().has_feature("nonexistent"));
}
#[test]
fn is_enabled_toggle_true() {
assert!(pro_tier().is_enabled("sso"));
}
#[test]
fn is_enabled_toggle_false() {
assert!(!free_tier().is_enabled("sso"));
}
#[test]
fn is_enabled_limit_returns_false() {
assert!(!free_tier().is_enabled("api_calls"));
}
#[test]
fn is_enabled_missing_returns_false() {
assert!(!free_tier().is_enabled("nonexistent"));
}
#[test]
fn limit_returns_ceiling() {
assert_eq!(free_tier().limit("api_calls"), Some(1_000));
}
#[test]
fn limit_toggle_returns_none() {
assert_eq!(free_tier().limit("basic_export"), None);
}
#[test]
fn limit_missing_returns_none() {
assert_eq!(free_tier().limit("nonexistent"), None);
}
#[test]
fn check_limit_under_ok() {
assert!(free_tier().check_limit("api_calls", 500).is_ok());
}
#[test]
fn check_limit_at_ceiling_forbidden() {
let err = free_tier().check_limit("api_calls", 1_000).unwrap_err();
assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
}
#[test]
fn check_limit_over_ceiling_forbidden() {
let err = free_tier().check_limit("api_calls", 2_000).unwrap_err();
assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
}
#[test]
fn check_limit_toggle_internal_error() {
let err = free_tier().check_limit("basic_export", 0).unwrap_err();
assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn check_limit_missing_forbidden() {
let err = free_tier().check_limit("nonexistent", 0).unwrap_err();
assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
}
#[test]
fn feature_access_toggle_roundtrip() {
let v = FeatureAccess::Toggle(true);
let json = serde_json::to_string(&v).unwrap();
let back: FeatureAccess = serde_json::from_str(&json).unwrap();
assert!(matches!(back, FeatureAccess::Toggle(true)));
}
#[test]
fn feature_access_limit_roundtrip() {
let v = FeatureAccess::Limit(5_000);
let json = serde_json::to_string(&v).unwrap();
let back: FeatureAccess = serde_json::from_str(&json).unwrap();
assert!(matches!(back, FeatureAccess::Limit(5_000)));
}
#[test]
fn tier_info_serde_roundtrip() {
let tier = free_tier();
let json = serde_json::to_string(&tier).unwrap();
let back: TierInfo = serde_json::from_str(&json).unwrap();
assert_eq!(back.name, "free");
assert!(back.has_feature("basic_export"));
assert!(!back.has_feature("sso"));
}
struct StaticBackend(TierInfo);
impl TierBackend for StaticBackend {
fn resolve(
&self,
_owner_id: &str,
) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
Box::pin(async { Ok(self.0.clone()) })
}
}
#[tokio::test]
async fn resolver_delegates_to_backend() {
let resolver = TierResolver::from_backend(Arc::new(StaticBackend(pro_tier())));
let info = resolver.resolve("tenant_123").await.unwrap();
assert_eq!(info.name, "pro");
}
}