1use crate::feature_flags::{
2 match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
3 FeatureFlag, FlagValue, InconclusiveMatchError,
4};
5use crate::Error;
6use reqwest::header::{HeaderMap, ETAG, IF_NONE_MATCH};
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13use tracing::{debug, error, info, instrument, trace, warn};
14
15fn extract_etag(headers: &HeaderMap) -> Option<String> {
18 headers
19 .get(ETAG)
20 .and_then(|v| v.to_str().ok())
21 .filter(|s| !s.is_empty())
22 .map(|s| s.to_string())
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct LocalEvaluationResponse {
31 pub flags: Vec<FeatureFlag>,
33 #[serde(default)]
35 pub group_type_mapping: HashMap<String, String>,
36 #[serde(default)]
38 pub cohorts: HashMap<String, Cohort>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Cohort {
47 pub id: String,
49 pub name: String,
51 pub properties: serde_json::Value,
53}
54
55#[derive(Clone)]
61pub struct FlagCache {
62 flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
63 group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
64 cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
65}
66
67impl Default for FlagCache {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl FlagCache {
74 pub fn new() -> Self {
75 Self {
76 flags: Arc::new(RwLock::new(HashMap::new())),
77 group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
78 cohorts: Arc::new(RwLock::new(HashMap::new())),
79 }
80 }
81
82 pub fn update(&self, response: LocalEvaluationResponse) {
83 let flag_count = response.flags.len();
84 let mut flags = self.flags.write().unwrap();
85 flags.clear();
86 for flag in response.flags {
87 flags.insert(flag.key.clone(), flag);
88 }
89
90 let mut mapping = self.group_type_mapping.write().unwrap();
91 *mapping = response.group_type_mapping;
92
93 let mut cohorts = self.cohorts.write().unwrap();
94 *cohorts = response.cohorts;
95
96 debug!(flag_count, "Updated flag cache");
97 }
98
99 pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
100 self.flags.read().unwrap().get(key).cloned()
101 }
102
103 pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
104 self.flags.read().unwrap().values().cloned().collect()
105 }
106
107 pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
108 self.cohorts.read().unwrap().get(id).cloned()
109 }
110
111 pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
112 self.cohorts.read().unwrap().clone()
113 }
114
115 pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
117 self.cohorts
118 .read()
119 .unwrap()
120 .iter()
121 .map(|(k, v)| {
122 (
123 k.clone(),
124 CohortDefinition {
125 id: v.id.clone(),
126 properties: v.properties.clone(),
127 },
128 )
129 })
130 .collect()
131 }
132
133 pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
135 self.flags.read().unwrap().clone()
136 }
137
138 pub fn clear(&self) {
139 self.flags.write().unwrap().clear();
140 self.group_type_mapping.write().unwrap().clear();
141 self.cohorts.write().unwrap().clear();
142 }
143}
144
145#[derive(Clone)]
150pub struct LocalEvaluationConfig {
151 pub personal_api_key: String,
153 pub project_api_key: String,
155 pub api_host: String,
157 pub poll_interval: Duration,
159 pub request_timeout: Duration,
161}
162
163pub struct FlagPoller {
169 config: LocalEvaluationConfig,
170 cache: FlagCache,
171 client: reqwest::blocking::Client,
172 stop_signal: Arc<AtomicBool>,
173 thread_handle: Option<std::thread::JoinHandle<()>>,
174}
175
176impl FlagPoller {
177 pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
178 let client = reqwest::blocking::Client::builder()
179 .timeout(config.request_timeout)
180 .build()
181 .unwrap();
182
183 Self {
184 config,
185 cache,
186 client,
187 stop_signal: Arc::new(AtomicBool::new(false)),
188 thread_handle: None,
189 }
190 }
191
192 pub fn start(&mut self) {
194 info!(
195 poll_interval_secs = self.config.poll_interval.as_secs(),
196 "Starting feature flag poller"
197 );
198
199 match self.load_flags() {
201 Ok(()) => info!("Initial flag definitions loaded successfully"),
202 Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
203 }
204
205 let config = self.config.clone();
206 let cache = self.cache.clone();
207 let stop_signal = self.stop_signal.clone();
208
209 let handle = std::thread::spawn(move || {
210 let client = reqwest::blocking::Client::builder()
211 .timeout(config.request_timeout)
212 .build()
213 .unwrap();
214
215 let mut last_etag: Option<String> = None;
216
217 loop {
218 std::thread::sleep(config.poll_interval);
219
220 if stop_signal.load(Ordering::Relaxed) {
221 debug!("Flag poller received stop signal");
222 break;
223 }
224
225 let url = format!(
226 "{}/api/feature_flag/local_evaluation/?send_cohorts",
227 config.api_host.trim_end_matches('/')
228 );
229
230 let mut request = client
231 .get(&url)
232 .header(
233 "Authorization",
234 format!("Bearer {}", config.personal_api_key),
235 )
236 .header("X-PostHog-Project-Api-Key", &config.project_api_key);
237
238 if let Some(ref etag) = last_etag {
239 request = request.header(IF_NONE_MATCH, etag.as_str());
240 }
241
242 match request.send() {
243 Ok(response) => {
244 if response.status() == StatusCode::NOT_MODIFIED {
245 debug!("Flag definitions unchanged (304 Not Modified)");
246 } else if response.status().is_success() {
247 let new_etag = extract_etag(response.headers());
249
250 match response.json::<LocalEvaluationResponse>() {
251 Ok(data) => {
252 trace!("Successfully fetched flag definitions");
253 cache.update(data);
254 last_etag = new_etag;
255 }
256 Err(e) => {
257 warn!(error = %e, "Failed to parse flag response");
258 }
259 }
260 } else {
261 warn!(status = %response.status(), "Failed to fetch flags");
262 }
263 }
264 Err(e) => {
265 warn!(error = %e, "Failed to fetch flags");
266 }
267 }
268 }
269 });
270
271 self.thread_handle = Some(handle);
272 }
273
274 #[instrument(skip(self), level = "debug")]
276 pub fn load_flags(&self) -> Result<(), Error> {
277 let url = format!(
278 "{}/api/feature_flag/local_evaluation/?send_cohorts",
279 self.config.api_host.trim_end_matches('/')
280 );
281
282 let response = self
283 .client
284 .get(&url)
285 .header(
286 "Authorization",
287 format!("Bearer {}", self.config.personal_api_key),
288 )
289 .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
290 .send()
291 .map_err(|e| {
292 error!(error = %e, "Connection error loading flags");
293 Error::Connection(e.to_string())
294 })?;
295
296 if !response.status().is_success() {
297 let status = response.status();
298 error!(status = %status, "HTTP error loading flags");
299 return Err(Error::Connection(format!("HTTP {}", status)));
300 }
301
302 let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
303 error!(error = %e, "Failed to parse flag response");
304 Error::Serialization(e.to_string())
305 })?;
306
307 self.cache.update(data);
308 Ok(())
309 }
310
311 pub fn stop(&mut self) {
313 debug!("Stopping flag poller");
314 self.stop_signal.store(true, Ordering::Relaxed);
315 if let Some(handle) = self.thread_handle.take() {
316 handle.join().ok();
317 }
318 }
319}
320
321impl Drop for FlagPoller {
322 fn drop(&mut self) {
323 self.stop();
324 }
325}
326
327#[cfg(feature = "async-client")]
333pub struct AsyncFlagPoller {
334 config: LocalEvaluationConfig,
335 cache: FlagCache,
336 client: reqwest::Client,
337 stop_signal: Arc<AtomicBool>,
338 task_handle: Option<tokio::task::JoinHandle<()>>,
339 is_running: Arc<tokio::sync::RwLock<bool>>,
340}
341
342#[cfg(feature = "async-client")]
343impl AsyncFlagPoller {
344 pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
345 let client = reqwest::Client::builder()
346 .timeout(config.request_timeout)
347 .build()
348 .unwrap();
349
350 Self {
351 config,
352 cache,
353 client,
354 stop_signal: Arc::new(AtomicBool::new(false)),
355 task_handle: None,
356 is_running: Arc::new(tokio::sync::RwLock::new(false)),
357 }
358 }
359
360 pub async fn start(&mut self) {
362 {
364 let mut is_running = self.is_running.write().await;
365 if *is_running {
366 debug!("Flag poller already running, skipping start");
367 return;
368 }
369 *is_running = true;
370 }
371
372 info!(
373 poll_interval_secs = self.config.poll_interval.as_secs(),
374 "Starting async feature flag poller"
375 );
376
377 match self.load_flags().await {
379 Ok(()) => info!("Initial flag definitions loaded successfully"),
380 Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
381 }
382
383 let config = self.config.clone();
384 let cache = self.cache.clone();
385 let stop_signal = self.stop_signal.clone();
386 let is_running = self.is_running.clone();
387 let client = self.client.clone();
388
389 let task = tokio::spawn(async move {
390 let mut interval = tokio::time::interval(config.poll_interval);
391 interval.tick().await; let mut last_etag: Option<String> = None;
394
395 loop {
396 tokio::select! {
397 _ = interval.tick() => {
398 if stop_signal.load(Ordering::Relaxed) {
399 debug!("Async flag poller received stop signal");
400 break;
401 }
402
403 let url = format!(
404 "{}/api/feature_flag/local_evaluation/?send_cohorts",
405 config.api_host.trim_end_matches('/')
406 );
407
408 let mut request = client
409 .get(&url)
410 .header("Authorization", format!("Bearer {}", config.personal_api_key))
411 .header("X-PostHog-Project-Api-Key", &config.project_api_key);
412
413 if let Some(ref etag) = last_etag {
414 request = request.header(IF_NONE_MATCH, etag.as_str());
415 }
416
417 match request.send().await {
418 Ok(response) => {
419 if response.status() == StatusCode::NOT_MODIFIED {
420 debug!("Flag definitions unchanged (304 Not Modified)");
421 } else if response.status().is_success() {
422 let new_etag = extract_etag(response.headers());
424
425 match response.json::<LocalEvaluationResponse>().await {
426 Ok(data) => {
427 trace!("Successfully fetched flag definitions");
428 cache.update(data);
429 last_etag = new_etag;
430 }
431 Err(e) => {
432 warn!(error = %e, "Failed to parse flag response");
433 }
434 }
435 } else {
436 warn!(status = %response.status(), "Failed to fetch flags");
437 }
438 }
439 Err(e) => {
440 warn!(error = %e, "Failed to fetch flags");
441 }
442 }
443 }
444 }
445 }
446
447 *is_running.write().await = false;
449 });
450
451 self.task_handle = Some(task);
452 }
453
454 #[instrument(skip(self), level = "debug")]
456 pub async fn load_flags(&self) -> Result<(), Error> {
457 let url = format!(
458 "{}/api/feature_flag/local_evaluation/?send_cohorts",
459 self.config.api_host.trim_end_matches('/')
460 );
461
462 let response = self
463 .client
464 .get(&url)
465 .header(
466 "Authorization",
467 format!("Bearer {}", self.config.personal_api_key),
468 )
469 .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
470 .send()
471 .await
472 .map_err(|e| {
473 error!(error = %e, "Connection error loading flags");
474 Error::Connection(e.to_string())
475 })?;
476
477 if !response.status().is_success() {
478 let status = response.status();
479 error!(status = %status, "HTTP error loading flags");
480 return Err(Error::Connection(format!("HTTP {}", status)));
481 }
482
483 let data = response
484 .json::<LocalEvaluationResponse>()
485 .await
486 .map_err(|e| {
487 error!(error = %e, "Failed to parse flag response");
488 Error::Serialization(e.to_string())
489 })?;
490
491 self.cache.update(data);
492 Ok(())
493 }
494
495 pub async fn stop(&mut self) {
497 debug!("Stopping async flag poller");
498 self.stop_signal.store(true, Ordering::Relaxed);
499 if let Some(handle) = self.task_handle.take() {
500 handle.abort();
501 }
502 *self.is_running.write().await = false;
503 }
504
505 pub async fn is_running(&self) -> bool {
507 *self.is_running.read().await
508 }
509}
510
511#[cfg(feature = "async-client")]
512impl Drop for AsyncFlagPoller {
513 fn drop(&mut self) {
514 if let Some(handle) = self.task_handle.take() {
516 handle.abort();
517 }
518 }
519}
520
521#[derive(Clone)]
527pub struct LocalEvaluator {
528 cache: FlagCache,
529}
530
531impl LocalEvaluator {
532 pub fn new(cache: FlagCache) -> Self {
533 Self { cache }
534 }
535
536 #[instrument(skip(self, person_properties), level = "trace")]
539 pub fn evaluate_flag(
540 &self,
541 key: &str,
542 distinct_id: &str,
543 person_properties: &HashMap<String, serde_json::Value>,
544 ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
545 match self.cache.get_flag(key) {
546 Some(flag) => {
547 let cohorts = self.cache.get_cohort_definitions();
549 let flags = self.cache.get_flags_map();
550
551 let ctx = EvaluationContext {
552 cohorts: &cohorts,
553 flags: &flags,
554 distinct_id,
555 };
556
557 let result =
558 match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
559 trace!(key, ?result, "Local flag evaluation");
560 result.map(Some)
561 }
562 None => {
563 trace!(key, "Flag not found in local cache");
564 Ok(None)
565 }
566 }
567 }
568
569 #[instrument(skip(self, person_properties), level = "trace")]
572 pub fn evaluate_flag_simple(
573 &self,
574 key: &str,
575 distinct_id: &str,
576 person_properties: &HashMap<String, serde_json::Value>,
577 ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
578 match self.cache.get_flag(key) {
579 Some(flag) => {
580 let result = match_feature_flag(&flag, distinct_id, person_properties);
581 trace!(key, ?result, "Local flag evaluation (simple)");
582 result.map(Some)
583 }
584 None => {
585 trace!(key, "Flag not found in local cache");
586 Ok(None)
587 }
588 }
589 }
590
591 #[instrument(skip(self, person_properties), level = "debug")]
593 pub fn evaluate_all_flags(
594 &self,
595 distinct_id: &str,
596 person_properties: &HashMap<String, serde_json::Value>,
597 ) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
598 let mut results = HashMap::new();
599
600 let cohorts = self.cache.get_cohort_definitions();
602 let flags = self.cache.get_flags_map();
603
604 let ctx = EvaluationContext {
605 cohorts: &cohorts,
606 flags: &flags,
607 distinct_id,
608 };
609
610 for flag in self.cache.get_all_flags() {
611 let result =
612 match_feature_flag_with_context(&flag, distinct_id, person_properties, &ctx);
613 results.insert(flag.key.clone(), result);
614 }
615
616 debug!(flag_count = results.len(), "Evaluated all local flags");
617 results
618 }
619}