1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Error, Result};
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FeatureAccess {
14 Toggle(bool),
16 Limit(u64),
18}
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub struct TierInfo {
23 pub name: String,
25 pub features: HashMap<String, FeatureAccess>,
27}
28
29pub trait TierBackend: Send + Sync {
34 fn resolve(
41 &self,
42 owner_id: &str,
43 ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>>;
44}
45
46#[derive(Clone)]
48pub struct TierResolver(Arc<dyn TierBackend>);
49
50impl TierInfo {
51 pub fn has_feature(&self, name: &str) -> bool {
53 match self.features.get(name) {
54 Some(FeatureAccess::Toggle(v)) => *v,
55 Some(FeatureAccess::Limit(v)) => *v > 0,
56 None => false,
57 }
58 }
59
60 pub fn is_enabled(&self, name: &str) -> bool {
62 matches!(self.features.get(name), Some(FeatureAccess::Toggle(true)))
63 }
64
65 pub fn limit(&self, name: &str) -> Option<u64> {
67 match self.features.get(name) {
68 Some(FeatureAccess::Limit(v)) => Some(*v),
69 _ => None,
70 }
71 }
72
73 pub fn limit_ceiling(&self, name: &str) -> Result<u64> {
82 match self.features.get(name) {
83 Some(FeatureAccess::Limit(v)) => Ok(*v),
84 Some(FeatureAccess::Toggle(_)) => {
85 Err(Error::internal(format!("Feature '{name}' is not a limit")))
86 }
87 None => Err(Error::forbidden(format!(
88 "Feature '{name}' is not available on your current plan"
89 ))),
90 }
91 }
92
93 pub fn check_limit(&self, name: &str, current: u64) -> Result<()> {
102 let ceiling = self.limit_ceiling(name)?;
103 if current >= ceiling {
104 Err(Error::forbidden(format!(
105 "Limit exceeded for '{name}': {current}/{ceiling}"
106 )))
107 } else {
108 Ok(())
109 }
110 }
111}
112
113impl TierResolver {
114 pub fn from_backend(backend: Arc<dyn TierBackend>) -> Self {
116 Self(backend)
117 }
118
119 pub async fn resolve(&self, owner_id: &str) -> Result<TierInfo> {
125 self.0.resolve(owner_id).await
126 }
127}
128
129#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
133pub mod test_support {
134 use super::*;
135
136 pub struct StaticTierBackend {
138 tier: TierInfo,
139 }
140
141 impl StaticTierBackend {
142 pub fn new(tier: TierInfo) -> Self {
144 Self { tier }
145 }
146 }
147
148 impl TierBackend for StaticTierBackend {
149 fn resolve(
150 &self,
151 _owner_id: &str,
152 ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
153 Box::pin(async { Ok(self.tier.clone()) })
154 }
155 }
156
157 pub struct FailingTierBackend;
159
160 impl TierBackend for FailingTierBackend {
161 fn resolve(
162 &self,
163 _owner_id: &str,
164 ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
165 Box::pin(async { Err(Error::internal("test: backend failure")) })
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 fn free_tier() -> TierInfo {
175 TierInfo {
176 name: "free".into(),
177 features: HashMap::from([
178 ("basic_export".into(), FeatureAccess::Toggle(true)),
179 ("sso".into(), FeatureAccess::Toggle(false)),
180 ("api_calls".into(), FeatureAccess::Limit(1_000)),
181 ("storage_mb".into(), FeatureAccess::Limit(0)),
182 ]),
183 }
184 }
185
186 fn pro_tier() -> TierInfo {
187 TierInfo {
188 name: "pro".into(),
189 features: HashMap::from([
190 ("basic_export".into(), FeatureAccess::Toggle(true)),
191 ("sso".into(), FeatureAccess::Toggle(true)),
192 ("api_calls".into(), FeatureAccess::Limit(100_000)),
193 ]),
194 }
195 }
196
197 #[test]
200 fn has_feature_toggle_true() {
201 assert!(free_tier().has_feature("basic_export"));
202 }
203
204 #[test]
205 fn has_feature_toggle_false() {
206 assert!(!free_tier().has_feature("sso"));
207 }
208
209 #[test]
210 fn has_feature_limit_positive() {
211 assert!(free_tier().has_feature("api_calls"));
212 }
213
214 #[test]
215 fn has_feature_limit_zero() {
216 assert!(!free_tier().has_feature("storage_mb"));
217 }
218
219 #[test]
220 fn has_feature_missing() {
221 assert!(!free_tier().has_feature("nonexistent"));
222 }
223
224 #[test]
227 fn is_enabled_toggle_true() {
228 assert!(pro_tier().is_enabled("sso"));
229 }
230
231 #[test]
232 fn is_enabled_toggle_false() {
233 assert!(!free_tier().is_enabled("sso"));
234 }
235
236 #[test]
237 fn is_enabled_limit_returns_false() {
238 assert!(!free_tier().is_enabled("api_calls"));
239 }
240
241 #[test]
242 fn is_enabled_missing_returns_false() {
243 assert!(!free_tier().is_enabled("nonexistent"));
244 }
245
246 #[test]
249 fn limit_returns_ceiling() {
250 assert_eq!(free_tier().limit("api_calls"), Some(1_000));
251 }
252
253 #[test]
254 fn limit_toggle_returns_none() {
255 assert_eq!(free_tier().limit("basic_export"), None);
256 }
257
258 #[test]
259 fn limit_missing_returns_none() {
260 assert_eq!(free_tier().limit("nonexistent"), None);
261 }
262
263 #[test]
266 fn check_limit_under_ok() {
267 assert!(free_tier().check_limit("api_calls", 500).is_ok());
268 }
269
270 #[test]
271 fn check_limit_at_ceiling_forbidden() {
272 let err = free_tier().check_limit("api_calls", 1_000).unwrap_err();
273 assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
274 }
275
276 #[test]
277 fn check_limit_over_ceiling_forbidden() {
278 let err = free_tier().check_limit("api_calls", 2_000).unwrap_err();
279 assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
280 }
281
282 #[test]
283 fn check_limit_toggle_internal_error() {
284 let err = free_tier().check_limit("basic_export", 0).unwrap_err();
285 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
286 }
287
288 #[test]
289 fn check_limit_missing_forbidden() {
290 let err = free_tier().check_limit("nonexistent", 0).unwrap_err();
291 assert_eq!(err.status(), http::StatusCode::FORBIDDEN);
292 }
293
294 #[test]
297 fn feature_access_toggle_roundtrip() {
298 let v = FeatureAccess::Toggle(true);
299 let json = serde_json::to_string(&v).unwrap();
300 let back: FeatureAccess = serde_json::from_str(&json).unwrap();
301 assert!(matches!(back, FeatureAccess::Toggle(true)));
302 }
303
304 #[test]
305 fn feature_access_limit_roundtrip() {
306 let v = FeatureAccess::Limit(5_000);
307 let json = serde_json::to_string(&v).unwrap();
308 let back: FeatureAccess = serde_json::from_str(&json).unwrap();
309 assert!(matches!(back, FeatureAccess::Limit(5_000)));
310 }
311
312 #[test]
313 fn tier_info_serde_roundtrip() {
314 let tier = free_tier();
315 let json = serde_json::to_string(&tier).unwrap();
316 let back: TierInfo = serde_json::from_str(&json).unwrap();
317 assert_eq!(back.name, "free");
318 assert!(back.has_feature("basic_export"));
319 assert!(!back.has_feature("sso"));
320 }
321
322 struct StaticBackend(TierInfo);
325
326 impl TierBackend for StaticBackend {
327 fn resolve(
328 &self,
329 _owner_id: &str,
330 ) -> Pin<Box<dyn Future<Output = Result<TierInfo>> + Send + '_>> {
331 Box::pin(async { Ok(self.0.clone()) })
332 }
333 }
334
335 #[tokio::test]
336 async fn resolver_delegates_to_backend() {
337 let resolver = TierResolver::from_backend(Arc::new(StaticBackend(pro_tier())));
338 let info = resolver.resolve("tenant_123").await.unwrap();
339 assert_eq!(info.name, "pro");
340 }
341}