use std::collections::hash_map::HashMap;
use std::collections::hash_set::HashSet;
use std::hash::BuildHasher;
use std::io::Cursor;
use log::{trace, warn};
use murmur3::murmur3_32;
use rand::Rng;
use crate::api::{Constraint, ConstraintExpression};
use crate::context::Context;
pub type Strategy =
Box<dyn Fn(Option<HashMap<String, String>>) -> Evaluate + Sync + Send + 'static>;
pub trait Evaluator: Fn(&Context) -> bool {
fn clone_boxed(&self) -> Box<dyn Evaluator + Send + Sync + 'static>;
}
pub type Evaluate = Box<dyn Evaluator + Send + Sync + 'static>;
impl<T> Evaluator for T
where
T: 'static + Clone + Sync + Send + Fn(&Context) -> bool,
{
fn clone_boxed(&self) -> Box<dyn Evaluator + Send + Sync + 'static> {
Box::new(T::clone(self))
}
}
impl Clone for Box<dyn Evaluator + Send + Sync + 'static> {
fn clone(&self) -> Self {
self.as_ref().clone_boxed()
}
}
pub fn default<S: BuildHasher>(_: Option<HashMap<String, String, S>>) -> Evaluate {
Box::new(|_: &Context| -> bool { true })
}
pub fn user_with_id<S: BuildHasher>(parameters: Option<HashMap<String, String, S>>) -> Evaluate {
let mut uids: HashSet<String> = HashSet::new();
if let Some(parameters) = parameters {
if let Some(uids_list) = parameters.get("userIds") {
for uid in uids_list.split(',') {
uids.insert(uid.trim().into());
}
}
}
Box::new(move |context: &Context| -> bool {
context
.user_id
.as_ref()
.map(|uid| uids.contains(uid))
.unwrap_or(false)
})
}
pub fn group_and_rollout<S: BuildHasher>(
parameters: &Option<HashMap<String, String, S>>,
rollout_key: &str,
) -> (String, u32) {
let parameters = if let Some(parameters) = parameters {
parameters
} else {
return ("".into(), 0);
};
let group = if let Some(group) = parameters.get("groupId") {
group.to_string()
} else {
"".into()
};
let mut rollout = 0;
if let Some(rollout_str) = parameters.get(rollout_key) {
if let Ok(percent) = rollout_str.parse::<u32>() {
rollout = percent
}
}
(group, rollout)
}
pub fn partial_rollout(group: &str, variable: Option<&String>, rollout: u32) -> bool {
let variable = if let Some(variable) = variable {
variable
} else {
return false;
};
if let Ok(normalised) = normalised_hash(group, &variable, 100) {
rollout > normalised
} else {
false
}
}
pub fn normalised_hash(group: &str, identifier: &str, modulus: u32) -> std::io::Result<u32> {
let mut reader = Cursor::new(format!("{}:{}", &group, &identifier));
murmur3_32(&mut reader, 0).map(|hash_result| hash_result % modulus)
}
fn _session_id<S: BuildHasher>(
parameters: Option<HashMap<String, String, S>>,
rollout_key: &str,
) -> Evaluate {
let (group, rollout) = group_and_rollout(¶meters, rollout_key);
Box::new(move |context: &Context| -> bool {
partial_rollout(&group, context.session_id.as_ref(), rollout)
})
}
fn _user_id<S: BuildHasher>(
parameters: Option<HashMap<String, String, S>>,
rollout_key: &str,
) -> Evaluate {
let (group, rollout) = group_and_rollout(¶meters, rollout_key);
Box::new(move |context: &Context| -> bool {
partial_rollout(&group, context.user_id.as_ref(), rollout)
})
}
pub fn flexible_rollout<S: BuildHasher>(
parameters: Option<HashMap<String, String, S>>,
) -> Evaluate {
let unwrapped_parameters = if let Some(parameters) = ¶meters {
parameters
} else {
return Box::new(|_| false);
};
match if let Some(stickiness) = unwrapped_parameters.get("stickiness") {
stickiness.as_str()
} else {
return Box::new(|_| false);
} {
"default" => {
let (group, rollout) = group_and_rollout(¶meters, "rollout");
Box::new(move |context: &Context| -> bool {
if context.user_id.is_some() {
partial_rollout(&group, context.user_id.as_ref(), rollout)
} else if context.session_id.is_some() {
partial_rollout(&group, context.session_id.as_ref(), rollout)
} else {
let picked = rand::thread_rng().gen_range(0, 100);
rollout > picked
}
})
}
"userId" => _user_id(parameters, "rollout"),
"sessionId" => _session_id(parameters, "rollout"),
"random" => _random(parameters, "rollout"),
_ => Box::new(|_| false),
}
}
pub fn user_id<S: BuildHasher>(parameters: Option<HashMap<String, String, S>>) -> Evaluate {
_user_id(parameters, "percentage")
}
pub fn session_id<S: BuildHasher>(parameters: Option<HashMap<String, String, S>>) -> Evaluate {
_session_id(parameters, "percentage")
}
pub fn _random<S: BuildHasher>(
parameters: Option<HashMap<String, String, S>>,
rollout_key: &str,
) -> Evaluate {
let mut pct = 0;
if let Some(parameters) = parameters {
if let Some(pct_str) = parameters.get(rollout_key) {
if let Ok(percent) = pct_str.parse::<u8>() {
pct = percent
}
}
}
Box::new(move |_: &Context| -> bool {
let mut rng = rand::thread_rng();
let picked = rng.gen_range(0, 100);
pct > picked
})
}
pub fn random<S: BuildHasher>(parameters: Option<HashMap<String, String, S>>) -> Evaluate {
_random(parameters, "percentage")
}
pub fn remote_address<S: BuildHasher>(parameters: Option<HashMap<String, String, S>>) -> Evaluate {
let mut ips: Vec<ipaddress::IPAddress> = Vec::new();
if let Some(parameters) = parameters {
if let Some(ips_str) = parameters.get("IPs") {
for ip_str in ips_str.split(',') {
let ip_parsed = ipaddress::IPAddress::parse(ip_str.trim());
if let Ok(ip) = ip_parsed {
ips.push(ip)
}
}
}
}
Box::new(move |context: &Context| -> bool {
if let Some(remote_address) = &context.remote_address {
for ip in &ips {
if ip.includes(&remote_address.0) {
return true;
}
}
}
false
})
}
pub fn hostname<S: BuildHasher>(parameters: Option<HashMap<String, String, S>>) -> Evaluate {
let mut result = false;
hostname::get().ok().and_then(|this_hostname| {
parameters.map(|parameters| {
parameters.get("hostNames").map(|hostnames: &String| {
for hostname in hostnames.split(',') {
if this_hostname == hostname.trim() {
result = true;
}
}
false
})
})
});
Box::new(move |_: &Context| -> bool { result })
}
fn _compile_constraint_string<F>(expression: ConstraintExpression, getter: F) -> Evaluate
where
F: Fn(&Context) -> Option<&String> + Clone + Sync + Send + 'static,
{
match &expression {
ConstraintExpression::In(values) => {
let as_set: HashSet<String> = values.iter().cloned().collect();
Box::new(move |context: &Context| {
getter(context).map(|v| as_set.contains(v)).unwrap_or(false)
})
}
ConstraintExpression::NotIn(values) => {
if values.is_empty() {
Box::new(|_| true)
} else {
let as_set: HashSet<String> = values.iter().cloned().collect();
Box::new(move |context: &Context| {
getter(context)
.map(|v| !as_set.contains(v))
.unwrap_or(false)
})
}
}
}
}
fn _ip_to_vec(ips: &[String]) -> Vec<ipaddress::IPAddress> {
let mut result = Vec::new();
for ip_str in ips {
let ip_parsed = ipaddress::IPAddress::parse(ip_str.trim());
if let Ok(ip) = ip_parsed {
result.push(ip);
} else {
warn!("Could not parse IP address {:?}", ip_str);
}
}
result
}
fn _compile_constraint_host<F>(expression: ConstraintExpression, getter: F) -> Evaluate
where
F: Fn(&Context) -> Option<&crate::context::IPAddress> + Clone + Sync + Send + 'static,
{
match &expression {
ConstraintExpression::In(values) => {
let ips = _ip_to_vec(values);
Box::new(move |context: &Context| {
getter(context)
.map(|remote_address| {
for ip in &ips {
if ip.includes(&remote_address.0) {
return true;
}
}
false
})
.unwrap_or(false)
})
}
ConstraintExpression::NotIn(values) => {
if values.is_empty() {
Box::new(|_| false)
} else {
let ips = _ip_to_vec(values);
Box::new(move |context: &Context| {
getter(context)
.map(|remote_address| {
if ips.is_empty() {
return false;
}
for ip in &ips {
if ip.includes(&remote_address.0) {
return false;
}
}
true
})
.unwrap_or(false)
})
}
}
}
}
fn _compile_constraints(constraints: Vec<Constraint>) -> Vec<Evaluate> {
constraints
.into_iter()
.map(|constraint| {
let (context_name, expression) = (constraint.context_name, constraint.expression);
match context_name.as_str() {
"appName" => {
_compile_constraint_string(expression, |context| Some(&context.app_name))
}
"environment" => {
_compile_constraint_string(expression, |context| Some(&context.environment))
}
"remoteAddress" => {
_compile_constraint_host(expression, |context| context.remote_address.as_ref())
}
"sessionId" => {
_compile_constraint_string(expression, |context| context.session_id.as_ref())
}
"userId" => {
_compile_constraint_string(expression, |context| context.user_id.as_ref())
}
_ => _compile_constraint_string(expression, move |context| {
context.properties.get(&context_name)
}),
}
})
.collect()
}
pub fn constrain<S: Fn(Option<HashMap<String, String>>) -> Evaluate + Sync + Send + 'static>(
constraints: Option<Vec<Constraint>>,
strategy: &S,
parameters: Option<HashMap<String, String>>,
) -> Evaluate {
let compiled_strategy = strategy(parameters);
match constraints {
None => {
trace!("constrain: no constraints, bypassing");
compiled_strategy
}
Some(constraints) => {
if constraints.is_empty() {
trace!("constrain: empty constraints list, bypassing");
compiled_strategy
} else {
trace!("constrain: compiling constraints list {:?}", constraints);
let constraints = _compile_constraints(constraints);
Box::new(move |context| {
for constraint in &constraints {
if !constraint(&context) {
return false;
}
}
compiled_strategy(context)
})
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::hash_map::HashMap;
use std::default::Default;
use maplit::hashmap;
use crate::api::{Constraint, ConstraintExpression};
use crate::context::{Context, IPAddress};
fn parse_ip(addr: &str) -> Option<IPAddress> {
Some(IPAddress(ipaddress::IPAddress::parse(addr).unwrap()))
}
#[test]
fn test_constrain() {
let context = Context::default();
assert_eq!(
true,
super::constrain(None, &super::default, None)(&context)
);
let context = Context::default();
assert_eq!(
true,
super::constrain(Some(vec![]), &super::default, None)(&context)
);
let context = Context {
environment: "development".into(),
..Default::default()
};
assert_eq!(
false,
super::constrain(
Some(vec![Constraint {
context_name: "".into(),
expression: ConstraintExpression::In(vec![]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
environment: "production".into(),
..Default::default()
};
assert_eq!(
false,
super::constrain(
Some(vec![Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec!["development".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
environment: "development".into(),
..Default::default()
};
assert_eq!(
false,
super::constrain(
Some(vec![Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::NotIn(vec!["development".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
environment: "development".into(),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec!["development".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
environment: "development".into(),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec![
"staging".into(),
"development".into()
]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
environment: "production".into(),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::NotIn(vec![
"staging".into(),
"development".into()
]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
user_id: Some("fred".into()),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "userId".into(),
expression: ConstraintExpression::In(vec!["fred".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
session_id: Some("qwerty".into()),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "sessionId".into(),
expression: ConstraintExpression::In(vec!["qwerty".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
remote_address: parse_ip("10.20.30.40"),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "remoteAddress".into(),
expression: ConstraintExpression::In(vec!["10.0.0.0/8".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
remote_address: parse_ip("1.2.3.4"),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![Constraint {
context_name: "remoteAddress".into(),
expression: ConstraintExpression::NotIn(vec!["10.0.0.0/8".into()]),
}]),
&super::default,
None
)(&context)
);
let context = Context {
environment: "development".into(),
..Default::default()
};
assert_eq!(
true,
super::constrain(
Some(vec![
Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec!["development".into()]),
},
Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec!["development".into()]),
},
]),
&super::default,
None
)(&context)
);
assert_eq!(
false,
super::constrain(
Some(vec![
Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec!["development".into()]),
},
Constraint {
context_name: "environment".into(),
expression: ConstraintExpression::In(vec![]),
}
]),
&super::default,
None
)(&context)
);
}
#[test]
fn test_user_with_id() {
let params: HashMap<String, String> = hashmap! {
"userIds".into() => "fred,barney".into(),
};
assert_eq!(
true,
super::user_with_id(Some(params.clone()))(&Context {
user_id: Some("fred".into()),
..Default::default()
})
);
assert_eq!(
true,
super::user_with_id(Some(params.clone()))(&Context {
user_id: Some("barney".into()),
..Default::default()
})
);
assert_eq!(
false,
super::user_with_id(Some(params))(&Context {
user_id: Some("betty".into()),
..Default::default()
})
);
}
#[test]
fn test_flexible_rollout() {
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "random".into(),
"rollout".into() => "0".into(),
};
let c: Context = Default::default();
assert_eq!(false, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "random".into(),
"rollout".into() => "100".into(),
};
let c: Context = Default::default();
assert_eq!(true, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "sessionId".into(),
"groupId".into() => "group1".into(),
"rollout".into() => "0".into(),
};
let c: Context = Context {
session_id: Some("session1".into()),
..Default::default()
};
assert_eq!(false, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "sessionId".into(),
"groupId".into() => "group1".into(),
"rollout".into() => "100".into(),
};
let c: Context = Context {
session_id: Some("session1".into()),
..Default::default()
};
assert_eq!(true, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "sessionId".into(),
"groupId".into() => "group1".into(),
"rollout".into() => "50".into(),
};
let c: Context = Context {
session_id: Some("session1".into()),
..Default::default()
};
assert_eq!(true, super::flexible_rollout(Some(params.clone()))(&c));
let c: Context = Context {
session_id: Some("session2".into()),
..Default::default()
};
assert_eq!(false, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "sessionId".into(),
"groupId".into() => "group3".into(),
"rollout".into() => "50".into(),
};
let c: Context = Context {
session_id: Some("session1".into()),
..Default::default()
};
assert_eq!(false, super::flexible_rollout(Some(params.clone()))(&c));
let c: Context = Context {
session_id: Some("session2".into()),
..Default::default()
};
assert_eq!(true, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "userId".into(),
"groupId".into() => "group1".into(),
"rollout".into() => "0".into(),
};
let c: Context = Context {
user_id: Some("user1".into()),
..Default::default()
};
assert_eq!(false, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "userId".into(),
"groupId".into() => "group1".into(),
"rollout".into() => "100".into(),
};
let c: Context = Context {
user_id: Some("user1".into()),
..Default::default()
};
assert_eq!(true, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "userId".into(),
"groupId".into() => "group1".into(),
"rollout".into() => "50".into(),
};
let c: Context = Context {
user_id: Some("user1".into()),
..Default::default()
};
assert_eq!(true, super::flexible_rollout(Some(params.clone()))(&c));
let c: Context = Context {
user_id: Some("user3".into()),
..Default::default()
};
assert_eq!(false, super::flexible_rollout(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"stickiness".into() => "userId".into(),
"groupId".into() => "group2".into(),
"rollout".into() => "50".into(),
};
let c: Context = Context {
user_id: Some("user3".into()),
..Default::default()
};
assert_eq!(false, super::flexible_rollout(Some(params.clone()))(&c));
let c: Context = Context {
user_id: Some("user1".into()),
..Default::default()
};
assert_eq!(true, super::flexible_rollout(Some(params))(&c));
}
#[test]
fn test_random() {
let params: HashMap<String, String> = hashmap! {
"percentage".into() => "0".into()
};
let c: Context = Default::default();
assert_eq!(false, super::random(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"percentage".into() => "100".into()
};
let c: Context = Default::default();
assert_eq!(true, super::random(Some(params))(&c));
}
#[test]
fn test_remote_address() {
let params: HashMap<String, String> = hashmap! {
"IPs".into() => "1.2/8,2.3.4.5,2222:FF:0:1234::/64".into()
};
let c: Context = Context {
remote_address: parse_ip("1.2.3.4"),
..Default::default()
};
assert_eq!(true, super::remote_address(Some(params.clone()))(&c));
let c: Context = Context {
remote_address: parse_ip("2.3.4.5"),
..Default::default()
};
assert_eq!(true, super::remote_address(Some(params.clone()))(&c));
let c: Context = Context {
remote_address: parse_ip("2222:FF:0:1234::FDEC"),
..Default::default()
};
assert_eq!(true, super::remote_address(Some(params.clone()))(&c));
let c: Context = Context {
remote_address: parse_ip("2.3.4.4"),
..Default::default()
};
assert_eq!(false, super::remote_address(Some(params))(&c));
}
#[test]
fn test_hostname() {
let c: Context = Default::default();
let this_hostname = hostname::get().unwrap().into_string().unwrap();
let params: HashMap<String, String> = hashmap! {
"hostNames".into() => format!("foo,{},bar", this_hostname)
};
assert_eq!(true, super::hostname(Some(params))(&c));
let params: HashMap<String, String> = hashmap! {
"hostNames".into() => "foo,bar".into()
};
assert_eq!(false, super::hostname(Some(params))(&c));
}
#[test]
fn normalised_hash() {
assert!(50 > super::normalised_hash("AB12A", "122", 100).unwrap());
}
}