Skip to main content

posthog_rs/
local_evaluation.rs

1use crate::feature_flags::{
2    match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
3    FeatureFlag, FlagValue, InconclusiveMatchError,
4};
5use crate::Error;
6use reqwest::header::{HeaderMap, ETAG, IF_NONE_MATCH};
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13use tracing::{debug, error, info, instrument, trace, warn};
14
15/// Extract the ETag header value from a response's headers.
16/// Returns None if the header is missing, invalid UTF-8, or empty.
17fn extract_etag(headers: &HeaderMap) -> Option<String> {
18    headers
19        .get(ETAG)
20        .and_then(|v| v.to_str().ok())
21        .filter(|s| !s.is_empty())
22        .map(|s| s.to_string())
23}
24
25/// Response from the PostHog local evaluation API.
26///
27/// Contains feature flag definitions, group type mappings, and cohort definitions
28/// that can be cached locally for flag evaluation without server round-trips.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct LocalEvaluationResponse {
31    /// List of feature flag definitions
32    pub flags: Vec<FeatureFlag>,
33    /// Mapping from group type keys to their display names
34    #[serde(default)]
35    pub group_type_mapping: HashMap<String, String>,
36    /// Cohort definitions for evaluating cohort membership
37    #[serde(default)]
38    pub cohorts: HashMap<String, Cohort>,
39}
40
41/// A cohort definition for local evaluation.
42///
43/// Cohorts are groups of users defined by property filters, used for
44/// targeting feature flags to specific user segments.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Cohort {
47    /// Unique identifier for this cohort
48    pub id: String,
49    /// Human-readable name of the cohort
50    pub name: String,
51    /// Property filters that define cohort membership
52    pub properties: serde_json::Value,
53}
54
55/// Thread-safe cache for feature flag definitions.
56///
57/// Stores feature flags, group type mappings, and cohort definitions that have
58/// been fetched from the PostHog API. The cache is shared between the poller
59/// (which updates it) and the evaluator (which reads from it).
60#[derive(Clone)]
61pub struct FlagCache {
62    flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
63    group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
64    cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
65}
66
67impl Default for FlagCache {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl FlagCache {
74    pub fn new() -> Self {
75        Self {
76            flags: Arc::new(RwLock::new(HashMap::new())),
77            group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
78            cohorts: Arc::new(RwLock::new(HashMap::new())),
79        }
80    }
81
82    pub fn update(&self, response: LocalEvaluationResponse) {
83        let flag_count = response.flags.len();
84        let mut flags = self.flags.write().unwrap();
85        flags.clear();
86        for flag in response.flags {
87            flags.insert(flag.key.clone(), flag);
88        }
89
90        let mut mapping = self.group_type_mapping.write().unwrap();
91        *mapping = response.group_type_mapping;
92
93        let mut cohorts = self.cohorts.write().unwrap();
94        *cohorts = response.cohorts;
95
96        debug!(flag_count, "Updated flag cache");
97    }
98
99    pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
100        self.flags.read().unwrap().get(key).cloned()
101    }
102
103    pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
104        self.flags.read().unwrap().values().cloned().collect()
105    }
106
107    pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
108        self.cohorts.read().unwrap().get(id).cloned()
109    }
110
111    pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
112        self.cohorts.read().unwrap().clone()
113    }
114
115    /// Get all cohorts as CohortDefinitions for evaluation context
116    pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
117        self.cohorts
118            .read()
119            .unwrap()
120            .iter()
121            .map(|(k, v)| {
122                (
123                    k.clone(),
124                    CohortDefinition {
125                        id: v.id.clone(),
126                        properties: v.properties.clone(),
127                    },
128                )
129            })
130            .collect()
131    }
132
133    /// Get all flags as a HashMap for evaluation context
134    pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
135        self.flags.read().unwrap().clone()
136    }
137
138    /// Get the group type mapping (group type index → group type name).
139    pub fn get_group_type_mapping(&self) -> HashMap<String, String> {
140        self.group_type_mapping.read().unwrap().clone()
141    }
142
143    pub fn clear(&self) {
144        self.flags.write().unwrap().clear();
145        self.group_type_mapping.write().unwrap().clear();
146        self.cohorts.write().unwrap().clear();
147    }
148}
149
150/// Configuration for local flag evaluation.
151///
152/// Specifies the credentials and settings needed to fetch feature flag
153/// definitions from the PostHog API for local evaluation.
154#[derive(Clone)]
155pub struct LocalEvaluationConfig {
156    /// Personal API key for authentication (found in PostHog project settings)
157    pub personal_api_key: String,
158    /// Project API key to identify which project's flags to fetch
159    pub project_api_key: String,
160    /// PostHog API host URL (e.g., "https://us.posthog.com")
161    pub api_host: String,
162    /// How often to poll for updated flag definitions
163    pub poll_interval: Duration,
164    /// Timeout for API requests
165    pub request_timeout: Duration,
166}
167
168/// Synchronous poller for feature flag definitions.
169///
170/// Runs a background thread that periodically fetches flag definitions from
171/// the PostHog API and updates the shared cache. Use this for blocking/sync
172/// applications. For async applications, use [`AsyncFlagPoller`] instead.
173pub struct FlagPoller {
174    config: LocalEvaluationConfig,
175    cache: FlagCache,
176    client: reqwest::blocking::Client,
177    stop_signal: Arc<AtomicBool>,
178    thread_handle: Option<std::thread::JoinHandle<()>>,
179}
180
181impl FlagPoller {
182    pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
183        let client = reqwest::blocking::Client::builder()
184            .timeout(config.request_timeout)
185            .build()
186            .unwrap();
187
188        Self {
189            config,
190            cache,
191            client,
192            stop_signal: Arc::new(AtomicBool::new(false)),
193            thread_handle: None,
194        }
195    }
196
197    /// Start the polling thread
198    pub fn start(&mut self) {
199        info!(
200            poll_interval_secs = self.config.poll_interval.as_secs(),
201            "Starting feature flag poller"
202        );
203
204        // Initial load
205        match self.load_flags() {
206            Ok(()) => info!("Initial flag definitions loaded successfully"),
207            Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
208        }
209
210        let config = self.config.clone();
211        let cache = self.cache.clone();
212        let stop_signal = self.stop_signal.clone();
213
214        let handle = std::thread::spawn(move || {
215            let client = reqwest::blocking::Client::builder()
216                .timeout(config.request_timeout)
217                .build()
218                .unwrap();
219
220            let mut last_etag: Option<String> = None;
221
222            loop {
223                std::thread::sleep(config.poll_interval);
224
225                if stop_signal.load(Ordering::Relaxed) {
226                    debug!("Flag poller received stop signal");
227                    break;
228                }
229
230                let url = format!(
231                    "{}/flags/definitions/?send_cohorts",
232                    config.api_host.trim_end_matches('/')
233                );
234
235                let mut request = client
236                    .get(&url)
237                    .header(
238                        "Authorization",
239                        format!("Bearer {}", config.personal_api_key),
240                    )
241                    .header("X-PostHog-Project-Api-Key", &config.project_api_key);
242
243                if let Some(ref etag) = last_etag {
244                    request = request.header(IF_NONE_MATCH, etag.as_str());
245                }
246
247                match request.send() {
248                    Ok(response) => {
249                        if response.status() == StatusCode::NOT_MODIFIED {
250                            debug!("Flag definitions unchanged (304 Not Modified)");
251                        } else if response.status().is_success() {
252                            // Extract ETag before consuming the response body
253                            let new_etag = extract_etag(response.headers());
254
255                            match response.json::<LocalEvaluationResponse>() {
256                                Ok(data) => {
257                                    trace!("Successfully fetched flag definitions");
258                                    cache.update(data);
259                                    last_etag = new_etag;
260                                }
261                                Err(e) => {
262                                    warn!(error = %e, "Failed to parse flag response");
263                                }
264                            }
265                        } else {
266                            warn!(status = %response.status(), "Failed to fetch flags");
267                        }
268                    }
269                    Err(e) => {
270                        warn!(error = %e, "Failed to fetch flags");
271                    }
272                }
273            }
274        });
275
276        self.thread_handle = Some(handle);
277    }
278
279    /// Load flags synchronously
280    #[instrument(skip(self), level = "debug")]
281    pub fn load_flags(&self) -> Result<(), Error> {
282        let url = format!(
283            "{}/flags/definitions/?send_cohorts",
284            self.config.api_host.trim_end_matches('/')
285        );
286
287        let response = self
288            .client
289            .get(&url)
290            .header(
291                "Authorization",
292                format!("Bearer {}", self.config.personal_api_key),
293            )
294            .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
295            .send()
296            .map_err(|e| {
297                error!(error = %e, "Connection error loading flags");
298                Error::Connection(e.to_string())
299            })?;
300
301        if !response.status().is_success() {
302            let status = response.status();
303            error!(status = %status, "HTTP error loading flags");
304            return Err(Error::Connection(format!("HTTP {}", status)));
305        }
306
307        let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
308            error!(error = %e, "Failed to parse flag response");
309            Error::Serialization(e.to_string())
310        })?;
311
312        self.cache.update(data);
313        Ok(())
314    }
315
316    /// Stop the polling thread
317    pub fn stop(&mut self) {
318        debug!("Stopping flag poller");
319        self.stop_signal.store(true, Ordering::Relaxed);
320        if let Some(handle) = self.thread_handle.take() {
321            handle.join().ok();
322        }
323    }
324}
325
326impl Drop for FlagPoller {
327    fn drop(&mut self) {
328        self.stop();
329    }
330}
331
332/// Asynchronous poller for feature flag definitions.
333///
334/// Runs a tokio task that periodically fetches flag definitions from the
335/// PostHog API and updates the shared cache. Use this for async applications.
336/// For blocking/sync applications, use [`FlagPoller`] instead.
337#[cfg(feature = "async-client")]
338pub struct AsyncFlagPoller {
339    config: LocalEvaluationConfig,
340    cache: FlagCache,
341    client: reqwest::Client,
342    stop_signal: Arc<AtomicBool>,
343    task_handle: Option<tokio::task::JoinHandle<()>>,
344    is_running: Arc<tokio::sync::RwLock<bool>>,
345}
346
347#[cfg(feature = "async-client")]
348impl AsyncFlagPoller {
349    pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
350        let client = reqwest::Client::builder()
351            .timeout(config.request_timeout)
352            .build()
353            .unwrap();
354
355        Self {
356            config,
357            cache,
358            client,
359            stop_signal: Arc::new(AtomicBool::new(false)),
360            task_handle: None,
361            is_running: Arc::new(tokio::sync::RwLock::new(false)),
362        }
363    }
364
365    /// Start the polling task
366    pub async fn start(&mut self) {
367        // Check if already running
368        {
369            let mut is_running = self.is_running.write().await;
370            if *is_running {
371                debug!("Flag poller already running, skipping start");
372                return;
373            }
374            *is_running = true;
375        }
376
377        info!(
378            poll_interval_secs = self.config.poll_interval.as_secs(),
379            "Starting async feature flag poller"
380        );
381
382        // Initial load
383        match self.load_flags().await {
384            Ok(()) => info!("Initial flag definitions loaded successfully"),
385            Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
386        }
387
388        let config = self.config.clone();
389        let cache = self.cache.clone();
390        let stop_signal = self.stop_signal.clone();
391        let is_running = self.is_running.clone();
392        let client = self.client.clone();
393
394        let task = tokio::spawn(async move {
395            let mut interval = tokio::time::interval(config.poll_interval);
396            interval.tick().await; // Skip the first immediate tick
397
398            let mut last_etag: Option<String> = None;
399
400            loop {
401                tokio::select! {
402                    _ = interval.tick() => {
403                        if stop_signal.load(Ordering::Relaxed) {
404                            debug!("Async flag poller received stop signal");
405                            break;
406                        }
407
408                        let url = format!(
409                            "{}/flags/definitions/?send_cohorts",
410                            config.api_host.trim_end_matches('/')
411                        );
412
413                        let mut request = client
414                            .get(&url)
415                            .header("Authorization", format!("Bearer {}", config.personal_api_key))
416                            .header("X-PostHog-Project-Api-Key", &config.project_api_key);
417
418                        if let Some(ref etag) = last_etag {
419                            request = request.header(IF_NONE_MATCH, etag.as_str());
420                        }
421
422                        match request.send().await {
423                            Ok(response) => {
424                                if response.status() == StatusCode::NOT_MODIFIED {
425                                    debug!("Flag definitions unchanged (304 Not Modified)");
426                                } else if response.status().is_success() {
427                                    // Extract ETag before consuming the response body
428                                    let new_etag = extract_etag(response.headers());
429
430                                    match response.json::<LocalEvaluationResponse>().await {
431                                        Ok(data) => {
432                                            trace!("Successfully fetched flag definitions");
433                                            cache.update(data);
434                                            last_etag = new_etag;
435                                        }
436                                        Err(e) => {
437                                            warn!(error = %e, "Failed to parse flag response");
438                                        }
439                                    }
440                                } else {
441                                    warn!(status = %response.status(), "Failed to fetch flags");
442                                }
443                            }
444                            Err(e) => {
445                                warn!(error = %e, "Failed to fetch flags");
446                            }
447                        }
448                    }
449                }
450            }
451
452            // Clear running flag when task exits
453            *is_running.write().await = false;
454        });
455
456        self.task_handle = Some(task);
457    }
458
459    /// Load flags asynchronously
460    #[instrument(skip(self), level = "debug")]
461    pub async fn load_flags(&self) -> Result<(), Error> {
462        let url = format!(
463            "{}/flags/definitions/?send_cohorts",
464            self.config.api_host.trim_end_matches('/')
465        );
466
467        let response = self
468            .client
469            .get(&url)
470            .header(
471                "Authorization",
472                format!("Bearer {}", self.config.personal_api_key),
473            )
474            .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
475            .send()
476            .await
477            .map_err(|e| {
478                error!(error = %e, "Connection error loading flags");
479                Error::Connection(e.to_string())
480            })?;
481
482        if !response.status().is_success() {
483            let status = response.status();
484            error!(status = %status, "HTTP error loading flags");
485            return Err(Error::Connection(format!("HTTP {}", status)));
486        }
487
488        let data = response
489            .json::<LocalEvaluationResponse>()
490            .await
491            .map_err(|e| {
492                error!(error = %e, "Failed to parse flag response");
493                Error::Serialization(e.to_string())
494            })?;
495
496        self.cache.update(data);
497        Ok(())
498    }
499
500    /// Stop the polling task
501    pub async fn stop(&mut self) {
502        debug!("Stopping async flag poller");
503        self.stop_signal.store(true, Ordering::Relaxed);
504        if let Some(handle) = self.task_handle.take() {
505            handle.abort();
506        }
507        *self.is_running.write().await = false;
508    }
509
510    /// Check if the poller is running
511    pub async fn is_running(&self) -> bool {
512        *self.is_running.read().await
513    }
514}
515
516#[cfg(feature = "async-client")]
517impl Drop for AsyncFlagPoller {
518    fn drop(&mut self) {
519        // Abort the task if still running
520        if let Some(handle) = self.task_handle.take() {
521            handle.abort();
522        }
523    }
524}
525
526/// Evaluates feature flags using locally cached definitions.
527///
528/// The evaluator reads from a [`FlagCache`] to determine flag values without
529/// making network requests. Supports cohort membership checks and flag
530/// dependencies through the evaluation context.
531#[derive(Clone)]
532pub struct LocalEvaluator {
533    cache: FlagCache,
534}
535
536impl LocalEvaluator {
537    pub fn new(cache: FlagCache) -> Self {
538        Self { cache }
539    }
540
541    /// Access the underlying flag cache (e.g. to read group type mappings).
542    pub fn cache(&self) -> &FlagCache {
543        &self.cache
544    }
545
546    /// Evaluate a feature flag locally with full context support.
547    ///
548    /// Supports cohort membership checks, flag dependency evaluation, and
549    /// group / mixed-targeting flags. `groups` and `group_properties` are
550    /// only consulted when the flag (or one of its conditions) targets a
551    /// group via `aggregation_group_type_index`; pass empty maps for
552    /// person-targeted flags.
553    #[instrument(
554        skip(self, person_properties, groups, group_properties),
555        level = "trace"
556    )]
557    pub fn evaluate_flag(
558        &self,
559        key: &str,
560        distinct_id: &str,
561        person_properties: &HashMap<String, serde_json::Value>,
562        groups: &HashMap<String, String>,
563        group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
564    ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
565        match self.cache.get_flag(key) {
566            Some(flag) => {
567                // Build evaluation context with cohorts, flags, and group info
568                let cohorts = self.cache.get_cohort_definitions();
569                let flags = self.cache.get_flags_map();
570                let group_type_mapping = self.cache.get_group_type_mapping();
571
572                let ctx = EvaluationContext {
573                    cohorts: &cohorts,
574                    flags: &flags,
575                    distinct_id,
576                    groups,
577                    group_properties,
578                    group_type_mapping: &group_type_mapping,
579                };
580
581                let result = match_feature_flag_with_context(&flag, person_properties, &ctx);
582                trace!(key, ?result, "Local flag evaluation");
583                result.map(Some)
584            }
585            None => {
586                trace!(key, "Flag not found in local cache");
587                Ok(None)
588            }
589        }
590    }
591
592    /// Evaluate a feature flag locally (simple version without cohort/flag dependency support).
593    /// Use this when you know the flag doesn't have cohort or flag dependency conditions.
594    #[instrument(
595        skip(self, person_properties, groups, group_properties),
596        level = "trace"
597    )]
598    pub fn evaluate_flag_simple(
599        &self,
600        key: &str,
601        distinct_id: &str,
602        person_properties: &HashMap<String, serde_json::Value>,
603        groups: &HashMap<String, String>,
604        group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
605    ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
606        match self.cache.get_flag(key) {
607            Some(flag) => {
608                let group_type_mapping = self.cache.get_group_type_mapping();
609                let result = match_feature_flag(
610                    &flag,
611                    distinct_id,
612                    person_properties,
613                    groups,
614                    group_properties,
615                    &group_type_mapping,
616                );
617                trace!(key, ?result, "Local flag evaluation (simple)");
618                result.map(Some)
619            }
620            None => {
621                trace!(key, "Flag not found in local cache");
622                Ok(None)
623            }
624        }
625    }
626
627    /// Get all flags and evaluate them with full context support.
628    #[instrument(
629        skip(self, person_properties, groups, group_properties),
630        level = "debug"
631    )]
632    pub fn evaluate_all_flags(
633        &self,
634        distinct_id: &str,
635        person_properties: &HashMap<String, serde_json::Value>,
636        groups: &HashMap<String, String>,
637        group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
638    ) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
639        let mut results = HashMap::new();
640
641        // Build evaluation context once for all flags
642        let cohorts = self.cache.get_cohort_definitions();
643        let flags = self.cache.get_flags_map();
644        let group_type_mapping = self.cache.get_group_type_mapping();
645
646        let ctx = EvaluationContext {
647            cohorts: &cohorts,
648            flags: &flags,
649            distinct_id,
650            groups,
651            group_properties,
652            group_type_mapping: &group_type_mapping,
653        };
654
655        for flag in self.cache.get_all_flags() {
656            let result = match_feature_flag_with_context(&flag, person_properties, &ctx);
657            results.insert(flag.key.clone(), result);
658        }
659
660        debug!(flag_count = results.len(), "Evaluated all local flags");
661        results
662    }
663}