Skip to main content

modo/tier/
types.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Error, Result};
9
10/// Whether a feature is a boolean toggle or a usage limit.
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FeatureAccess {
14    /// Feature is enabled or disabled.
15    Toggle(bool),
16    /// Feature has a usage limit ceiling.
17    Limit(u64),
18}
19
20/// Resolved tier information for an owner.
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub struct TierInfo {
23    /// Plan name (e.g., "free", "pro", "enterprise").
24    pub name: String,
25    /// Feature map: feature name → access level.
26    pub features: HashMap<String, FeatureAccess>,
27}
28
29/// Backend trait for tier resolution. Object-safe.
30///
31/// The app implements this with its own storage/logic — the framework
32/// provides the trait, wrapper, middleware, and guards.
33pub trait TierBackend: Send + Sync {
34    /// Resolve tier information for the given owner.
35    ///
36    /// # Errors
37    ///
38    /// Implementation-defined. Errors are surfaced by
39    /// [`TierLayer`](super::TierLayer) as HTTP error responses.
40    fn resolve(
41        &self,
42        owner_id: &str,
43    ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>>;
44}
45
46/// Concrete wrapper around a [`TierBackend`]. `Arc` internally, cheap to clone.
47#[derive(Clone)]
48pub struct TierResolver(Arc<dyn TierBackend>);
49
50impl TierInfo {
51    /// Feature is available (Toggle=true or Limit>0).
52    pub fn has_feature(&self, name: &str) -> bool {
53        match self.features.get(name) {
54            Some(FeatureAccess::Toggle(v)) => *v,
55            Some(FeatureAccess::Limit(v)) => *v > 0,
56            None => false,
57        }
58    }
59
60    /// Feature is explicitly enabled (Toggle only, false for Limit or missing).
61    pub fn is_enabled(&self, name: &str) -> bool {
62        matches!(self.features.get(name), Some(FeatureAccess::Toggle(true)))
63    }
64
65    /// Get the limit ceiling (Limit only, None for Toggle or missing).
66    pub fn limit(&self, name: &str) -> Option<u64> {
67        match self.features.get(name) {
68            Some(FeatureAccess::Limit(v)) => Some(*v),
69            _ => None,
70        }
71    }
72
73    /// Get the limit ceiling, returning typed errors for missing or non-limit features.
74    ///
75    /// Returns `Ok(ceiling)` for `Limit` features.
76    ///
77    /// # Errors
78    ///
79    /// - [`Error::forbidden`](crate::Error::forbidden) if the feature is missing.
80    /// - [`Error::internal`](crate::Error::internal) if the feature is a `Toggle` (not a limit).
81    pub fn limit_ceiling(&self, name: &str) -> Result<u64> {
82        match self.features.get(name) {
83            Some(FeatureAccess::Limit(v)) => Ok(*v),
84            Some(FeatureAccess::Toggle(_)) => {
85                Err(Error::internal(format!("Feature '{name}' is not a limit")))
86            }
87            None => Err(Error::forbidden(format!(
88                "Feature '{name}' is not available on your current plan"
89            ))),
90        }
91    }
92
93    /// Check current usage against limit ceiling.
94    ///
95    /// Returns `Ok(())` if usage is under the limit.
96    ///
97    /// # Errors
98    ///
99    /// - [`Error::forbidden`](crate::Error::forbidden) if the feature is missing or usage >= limit.
100    /// - [`Error::internal`](crate::Error::internal) if the feature is a `Toggle` (not a limit).
101    pub fn check_limit(&self, name: &str, current: u64) -> Result<()> {
102        let ceiling = self.limit_ceiling(name)?;
103        if current >= ceiling {
104            Err(Error::forbidden(format!(
105                "Limit exceeded for '{name}': {current}/{ceiling}"
106            )))
107        } else {
108            Ok(())
109        }
110    }
111}
112
113impl TierResolver {
114    /// Create from a custom backend.
115    pub fn from_backend(backend: Arc<dyn TierBackend>) -> Self {
116        Self(backend)
117    }
118
119    /// Resolve tier information for an owner.
120    ///
121    /// # Errors
122    ///
123    /// Returns any error produced by the underlying [`TierBackend`].
124    pub async fn resolve(&self, owner_id: &str) -> Result<TierInfo> {
125        self.0.resolve(owner_id).await
126    }
127}
128
129/// Test helpers for the tier module.
130///
131/// Available when running tests or when the `test-helpers` feature is enabled.
132#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
133pub mod test_support {
134    use super::*;
135
136    /// In-memory backend that returns a fixed `TierInfo` for any owner ID.
137    pub struct StaticTierBackend {
138        tier: TierInfo,
139    }
140
141    impl StaticTierBackend {
142        /// Create a backend that always returns the given tier.
143        pub fn new(tier: TierInfo) -> Self {
144            Self { tier }
145        }
146    }
147
148    impl TierBackend for StaticTierBackend {
149        fn resolve(
150            &self,
151            _owner_id: &str,
152        ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
153            Box::pin(async { Ok(self.tier.clone()) })
154        }
155    }
156
157    /// In-memory backend that always returns an error.
158    pub struct FailingTierBackend;
159
160    impl TierBackend for FailingTierBackend {
161        fn resolve(
162            &self,
163            _owner_id: &str,
164        ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
165            Box::pin(async { Err(Error::internal("test: backend failure")) })
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    fn free_tier() -> TierInfo {
175        TierInfo {
176            name: "free".into(),
177            features: HashMap::from([
178                ("basic_export".into(), FeatureAccess::Toggle(true)),
179                ("sso".into(), FeatureAccess::Toggle(false)),
180                ("api_calls".into(), FeatureAccess::Limit(1_000)),
181                ("storage_mb".into(), FeatureAccess::Limit(0)),
182            ]),
183        }
184    }
185
186    fn pro_tier() -> TierInfo {
187        TierInfo {
188            name: "pro".into(),
189            features: HashMap::from([
190                ("basic_export".into(), FeatureAccess::Toggle(true)),
191                ("sso".into(), FeatureAccess::Toggle(true)),
192                ("api_calls".into(), FeatureAccess::Limit(100_000)),
193            ]),
194        }
195    }
196
197    // --- has_feature ---
198
199    #[test]
200    fn has_feature_toggle_true() {
201        assert!(free_tier().has_feature("basic_export"));
202    }
203
204    #[test]
205    fn has_feature_toggle_false() {
206        assert!(!free_tier().has_feature("sso"));
207    }
208
209    #[test]
210    fn has_feature_limit_positive() {
211        assert!(free_tier().has_feature("api_calls"));
212    }
213
214    #[test]
215    fn has_feature_limit_zero() {
216        assert!(!free_tier().has_feature("storage_mb"));
217    }
218
219    #[test]
220    fn has_feature_missing() {
221        assert!(!free_tier().has_feature("nonexistent"));
222    }
223
224    // --- is_enabled ---
225
226    #[test]
227    fn is_enabled_toggle_true() {
228        assert!(pro_tier().is_enabled("sso"));
229    }
230
231    #[test]
232    fn is_enabled_toggle_false() {
233        assert!(!free_tier().is_enabled("sso"));
234    }
235
236    #[test]
237    fn is_enabled_limit_returns_false() {
238        assert!(!free_tier().is_enabled("api_calls"));
239    }
240
241    #[test]
242    fn is_enabled_missing_returns_false() {
243        assert!(!free_tier().is_enabled("nonexistent"));
244    }
245
246    // --- limit ---
247
248    #[test]
249    fn limit_returns_ceiling() {
250        assert_eq!(free_tier().limit("api_calls"), Some(1_000));
251    }
252
253    #[test]
254    fn limit_toggle_returns_none() {
255        assert_eq!(free_tier().limit("basic_export"), None);
256    }
257
258    #[test]
259    fn limit_missing_returns_none() {
260        assert_eq!(free_tier().limit("nonexistent"), None);
261    }
262
263    // --- check_limit ---
264
265    #[test]
266    fn check_limit_under_ok() {
267        assert!(free_tier().check_limit("api_calls", 500).is_ok());
268    }
269
270    #[test]
271    fn check_limit_at_ceiling_forbidden() {
272        let err = free_tier().check_limit("api_calls", 1_000).unwrap_err();
273        assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
274    }
275
276    #[test]
277    fn check_limit_over_ceiling_forbidden() {
278        let err = free_tier().check_limit("api_calls", 2_000).unwrap_err();
279        assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
280    }
281
282    #[test]
283    fn check_limit_toggle_internal_error() {
284        let err = free_tier().check_limit("basic_export", 0).unwrap_err();
285        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
286    }
287
288    #[test]
289    fn check_limit_missing_forbidden() {
290        let err = free_tier().check_limit("nonexistent", 0).unwrap_err();
291        assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
292    }
293
294    // --- FeatureAccess serde ---
295
296    #[test]
297    fn feature_access_toggle_roundtrip() {
298        let v = FeatureAccess::Toggle(true);
299        let json = serde_json::to_string(&v).unwrap();
300        let back: FeatureAccess = serde_json::from_str(&json).unwrap();
301        assert!(matches!(back, FeatureAccess::Toggle(true)));
302    }
303
304    #[test]
305    fn feature_access_limit_roundtrip() {
306        let v = FeatureAccess::Limit(5_000);
307        let json = serde_json::to_string(&v).unwrap();
308        let back: FeatureAccess = serde_json::from_str(&json).unwrap();
309        assert!(matches!(back, FeatureAccess::Limit(5_000)));
310    }
311
312    #[test]
313    fn tier_info_serde_roundtrip() {
314        let tier = free_tier();
315        let json = serde_json::to_string(&tier).unwrap();
316        let back: TierInfo = serde_json::from_str(&json).unwrap();
317        assert_eq!(back.name, "free");
318        assert!(back.has_feature("basic_export"));
319        assert!(!back.has_feature("sso"));
320    }
321
322    // --- TierResolver ---
323
324    struct StaticBackend(TierInfo);
325
326    impl TierBackend for StaticBackend {
327        fn resolve(
328            &self,
329            _owner_id: &str,
330        ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
331            Box::pin(async { Ok(self.0.clone()) })
332        }
333    }
334
335    #[tokio::test]
336    async fn resolver_delegates_to_backend() {
337        let resolver = TierResolver::from_backend(Arc::new(StaticBackend(pro_tier())));
338        let info = resolver.resolve("tenant_123").await.unwrap();
339        assert_eq!(info.name, "pro");
340    }
341}