1use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
7pub enum FeatureTier {
8 OpenSource,
9 Premium,
10 Enterprise,
11}
12
13#[derive(Debug, Clone)]
15pub struct FeatureRegistry {
16 features: HashMap<String, FeatureTier>,
17}
18
19impl FeatureRegistry {
20 pub fn new() -> Self {
21 let mut features = HashMap::new();
22
23 features.insert(
25 "basic_position_encoding".to_string(),
26 FeatureTier::OpenSource,
27 );
28 features.insert("similarity_search".to_string(), FeatureTier::OpenSource);
29 features.insert("basic_tactical_search".to_string(), FeatureTier::OpenSource);
30 features.insert("uci_basic".to_string(), FeatureTier::OpenSource);
31 features.insert("opening_book".to_string(), FeatureTier::OpenSource);
32 features.insert("json_training_data".to_string(), FeatureTier::OpenSource);
33 features.insert("basic_persistence".to_string(), FeatureTier::OpenSource);
34
35 features.insert("advanced_nnue".to_string(), FeatureTier::Premium);
37 features.insert("gpu_acceleration".to_string(), FeatureTier::Premium);
38 features.insert("ultra_fast_loading".to_string(), FeatureTier::Premium);
39 features.insert("memory_mapped_files".to_string(), FeatureTier::Premium);
40 features.insert("advanced_tactical_search".to_string(), FeatureTier::Premium);
41 features.insert("pondering".to_string(), FeatureTier::Premium);
42 features.insert("multi_pv_analysis".to_string(), FeatureTier::Premium);
43 features.insert("advanced_pruning".to_string(), FeatureTier::Premium);
44 features.insert("parallel_search".to_string(), FeatureTier::Premium);
45
46 features.insert("distributed_training".to_string(), FeatureTier::Enterprise);
48 features.insert("cloud_deployment".to_string(), FeatureTier::Enterprise);
49 features.insert("enterprise_analytics".to_string(), FeatureTier::Enterprise);
50 features.insert("custom_algorithms".to_string(), FeatureTier::Enterprise);
51 features.insert("dedicated_support".to_string(), FeatureTier::Enterprise);
52 features.insert("unlimited_positions".to_string(), FeatureTier::Enterprise);
53
54 Self { features }
55 }
56
57 pub fn get_feature_tier(&self, feature: &str) -> Option<&FeatureTier> {
58 self.features.get(feature)
59 }
60
61 pub fn is_feature_available(&self, feature: &str, current_tier: &FeatureTier) -> bool {
62 match self.get_feature_tier(feature) {
63 Some(required_tier) => Self::tier_includes(current_tier, required_tier),
64 None => false, }
66 }
67
68 fn tier_includes(current: &FeatureTier, required: &FeatureTier) -> bool {
70 matches!(
71 (current, required),
72 (FeatureTier::OpenSource, FeatureTier::OpenSource)
73 | (FeatureTier::Premium, FeatureTier::OpenSource)
74 | (FeatureTier::Premium, FeatureTier::Premium)
75 | (FeatureTier::Enterprise, _)
76 )
77 }
78
79 pub fn get_features_for_tier(&self, tier: &FeatureTier) -> Vec<String> {
80 self.features
81 .iter()
82 .filter(|(_, required_tier)| Self::tier_includes(tier, required_tier))
83 .map(|(feature, _)| feature.clone())
84 .collect()
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct FeatureChecker {
91 registry: FeatureRegistry,
92 current_tier: FeatureTier,
93}
94
95impl FeatureChecker {
96 pub fn new(tier: FeatureTier) -> Self {
97 Self {
98 registry: FeatureRegistry::new(),
99 current_tier: tier,
100 }
101 }
102
103 pub fn check_feature(&self, feature: &str) -> Result<(), FeatureError> {
104 if self
105 .registry
106 .is_feature_available(feature, &self.current_tier)
107 {
108 Ok(())
109 } else {
110 match self.registry.get_feature_tier(feature) {
111 Some(required_tier) => Err(FeatureError::InsufficientTier {
112 feature: feature.to_string(),
113 required: required_tier.clone(),
114 current: self.current_tier.clone(),
115 }),
116 None => Err(FeatureError::UnknownFeature(feature.to_string())),
117 }
118 }
119 }
120
121 pub fn require_feature(&self, feature: &str) -> Result<(), FeatureError> {
122 self.check_feature(feature)
123 }
124
125 pub fn get_current_tier(&self) -> &FeatureTier {
126 &self.current_tier
127 }
128
129 pub fn upgrade_tier(&mut self, new_tier: FeatureTier) {
130 self.current_tier = new_tier;
131 }
132}
133
134#[derive(Debug, Clone)]
136pub enum FeatureError {
137 InsufficientTier {
138 feature: String,
139 required: FeatureTier,
140 current: FeatureTier,
141 },
142 UnknownFeature(String),
143}
144
145impl std::fmt::Display for FeatureError {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 match self {
148 FeatureError::InsufficientTier {
149 feature,
150 required,
151 current,
152 } => {
153 write!(
154 f,
155 "Feature '{feature}' requires {required:?} tier, but current tier is {current:?}. Please upgrade your subscription."
156 )
157 }
158 FeatureError::UnknownFeature(feature) => {
159 write!(f, "Unknown feature: '{feature}'")
160 }
161 }
162 }
163}
164
165impl std::error::Error for FeatureError {}
166
167#[macro_export]
169macro_rules! require_feature {
170 ($checker:expr, $feature:expr) => {
171 $checker.require_feature($feature)?
172 };
173}
174
175#[macro_export]
177macro_rules! if_feature {
178 ($checker:expr, $feature:expr, $code:block) => {
179 if $checker.check_feature($feature).is_ok() {
180 $code
181 }
182 };
183}
184
185#[macro_export]
186macro_rules! if_feature_else {
187 ($checker:expr, $feature:expr, $if_code:block, $else_code:block) => {
188 if $checker.check_feature($feature).is_ok() {
189 $if_code
190 } else {
191 $else_code
192 }
193 };
194}
195
196impl Default for FeatureRegistry {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl Default for FeatureChecker {
203 fn default() -> Self {
204 Self::new(FeatureTier::OpenSource)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_feature_registry() {
214 let registry = FeatureRegistry::new();
215
216 assert_eq!(
218 registry.get_feature_tier("basic_position_encoding"),
219 Some(&FeatureTier::OpenSource)
220 );
221
222 assert_eq!(
224 registry.get_feature_tier("gpu_acceleration"),
225 Some(&FeatureTier::Premium)
226 );
227
228 assert_eq!(
230 registry.get_feature_tier("distributed_training"),
231 Some(&FeatureTier::Enterprise)
232 );
233 }
234
235 #[test]
236 fn test_tier_access() {
237 let registry = FeatureRegistry::new();
238
239 assert!(registry.is_feature_available("basic_position_encoding", &FeatureTier::OpenSource));
241 assert!(!registry.is_feature_available("gpu_acceleration", &FeatureTier::OpenSource));
242 assert!(!registry.is_feature_available("distributed_training", &FeatureTier::OpenSource));
243
244 assert!(registry.is_feature_available("basic_position_encoding", &FeatureTier::Premium));
246 assert!(registry.is_feature_available("gpu_acceleration", &FeatureTier::Premium));
247 assert!(!registry.is_feature_available("distributed_training", &FeatureTier::Premium));
248
249 assert!(registry.is_feature_available("basic_position_encoding", &FeatureTier::Enterprise));
251 assert!(registry.is_feature_available("gpu_acceleration", &FeatureTier::Enterprise));
252 assert!(registry.is_feature_available("distributed_training", &FeatureTier::Enterprise));
253 }
254
255 #[test]
256 fn test_feature_checker() {
257 let mut checker = FeatureChecker::new(FeatureTier::OpenSource);
258
259 assert!(checker.check_feature("basic_position_encoding").is_ok());
261
262 assert!(checker.check_feature("gpu_acceleration").is_err());
264
265 checker.upgrade_tier(FeatureTier::Premium);
267
268 assert!(checker.check_feature("gpu_acceleration").is_ok());
270 }
271
272 #[test]
273 fn test_feature_error_messages() {
274 let checker = FeatureChecker::new(FeatureTier::OpenSource);
275
276 match checker.check_feature("gpu_acceleration") {
277 Err(FeatureError::InsufficientTier {
278 feature,
279 required,
280 current,
281 }) => {
282 assert_eq!(feature, "gpu_acceleration");
283 assert_eq!(required, FeatureTier::Premium);
284 assert_eq!(current, FeatureTier::OpenSource);
285 }
286 _ => panic!("Expected InsufficientTier error"),
287 }
288
289 match checker.check_feature("nonexistent_feature") {
290 Err(FeatureError::UnknownFeature(feature)) => {
291 assert_eq!(feature, "nonexistent_feature");
292 }
293 _ => panic!("Expected UnknownFeature error"),
294 }
295 }
296}