Skip to main content

posthog_rs/
local_evaluation.rs

1use crate::feature_flags::{
2    match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
3    FeatureFlag, FlagValue, InconclusiveMatchError,
4};
5use crate::Error;
6use reqwest::header::{HeaderMap, ETAG, IF_NONE_MATCH};
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13use tracing::{debug, error, info, instrument, trace, warn};
14
15/// Extract the ETag header value from a response's headers.
16/// Returns None if the header is missing, invalid UTF-8, or empty.
17fn extract_etag(headers: &HeaderMap) -> Option<String> {
18    headers
19        .get(ETAG)
20        .and_then(|v| v.to_str().ok())
21        .filter(|s| !s.is_empty())
22        .map(|s| s.to_string())
23}
24
25/// Response from the PostHog local evaluation API.
26///
27/// Contains feature flag definitions, group type mappings, and cohort definitions
28/// that can be cached locally for flag evaluation without server round-trips.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct LocalEvaluationResponse {
31    /// List of feature flag definitions
32    pub flags: Vec<FeatureFlag>,
33    /// Mapping from group type keys to their display names
34    #[serde(default)]
35    pub group_type_mapping: HashMap<String, String>,
36    /// Cohort definitions for evaluating cohort membership
37    #[serde(default)]
38    pub cohorts: HashMap<String, Cohort>,
39}
40
41/// A cohort definition for local evaluation.
42///
43/// Cohorts are groups of users defined by property filters, used for
44/// targeting feature flags to specific user segments.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Cohort {
47    /// Unique identifier for this cohort
48    pub id: String,
49    /// Human-readable name of the cohort
50    pub name: String,
51    /// Property filters that define cohort membership
52    pub properties: serde_json::Value,
53}
54
55/// Thread-safe cache for feature flag definitions.
56///
57/// Stores feature flags, group type mappings, and cohort definitions that have
58/// been fetched from the PostHog API. The cache is shared between the poller
59/// (which updates it) and the evaluator (which reads from it).
60#[derive(Clone)]
61pub struct FlagCache {
62    flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
63    group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
64    cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
65}
66
67impl Default for FlagCache {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl FlagCache {
74    pub fn new() -> Self {
75        Self {
76            flags: Arc::new(RwLock::new(HashMap::new())),
77            group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
78            cohorts: Arc::new(RwLock::new(HashMap::new())),
79        }
80    }
81
82    pub fn update(&self, response: LocalEvaluationResponse) {
83        let flag_count = response.flags.len();
84        let mut flags = self.flags.write().unwrap();
85        flags.clear();
86        for flag in response.flags {
87            flags.insert(flag.key.clone(), flag);
88        }
89
90        let mut mapping = self.group_type_mapping.write().unwrap();
91        *mapping = response.group_type_mapping;
92
93        let mut cohorts = self.cohorts.write().unwrap();
94        *cohorts = response.cohorts;
95
96        debug!(flag_count, "Updated flag cache");
97    }
98
99    pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
100        self.flags.read().unwrap().get(key).cloned()
101    }
102
103    pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
104        self.flags.read().unwrap().values().cloned().collect()
105    }
106
107    pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
108        self.cohorts.read().unwrap().get(id).cloned()
109    }
110
111    pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
112        self.cohorts.read().unwrap().clone()
113    }
114
115    /// Get all cohorts as CohortDefinitions for evaluation context
116    pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
117        self.cohorts
118            .read()
119            .unwrap()
120            .iter()
121            .map(|(k, v)| {
122                (
123                    k.clone(),
124                    CohortDefinition {
125                        id: v.id.clone(),
126                        properties: v.properties.clone(),
127                    },
128                )
129            })
130            .collect()
131    }
132
133    /// Get all flags as a HashMap for evaluation context
134    pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
135        self.flags.read().unwrap().clone()
136    }
137
138    pub fn clear(&self) {
139        self.flags.write().unwrap().clear();
140        self.group_type_mapping.write().unwrap().clear();
141        self.cohorts.write().unwrap().clear();
142    }
143}
144
145/// Configuration for local flag evaluation.
146///
147/// Specifies the credentials and settings needed to fetch feature flag
148/// definitions from the PostHog API for local evaluation.
149#[derive(Clone)]
150pub struct LocalEvaluationConfig {
151    /// Personal API key for authentication (found in PostHog project settings)
152    pub personal_api_key: String,
153    /// Project API key to identify which project's flags to fetch
154    pub project_api_key: String,
155    /// PostHog API host URL (e.g., "https://us.posthog.com")
156    pub api_host: String,
157    /// How often to poll for updated flag definitions
158    pub poll_interval: Duration,
159    /// Timeout for API requests
160    pub request_timeout: Duration,
161}
162
163/// Synchronous poller for feature flag definitions.
164///
165/// Runs a background thread that periodically fetches flag definitions from
166/// the PostHog API and updates the shared cache. Use this for blocking/sync
167/// applications. For async applications, use [`AsyncFlagPoller`] instead.
168pub struct FlagPoller {
169    config: LocalEvaluationConfig,
170    cache: FlagCache,
171    client: reqwest::blocking::Client,
172    stop_signal: Arc<AtomicBool>,
173    thread_handle: Option<std::thread::JoinHandle<()>>,
174}
175
176impl FlagPoller {
177    pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
178        let client = reqwest::blocking::Client::builder()
179            .timeout(config.request_timeout)
180            .build()
181            .unwrap();
182
183        Self {
184            config,
185            cache,
186            client,
187            stop_signal: Arc::new(AtomicBool::new(false)),
188            thread_handle: None,
189        }
190    }
191
192    /// Start the polling thread
193    pub fn start(&mut self) {
194        info!(
195            poll_interval_secs = self.config.poll_interval.as_secs(),
196            "Starting feature flag poller"
197        );
198
199        // Initial load
200        match self.load_flags() {
201            Ok(()) => info!("Initial flag definitions loaded successfully"),
202            Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
203        }
204
205        let config = self.config.clone();
206        let cache = self.cache.clone();
207        let stop_signal = self.stop_signal.clone();
208
209        let handle = std::thread::spawn(move || {
210            let client = reqwest::blocking::Client::builder()
211                .timeout(config.request_timeout)
212                .build()
213                .unwrap();
214
215            let mut last_etag: Option<String> = None;
216
217            loop {
218                std::thread::sleep(config.poll_interval);
219
220                if stop_signal.load(Ordering::Relaxed) {
221                    debug!("Flag poller received stop signal");
222                    break;
223                }
224
225                let url = format!(
226                    "{}/api/feature_flag/local_evaluation/?send_cohorts",
227                    config.api_host.trim_end_matches('/')
228                );
229
230                let mut request = client
231                    .get(&url)
232                    .header(
233                        "Authorization",
234                        format!("Bearer {}", config.personal_api_key),
235                    )
236                    .header("X-PostHog-Project-Api-Key", &config.project_api_key);
237
238                if let Some(ref etag) = last_etag {
239                    request = request.header(IF_NONE_MATCH, etag.as_str());
240                }
241
242                match request.send() {
243                    Ok(response) => {
244                        if response.status() == StatusCode::NOT_MODIFIED {
245                            debug!("Flag definitions unchanged (304 Not Modified)");
246                        } else if response.status().is_success() {
247                            // Extract ETag before consuming the response body
248                            let new_etag = extract_etag(response.headers());
249
250                            match response.json::<LocalEvaluationResponse>() {
251                                Ok(data) => {
252                                    trace!("Successfully fetched flag definitions");
253                                    cache.update(data);
254                                    last_etag = new_etag;
255                                }
256                                Err(e) => {
257                                    warn!(error = %e, "Failed to parse flag response");
258                                }
259                            }
260                        } else {
261                            warn!(status = %response.status(), "Failed to fetch flags");
262                        }
263                    }
264                    Err(e) => {
265                        warn!(error = %e, "Failed to fetch flags");
266                    }
267                }
268            }
269        });
270
271        self.thread_handle = Some(handle);
272    }
273
274    /// Load flags synchronously
275    #[instrument(skip(self), level = "debug")]
276    pub fn load_flags(&self) -> Result<(), Error> {
277        let url = format!(
278            "{}/api/feature_flag/local_evaluation/?send_cohorts",
279            self.config.api_host.trim_end_matches('/')
280        );
281
282        let response = self
283            .client
284            .get(&url)
285            .header(
286                "Authorization",
287                format!("Bearer {}", self.config.personal_api_key),
288            )
289            .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
290            .send()
291            .map_err(|e| {
292                error!(error = %e, "Connection error loading flags");
293                Error::Connection(e.to_string())
294            })?;
295
296        if !response.status().is_success() {
297            let status = response.status();
298            error!(status = %status, "HTTP error loading flags");
299            return Err(Error::Connection(format!("HTTP {}", status)));
300        }
301
302        let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
303            error!(error = %e, "Failed to parse flag response");
304            Error::Serialization(e.to_string())
305        })?;
306
307        self.cache.update(data);
308        Ok(())
309    }
310
311    /// Stop the polling thread
312    pub fn stop(&mut self) {
313        debug!("Stopping flag poller");
314        self.stop_signal.store(true, Ordering::Relaxed);
315        if let Some(handle) = self.thread_handle.take() {
316            handle.join().ok();
317        }
318    }
319}
320
321impl Drop for FlagPoller {
322    fn drop(&mut self) {
323        self.stop();
324    }
325}
326
327/// Asynchronous poller for feature flag definitions.
328///
329/// Runs a tokio task that periodically fetches flag definitions from the
330/// PostHog API and updates the shared cache. Use this for async applications.
331/// For blocking/sync applications, use [`FlagPoller`] instead.
332#[cfg(feature = "async-client")]
333pub struct AsyncFlagPoller {
334    config: LocalEvaluationConfig,
335    cache: FlagCache,
336    client: reqwest::Client,
337    stop_signal: Arc<AtomicBool>,
338    task_handle: Option<tokio::task::JoinHandle<()>>,
339    is_running: Arc<tokio::sync::RwLock<bool>>,
340}
341
342#[cfg(feature = "async-client")]
343impl AsyncFlagPoller {
344    pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
345        let client = reqwest::Client::builder()
346            .timeout(config.request_timeout)
347            .build()
348            .unwrap();
349
350        Self {
351            config,
352            cache,
353            client,
354            stop_signal: Arc::new(AtomicBool::new(false)),
355            task_handle: None,
356            is_running: Arc::new(tokio::sync::RwLock::new(false)),
357        }
358    }
359
360    /// Start the polling task
361    pub async fn start(&mut self) {
362        // Check if already running
363        {
364            let mut is_running = self.is_running.write().await;
365            if *is_running {
366                debug!("Flag poller already running, skipping start");
367                return;
368            }
369            *is_running = true;
370        }
371
372        info!(
373            poll_interval_secs = self.config.poll_interval.as_secs(),
374            "Starting async feature flag poller"
375        );
376
377        // Initial load
378        match self.load_flags().await {
379            Ok(()) => info!("Initial flag definitions loaded successfully"),
380            Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
381        }
382
383        let config = self.config.clone();
384        let cache = self.cache.clone();
385        let stop_signal = self.stop_signal.clone();
386        let is_running = self.is_running.clone();
387        let client = self.client.clone();
388
389        let task = tokio::spawn(async move {
390            let mut interval = tokio::time::interval(config.poll_interval);
391            interval.tick().await; // Skip the first immediate tick
392
393            let mut last_etag: Option<String> = None;
394
395            loop {
396                tokio::select! {
397                    _ = interval.tick() => {
398                        if stop_signal.load(Ordering::Relaxed) {
399                            debug!("Async flag poller received stop signal");
400                            break;
401                        }
402
403                        let url = format!(
404                            "{}/api/feature_flag/local_evaluation/?send_cohorts",
405                            config.api_host.trim_end_matches('/')
406                        );
407
408                        let mut request = client
409                            .get(&url)
410                            .header("Authorization", format!("Bearer {}", config.personal_api_key))
411                            .header("X-PostHog-Project-Api-Key", &config.project_api_key);
412
413                        if let Some(ref etag) = last_etag {
414                            request = request.header(IF_NONE_MATCH, etag.as_str());
415                        }
416
417                        match request.send().await {
418                            Ok(response) => {
419                                if response.status() == StatusCode::NOT_MODIFIED {
420                                    debug!("Flag definitions unchanged (304 Not Modified)");
421                                } else if response.status().is_success() {
422                                    // Extract ETag before consuming the response body
423                                    let new_etag = extract_etag(response.headers());
424
425                                    match response.json::<LocalEvaluationResponse>().await {
426                                        Ok(data) => {
427                                            trace!("Successfully fetched flag definitions");
428                                            cache.update(data);
429                                            last_etag = new_etag;
430                                        }
431                                        Err(e) => {
432                                            warn!(error = %e, "Failed to parse flag response");
433                                        }
434                                    }
435                                } else {
436                                    warn!(status = %response.status(), "Failed to fetch flags");
437                                }
438                            }
439                            Err(e) => {
440                                warn!(error = %e, "Failed to fetch flags");
441                            }
442                        }
443                    }
444                }
445            }
446
447            // Clear running flag when task exits
448            *is_running.write().await = false;
449        });
450
451        self.task_handle = Some(task);
452    }
453
454    /// Load flags asynchronously
455    #[instrument(skip(self), level = "debug")]
456    pub async fn load_flags(&self) -> Result<(), Error> {
457        let url = format!(
458            "{}/api/feature_flag/local_evaluation/?send_cohorts",
459            self.config.api_host.trim_end_matches('/')
460        );
461
462        let response = self
463            .client
464            .get(&url)
465            .header(
466                "Authorization",
467                format!("Bearer {}", self.config.personal_api_key),
468            )
469            .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
470            .send()
471            .await
472            .map_err(|e| {
473                error!(error = %e, "Connection error loading flags");
474                Error::Connection(e.to_string())
475            })?;
476
477        if !response.status().is_success() {
478            let status = response.status();
479            error!(status = %status, "HTTP error loading flags");
480            return Err(Error::Connection(format!("HTTP {}", status)));
481        }
482
483        let data = response
484            .json::<LocalEvaluationResponse>()
485            .await
486            .map_err(|e| {
487                error!(error = %e, "Failed to parse flag response");
488                Error::Serialization(e.to_string())
489            })?;
490
491        self.cache.update(data);
492        Ok(())
493    }
494
495    /// Stop the polling task
496    pub async fn stop(&mut self) {
497        debug!("Stopping async flag poller");
498        self.stop_signal.store(true, Ordering::Relaxed);
499        if let Some(handle) = self.task_handle.take() {
500            handle.abort();
501        }
502        *self.is_running.write().await = false;
503    }
504
505    /// Check if the poller is running
506    pub async fn is_running(&self) -> bool {
507        *self.is_running.read().await
508    }
509}
510
511#[cfg(feature = "async-client")]
512impl Drop for AsyncFlagPoller {
513    fn drop(&mut self) {
514        // Abort the task if still running
515        if let Some(handle) = self.task_handle.take() {
516            handle.abort();
517        }
518    }
519}
520
521/// Evaluates feature flags using locally cached definitions.
522///
523/// The evaluator reads from a [`FlagCache`] to determine flag values without
524/// making network requests. Supports cohort membership checks and flag
525/// dependencies through the evaluation context.
526#[derive(Clone)]
527pub struct LocalEvaluator {
528    cache: FlagCache,
529}
530
531impl LocalEvaluator {
532    pub fn new(cache: FlagCache) -> Self {
533        Self { cache }
534    }
535
536    /// Evaluate a feature flag locally with full context support
537    /// This supports cohort membership checks and flag dependency evaluation
538    #[instrument(skip(self, person_properties), level = "trace")]
539    pub fn evaluate_flag(
540        &self,
541        key: &str,
542        distinct_id: &str,
543        person_properties: &HashMap<String, serde_json::Value>,
544    ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
545        match self.cache.get_flag(key) {
546            Some(flag) => {
547                // Build evaluation context with cohorts and flags for dependency resolution
548                let cohorts = self.cache.get_cohort_definitions();
549                let flags = self.cache.get_flags_map();
550
551                let ctx = EvaluationContext {
552                    cohorts: &cohorts,
553                    flags: &flags,
554                    distinct_id,
555                };
556
557                let result =
558                    match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
559                trace!(key, ?result, "Local flag evaluation");
560                result.map(Some)
561            }
562            None => {
563                trace!(key, "Flag not found in local cache");
564                Ok(None)
565            }
566        }
567    }
568
569    /// Evaluate a feature flag locally (simple version without cohort/flag dependency support)
570    /// Use this when you know the flag doesn't have cohort or flag dependency conditions
571    #[instrument(skip(self, person_properties), level = "trace")]
572    pub fn evaluate_flag_simple(
573        &self,
574        key: &str,
575        distinct_id: &str,
576        person_properties: &HashMap<String, serde_json::Value>,
577    ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
578        match self.cache.get_flag(key) {
579            Some(flag) => {
580                let result = match_feature_flag(&flag, distinct_id, person_properties);
581                trace!(key, ?result, "Local flag evaluation (simple)");
582                result.map(Some)
583            }
584            None => {
585                trace!(key, "Flag not found in local cache");
586                Ok(None)
587            }
588        }
589    }
590
591    /// Get all flags and evaluate them with full context support
592    #[instrument(skip(self, person_properties), level = "debug")]
593    pub fn evaluate_all_flags(
594        &self,
595        distinct_id: &str,
596        person_properties: &HashMap<String, serde_json::Value>,
597    ) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
598        let mut results = HashMap::new();
599
600        // Build evaluation context once for all flags
601        let cohorts = self.cache.get_cohort_definitions();
602        let flags = self.cache.get_flags_map();
603
604        let ctx = EvaluationContext {
605            cohorts: &cohorts,
606            flags: &flags,
607            distinct_id,
608        };
609
610        for flag in self.cache.get_all_flags() {
611            let result =
612                match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
613            results.insert(flag.key.clone(), result);
614        }
615
616        debug!(flag_count = results.len(), "Evaluated all local flags");
617        results
618    }
619}