use std::collections::{BTreeMap, HashMap, HashSet};
use hive_router_config::override_labels::{LabelOverrideValue, OverrideLabelsConfig};
use hive_router_internal::expressions::CompileExpression;
use hive_router_plan_executor::execution::client_request_details::ClientRequestDetailsView;
use hive_router_plan_executor::request_context::RequestContextError;
use hive_router_plan_executor::request_context::SharedRequestContext;
use hive_router_query_planner::{
graph::{PlannerOverrideContext, PERCENTAGE_SCALE_FACTOR},
state::supergraph_state::SupergraphState,
};
use rand::prelude::*;
use vrl::{
compiler::Program as VrlProgram,
compiler::TargetValue as VrlTargetValue,
core::Value as VrlValue,
prelude::{
state::RuntimeState as VrlState, Context as VrlContext, ExpressionError,
TimeZone as VrlTimeZone,
},
value::Secrets as VrlSecrets,
};
#[derive(thiserror::Error, Debug)]
#[error("Failed to compile override label expression for label '{label}': {error}")]
pub struct OverrideLabelsCompileError {
pub label: String,
pub error: String,
}
#[derive(thiserror::Error, Debug)]
pub enum LabelEvaluationError {
#[error(
"Failed to resolve VRL expression for override label '{label}'. Runtime error: {source}"
)]
ExpressionResolutionFailure {
label: String,
source: ExpressionError,
},
#[error(
"VRL expression for override label '{label}' did not evaluate to a boolean. Got: {got}"
)]
ExpressionWrongType { label: String, got: String },
#[error(transparent)]
RequestContext(#[from] RequestContextError),
}
#[derive(Debug, Clone)]
pub struct RequestOverrideContext {
pub active_flags: HashSet<String>,
pub percentage_value: u64,
}
impl RequestOverrideContext {
#[inline]
pub fn new(
override_labels_evaluator: &OverrideLabelsEvaluator,
client_request_details: &impl ClientRequestDetailsView,
request_context: &SharedRequestContext,
) -> Result<Self, LabelEvaluationError> {
let progressive_override_state = request_context.snapshot()?.progressive_override;
let active_flags = override_labels_evaluator.evaluate(
progressive_override_state.labels_to_override.as_ref(),
client_request_details,
)?;
if !active_flags.is_empty() {
request_context.update(|ctx| {
ctx.progressive_override.labels_to_override = Some(active_flags.clone());
ctx.progressive_override.unresolved_labels =
match ctx.progressive_override.unresolved_labels.as_ref() {
Some(labels) => {
let diff: HashSet<_> =
labels.difference(&active_flags).cloned().collect();
if diff.is_empty() {
None
} else {
Some(diff)
}
}
None => None,
}
})?;
}
let percentage_value: u64 = rand::rng().random_range(0..=(100 * PERCENTAGE_SCALE_FACTOR));
let override_context = RequestOverrideContext {
active_flags,
percentage_value,
};
Ok(override_context)
}
pub fn update_from(
&mut self,
request_context: &SharedRequestContext,
) -> Result<(), RequestContextError> {
let active_labels = request_context
.snapshot()?
.progressive_override
.labels_to_override;
self.active_flags = active_labels.unwrap_or_default();
Ok(())
}
}
impl From<&RequestOverrideContext> for PlannerOverrideContext {
fn from(value: &RequestOverrideContext) -> Self {
Self::new(value.active_flags.clone(), value.percentage_value)
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct StableOverrideContext {
active_flags: BTreeMap<String, bool>,
percentage_outcomes: BTreeMap<u64, bool>,
}
impl StableOverrideContext {
pub fn new(
supergraph: &SupergraphState,
request_override_context: &RequestOverrideContext,
) -> Self {
let mut active_flags = BTreeMap::new();
for flag_name in &supergraph.progressive_overrides.flags {
let is_active = request_override_context.active_flags.contains(flag_name);
active_flags.insert(flag_name.clone(), is_active);
}
let mut percentage_outcomes = BTreeMap::new();
for &threshold in &supergraph.progressive_overrides.percentages {
let in_range = request_override_context.percentage_value < threshold;
percentage_outcomes.insert(threshold, in_range);
}
StableOverrideContext {
active_flags,
percentage_outcomes,
}
}
}
pub struct OverrideLabelsEvaluator {
static_enabled_labels: HashSet<String>,
expressions: HashMap<String, VrlProgram>,
}
impl OverrideLabelsEvaluator {
pub(crate) fn from_config(
override_labels_config: &OverrideLabelsConfig,
) -> Result<Self, OverrideLabelsCompileError> {
let mut static_enabled_labels = HashSet::new();
let mut expressions = HashMap::new();
for (label, value) in override_labels_config.iter() {
match value {
LabelOverrideValue::Boolean(true) => {
static_enabled_labels.insert(label.clone());
}
LabelOverrideValue::Expression { expression } => {
let program = expression.compile_expression(None).map_err(|err| {
OverrideLabelsCompileError {
label: label.clone(),
error: err.to_string(),
}
})?;
expressions.insert(label.clone(), program);
}
_ => {} }
}
Ok(Self {
static_enabled_labels,
expressions,
})
}
pub(crate) fn evaluate(
&self,
resolved_labels: Option<&HashSet<String>>,
client_request: &impl ClientRequestDetailsView,
) -> Result<HashSet<String>, LabelEvaluationError> {
let mut active_flags = match resolved_labels {
Some(set) => set.union(&self.static_enabled_labels).cloned().collect(),
None => self.static_enabled_labels.clone(),
};
if self.expressions.is_empty() {
return Ok(active_flags);
}
let mut target = VrlTargetValue {
value: VrlValue::Object(BTreeMap::from([(
"request".into(),
client_request.to_vrl_value(),
)])),
metadata: VrlValue::Object(BTreeMap::new()),
secrets: VrlSecrets::default(),
};
let mut state = VrlState::default();
let timezone = VrlTimeZone::default();
let mut ctx = VrlContext::new(&mut target, &mut state, &timezone);
for (label, expression) in &self.expressions {
if active_flags.contains(label) {
continue;
}
match expression.resolve(&mut ctx) {
Ok(evaluated_value) => match evaluated_value {
VrlValue::Boolean(true) => {
active_flags.insert(label.clone());
}
VrlValue::Boolean(false) => {
}
invalid_value => {
return Err(LabelEvaluationError::ExpressionWrongType {
label: label.clone(),
got: format!("{:?}", invalid_value),
});
}
},
Err(err) => {
return Err(LabelEvaluationError::ExpressionResolutionFailure {
label: label.clone(),
source: err,
});
}
}
}
Ok(active_flags)
}
}