Skip to main content

open_feature_flagd/resolver/in_process/storage/
mod.rs

1pub mod connector;
2use crate::error::FlagdError;
3pub use connector::{
4    Connector, FileConnector, GrpcStreamConnector, QueuePayload, QueuePayloadType,
5};
6use tracing::{debug, error, warn};
7
8use crate::resolver::in_process::model::feature_flag::FeatureFlag;
9use crate::resolver::in_process::model::flag_parser::FlagParser;
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use tokio::sync::RwLock;
14use tokio::sync::mpsc::{Receiver, Sender, channel};
15
16/// State of the flag storage
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
18pub enum StorageState {
19    /// Storage is healthy and up-to-date
20    #[default]
21    Ok,
22    /// Storage data may be stale (connection issues)
23    Stale,
24    /// Storage encountered an error
25    Error,
26}
27
28/// Represents a change in storage state with affected flags
29#[derive(Debug, Clone, PartialEq)]
30pub struct StorageStateChange {
31    /// Current state of the storage
32    pub storage_state: StorageState,
33    /// Keys of flags that changed in this update
34    pub changed_flags_keys: Vec<String>,
35    /// Metadata from the sync operation
36    pub sync_metadata: HashMap<String, serde_json::Value>,
37}
38
39impl Default for StorageStateChange {
40    fn default() -> Self {
41        Self {
42            storage_state: StorageState::Ok,
43            changed_flags_keys: Vec::new(),
44            sync_metadata: HashMap::new(),
45        }
46    }
47}
48
49/// Result of querying a flag from storage
50#[derive(Debug, Clone)]
51pub struct StorageQueryResult {
52    /// The feature flag if found
53    pub feature_flag: Option<FeatureFlag>,
54    /// Metadata associated with the flag set
55    pub flag_set_metadata: HashMap<String, serde_json::Value>,
56}
57
58pub struct FlagStore {
59    flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
60    flag_set_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
61    state_sender: Sender<StorageStateChange>,
62    connector: Arc<dyn Connector>,
63    shutdown: Arc<AtomicBool>,
64}
65
66impl FlagStore {
67    pub fn new(connector: Arc<dyn Connector>) -> (Self, Receiver<StorageStateChange>) {
68        let (state_sender, state_receiver) = channel(1000);
69
70        (
71            Self {
72                flags: Arc::new(RwLock::new(HashMap::new())),
73                flag_set_metadata: Arc::new(RwLock::new(HashMap::new())),
74                state_sender,
75                connector,
76                shutdown: Arc::new(AtomicBool::new(false)),
77            },
78            state_receiver,
79        )
80    }
81
82    pub async fn init(&self) -> Result<(), FlagdError> {
83        debug!("Initializing flag store");
84        self.connector.init().await?;
85
86        // Handle initial sync
87        let stream = self.connector.get_stream();
88        let mut receiver = stream.lock().await;
89        debug!("Waiting for initial sync message");
90
91        if let Some(receiver_ref) = receiver.as_mut() {
92            match tokio::time::timeout(std::time::Duration::from_secs(5), receiver_ref.recv())
93                .await?
94            {
95                Some(payload) => {
96                    debug!("Received initial sync message");
97                    match payload.payload_type {
98                        QueuePayloadType::Data => {
99                            debug!("Parsing flag data: {}", &payload.flag_data);
100                            let parsing_result = FlagParser::parse_string(&payload.flag_data)?;
101                            let mut flags_write = self.flags.write().await;
102                            let mut metadata_write = self.flag_set_metadata.write().await;
103                            let flag_keys: Vec<String> =
104                                parsing_result.flags.keys().cloned().collect();
105                            *flags_write = parsing_result.flags;
106                            *metadata_write = parsing_result.flag_set_metadata;
107                            debug!("Successfully parsed {} flags", flags_write.len());
108
109                            // Send initial state change so FileResolver knows init completed
110                            let _ = self
111                                .state_sender
112                                .send(StorageStateChange {
113                                    storage_state: StorageState::Ok,
114                                    changed_flags_keys: flag_keys,
115                                    sync_metadata: payload.metadata.unwrap_or_default(),
116                                })
117                                .await;
118                        }
119                        QueuePayloadType::Error => {
120                            error!("Error in initial sync: {}", payload.flag_data);
121                            return Err(FlagdError::Sync(format!(
122                                "Error in initial sync: {}",
123                                payload.flag_data
124                            )));
125                        }
126                    }
127                }
128                None => {
129                    error!("No initial sync message received");
130                    return Err(FlagdError::Sync(
131                        "No initial sync message received".to_string(),
132                    ));
133                }
134            }
135        }
136
137        // Start continuous stream processing
138        self.start_stream_listener().await;
139        Ok(())
140    }
141
142    pub async fn shutdown(&self) -> Result<(), FlagdError> {
143        debug!("Shutting down flag store");
144        self.shutdown.store(true, Ordering::Relaxed);
145        self.connector.shutdown().await
146    }
147
148    pub async fn get_flag(&self, key: &str) -> StorageQueryResult {
149        let flags = self.flags.read().await;
150        let metadata = self.flag_set_metadata.read().await;
151
152        StorageQueryResult {
153            feature_flag: flags.get(key).cloned(),
154            flag_set_metadata: metadata.clone(),
155        }
156    }
157
158    /// Compute which flags have changed between old and new flag sets
159    fn compute_changed_flags(
160        old_flags: &HashMap<String, FeatureFlag>,
161        new_flags: &HashMap<String, FeatureFlag>,
162    ) -> Vec<String> {
163        let mut changed = Vec::new();
164
165        // Check for modified or added flags
166        for (key, new_flag) in new_flags {
167            match old_flags.get(key) {
168                Some(old_flag) if old_flag != new_flag => {
169                    changed.push(key.clone());
170                }
171                None => {
172                    changed.push(key.clone());
173                }
174                _ => {}
175            }
176        }
177
178        // Check for deleted flags
179        let old_keys: HashSet<_> = old_flags.keys().collect();
180        let new_keys: HashSet<_> = new_flags.keys().collect();
181        for key in old_keys.difference(&new_keys) {
182            changed.push((*key).clone());
183        }
184
185        changed
186    }
187
188    async fn start_stream_listener(&self) {
189        let flags = self.flags.clone();
190        let metadata = self.flag_set_metadata.clone();
191        let sender = self.state_sender.clone();
192        let stream = self.connector.get_stream();
193        let shutdown = self.shutdown.clone();
194
195        tokio::spawn(async move {
196            let mut receiver = stream.lock().await;
197            if let Some(receiver) = receiver.as_mut() {
198                while let Some(payload) = receiver.recv().await {
199                    if shutdown.load(Ordering::Relaxed) {
200                        debug!("Stream listener shutting down");
201                        break;
202                    }
203
204                    match payload.payload_type {
205                        QueuePayloadType::Data => {
206                            match FlagParser::parse_string(&payload.flag_data) {
207                                Ok(parsing_result) => {
208                                    let mut flags_write = flags.write().await;
209                                    let mut metadata_write = metadata.write().await;
210
211                                    // Compute changed flags before updating
212                                    let changed_keys = Self::compute_changed_flags(
213                                        &flags_write,
214                                        &parsing_result.flags,
215                                    );
216
217                                    let num_changes = changed_keys.len();
218                                    *flags_write = parsing_result.flags;
219                                    *metadata_write = parsing_result.flag_set_metadata;
220
221                                    debug!(
222                                        "Flag store updated: {} flags changed ({} total flags)",
223                                        num_changes,
224                                        flags_write.len()
225                                    );
226
227                                    let _ = sender
228                                        .send(StorageStateChange {
229                                            storage_state: StorageState::Ok,
230                                            changed_flags_keys: changed_keys,
231                                            sync_metadata: payload.metadata.unwrap_or_default(),
232                                        })
233                                        .await;
234                                }
235                                Err(e) => {
236                                    warn!("Failed to parse flag data: {}", e);
237                                    let _ = sender
238                                        .send(StorageStateChange {
239                                            storage_state: StorageState::Error,
240                                            changed_flags_keys: vec![],
241                                            sync_metadata: HashMap::new(),
242                                        })
243                                        .await;
244                                }
245                            }
246                        }
247                        QueuePayloadType::Error => {
248                            error!("Received error from connector: {}", payload.flag_data);
249                            let _ = sender
250                                .send(StorageStateChange {
251                                    storage_state: StorageState::Error,
252                                    changed_flags_keys: vec![],
253                                    sync_metadata: HashMap::new(),
254                                })
255                                .await;
256                        }
257                    }
258                }
259            }
260            debug!("Stream listener stopped");
261        });
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use serde_json::json;
269
270    fn create_test_flag(state: &str, default_variant: &str) -> FeatureFlag {
271        FeatureFlag {
272            state: state.to_string(),
273            default_variant: default_variant.to_string(),
274            variants: {
275                let mut map = HashMap::new();
276                map.insert("on".to_string(), json!(true));
277                map.insert("off".to_string(), json!(false));
278                map
279            },
280            targeting: None,
281            metadata: HashMap::new(),
282        }
283    }
284
285    #[test]
286    fn test_compute_changed_flags_no_changes() {
287        let mut flags = HashMap::new();
288        flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
289        flags.insert("flag2".to_string(), create_test_flag("ENABLED", "off"));
290
291        let changed = FlagStore::compute_changed_flags(&flags, &flags);
292        assert!(
293            changed.is_empty(),
294            "Expected no changes for identical flags"
295        );
296    }
297
298    #[test]
299    fn test_compute_changed_flags_added_flag() {
300        let old_flags = HashMap::new();
301        let mut new_flags = HashMap::new();
302        new_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
303
304        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
305        assert_eq!(changed.len(), 1);
306        assert!(changed.contains(&"flag1".to_string()));
307    }
308
309    #[test]
310    fn test_compute_changed_flags_removed_flag() {
311        let mut old_flags = HashMap::new();
312        old_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
313        let new_flags = HashMap::new();
314
315        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
316        assert_eq!(changed.len(), 1);
317        assert!(changed.contains(&"flag1".to_string()));
318    }
319
320    #[test]
321    fn test_compute_changed_flags_modified_flag() {
322        let mut old_flags = HashMap::new();
323        old_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
324
325        let mut new_flags = HashMap::new();
326        new_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "off")); // Changed default
327
328        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
329        assert_eq!(changed.len(), 1);
330        assert!(changed.contains(&"flag1".to_string()));
331    }
332
333    #[test]
334    fn test_compute_changed_flags_mixed_changes() {
335        let mut old_flags = HashMap::new();
336        old_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
337        old_flags.insert("flag2".to_string(), create_test_flag("ENABLED", "on"));
338        old_flags.insert("flag3".to_string(), create_test_flag("ENABLED", "on"));
339
340        let mut new_flags = HashMap::new();
341        new_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on")); // Unchanged
342        new_flags.insert("flag2".to_string(), create_test_flag("DISABLED", "on")); // Modified
343        new_flags.insert("flag4".to_string(), create_test_flag("ENABLED", "on")); // Added
344        // flag3 is removed
345
346        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
347        assert_eq!(changed.len(), 3);
348        assert!(changed.contains(&"flag2".to_string())); // Modified
349        assert!(changed.contains(&"flag3".to_string())); // Removed
350        assert!(changed.contains(&"flag4".to_string())); // Added
351        assert!(!changed.contains(&"flag1".to_string())); // Unchanged
352    }
353
354    #[test]
355    fn test_storage_state_equality() {
356        assert_eq!(StorageState::Ok, StorageState::Ok);
357        assert_ne!(StorageState::Ok, StorageState::Error);
358        assert_ne!(StorageState::Error, StorageState::Stale);
359    }
360}