use graphql_tools::validation::validate::ValidationPlan;
use hive_console_sdk::agent::usage_agent::{AgentError, UsageAgent};
use hive_router_config::HiveRouterConfig;
use hive_router_internal::expressions::values::boolean::BooleanOrProgram;
use hive_router_internal::expressions::ExpressionCompileError;
use hive_router_internal::telemetry::TelemetryContext;
use hive_router_plan_executor::headers::{
compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan,
};
use hive_router_plan_executor::plugin_trait::RouterPluginBoxed;
use moka::future::Cache;
use moka::Expiry;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::cache_state::CacheState;
use crate::jwt::context::JwtTokenPayload;
use crate::jwt::JwtAuthRuntime;
use crate::pipeline::cors::{CORSConfigError, Cors};
use crate::pipeline::introspection_policy::compile_introspection_policy;
use crate::pipeline::parser::ParseCacheEntry;
use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator};
pub type JwtClaimsCache = Cache<String, Arc<JwtTokenPayload>>;
const DEFAULT_JWT_CACHE_TTL_SECS: u64 = 5;
struct JwtClaimsExpiry;
impl Expiry<String, Arc<JwtTokenPayload>> for JwtClaimsExpiry {
fn expire_after_create(
&self,
_key: &String,
value: &Arc<JwtTokenPayload>,
_created_at: std::time::Instant,
) -> Option<Duration> {
const DEFAULT_TTL: Duration = Duration::from_secs(DEFAULT_JWT_CACHE_TTL_SECS);
let exp = match value.claims.exp {
Some(e) => e,
None => return Some(DEFAULT_TTL),
};
let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(duration) => duration.as_secs(),
Err(_) => return Some(DEFAULT_TTL), };
if exp <= now {
return Some(Duration::ZERO);
}
let time_until_exp = Duration::from_secs(exp - now);
Some(DEFAULT_TTL.min(time_until_exp))
}
}
pub struct RouterSharedState {
pub validation_plan: Arc<ValidationPlan>,
pub parse_cache: Cache<u64, ParseCacheEntry>,
pub router_config: Arc<HiveRouterConfig>,
pub headers_plan: HeaderRulesPlan,
pub override_labels_evaluator: OverrideLabelsEvaluator,
pub cors_runtime: Option<Cors>,
pub jwt_claims_cache: JwtClaimsCache,
pub jwt_auth_runtime: Option<JwtAuthRuntime>,
pub hive_usage_agent: Option<UsageAgent>,
pub introspection_policy: BooleanOrProgram,
pub telemetry_context: Arc<TelemetryContext>,
pub plugins: Option<Arc<Vec<RouterPluginBoxed>>>,
}
impl RouterSharedState {
pub fn new(
router_config: Arc<HiveRouterConfig>,
jwt_auth_runtime: Option<JwtAuthRuntime>,
hive_usage_agent: Option<UsageAgent>,
validation_plan: ValidationPlan,
telemetry_context: Arc<TelemetryContext>,
plugins: Option<Arc<Vec<RouterPluginBoxed>>>,
cache_state: Arc<CacheState>,
) -> Result<Self, SharedStateError> {
let parse_cache = cache_state.parse_cache.clone();
Ok(Self {
validation_plan: Arc::new(validation_plan),
headers_plan: compile_headers_plan(&router_config.headers).map_err(Box::new)?,
parse_cache,
cors_runtime: Cors::from_config(&router_config.cors).map_err(Box::new)?,
jwt_claims_cache: Cache::builder()
.max_capacity(10_000)
.expire_after(JwtClaimsExpiry)
.build(),
router_config: router_config.clone(),
override_labels_evaluator: OverrideLabelsEvaluator::from_config(
&router_config.override_labels,
)
.map_err(Box::new)?,
jwt_auth_runtime,
hive_usage_agent,
introspection_policy: compile_introspection_policy(&router_config.introspection)
.map_err(Box::new)?,
telemetry_context,
plugins,
})
}
}
#[derive(thiserror::Error, Debug)]
pub enum SharedStateError {
#[error("invalid headers config: {0}")]
HeaderRuleCompile(#[from] Box<HeaderRuleCompileError>),
#[error("invalid regex in CORS config: {0}")]
CORSConfig(#[from] Box<CORSConfigError>),
#[error("invalid override labels config: {0}")]
OverrideLabelsCompile(#[from] Box<OverrideLabelsCompileError>),
#[error("error creating hive usage agent: {0}")]
UsageAgent(#[from] Box<AgentError>),
#[error("invalid introspection config: {0}")]
IntrospectionPolicyCompile(#[from] Box<ExpressionCompileError>),
}