1use crate::feature_flags::{
2 match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
3 FeatureFlag, FlagValue, InconclusiveMatchError,
4};
5use crate::Error;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11use tracing::{debug, error, info, instrument, trace, warn};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LocalEvaluationResponse {
19 pub flags: Vec<FeatureFlag>,
21 #[serde(default)]
23 pub group_type_mapping: HashMap<String, String>,
24 #[serde(default)]
26 pub cohorts: HashMap<String, Cohort>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Cohort {
35 pub id: String,
37 pub name: String,
39 pub properties: serde_json::Value,
41}
42
43#[derive(Clone)]
49pub struct FlagCache {
50 flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
51 group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
52 cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
53}
54
55impl Default for FlagCache {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl FlagCache {
62 pub fn new() -> Self {
63 Self {
64 flags: Arc::new(RwLock::new(HashMap::new())),
65 group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
66 cohorts: Arc::new(RwLock::new(HashMap::new())),
67 }
68 }
69
70 pub fn update(&self, response: LocalEvaluationResponse) {
71 let flag_count = response.flags.len();
72 let mut flags = self.flags.write().unwrap();
73 flags.clear();
74 for flag in response.flags {
75 flags.insert(flag.key.clone(), flag);
76 }
77
78 let mut mapping = self.group_type_mapping.write().unwrap();
79 *mapping = response.group_type_mapping;
80
81 let mut cohorts = self.cohorts.write().unwrap();
82 *cohorts = response.cohorts;
83
84 debug!(flag_count, "Updated flag cache");
85 }
86
87 pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
88 self.flags.read().unwrap().get(key).cloned()
89 }
90
91 pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
92 self.flags.read().unwrap().values().cloned().collect()
93 }
94
95 pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
96 self.cohorts.read().unwrap().get(id).cloned()
97 }
98
99 pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
100 self.cohorts.read().unwrap().clone()
101 }
102
103 pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
105 self.cohorts
106 .read()
107 .unwrap()
108 .iter()
109 .map(|(k, v)| {
110 (
111 k.clone(),
112 CohortDefinition {
113 id: v.id.clone(),
114 properties: v.properties.clone(),
115 },
116 )
117 })
118 .collect()
119 }
120
121 pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
123 self.flags.read().unwrap().clone()
124 }
125
126 pub fn clear(&self) {
127 self.flags.write().unwrap().clear();
128 self.group_type_mapping.write().unwrap().clear();
129 self.cohorts.write().unwrap().clear();
130 }
131}
132
133#[derive(Clone)]
138pub struct LocalEvaluationConfig {
139 pub personal_api_key: String,
141 pub project_api_key: String,
143 pub api_host: String,
145 pub poll_interval: Duration,
147 pub request_timeout: Duration,
149}
150
151pub struct FlagPoller {
157 config: LocalEvaluationConfig,
158 cache: FlagCache,
159 client: reqwest::blocking::Client,
160 stop_signal: Arc<AtomicBool>,
161 thread_handle: Option<std::thread::JoinHandle<()>>,
162}
163
164impl FlagPoller {
165 pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
166 let client = reqwest::blocking::Client::builder()
167 .timeout(config.request_timeout)
168 .build()
169 .unwrap();
170
171 Self {
172 config,
173 cache,
174 client,
175 stop_signal: Arc::new(AtomicBool::new(false)),
176 thread_handle: None,
177 }
178 }
179
180 pub fn start(&mut self) {
182 info!(
183 poll_interval_secs = self.config.poll_interval.as_secs(),
184 "Starting feature flag poller"
185 );
186
187 match self.load_flags() {
189 Ok(()) => info!("Initial flag definitions loaded successfully"),
190 Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
191 }
192
193 let config = self.config.clone();
194 let cache = self.cache.clone();
195 let stop_signal = self.stop_signal.clone();
196
197 let handle = std::thread::spawn(move || {
198 let client = reqwest::blocking::Client::builder()
199 .timeout(config.request_timeout)
200 .build()
201 .unwrap();
202
203 loop {
204 std::thread::sleep(config.poll_interval);
205
206 if stop_signal.load(Ordering::Relaxed) {
207 debug!("Flag poller received stop signal");
208 break;
209 }
210
211 let url = format!(
212 "{}/api/feature_flag/local_evaluation/?send_cohorts",
213 config.api_host.trim_end_matches('/')
214 );
215
216 match client
217 .get(&url)
218 .header(
219 "Authorization",
220 format!("Bearer {}", config.personal_api_key),
221 )
222 .header("X-PostHog-Project-Api-Key", &config.project_api_key)
223 .send()
224 {
225 Ok(response) => {
226 if response.status().is_success() {
227 match response.json::<LocalEvaluationResponse>() {
228 Ok(data) => {
229 trace!("Successfully fetched flag definitions");
230 cache.update(data);
231 }
232 Err(e) => {
233 warn!(error = %e, "Failed to parse flag response");
234 }
235 }
236 } else {
237 warn!(status = %response.status(), "Failed to fetch flags");
238 }
239 }
240 Err(e) => {
241 warn!(error = %e, "Failed to fetch flags");
242 }
243 }
244 }
245 });
246
247 self.thread_handle = Some(handle);
248 }
249
250 #[instrument(skip(self), level = "debug")]
252 pub fn load_flags(&self) -> Result<(), Error> {
253 let url = format!(
254 "{}/api/feature_flag/local_evaluation/?send_cohorts",
255 self.config.api_host.trim_end_matches('/')
256 );
257
258 let response = self
259 .client
260 .get(&url)
261 .header(
262 "Authorization",
263 format!("Bearer {}", self.config.personal_api_key),
264 )
265 .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
266 .send()
267 .map_err(|e| {
268 error!(error = %e, "Connection error loading flags");
269 Error::Connection(e.to_string())
270 })?;
271
272 if !response.status().is_success() {
273 let status = response.status();
274 error!(status = %status, "HTTP error loading flags");
275 return Err(Error::Connection(format!("HTTP {}", status)));
276 }
277
278 let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
279 error!(error = %e, "Failed to parse flag response");
280 Error::Serialization(e.to_string())
281 })?;
282
283 self.cache.update(data);
284 Ok(())
285 }
286
287 pub fn stop(&mut self) {
289 debug!("Stopping flag poller");
290 self.stop_signal.store(true, Ordering::Relaxed);
291 if let Some(handle) = self.thread_handle.take() {
292 handle.join().ok();
293 }
294 }
295}
296
297impl Drop for FlagPoller {
298 fn drop(&mut self) {
299 self.stop();
300 }
301}
302
303#[cfg(feature = "async-client")]
309pub struct AsyncFlagPoller {
310 config: LocalEvaluationConfig,
311 cache: FlagCache,
312 client: reqwest::Client,
313 stop_signal: Arc<AtomicBool>,
314 task_handle: Option<tokio::task::JoinHandle<()>>,
315 is_running: Arc<tokio::sync::RwLock<bool>>,
316}
317
318#[cfg(feature = "async-client")]
319impl AsyncFlagPoller {
320 pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
321 let client = reqwest::Client::builder()
322 .timeout(config.request_timeout)
323 .build()
324 .unwrap();
325
326 Self {
327 config,
328 cache,
329 client,
330 stop_signal: Arc::new(AtomicBool::new(false)),
331 task_handle: None,
332 is_running: Arc::new(tokio::sync::RwLock::new(false)),
333 }
334 }
335
336 pub async fn start(&mut self) {
338 {
340 let mut is_running = self.is_running.write().await;
341 if *is_running {
342 debug!("Flag poller already running, skipping start");
343 return;
344 }
345 *is_running = true;
346 }
347
348 info!(
349 poll_interval_secs = self.config.poll_interval.as_secs(),
350 "Starting async feature flag poller"
351 );
352
353 match self.load_flags().await {
355 Ok(()) => info!("Initial flag definitions loaded successfully"),
356 Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
357 }
358
359 let config = self.config.clone();
360 let cache = self.cache.clone();
361 let stop_signal = self.stop_signal.clone();
362 let is_running = self.is_running.clone();
363 let client = self.client.clone();
364
365 let task = tokio::spawn(async move {
366 let mut interval = tokio::time::interval(config.poll_interval);
367 interval.tick().await; loop {
370 tokio::select! {
371 _ = interval.tick() => {
372 if stop_signal.load(Ordering::Relaxed) {
373 debug!("Async flag poller received stop signal");
374 break;
375 }
376
377 let url = format!(
378 "{}/api/feature_flag/local_evaluation/?send_cohorts",
379 config.api_host.trim_end_matches('/')
380 );
381
382 match client
383 .get(&url)
384 .header("Authorization", format!("Bearer {}", config.personal_api_key))
385 .header("X-PostHog-Project-Api-Key", &config.project_api_key)
386 .send()
387 .await
388 {
389 Ok(response) => {
390 if response.status().is_success() {
391 match response.json::<LocalEvaluationResponse>().await {
392 Ok(data) => {
393 trace!("Successfully fetched flag definitions");
394 cache.update(data);
395 }
396 Err(e) => {
397 warn!(error = %e, "Failed to parse flag response");
398 }
399 }
400 } else {
401 warn!(status = %response.status(), "Failed to fetch flags");
402 }
403 }
404 Err(e) => {
405 warn!(error = %e, "Failed to fetch flags");
406 }
407 }
408 }
409 }
410 }
411
412 *is_running.write().await = false;
414 });
415
416 self.task_handle = Some(task);
417 }
418
419 #[instrument(skip(self), level = "debug")]
421 pub async fn load_flags(&self) -> Result<(), Error> {
422 let url = format!(
423 "{}/api/feature_flag/local_evaluation/?send_cohorts",
424 self.config.api_host.trim_end_matches('/')
425 );
426
427 let response = self
428 .client
429 .get(&url)
430 .header(
431 "Authorization",
432 format!("Bearer {}", self.config.personal_api_key),
433 )
434 .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
435 .send()
436 .await
437 .map_err(|e| {
438 error!(error = %e, "Connection error loading flags");
439 Error::Connection(e.to_string())
440 })?;
441
442 if !response.status().is_success() {
443 let status = response.status();
444 error!(status = %status, "HTTP error loading flags");
445 return Err(Error::Connection(format!("HTTP {}", status)));
446 }
447
448 let data = response
449 .json::<LocalEvaluationResponse>()
450 .await
451 .map_err(|e| {
452 error!(error = %e, "Failed to parse flag response");
453 Error::Serialization(e.to_string())
454 })?;
455
456 self.cache.update(data);
457 Ok(())
458 }
459
460 pub async fn stop(&mut self) {
462 debug!("Stopping async flag poller");
463 self.stop_signal.store(true, Ordering::Relaxed);
464 if let Some(handle) = self.task_handle.take() {
465 handle.abort();
466 }
467 *self.is_running.write().await = false;
468 }
469
470 pub async fn is_running(&self) -> bool {
472 *self.is_running.read().await
473 }
474}
475
476#[cfg(feature = "async-client")]
477impl Drop for AsyncFlagPoller {
478 fn drop(&mut self) {
479 if let Some(handle) = self.task_handle.take() {
481 handle.abort();
482 }
483 }
484}
485
486#[derive(Clone)]
492pub struct LocalEvaluator {
493 cache: FlagCache,
494}
495
496impl LocalEvaluator {
497 pub fn new(cache: FlagCache) -> Self {
498 Self { cache }
499 }
500
501 #[instrument(skip(self, person_properties), level = "trace")]
504 pub fn evaluate_flag(
505 &self,
506 key: &str,
507 distinct_id: &str,
508 person_properties: &HashMap<String, serde_json::Value>,
509 ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
510 match self.cache.get_flag(key) {
511 Some(flag) => {
512 let cohorts = self.cache.get_cohort_definitions();
514 let flags = self.cache.get_flags_map();
515
516 let ctx = EvaluationContext {
517 cohorts: &cohorts,
518 flags: &flags,
519 distinct_id,
520 };
521
522 let result =
523 match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
524 trace!(key, ?result, "Local flag evaluation");
525 result.map(Some)
526 }
527 None => {
528 trace!(key, "Flag not found in local cache");
529 Ok(None)
530 }
531 }
532 }
533
534 #[instrument(skip(self, person_properties), level = "trace")]
537 pub fn evaluate_flag_simple(
538 &self,
539 key: &str,
540 distinct_id: &str,
541 person_properties: &HashMap<String, serde_json::Value>,
542 ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
543 match self.cache.get_flag(key) {
544 Some(flag) => {
545 let result = match_feature_flag(&flag, distinct_id, person_properties);
546 trace!(key, ?result, "Local flag evaluation (simple)");
547 result.map(Some)
548 }
549 None => {
550 trace!(key, "Flag not found in local cache");
551 Ok(None)
552 }
553 }
554 }
555
556 #[instrument(skip(self, person_properties), level = "debug")]
558 pub fn evaluate_all_flags(
559 &self,
560 distinct_id: &str,
561 person_properties: &HashMap<String, serde_json::Value>,
562 ) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
563 let mut results = HashMap::new();
564
565 let cohorts = self.cache.get_cohort_definitions();
567 let flags = self.cache.get_flags_map();
568
569 let ctx = EvaluationContext {
570 cohorts: &cohorts,
571 flags: &flags,
572 distinct_id,
573 };
574
575 for flag in self.cache.get_all_flags() {
576 let result =
577 match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
578 results.insert(flag.key.clone(), result);
579 }
580
581 debug!(flag_count = results.len(), "Evaluated all local flags");
582 results
583 }
584}