open_feature_flagd/resolver/in_process/storage/
mod.rs

1pub mod connector;
2use crate::error::FlagdError;
3pub use connector::{Connector, QueuePayload, QueuePayloadType};
4use tracing::{debug, error, warn};
5
6use crate::resolver::in_process::model::feature_flag::FeatureFlag;
7use crate::resolver::in_process::model::flag_parser::FlagParser;
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use tokio::sync::RwLock;
12use tokio::sync::mpsc::{Receiver, Sender, channel};
13
14/// State of the flag storage
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum StorageState {
17    /// Storage is healthy and up-to-date
18    Ok,
19    /// Storage data may be stale (connection issues)
20    Stale,
21    /// Storage encountered an error
22    Error,
23}
24
25impl Default for StorageState {
26    fn default() -> Self {
27        Self::Ok
28    }
29}
30
31/// Represents a change in storage state with affected flags
32#[derive(Debug, Clone, PartialEq)]
33pub struct StorageStateChange {
34    /// Current state of the storage
35    pub storage_state: StorageState,
36    /// Keys of flags that changed in this update
37    pub changed_flags_keys: Vec<String>,
38    /// Metadata from the sync operation
39    pub sync_metadata: HashMap<String, serde_json::Value>,
40}
41
42impl Default for StorageStateChange {
43    fn default() -> Self {
44        Self {
45            storage_state: StorageState::Ok,
46            changed_flags_keys: Vec::new(),
47            sync_metadata: HashMap::new(),
48        }
49    }
50}
51
52/// Result of querying a flag from storage
53#[derive(Debug, Clone)]
54pub struct StorageQueryResult {
55    /// The feature flag if found
56    pub feature_flag: Option<FeatureFlag>,
57    /// Metadata associated with the flag set
58    pub flag_set_metadata: HashMap<String, serde_json::Value>,
59}
60
61pub struct FlagStore {
62    flags: Arc<RwLock<HashMap<String, FeatureFlag>>>,
63    flag_set_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
64    state_sender: Sender<StorageStateChange>,
65    connector: Arc<dyn Connector>,
66    shutdown: Arc<AtomicBool>,
67}
68
69impl FlagStore {
70    pub fn new(connector: Arc<dyn Connector>) -> (Self, Receiver<StorageStateChange>) {
71        let (state_sender, state_receiver) = channel(1000);
72
73        (
74            Self {
75                flags: Arc::new(RwLock::new(HashMap::new())),
76                flag_set_metadata: Arc::new(RwLock::new(HashMap::new())),
77                state_sender,
78                connector,
79                shutdown: Arc::new(AtomicBool::new(false)),
80            },
81            state_receiver,
82        )
83    }
84
85    pub async fn init(&self) -> Result<(), FlagdError> {
86        debug!("Initializing flag store");
87        self.connector.init().await?;
88
89        // Handle initial sync
90        let stream = self.connector.get_stream();
91        let mut receiver = stream.lock().await;
92        debug!("Waiting for initial sync message");
93
94        if let Some(receiver_ref) = receiver.as_mut() {
95            match tokio::time::timeout(std::time::Duration::from_secs(5), receiver_ref.recv())
96                .await?
97            {
98                Some(payload) => {
99                    debug!("Received initial sync message");
100                    match payload.payload_type {
101                        QueuePayloadType::Data => {
102                            debug!("Parsing flag data: {}", &payload.flag_data);
103                            let parsing_result = FlagParser::parse_string(&payload.flag_data)?;
104                            let mut flags_write = self.flags.write().await;
105                            let mut metadata_write = self.flag_set_metadata.write().await;
106                            let flag_keys: Vec<String> =
107                                parsing_result.flags.keys().cloned().collect();
108                            *flags_write = parsing_result.flags;
109                            *metadata_write = parsing_result.flag_set_metadata;
110                            debug!("Successfully parsed {} flags", flags_write.len());
111
112                            // Send initial state change so FileResolver knows init completed
113                            let _ = self
114                                .state_sender
115                                .send(StorageStateChange {
116                                    storage_state: StorageState::Ok,
117                                    changed_flags_keys: flag_keys,
118                                    sync_metadata: payload.metadata.unwrap_or_default(),
119                                })
120                                .await;
121                        }
122                        QueuePayloadType::Error => {
123                            error!("Error in initial sync: {}", payload.flag_data);
124                            return Err(FlagdError::Sync(format!(
125                                "Error in initial sync: {}",
126                                payload.flag_data
127                            )));
128                        }
129                    }
130                }
131                None => {
132                    error!("No initial sync message received");
133                    return Err(FlagdError::Sync(
134                        "No initial sync message received".to_string(),
135                    ));
136                }
137            }
138        }
139
140        // Start continuous stream processing
141        self.start_stream_listener().await;
142        Ok(())
143    }
144
145    pub async fn shutdown(&self) -> Result<(), FlagdError> {
146        debug!("Shutting down flag store");
147        self.shutdown.store(true, Ordering::Relaxed);
148        self.connector.shutdown().await
149    }
150
151    pub async fn get_flag(&self, key: &str) -> StorageQueryResult {
152        let flags = self.flags.read().await;
153        let metadata = self.flag_set_metadata.read().await;
154
155        StorageQueryResult {
156            feature_flag: flags.get(key).cloned(),
157            flag_set_metadata: metadata.clone(),
158        }
159    }
160
161    /// Compute which flags have changed between old and new flag sets
162    fn compute_changed_flags(
163        old_flags: &HashMap<String, FeatureFlag>,
164        new_flags: &HashMap<String, FeatureFlag>,
165    ) -> Vec<String> {
166        let mut changed = Vec::new();
167
168        // Check for modified or added flags
169        for (key, new_flag) in new_flags {
170            match old_flags.get(key) {
171                Some(old_flag) if old_flag != new_flag => {
172                    changed.push(key.clone());
173                }
174                None => {
175                    changed.push(key.clone());
176                }
177                _ => {}
178            }
179        }
180
181        // Check for deleted flags
182        let old_keys: HashSet<_> = old_flags.keys().collect();
183        let new_keys: HashSet<_> = new_flags.keys().collect();
184        for key in old_keys.difference(&new_keys) {
185            changed.push((*key).clone());
186        }
187
188        changed
189    }
190
191    async fn start_stream_listener(&self) {
192        let flags = self.flags.clone();
193        let metadata = self.flag_set_metadata.clone();
194        let sender = self.state_sender.clone();
195        let stream = self.connector.get_stream();
196        let shutdown = self.shutdown.clone();
197
198        tokio::spawn(async move {
199            let mut receiver = stream.lock().await;
200            if let Some(receiver) = receiver.as_mut() {
201                while let Some(payload) = receiver.recv().await {
202                    if shutdown.load(Ordering::Relaxed) {
203                        debug!("Stream listener shutting down");
204                        break;
205                    }
206
207                    match payload.payload_type {
208                        QueuePayloadType::Data => {
209                            match FlagParser::parse_string(&payload.flag_data) {
210                                Ok(parsing_result) => {
211                                    let mut flags_write = flags.write().await;
212                                    let mut metadata_write = metadata.write().await;
213
214                                    // Compute changed flags before updating
215                                    let changed_keys = Self::compute_changed_flags(
216                                        &flags_write,
217                                        &parsing_result.flags,
218                                    );
219
220                                    let num_changes = changed_keys.len();
221                                    *flags_write = parsing_result.flags;
222                                    *metadata_write = parsing_result.flag_set_metadata;
223
224                                    debug!(
225                                        "Flag store updated: {} flags changed ({} total flags)",
226                                        num_changes,
227                                        flags_write.len()
228                                    );
229
230                                    let _ = sender
231                                        .send(StorageStateChange {
232                                            storage_state: StorageState::Ok,
233                                            changed_flags_keys: changed_keys,
234                                            sync_metadata: payload.metadata.unwrap_or_default(),
235                                        })
236                                        .await;
237                                }
238                                Err(e) => {
239                                    warn!("Failed to parse flag data: {}", e);
240                                    let _ = sender
241                                        .send(StorageStateChange {
242                                            storage_state: StorageState::Error,
243                                            changed_flags_keys: vec![],
244                                            sync_metadata: HashMap::new(),
245                                        })
246                                        .await;
247                                }
248                            }
249                        }
250                        QueuePayloadType::Error => {
251                            error!("Received error from connector: {}", payload.flag_data);
252                            let _ = sender
253                                .send(StorageStateChange {
254                                    storage_state: StorageState::Error,
255                                    changed_flags_keys: vec![],
256                                    sync_metadata: HashMap::new(),
257                                })
258                                .await;
259                        }
260                    }
261                }
262            }
263            debug!("Stream listener stopped");
264        });
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use serde_json::json;
272
273    fn create_test_flag(state: &str, default_variant: &str) -> FeatureFlag {
274        FeatureFlag {
275            state: state.to_string(),
276            default_variant: default_variant.to_string(),
277            variants: {
278                let mut map = HashMap::new();
279                map.insert("on".to_string(), json!(true));
280                map.insert("off".to_string(), json!(false));
281                map
282            },
283            targeting: None,
284            metadata: HashMap::new(),
285        }
286    }
287
288    #[test]
289    fn test_compute_changed_flags_no_changes() {
290        let mut flags = HashMap::new();
291        flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
292        flags.insert("flag2".to_string(), create_test_flag("ENABLED", "off"));
293
294        let changed = FlagStore::compute_changed_flags(&flags, &flags);
295        assert!(
296            changed.is_empty(),
297            "Expected no changes for identical flags"
298        );
299    }
300
301    #[test]
302    fn test_compute_changed_flags_added_flag() {
303        let old_flags = HashMap::new();
304        let mut new_flags = HashMap::new();
305        new_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
306
307        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
308        assert_eq!(changed.len(), 1);
309        assert!(changed.contains(&"flag1".to_string()));
310    }
311
312    #[test]
313    fn test_compute_changed_flags_removed_flag() {
314        let mut old_flags = HashMap::new();
315        old_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
316        let new_flags = HashMap::new();
317
318        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
319        assert_eq!(changed.len(), 1);
320        assert!(changed.contains(&"flag1".to_string()));
321    }
322
323    #[test]
324    fn test_compute_changed_flags_modified_flag() {
325        let mut old_flags = HashMap::new();
326        old_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
327
328        let mut new_flags = HashMap::new();
329        new_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "off")); // Changed default
330
331        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
332        assert_eq!(changed.len(), 1);
333        assert!(changed.contains(&"flag1".to_string()));
334    }
335
336    #[test]
337    fn test_compute_changed_flags_mixed_changes() {
338        let mut old_flags = HashMap::new();
339        old_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on"));
340        old_flags.insert("flag2".to_string(), create_test_flag("ENABLED", "on"));
341        old_flags.insert("flag3".to_string(), create_test_flag("ENABLED", "on"));
342
343        let mut new_flags = HashMap::new();
344        new_flags.insert("flag1".to_string(), create_test_flag("ENABLED", "on")); // Unchanged
345        new_flags.insert("flag2".to_string(), create_test_flag("DISABLED", "on")); // Modified
346        new_flags.insert("flag4".to_string(), create_test_flag("ENABLED", "on")); // Added
347        // flag3 is removed
348
349        let changed = FlagStore::compute_changed_flags(&old_flags, &new_flags);
350        assert_eq!(changed.len(), 3);
351        assert!(changed.contains(&"flag2".to_string())); // Modified
352        assert!(changed.contains(&"flag3".to_string())); // Removed
353        assert!(changed.contains(&"flag4".to_string())); // Added
354        assert!(!changed.contains(&"flag1".to_string())); // Unchanged
355    }
356
357    #[test]
358    fn test_storage_state_equality() {
359        assert_eq!(StorageState::Ok, StorageState::Ok);
360        assert_ne!(StorageState::Ok, StorageState::Error);
361        assert_ne!(StorageState::Error, StorageState::Stale);
362    }
363}