Skip to main content

posthog_rs/
local_evaluation.rs

1use crate::feature_flags::{
2    match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
3    FeatureFlag, FlagValue, InconclusiveMatchError,
4};
5use crate::Error;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11use tracing::{debug, error, info, instrument, trace, warn};
12
13/// Response from the PostHog local evaluation API.
14///
15/// Contains feature flag definitions, group type mappings, and cohort definitions
16/// that can be cached locally for flag evaluation without server round-trips.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LocalEvaluationResponse {
19    /// List of feature flag definitions
20    pub flags: Vec<FeatureFlag>,
21    /// Mapping from group type keys to their display names
22    #[serde(default)]
23    pub group_type_mapping: HashMap<String, String>,
24    /// Cohort definitions for evaluating cohort membership
25    #[serde(default)]
26    pub cohorts: HashMap<String, Cohort>,
27}
28
29/// A cohort definition for local evaluation.
30///
31/// Cohorts are groups of users defined by property filters, used for
32/// targeting feature flags to specific user segments.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Cohort {
35    /// Unique identifier for this cohort
36    pub id: String,
37    /// Human-readable name of the cohort
38    pub name: String,
39    /// Property filters that define cohort membership
40    pub properties: serde_json::Value,
41}
42
43/// Thread-safe cache for feature flag definitions.
44///
45/// Stores feature flags, group type mappings, and cohort definitions that have
46/// been fetched from the PostHog API. The cache is shared between the poller
47/// (which updates it) and the evaluator (which reads from it).
48#[derive(Clone)]
49pub struct FlagCache {
50    flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
51    group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
52    cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
53}
54
55impl Default for FlagCache {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl FlagCache {
62    pub fn new() -> Self {
63        Self {
64            flags: Arc::new(RwLock::new(HashMap::new())),
65            group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
66            cohorts: Arc::new(RwLock::new(HashMap::new())),
67        }
68    }
69
70    pub fn update(&self, response: LocalEvaluationResponse) {
71        let flag_count = response.flags.len();
72        let mut flags = self.flags.write().unwrap();
73        flags.clear();
74        for flag in response.flags {
75            flags.insert(flag.key.clone(), flag);
76        }
77
78        let mut mapping = self.group_type_mapping.write().unwrap();
79        *mapping = response.group_type_mapping;
80
81        let mut cohorts = self.cohorts.write().unwrap();
82        *cohorts = response.cohorts;
83
84        debug!(flag_count, "Updated flag cache");
85    }
86
87    pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
88        self.flags.read().unwrap().get(key).cloned()
89    }
90
91    pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
92        self.flags.read().unwrap().values().cloned().collect()
93    }
94
95    pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
96        self.cohorts.read().unwrap().get(id).cloned()
97    }
98
99    pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
100        self.cohorts.read().unwrap().clone()
101    }
102
103    /// Get all cohorts as CohortDefinitions for evaluation context
104    pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
105        self.cohorts
106            .read()
107            .unwrap()
108            .iter()
109            .map(|(k, v)| {
110                (
111                    k.clone(),
112                    CohortDefinition {
113                        id: v.id.clone(),
114                        properties: v.properties.clone(),
115                    },
116                )
117            })
118            .collect()
119    }
120
121    /// Get all flags as a HashMap for evaluation context
122    pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
123        self.flags.read().unwrap().clone()
124    }
125
126    pub fn clear(&self) {
127        self.flags.write().unwrap().clear();
128        self.group_type_mapping.write().unwrap().clear();
129        self.cohorts.write().unwrap().clear();
130    }
131}
132
133/// Configuration for local flag evaluation.
134///
135/// Specifies the credentials and settings needed to fetch feature flag
136/// definitions from the PostHog API for local evaluation.
137#[derive(Clone)]
138pub struct LocalEvaluationConfig {
139    /// Personal API key for authentication (found in PostHog project settings)
140    pub personal_api_key: String,
141    /// Project API key to identify which project's flags to fetch
142    pub project_api_key: String,
143    /// PostHog API host URL (e.g., "https://us.posthog.com")
144    pub api_host: String,
145    /// How often to poll for updated flag definitions
146    pub poll_interval: Duration,
147    /// Timeout for API requests
148    pub request_timeout: Duration,
149}
150
151/// Synchronous poller for feature flag definitions.
152///
153/// Runs a background thread that periodically fetches flag definitions from
154/// the PostHog API and updates the shared cache. Use this for blocking/sync
155/// applications. For async applications, use [`AsyncFlagPoller`] instead.
156pub struct FlagPoller {
157    config: LocalEvaluationConfig,
158    cache: FlagCache,
159    client: reqwest::blocking::Client,
160    stop_signal: Arc<AtomicBool>,
161    thread_handle: Option<std::thread::JoinHandle<()>>,
162}
163
164impl FlagPoller {
165    pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
166        let client = reqwest::blocking::Client::builder()
167            .timeout(config.request_timeout)
168            .build()
169            .unwrap();
170
171        Self {
172            config,
173            cache,
174            client,
175            stop_signal: Arc::new(AtomicBool::new(false)),
176            thread_handle: None,
177        }
178    }
179
180    /// Start the polling thread
181    pub fn start(&mut self) {
182        info!(
183            poll_interval_secs = self.config.poll_interval.as_secs(),
184            "Starting feature flag poller"
185        );
186
187        // Initial load
188        match self.load_flags() {
189            Ok(()) => info!("Initial flag definitions loaded successfully"),
190            Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
191        }
192
193        let config = self.config.clone();
194        let cache = self.cache.clone();
195        let stop_signal = self.stop_signal.clone();
196
197        let handle = std::thread::spawn(move || {
198            let client = reqwest::blocking::Client::builder()
199                .timeout(config.request_timeout)
200                .build()
201                .unwrap();
202
203            loop {
204                std::thread::sleep(config.poll_interval);
205
206                if stop_signal.load(Ordering::Relaxed) {
207                    debug!("Flag poller received stop signal");
208                    break;
209                }
210
211                let url = format!(
212                    "{}/api/feature_flag/local_evaluation/?send_cohorts",
213                    config.api_host.trim_end_matches('/')
214                );
215
216                match client
217                    .get(&url)
218                    .header(
219                        "Authorization",
220                        format!("Bearer {}", config.personal_api_key),
221                    )
222                    .header("X-PostHog-Project-Api-Key", &config.project_api_key)
223                    .send()
224                {
225                    Ok(response) => {
226                        if response.status().is_success() {
227                            match response.json::<LocalEvaluationResponse>() {
228                                Ok(data) => {
229                                    trace!("Successfully fetched flag definitions");
230                                    cache.update(data);
231                                }
232                                Err(e) => {
233                                    warn!(error = %e, "Failed to parse flag response");
234                                }
235                            }
236                        } else {
237                            warn!(status = %response.status(), "Failed to fetch flags");
238                        }
239                    }
240                    Err(e) => {
241                        warn!(error = %e, "Failed to fetch flags");
242                    }
243                }
244            }
245        });
246
247        self.thread_handle = Some(handle);
248    }
249
250    /// Load flags synchronously
251    #[instrument(skip(self), level = "debug")]
252    pub fn load_flags(&self) -> Result<(), Error> {
253        let url = format!(
254            "{}/api/feature_flag/local_evaluation/?send_cohorts",
255            self.config.api_host.trim_end_matches('/')
256        );
257
258        let response = self
259            .client
260            .get(&url)
261            .header(
262                "Authorization",
263                format!("Bearer {}", self.config.personal_api_key),
264            )
265            .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
266            .send()
267            .map_err(|e| {
268                error!(error = %e, "Connection error loading flags");
269                Error::Connection(e.to_string())
270            })?;
271
272        if !response.status().is_success() {
273            let status = response.status();
274            error!(status = %status, "HTTP error loading flags");
275            return Err(Error::Connection(format!("HTTP {}", status)));
276        }
277
278        let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
279            error!(error = %e, "Failed to parse flag response");
280            Error::Serialization(e.to_string())
281        })?;
282
283        self.cache.update(data);
284        Ok(())
285    }
286
287    /// Stop the polling thread
288    pub fn stop(&mut self) {
289        debug!("Stopping flag poller");
290        self.stop_signal.store(true, Ordering::Relaxed);
291        if let Some(handle) = self.thread_handle.take() {
292            handle.join().ok();
293        }
294    }
295}
296
297impl Drop for FlagPoller {
298    fn drop(&mut self) {
299        self.stop();
300    }
301}
302
303/// Asynchronous poller for feature flag definitions.
304///
305/// Runs a tokio task that periodically fetches flag definitions from the
306/// PostHog API and updates the shared cache. Use this for async applications.
307/// For blocking/sync applications, use [`FlagPoller`] instead.
308#[cfg(feature = "async-client")]
309pub struct AsyncFlagPoller {
310    config: LocalEvaluationConfig,
311    cache: FlagCache,
312    client: reqwest::Client,
313    stop_signal: Arc<AtomicBool>,
314    task_handle: Option<tokio::task::JoinHandle<()>>,
315    is_running: Arc<tokio::sync::RwLock<bool>>,
316}
317
318#[cfg(feature = "async-client")]
319impl AsyncFlagPoller {
320    pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
321        let client = reqwest::Client::builder()
322            .timeout(config.request_timeout)
323            .build()
324            .unwrap();
325
326        Self {
327            config,
328            cache,
329            client,
330            stop_signal: Arc::new(AtomicBool::new(false)),
331            task_handle: None,
332            is_running: Arc::new(tokio::sync::RwLock::new(false)),
333        }
334    }
335
336    /// Start the polling task
337    pub async fn start(&mut self) {
338        // Check if already running
339        {
340            let mut is_running = self.is_running.write().await;
341            if *is_running {
342                debug!("Flag poller already running, skipping start");
343                return;
344            }
345            *is_running = true;
346        }
347
348        info!(
349            poll_interval_secs = self.config.poll_interval.as_secs(),
350            "Starting async feature flag poller"
351        );
352
353        // Initial load
354        match self.load_flags().await {
355            Ok(()) => info!("Initial flag definitions loaded successfully"),
356            Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
357        }
358
359        let config = self.config.clone();
360        let cache = self.cache.clone();
361        let stop_signal = self.stop_signal.clone();
362        let is_running = self.is_running.clone();
363        let client = self.client.clone();
364
365        let task = tokio::spawn(async move {
366            let mut interval = tokio::time::interval(config.poll_interval);
367            interval.tick().await; // Skip the first immediate tick
368
369            loop {
370                tokio::select! {
371                    _ = interval.tick() => {
372                        if stop_signal.load(Ordering::Relaxed) {
373                            debug!("Async flag poller received stop signal");
374                            break;
375                        }
376
377                        let url = format!(
378                            "{}/api/feature_flag/local_evaluation/?send_cohorts",
379                            config.api_host.trim_end_matches('/')
380                        );
381
382                        match client
383                            .get(&url)
384                            .header("Authorization", format!("Bearer {}", config.personal_api_key))
385                            .header("X-PostHog-Project-Api-Key", &config.project_api_key)
386                            .send()
387                            .await
388                        {
389                            Ok(response) => {
390                                if response.status().is_success() {
391                                    match response.json::<LocalEvaluationResponse>().await {
392                                        Ok(data) => {
393                                            trace!("Successfully fetched flag definitions");
394                                            cache.update(data);
395                                        }
396                                        Err(e) => {
397                                            warn!(error = %e, "Failed to parse flag response");
398                                        }
399                                    }
400                                } else {
401                                    warn!(status = %response.status(), "Failed to fetch flags");
402                                }
403                            }
404                            Err(e) => {
405                                warn!(error = %e, "Failed to fetch flags");
406                            }
407                        }
408                    }
409                }
410            }
411
412            // Clear running flag when task exits
413            *is_running.write().await = false;
414        });
415
416        self.task_handle = Some(task);
417    }
418
419    /// Load flags asynchronously
420    #[instrument(skip(self), level = "debug")]
421    pub async fn load_flags(&self) -> Result<(), Error> {
422        let url = format!(
423            "{}/api/feature_flag/local_evaluation/?send_cohorts",
424            self.config.api_host.trim_end_matches('/')
425        );
426
427        let response = self
428            .client
429            .get(&url)
430            .header(
431                "Authorization",
432                format!("Bearer {}", self.config.personal_api_key),
433            )
434            .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
435            .send()
436            .await
437            .map_err(|e| {
438                error!(error = %e, "Connection error loading flags");
439                Error::Connection(e.to_string())
440            })?;
441
442        if !response.status().is_success() {
443            let status = response.status();
444            error!(status = %status, "HTTP error loading flags");
445            return Err(Error::Connection(format!("HTTP {}", status)));
446        }
447
448        let data = response
449            .json::<LocalEvaluationResponse>()
450            .await
451            .map_err(|e| {
452                error!(error = %e, "Failed to parse flag response");
453                Error::Serialization(e.to_string())
454            })?;
455
456        self.cache.update(data);
457        Ok(())
458    }
459
460    /// Stop the polling task
461    pub async fn stop(&mut self) {
462        debug!("Stopping async flag poller");
463        self.stop_signal.store(true, Ordering::Relaxed);
464        if let Some(handle) = self.task_handle.take() {
465            handle.abort();
466        }
467        *self.is_running.write().await = false;
468    }
469
470    /// Check if the poller is running
471    pub async fn is_running(&self) -> bool {
472        *self.is_running.read().await
473    }
474}
475
476#[cfg(feature = "async-client")]
477impl Drop for AsyncFlagPoller {
478    fn drop(&mut self) {
479        // Abort the task if still running
480        if let Some(handle) = self.task_handle.take() {
481            handle.abort();
482        }
483    }
484}
485
486/// Evaluates feature flags using locally cached definitions.
487///
488/// The evaluator reads from a [`FlagCache`] to determine flag values without
489/// making network requests. Supports cohort membership checks and flag
490/// dependencies through the evaluation context.
491#[derive(Clone)]
492pub struct LocalEvaluator {
493    cache: FlagCache,
494}
495
496impl LocalEvaluator {
497    pub fn new(cache: FlagCache) -> Self {
498        Self { cache }
499    }
500
501    /// Evaluate a feature flag locally with full context support
502    /// This supports cohort membership checks and flag dependency evaluation
503    #[instrument(skip(self, person_properties), level = "trace")]
504    pub fn evaluate_flag(
505        &self,
506        key: &str,
507        distinct_id: &str,
508        person_properties: &HashMap<String, serde_json::Value>,
509    ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
510        match self.cache.get_flag(key) {
511            Some(flag) => {
512                // Build evaluation context with cohorts and flags for dependency resolution
513                let cohorts = self.cache.get_cohort_definitions();
514                let flags = self.cache.get_flags_map();
515
516                let ctx = EvaluationContext {
517                    cohorts: &cohorts,
518                    flags: &flags,
519                    distinct_id,
520                };
521
522                let result =
523                    match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
524                trace!(key, ?result, "Local flag evaluation");
525                result.map(Some)
526            }
527            None => {
528                trace!(key, "Flag not found in local cache");
529                Ok(None)
530            }
531        }
532    }
533
534    /// Evaluate a feature flag locally (simple version without cohort/flag dependency support)
535    /// Use this when you know the flag doesn't have cohort or flag dependency conditions
536    #[instrument(skip(self, person_properties), level = "trace")]
537    pub fn evaluate_flag_simple(
538        &self,
539        key: &str,
540        distinct_id: &str,
541        person_properties: &HashMap<String, serde_json::Value>,
542    ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
543        match self.cache.get_flag(key) {
544            Some(flag) => {
545                let result = match_feature_flag(&flag, distinct_id, person_properties);
546                trace!(key, ?result, "Local flag evaluation (simple)");
547                result.map(Some)
548            }
549            None => {
550                trace!(key, "Flag not found in local cache");
551                Ok(None)
552            }
553        }
554    }
555
556    /// Get all flags and evaluate them with full context support
557    #[instrument(skip(self, person_properties), level = "debug")]
558    pub fn evaluate_all_flags(
559        &self,
560        distinct_id: &str,
561        person_properties: &HashMap<String, serde_json::Value>,
562    ) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
563        let mut results = HashMap::new();
564
565        // Build evaluation context once for all flags
566        let cohorts = self.cache.get_cohort_definitions();
567        let flags = self.cache.get_flags_map();
568
569        let ctx = EvaluationContext {
570            cohorts: &cohorts,
571            flags: &flags,
572            distinct_id,
573        };
574
575        for flag in self.cache.get_all_flags() {
576            let result =
577                match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
578            results.insert(flag.key.clone(), result);
579        }
580
581        debug!(flag_count = results.len(), "Evaluated all local flags");
582        results
583    }
584}