use crate::feature_flags::{
match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
FeatureFlag, FlagValue, InconclusiveMatchError,
};
use crate::Error;
use reqwest::header::{HeaderMap, ETAG, IF_NONE_MATCH};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{debug, error, info, instrument, trace, warn};
fn extract_etag(headers: &HeaderMap) -> Option<String> {
headers
.get(ETAG)
.and_then(|v| v.to_str().ok())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalEvaluationResponse {
pub flags: Vec<FeatureFlag>,
#[serde(default)]
pub group_type_mapping: HashMap<String, String>,
#[serde(default)]
pub cohorts: HashMap<String, Cohort>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cohort {
pub id: String,
pub name: String,
pub properties: serde_json::Value,
}
#[derive(Clone)]
pub struct FlagCache {
flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
}
impl Default for FlagCache {
fn default() -> Self {
Self::new()
}
}
impl FlagCache {
pub fn new() -> Self {
Self {
flags: Arc::new(RwLock::new(HashMap::new())),
group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
cohorts: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn update(&self, response: LocalEvaluationResponse) {
let flag_count = response.flags.len();
let mut flags = self.flags.write().unwrap();
flags.clear();
for flag in response.flags {
flags.insert(flag.key.clone(), flag);
}
let mut mapping = self.group_type_mapping.write().unwrap();
*mapping = response.group_type_mapping;
let mut cohorts = self.cohorts.write().unwrap();
*cohorts = response.cohorts;
debug!(flag_count, "Updated flag cache");
}
pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
self.flags.read().unwrap().get(key).cloned()
}
pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
self.flags.read().unwrap().values().cloned().collect()
}
pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
self.cohorts.read().unwrap().get(id).cloned()
}
pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
self.cohorts.read().unwrap().clone()
}
pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
self.cohorts
.read()
.unwrap()
.iter()
.map(|(k, v)| {
(
k.clone(),
CohortDefinition {
id: v.id.clone(),
properties: v.properties.clone(),
},
)
})
.collect()
}
pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
self.flags.read().unwrap().clone()
}
pub fn get_group_type_mapping(&self) -> HashMap<String, String> {
self.group_type_mapping.read().unwrap().clone()
}
pub fn clear(&self) {
self.flags.write().unwrap().clear();
self.group_type_mapping.write().unwrap().clear();
self.cohorts.write().unwrap().clear();
}
}
#[derive(Clone)]
pub struct LocalEvaluationConfig {
pub personal_api_key: String,
pub project_api_key: String,
pub api_host: String,
pub poll_interval: Duration,
pub request_timeout: Duration,
}
pub struct FlagPoller {
config: LocalEvaluationConfig,
cache: FlagCache,
client: reqwest::blocking::Client,
stop_signal: Arc<AtomicBool>,
thread_handle: Option<std::thread::JoinHandle<()>>,
}
impl FlagPoller {
pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
let client = reqwest::blocking::Client::builder()
.timeout(config.request_timeout)
.build()
.unwrap();
Self {
config,
cache,
client,
stop_signal: Arc::new(AtomicBool::new(false)),
thread_handle: None,
}
}
pub fn start(&mut self) {
info!(
poll_interval_secs = self.config.poll_interval.as_secs(),
"Starting feature flag poller"
);
match self.load_flags() {
Ok(()) => info!("Initial flag definitions loaded successfully"),
Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
}
let config = self.config.clone();
let cache = self.cache.clone();
let stop_signal = self.stop_signal.clone();
let handle = std::thread::spawn(move || {
let client = reqwest::blocking::Client::builder()
.timeout(config.request_timeout)
.build()
.unwrap();
let mut last_etag: Option<String> = None;
loop {
std::thread::sleep(config.poll_interval);
if stop_signal.load(Ordering::Relaxed) {
debug!("Flag poller received stop signal");
break;
}
let url = format!(
"{}/flags/definitions/?send_cohorts",
config.api_host.trim_end_matches('/')
);
let mut request = client
.get(&url)
.header(
"Authorization",
format!("Bearer {}", config.personal_api_key),
)
.header("X-PostHog-Project-Api-Key", &config.project_api_key);
if let Some(ref etag) = last_etag {
request = request.header(IF_NONE_MATCH, etag.as_str());
}
match request.send() {
Ok(response) => {
if response.status() == StatusCode::NOT_MODIFIED {
debug!("Flag definitions unchanged (304 Not Modified)");
} else if response.status().is_success() {
let new_etag = extract_etag(response.headers());
match response.json::<LocalEvaluationResponse>() {
Ok(data) => {
trace!("Successfully fetched flag definitions");
cache.update(data);
last_etag = new_etag;
}
Err(e) => {
warn!(error = %e, "Failed to parse flag response");
}
}
} else {
warn!(status = %response.status(), "Failed to fetch flags");
}
}
Err(e) => {
warn!(error = %e, "Failed to fetch flags");
}
}
}
});
self.thread_handle = Some(handle);
}
#[instrument(skip(self), level = "debug")]
pub fn load_flags(&self) -> Result<(), Error> {
let url = format!(
"{}/flags/definitions/?send_cohorts",
self.config.api_host.trim_end_matches('/')
);
let response = self
.client
.get(&url)
.header(
"Authorization",
format!("Bearer {}", self.config.personal_api_key),
)
.header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
.send()
.map_err(|e| {
error!(error = %e, "Connection error loading flags");
Error::Connection(e.to_string())
})?;
if !response.status().is_success() {
let status = response.status();
error!(status = %status, "HTTP error loading flags");
return Err(Error::Connection(format!("HTTP {}", status)));
}
let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
error!(error = %e, "Failed to parse flag response");
Error::Serialization(e.to_string())
})?;
self.cache.update(data);
Ok(())
}
pub fn stop(&mut self) {
debug!("Stopping flag poller");
self.stop_signal.store(true, Ordering::Relaxed);
if let Some(handle) = self.thread_handle.take() {
handle.join().ok();
}
}
}
impl Drop for FlagPoller {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(feature = "async-client")]
pub struct AsyncFlagPoller {
config: LocalEvaluationConfig,
cache: FlagCache,
client: reqwest::Client,
stop_signal: Arc<AtomicBool>,
task_handle: Option<tokio::task::JoinHandle<()>>,
is_running: Arc<tokio::sync::RwLock<bool>>,
}
#[cfg(feature = "async-client")]
impl AsyncFlagPoller {
pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
let client = reqwest::Client::builder()
.timeout(config.request_timeout)
.build()
.unwrap();
Self {
config,
cache,
client,
stop_signal: Arc::new(AtomicBool::new(false)),
task_handle: None,
is_running: Arc::new(tokio::sync::RwLock::new(false)),
}
}
pub async fn start(&mut self) {
{
let mut is_running = self.is_running.write().await;
if *is_running {
debug!("Flag poller already running, skipping start");
return;
}
*is_running = true;
}
info!(
poll_interval_secs = self.config.poll_interval.as_secs(),
"Starting async feature flag poller"
);
match self.load_flags().await {
Ok(()) => info!("Initial flag definitions loaded successfully"),
Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
}
let config = self.config.clone();
let cache = self.cache.clone();
let stop_signal = self.stop_signal.clone();
let is_running = self.is_running.clone();
let client = self.client.clone();
let task = tokio::spawn(async move {
let mut interval = tokio::time::interval(config.poll_interval);
interval.tick().await;
let mut last_etag: Option<String> = None;
loop {
tokio::select! {
_ = interval.tick() => {
if stop_signal.load(Ordering::Relaxed) {
debug!("Async flag poller received stop signal");
break;
}
let url = format!(
"{}/flags/definitions/?send_cohorts",
config.api_host.trim_end_matches('/')
);
let mut request = client
.get(&url)
.header("Authorization", format!("Bearer {}", config.personal_api_key))
.header("X-PostHog-Project-Api-Key", &config.project_api_key);
if let Some(ref etag) = last_etag {
request = request.header(IF_NONE_MATCH, etag.as_str());
}
match request.send().await {
Ok(response) => {
if response.status() == StatusCode::NOT_MODIFIED {
debug!("Flag definitions unchanged (304 Not Modified)");
} else if response.status().is_success() {
let new_etag = extract_etag(response.headers());
match response.json::<LocalEvaluationResponse>().await {
Ok(data) => {
trace!("Successfully fetched flag definitions");
cache.update(data);
last_etag = new_etag;
}
Err(e) => {
warn!(error = %e, "Failed to parse flag response");
}
}
} else {
warn!(status = %response.status(), "Failed to fetch flags");
}
}
Err(e) => {
warn!(error = %e, "Failed to fetch flags");
}
}
}
}
}
*is_running.write().await = false;
});
self.task_handle = Some(task);
}
#[instrument(skip(self), level = "debug")]
pub async fn load_flags(&self) -> Result<(), Error> {
let url = format!(
"{}/flags/definitions/?send_cohorts",
self.config.api_host.trim_end_matches('/')
);
let response = self
.client
.get(&url)
.header(
"Authorization",
format!("Bearer {}", self.config.personal_api_key),
)
.header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
.send()
.await
.map_err(|e| {
error!(error = %e, "Connection error loading flags");
Error::Connection(e.to_string())
})?;
if !response.status().is_success() {
let status = response.status();
error!(status = %status, "HTTP error loading flags");
return Err(Error::Connection(format!("HTTP {}", status)));
}
let data = response
.json::<LocalEvaluationResponse>()
.await
.map_err(|e| {
error!(error = %e, "Failed to parse flag response");
Error::Serialization(e.to_string())
})?;
self.cache.update(data);
Ok(())
}
pub async fn stop(&mut self) {
debug!("Stopping async flag poller");
self.stop_signal.store(true, Ordering::Relaxed);
if let Some(handle) = self.task_handle.take() {
handle.abort();
}
*self.is_running.write().await = false;
}
pub async fn is_running(&self) -> bool {
*self.is_running.read().await
}
}
#[cfg(feature = "async-client")]
impl Drop for AsyncFlagPoller {
fn drop(&mut self) {
if let Some(handle) = self.task_handle.take() {
handle.abort();
}
}
}
#[derive(Clone)]
pub struct LocalEvaluator {
cache: FlagCache,
}
impl LocalEvaluator {
pub fn new(cache: FlagCache) -> Self {
Self { cache }
}
pub fn cache(&self) -> &FlagCache {
&self.cache
}
#[instrument(
skip(self, person_properties, groups, group_properties),
level = "trace"
)]
pub fn evaluate_flag(
&self,
key: &str,
distinct_id: &str,
person_properties: &HashMap<String, serde_json::Value>,
groups: &HashMap<String, String>,
group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
) -> Result<Option<FlagValue>, InconclusiveMatchError> {
match self.cache.get_flag(key) {
Some(flag) => {
let cohorts = self.cache.get_cohort_definitions();
let flags = self.cache.get_flags_map();
let group_type_mapping = self.cache.get_group_type_mapping();
let ctx = EvaluationContext {
cohorts: &cohorts,
flags: &flags,
distinct_id,
groups,
group_properties,
group_type_mapping: &group_type_mapping,
};
let result = match_feature_flag_with_context(&flag, person_properties, &ctx);
trace!(key, ?result, "Local flag evaluation");
result.map(Some)
}
None => {
trace!(key, "Flag not found in local cache");
Ok(None)
}
}
}
#[instrument(
skip(self, person_properties, groups, group_properties),
level = "trace"
)]
pub fn evaluate_flag_simple(
&self,
key: &str,
distinct_id: &str,
person_properties: &HashMap<String, serde_json::Value>,
groups: &HashMap<String, String>,
group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
) -> Result<Option<FlagValue>, InconclusiveMatchError> {
match self.cache.get_flag(key) {
Some(flag) => {
let group_type_mapping = self.cache.get_group_type_mapping();
let result = match_feature_flag(
&flag,
distinct_id,
person_properties,
groups,
group_properties,
&group_type_mapping,
);
trace!(key, ?result, "Local flag evaluation (simple)");
result.map(Some)
}
None => {
trace!(key, "Flag not found in local cache");
Ok(None)
}
}
}
#[instrument(
skip(self, person_properties, groups, group_properties),
level = "debug"
)]
pub fn evaluate_all_flags(
&self,
distinct_id: &str,
person_properties: &HashMap<String, serde_json::Value>,
groups: &HashMap<String, String>,
group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
let mut results = HashMap::new();
let cohorts = self.cache.get_cohort_definitions();
let flags = self.cache.get_flags_map();
let group_type_mapping = self.cache.get_group_type_mapping();
let ctx = EvaluationContext {
cohorts: &cohorts,
flags: &flags,
distinct_id,
groups,
group_properties,
group_type_mapping: &group_type_mapping,
};
for flag in self.cache.get_all_flags() {
let result = match_feature_flag_with_context(&flag, person_properties, &ctx);
results.insert(flag.key.clone(), result);
}
debug!(flag_count = results.len(), "Evaluated all local flags");
results
}
}