1use crate::feature_flags::{
2 match_feature_flag, match_feature_flag_with_context, CohortDefinition, EvaluationContext,
3 FeatureFlag, FlagValue, InconclusiveMatchError,
4};
5use crate::Error;
6use reqwest::header::{HeaderMap, ETAG, IF_NONE_MATCH};
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13use tracing::{debug, error, info, instrument, trace, warn};
14
15fn extract_etag(headers: &HeaderMap) -> Option<String> {
18 headers
19 .get(ETAG)
20 .and_then(|v| v.to_str().ok())
21 .filter(|s| !s.is_empty())
22 .map(|s| s.to_string())
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct LocalEvaluationResponse {
31 pub flags: Vec<FeatureFlag>,
33 #[serde(default)]
35 pub group_type_mapping: HashMap<String, String>,
36 #[serde(default)]
38 pub cohorts: HashMap<String, Cohort>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Cohort {
47 pub id: String,
49 pub name: String,
51 pub properties: serde_json::Value,
53}
54
55#[derive(Clone)]
61pub struct FlagCache {
62 flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
63 group_type_mapping: Arc<RwLock<HashMap<String, String>>>,
64 cohorts: Arc<RwLock<HashMap<String, Cohort>>>,
65}
66
67impl Default for FlagCache {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl FlagCache {
74 pub fn new() -> Self {
75 Self {
76 flags: Arc::new(RwLock::new(HashMap::new())),
77 group_type_mapping: Arc::new(RwLock::new(HashMap::new())),
78 cohorts: Arc::new(RwLock::new(HashMap::new())),
79 }
80 }
81
82 pub fn update(&self, response: LocalEvaluationResponse) {
83 let flag_count = response.flags.len();
84 let mut flags = self.flags.write().unwrap();
85 flags.clear();
86 for flag in response.flags {
87 flags.insert(flag.key.clone(), flag);
88 }
89
90 let mut mapping = self.group_type_mapping.write().unwrap();
91 *mapping = response.group_type_mapping;
92
93 let mut cohorts = self.cohorts.write().unwrap();
94 *cohorts = response.cohorts;
95
96 debug!(flag_count, "Updated flag cache");
97 }
98
99 pub fn get_flag(&self, key: &str) -> Option<FeatureFlag> {
100 self.flags.read().unwrap().get(key).cloned()
101 }
102
103 pub fn get_all_flags(&self) -> Vec<FeatureFlag> {
104 self.flags.read().unwrap().values().cloned().collect()
105 }
106
107 pub fn get_cohort(&self, id: &str) -> Option<Cohort> {
108 self.cohorts.read().unwrap().get(id).cloned()
109 }
110
111 pub fn get_all_cohorts(&self) -> HashMap<String, Cohort> {
112 self.cohorts.read().unwrap().clone()
113 }
114
115 pub fn get_cohort_definitions(&self) -> HashMap<String, CohortDefinition> {
117 self.cohorts
118 .read()
119 .unwrap()
120 .iter()
121 .map(|(k, v)| {
122 (
123 k.clone(),
124 CohortDefinition {
125 id: v.id.clone(),
126 properties: v.properties.clone(),
127 },
128 )
129 })
130 .collect()
131 }
132
133 pub fn get_flags_map(&self) -> HashMap<String, FeatureFlag> {
135 self.flags.read().unwrap().clone()
136 }
137
138 pub fn get_group_type_mapping(&self) -> HashMap<String, String> {
140 self.group_type_mapping.read().unwrap().clone()
141 }
142
143 pub fn clear(&self) {
144 self.flags.write().unwrap().clear();
145 self.group_type_mapping.write().unwrap().clear();
146 self.cohorts.write().unwrap().clear();
147 }
148}
149
150#[derive(Clone)]
155pub struct LocalEvaluationConfig {
156 pub personal_api_key: String,
158 pub project_api_key: String,
160 pub api_host: String,
162 pub poll_interval: Duration,
164 pub request_timeout: Duration,
166}
167
168pub struct FlagPoller {
174 config: LocalEvaluationConfig,
175 cache: FlagCache,
176 client: reqwest::blocking::Client,
177 stop_signal: Arc<AtomicBool>,
178 thread_handle: Option<std::thread::JoinHandle<()>>,
179}
180
181impl FlagPoller {
182 pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
183 let client = reqwest::blocking::Client::builder()
184 .timeout(config.request_timeout)
185 .build()
186 .unwrap();
187
188 Self {
189 config,
190 cache,
191 client,
192 stop_signal: Arc::new(AtomicBool::new(false)),
193 thread_handle: None,
194 }
195 }
196
197 pub fn start(&mut self) {
199 info!(
200 poll_interval_secs = self.config.poll_interval.as_secs(),
201 "Starting feature flag poller"
202 );
203
204 match self.load_flags() {
206 Ok(()) => info!("Initial flag definitions loaded successfully"),
207 Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
208 }
209
210 let config = self.config.clone();
211 let cache = self.cache.clone();
212 let stop_signal = self.stop_signal.clone();
213
214 let handle = std::thread::spawn(move || {
215 let client = reqwest::blocking::Client::builder()
216 .timeout(config.request_timeout)
217 .build()
218 .unwrap();
219
220 let mut last_etag: Option<String> = None;
221
222 loop {
223 std::thread::sleep(config.poll_interval);
224
225 if stop_signal.load(Ordering::Relaxed) {
226 debug!("Flag poller received stop signal");
227 break;
228 }
229
230 let url = format!(
231 "{}/flags/definitions/?send_cohorts",
232 config.api_host.trim_end_matches('/')
233 );
234
235 let mut request = client
236 .get(&url)
237 .header(
238 "Authorization",
239 format!("Bearer {}", config.personal_api_key),
240 )
241 .header("X-PostHog-Project-Api-Key", &config.project_api_key);
242
243 if let Some(ref etag) = last_etag {
244 request = request.header(IF_NONE_MATCH, etag.as_str());
245 }
246
247 match request.send() {
248 Ok(response) => {
249 if response.status() == StatusCode::NOT_MODIFIED {
250 debug!("Flag definitions unchanged (304 Not Modified)");
251 } else if response.status().is_success() {
252 let new_etag = extract_etag(response.headers());
254
255 match response.json::<LocalEvaluationResponse>() {
256 Ok(data) => {
257 trace!("Successfully fetched flag definitions");
258 cache.update(data);
259 last_etag = new_etag;
260 }
261 Err(e) => {
262 warn!(error = %e, "Failed to parse flag response");
263 }
264 }
265 } else {
266 warn!(status = %response.status(), "Failed to fetch flags");
267 }
268 }
269 Err(e) => {
270 warn!(error = %e, "Failed to fetch flags");
271 }
272 }
273 }
274 });
275
276 self.thread_handle = Some(handle);
277 }
278
279 #[instrument(skip(self), level = "debug")]
281 pub fn load_flags(&self) -> Result<(), Error> {
282 let url = format!(
283 "{}/flags/definitions/?send_cohorts",
284 self.config.api_host.trim_end_matches('/')
285 );
286
287 let response = self
288 .client
289 .get(&url)
290 .header(
291 "Authorization",
292 format!("Bearer {}", self.config.personal_api_key),
293 )
294 .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
295 .send()
296 .map_err(|e| {
297 error!(error = %e, "Connection error loading flags");
298 Error::Connection(e.to_string())
299 })?;
300
301 if !response.status().is_success() {
302 let status = response.status();
303 error!(status = %status, "HTTP error loading flags");
304 return Err(Error::Connection(format!("HTTP {}", status)));
305 }
306
307 let data = response.json::<LocalEvaluationResponse>().map_err(|e| {
308 error!(error = %e, "Failed to parse flag response");
309 Error::Serialization(e.to_string())
310 })?;
311
312 self.cache.update(data);
313 Ok(())
314 }
315
316 pub fn stop(&mut self) {
318 debug!("Stopping flag poller");
319 self.stop_signal.store(true, Ordering::Relaxed);
320 if let Some(handle) = self.thread_handle.take() {
321 handle.join().ok();
322 }
323 }
324}
325
326impl Drop for FlagPoller {
327 fn drop(&mut self) {
328 self.stop();
329 }
330}
331
332#[cfg(feature = "async-client")]
338pub struct AsyncFlagPoller {
339 config: LocalEvaluationConfig,
340 cache: FlagCache,
341 client: reqwest::Client,
342 stop_signal: Arc<AtomicBool>,
343 task_handle: Option<tokio::task::JoinHandle<()>>,
344 is_running: Arc<tokio::sync::RwLock<bool>>,
345}
346
347#[cfg(feature = "async-client")]
348impl AsyncFlagPoller {
349 pub fn new(config: LocalEvaluationConfig, cache: FlagCache) -> Self {
350 let client = reqwest::Client::builder()
351 .timeout(config.request_timeout)
352 .build()
353 .unwrap();
354
355 Self {
356 config,
357 cache,
358 client,
359 stop_signal: Arc::new(AtomicBool::new(false)),
360 task_handle: None,
361 is_running: Arc::new(tokio::sync::RwLock::new(false)),
362 }
363 }
364
365 pub async fn start(&mut self) {
367 {
369 let mut is_running = self.is_running.write().await;
370 if *is_running {
371 debug!("Flag poller already running, skipping start");
372 return;
373 }
374 *is_running = true;
375 }
376
377 info!(
378 poll_interval_secs = self.config.poll_interval.as_secs(),
379 "Starting async feature flag poller"
380 );
381
382 match self.load_flags().await {
384 Ok(()) => info!("Initial flag definitions loaded successfully"),
385 Err(e) => warn!(error = %e, "Failed to load initial flags, will retry on next poll"),
386 }
387
388 let config = self.config.clone();
389 let cache = self.cache.clone();
390 let stop_signal = self.stop_signal.clone();
391 let is_running = self.is_running.clone();
392 let client = self.client.clone();
393
394 let task = tokio::spawn(async move {
395 let mut interval = tokio::time::interval(config.poll_interval);
396 interval.tick().await; let mut last_etag: Option<String> = None;
399
400 loop {
401 tokio::select! {
402 _ = interval.tick() => {
403 if stop_signal.load(Ordering::Relaxed) {
404 debug!("Async flag poller received stop signal");
405 break;
406 }
407
408 let url = format!(
409 "{}/flags/definitions/?send_cohorts",
410 config.api_host.trim_end_matches('/')
411 );
412
413 let mut request = client
414 .get(&url)
415 .header("Authorization", format!("Bearer {}", config.personal_api_key))
416 .header("X-PostHog-Project-Api-Key", &config.project_api_key);
417
418 if let Some(ref etag) = last_etag {
419 request = request.header(IF_NONE_MATCH, etag.as_str());
420 }
421
422 match request.send().await {
423 Ok(response) => {
424 if response.status() == StatusCode::NOT_MODIFIED {
425 debug!("Flag definitions unchanged (304 Not Modified)");
426 } else if response.status().is_success() {
427 let new_etag = extract_etag(response.headers());
429
430 match response.json::<LocalEvaluationResponse>().await {
431 Ok(data) => {
432 trace!("Successfully fetched flag definitions");
433 cache.update(data);
434 last_etag = new_etag;
435 }
436 Err(e) => {
437 warn!(error = %e, "Failed to parse flag response");
438 }
439 }
440 } else {
441 warn!(status = %response.status(), "Failed to fetch flags");
442 }
443 }
444 Err(e) => {
445 warn!(error = %e, "Failed to fetch flags");
446 }
447 }
448 }
449 }
450 }
451
452 *is_running.write().await = false;
454 });
455
456 self.task_handle = Some(task);
457 }
458
459 #[instrument(skip(self), level = "debug")]
461 pub async fn load_flags(&self) -> Result<(), Error> {
462 let url = format!(
463 "{}/flags/definitions/?send_cohorts",
464 self.config.api_host.trim_end_matches('/')
465 );
466
467 let response = self
468 .client
469 .get(&url)
470 .header(
471 "Authorization",
472 format!("Bearer {}", self.config.personal_api_key),
473 )
474 .header("X-PostHog-Project-Api-Key", &self.config.project_api_key)
475 .send()
476 .await
477 .map_err(|e| {
478 error!(error = %e, "Connection error loading flags");
479 Error::Connection(e.to_string())
480 })?;
481
482 if !response.status().is_success() {
483 let status = response.status();
484 error!(status = %status, "HTTP error loading flags");
485 return Err(Error::Connection(format!("HTTP {}", status)));
486 }
487
488 let data = response
489 .json::<LocalEvaluationResponse>()
490 .await
491 .map_err(|e| {
492 error!(error = %e, "Failed to parse flag response");
493 Error::Serialization(e.to_string())
494 })?;
495
496 self.cache.update(data);
497 Ok(())
498 }
499
500 pub async fn stop(&mut self) {
502 debug!("Stopping async flag poller");
503 self.stop_signal.store(true, Ordering::Relaxed);
504 if let Some(handle) = self.task_handle.take() {
505 handle.abort();
506 }
507 *self.is_running.write().await = false;
508 }
509
510 pub async fn is_running(&self) -> bool {
512 *self.is_running.read().await
513 }
514}
515
516#[cfg(feature = "async-client")]
517impl Drop for AsyncFlagPoller {
518 fn drop(&mut self) {
519 if let Some(handle) = self.task_handle.take() {
521 handle.abort();
522 }
523 }
524}
525
526#[derive(Clone)]
532pub struct LocalEvaluator {
533 cache: FlagCache,
534}
535
536impl LocalEvaluator {
537 pub fn new(cache: FlagCache) -> Self {
538 Self { cache }
539 }
540
541 pub fn cache(&self) -> &FlagCache {
543 &self.cache
544 }
545
546 #[instrument(
554 skip(self, person_properties, groups, group_properties),
555 level = "trace"
556 )]
557 pub fn evaluate_flag(
558 &self,
559 key: &str,
560 distinct_id: &str,
561 person_properties: &HashMap<String, serde_json::Value>,
562 groups: &HashMap<String, String>,
563 group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
564 ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
565 match self.cache.get_flag(key) {
566 Some(flag) => {
567 let cohorts = self.cache.get_cohort_definitions();
569 let flags = self.cache.get_flags_map();
570 let group_type_mapping = self.cache.get_group_type_mapping();
571
572 let ctx = EvaluationContext {
573 cohorts: &cohorts,
574 flags: &flags,
575 distinct_id,
576 groups,
577 group_properties,
578 group_type_mapping: &group_type_mapping,
579 };
580
581 let result = match_feature_flag_with_context(&flag, person_properties, &ctx);
582 trace!(key, ?result, "Local flag evaluation");
583 result.map(Some)
584 }
585 None => {
586 trace!(key, "Flag not found in local cache");
587 Ok(None)
588 }
589 }
590 }
591
592 #[instrument(
595 skip(self, person_properties, groups, group_properties),
596 level = "trace"
597 )]
598 pub fn evaluate_flag_simple(
599 &self,
600 key: &str,
601 distinct_id: &str,
602 person_properties: &HashMap<String, serde_json::Value>,
603 groups: &HashMap<String, String>,
604 group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
605 ) -> Result<Option<FlagValue>, InconclusiveMatchError> {
606 match self.cache.get_flag(key) {
607 Some(flag) => {
608 let group_type_mapping = self.cache.get_group_type_mapping();
609 let result = match_feature_flag(
610 &flag,
611 distinct_id,
612 person_properties,
613 groups,
614 group_properties,
615 &group_type_mapping,
616 );
617 trace!(key, ?result, "Local flag evaluation (simple)");
618 result.map(Some)
619 }
620 None => {
621 trace!(key, "Flag not found in local cache");
622 Ok(None)
623 }
624 }
625 }
626
627 #[instrument(
629 skip(self, person_properties, groups, group_properties),
630 level = "debug"
631 )]
632 pub fn evaluate_all_flags(
633 &self,
634 distinct_id: &str,
635 person_properties: &HashMap<String, serde_json::Value>,
636 groups: &HashMap<String, String>,
637 group_properties: &HashMap<String, HashMap<String, serde_json::Value>>,
638 ) -> HashMap<String, Result<FlagValue, InconclusiveMatchError>> {
639 let mut results = HashMap::new();
640
641 let cohorts = self.cache.get_cohort_definitions();
643 let flags = self.cache.get_flags_map();
644 let group_type_mapping = self.cache.get_group_type_mapping();
645
646 let ctx = EvaluationContext {
647 cohorts: &cohorts,
648 flags: &flags,
649 distinct_id,
650 groups,
651 group_properties,
652 group_type_mapping: &group_type_mapping,
653 };
654
655 for flag in self.cache.get_all_flags() {
656 let result = match_feature_flag_with_context(&flag, person_properties, &ctx);
657 results.insert(flag.key.clone(), result);
658 }
659
660 debug!(flag_count = results.len(), "Evaluated all local flags");
661 results
662 }
663}