1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum Variant {
12 Control,
14 Treatment,
16 Custom(String),
18}
19
20impl Variant {
21 #[must_use]
23 pub fn is_control(&self) -> bool {
24 matches!(self, Variant::Control)
25 }
26
27 #[must_use]
29 pub fn is_treatment(&self) -> bool {
30 matches!(self, Variant::Treatment)
31 }
32
33 #[must_use]
35 pub fn name(&self) -> &str {
36 match self {
37 Variant::Control => "control",
38 Variant::Treatment => "treatment",
39 Variant::Custom(name) => name,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Experiment {
59 pub id: String,
61 pub name: String,
63 pub description: String,
65 pub rollout_percentage: u8,
67 pub enabled: bool,
69 pub variant_weights: Option<HashMap<String, u8>>,
71 pub sticky: bool,
73}
74
75impl Experiment {
76 #[must_use]
78 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
79 Self {
80 id: id.into(),
81 name: name.into(),
82 description: String::new(),
83 rollout_percentage: 0,
84 enabled: false,
85 variant_weights: None,
86 sticky: true,
87 }
88 }
89
90 #[must_use]
92 pub fn with_description(mut self, description: impl Into<String>) -> Self {
93 self.description = description.into();
94 self
95 }
96
97 #[must_use]
99 pub fn with_rollout_percentage(mut self, percentage: u8) -> Self {
100 self.rollout_percentage = percentage.min(100);
101 self
102 }
103
104 #[must_use]
106 pub fn enabled(mut self) -> Self {
107 self.enabled = true;
108 self
109 }
110
111 #[must_use]
113 pub fn with_variant_weights(mut self, weights: HashMap<String, u8>) -> Self {
114 self.variant_weights = Some(weights);
115 self
116 }
117
118 #[must_use]
120 pub fn non_sticky(mut self) -> Self {
121 self.sticky = false;
122 self
123 }
124
125 #[must_use]
129 pub fn assign_variant(&self, user_id: &str) -> Variant {
130 if !self.enabled {
131 return Variant::Control;
132 }
133
134 let hash = self.hash_user(user_id);
136 let bucket = hash % 100;
137
138 if bucket >= u64::from(self.rollout_percentage) {
140 return Variant::Control;
141 }
142
143 if let Some(weights) = &self.variant_weights {
145 let mut cumulative = 0u8;
146 for (variant_name, weight) in weights {
147 cumulative += weight;
148 if bucket < u64::from(cumulative) {
149 return Variant::Custom(variant_name.clone());
150 }
151 }
152 }
153
154 Variant::Treatment
156 }
157
158 #[must_use]
160 pub fn is_enrolled(&self, user_id: &str) -> bool {
161 self.enabled && !self.assign_variant(user_id).is_control()
162 }
163
164 fn hash_user(&self, user_id: &str) -> u64 {
166 let combined = format!("{}{}", self.id, user_id);
168 let mut hash = 0u64;
169 for byte in combined.bytes() {
170 hash = hash.wrapping_mul(31).wrapping_add(u64::from(byte));
171 }
172 hash
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct ExperimentResult {
179 pub experiment_id: String,
181 pub user_id: String,
183 pub variant: Variant,
185 pub assigned_at: u64,
187 pub metrics: HashMap<String, f64>,
189}
190
191impl ExperimentResult {
192 #[must_use]
194 pub fn new(experiment_id: String, user_id: String, variant: Variant, assigned_at: u64) -> Self {
195 Self {
196 experiment_id,
197 user_id,
198 variant,
199 assigned_at,
200 metrics: HashMap::new(),
201 }
202 }
203
204 pub fn add_metric(&mut self, name: impl Into<String>, value: f64) {
206 self.metrics.insert(name.into(), value);
207 }
208
209 #[must_use]
211 pub fn get_metric(&self, name: &str) -> Option<f64> {
212 self.metrics.get(name).copied()
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct GradualRollout {
221 pub feature_id: String,
223 pub current_percentage: u8,
225 pub target_percentage: u8,
227 pub increment_step: u8,
229 pub enabled: bool,
231}
232
233impl GradualRollout {
234 #[must_use]
236 pub fn new(feature_id: impl Into<String>) -> Self {
237 Self {
238 feature_id: feature_id.into(),
239 current_percentage: 0,
240 target_percentage: 100,
241 increment_step: 10,
242 enabled: false,
243 }
244 }
245
246 #[must_use]
248 pub fn with_target(mut self, target: u8) -> Self {
249 self.target_percentage = target.min(100);
250 self
251 }
252
253 #[must_use]
255 pub fn with_step(mut self, step: u8) -> Self {
256 self.increment_step = step.max(1);
257 self
258 }
259
260 #[must_use]
262 pub fn enabled(mut self) -> Self {
263 self.enabled = true;
264 self
265 }
266
267 pub fn ramp_up(&mut self) {
269 if self.enabled && self.current_percentage < self.target_percentage {
270 self.current_percentage =
271 (self.current_percentage + self.increment_step).min(self.target_percentage);
272 }
273 }
274
275 pub fn ramp_down(&mut self) {
277 if self.current_percentage > 0 {
278 self.current_percentage = self.current_percentage.saturating_sub(self.increment_step);
279 }
280 }
281
282 #[must_use]
284 pub fn has_access(&self, user_id: &str) -> bool {
285 if !self.enabled {
286 return false;
287 }
288
289 let hash = self.hash_user(user_id);
290 let bucket = hash % 100;
291 bucket < u64::from(self.current_percentage)
292 }
293
294 #[must_use]
296 pub fn is_complete(&self) -> bool {
297 self.current_percentage >= self.target_percentage
298 }
299
300 fn hash_user(&self, user_id: &str) -> u64 {
301 let combined = format!("{}{}", self.feature_id, user_id);
302 let mut hash = 0u64;
303 for byte in combined.bytes() {
304 hash = hash.wrapping_mul(31).wrapping_add(u64::from(byte));
305 }
306 hash
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_variant_is_control() {
316 assert!(Variant::Control.is_control());
317 assert!(!Variant::Treatment.is_control());
318 assert!(!Variant::Custom("test".to_string()).is_control());
319 }
320
321 #[test]
322 fn test_variant_is_treatment() {
323 assert!(!Variant::Control.is_treatment());
324 assert!(Variant::Treatment.is_treatment());
325 assert!(!Variant::Custom("test".to_string()).is_treatment());
326 }
327
328 #[test]
329 fn test_variant_name() {
330 assert_eq!(Variant::Control.name(), "control");
331 assert_eq!(Variant::Treatment.name(), "treatment");
332 assert_eq!(Variant::Custom("custom".to_string()).name(), "custom");
333 }
334
335 #[test]
336 fn test_experiment_disabled() {
337 let exp = Experiment::new("test", "Test Experiment");
338 assert!(!exp.enabled);
339 assert_eq!(exp.assign_variant("user123"), Variant::Control);
340 }
341
342 #[test]
343 fn test_experiment_rollout_percentage() {
344 let exp = Experiment::new("test", "Test Experiment")
345 .with_rollout_percentage(50)
346 .enabled();
347
348 let variant1 = exp.assign_variant("user123");
350 let variant2 = exp.assign_variant("user123");
351 assert_eq!(variant1, variant2); }
353
354 #[test]
355 fn test_experiment_rollout_percentage_clamping() {
356 let exp = Experiment::new("test", "Test").with_rollout_percentage(150); assert_eq!(exp.rollout_percentage, 100);
358 }
359
360 #[test]
361 fn test_experiment_is_enrolled() {
362 let exp = Experiment::new("test", "Test")
363 .with_rollout_percentage(100)
364 .enabled();
365
366 assert!(exp.is_enrolled("user123"));
368 }
369
370 #[test]
371 fn test_experiment_result() {
372 let mut result = ExperimentResult::new(
373 "exp1".to_string(),
374 "user123".to_string(),
375 Variant::Treatment,
376 1000,
377 );
378
379 result.add_metric("conversion_rate", 0.25);
380 result.add_metric("revenue", 100.0);
381
382 assert_eq!(result.get_metric("conversion_rate"), Some(0.25));
383 assert_eq!(result.get_metric("revenue"), Some(100.0));
384 assert_eq!(result.get_metric("nonexistent"), None);
385 }
386
387 #[test]
388 fn test_gradual_rollout_new() {
389 let rollout = GradualRollout::new("feature1");
390 assert_eq!(rollout.current_percentage, 0);
391 assert_eq!(rollout.target_percentage, 100);
392 assert!(!rollout.enabled);
393 }
394
395 #[test]
396 fn test_gradual_rollout_ramp_up() {
397 let mut rollout = GradualRollout::new("feature1")
398 .with_step(10)
399 .with_target(50)
400 .enabled();
401
402 assert_eq!(rollout.current_percentage, 0);
403 rollout.ramp_up();
404 assert_eq!(rollout.current_percentage, 10);
405 rollout.ramp_up();
406 assert_eq!(rollout.current_percentage, 20);
407
408 rollout.ramp_up();
410 rollout.ramp_up();
411 rollout.ramp_up();
412 assert_eq!(rollout.current_percentage, 50);
413
414 rollout.ramp_up();
416 assert_eq!(rollout.current_percentage, 50);
417 }
418
419 #[test]
420 fn test_gradual_rollout_ramp_down() {
421 let mut rollout = GradualRollout::new("feature1").with_step(10).enabled();
422
423 rollout.current_percentage = 50;
424 rollout.ramp_down();
425 assert_eq!(rollout.current_percentage, 40);
426
427 rollout.current_percentage = 5;
429 rollout.ramp_down();
430 assert_eq!(rollout.current_percentage, 0);
431 rollout.ramp_down();
432 assert_eq!(rollout.current_percentage, 0);
433 }
434
435 #[test]
436 fn test_gradual_rollout_has_access() {
437 let mut rollout = GradualRollout::new("feature1").enabled();
438
439 assert!(!rollout.has_access("user123"));
441
442 rollout.current_percentage = 100;
444 assert!(rollout.has_access("user123"));
445 }
446
447 #[test]
448 fn test_gradual_rollout_is_complete() {
449 let mut rollout = GradualRollout::new("feature1").with_target(50);
450
451 assert!(!rollout.is_complete());
452 rollout.current_percentage = 50;
453 assert!(rollout.is_complete());
454 rollout.current_percentage = 60;
455 assert!(rollout.is_complete());
456 }
457
458 #[test]
459 fn test_experiment_serde() {
460 let exp = Experiment::new("test", "Test Experiment")
461 .with_description("A test experiment")
462 .with_rollout_percentage(50)
463 .enabled();
464
465 let json = serde_json::to_string(&exp).unwrap();
466 let decoded: Experiment = serde_json::from_str(&json).unwrap();
467 assert_eq!(exp.id, decoded.id);
468 assert_eq!(exp.rollout_percentage, decoded.rollout_percentage);
469 assert_eq!(exp.enabled, decoded.enabled);
470 }
471
472 #[test]
473 fn test_gradual_rollout_serde() {
474 let rollout = GradualRollout::new("feature1")
475 .with_target(75)
476 .with_step(25)
477 .enabled();
478
479 let json = serde_json::to_string(&rollout).unwrap();
480 let decoded: GradualRollout = serde_json::from_str(&json).unwrap();
481 assert_eq!(rollout.feature_id, decoded.feature_id);
482 assert_eq!(rollout.target_percentage, decoded.target_percentage);
483 assert_eq!(rollout.increment_step, decoded.increment_step);
484 }
485}