revoke_config/
memory.rs

1use crate::types::{ChangeType, ConfigCache, ConfigChange, ConfigMetadata, ConfigValue};
2use crate::utils::UnpinStream;
3use async_trait::async_trait;
4use futures::Stream;
5use parking_lot::RwLock;
6use revoke_core::{ConfigProvider, Result, RevokeError};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tracing::info;
11
12pub struct MemoryConfigProvider {
13    cache: ConfigCache,
14    watchers: Arc<RwLock<Vec<mpsc::UnboundedSender<ConfigChange>>>>,
15}
16
17impl MemoryConfigProvider {
18    pub fn new() -> Self {
19        Self {
20            cache: ConfigCache::new(),
21            watchers: Arc::new(RwLock::new(Vec::new())),
22        }
23    }
24
25    pub fn with_initial_values(values: HashMap<String, serde_json::Value>) -> Self {
26        let provider = Self::new();
27
28        for (key, value) in values {
29            let config_value = ConfigValue {
30                key: key.clone(),
31                value,
32                version: 1,
33                metadata: ConfigMetadata::new(),
34            };
35            provider.cache.set(key, config_value);
36        }
37
38        provider
39    }
40
41    fn notify_watchers(&self, change: ConfigChange) {
42        let mut watchers = self.watchers.write();
43        watchers.retain(|tx| tx.send(change.clone()).is_ok());
44    }
45}
46
47impl Default for MemoryConfigProvider {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[async_trait]
54impl ConfigProvider for MemoryConfigProvider {
55    async fn get(&self, key: &str) -> Result<String> {
56        self.cache
57            .get(key)
58            .map(|v| v.value.to_string())
59            .ok_or_else(|| RevokeError::ConfigError(format!("Key not found: {}", key)))
60    }
61
62    async fn set(&self, key: &str, value: &str) -> Result<()> {
63        let json_value: serde_json::Value = serde_json::from_str(value)
64            .unwrap_or_else(|_| serde_json::Value::String(value.to_string()));
65
66        let old_value = self.cache.get(key);
67        let version = old_value.as_ref().map(|v| v.version + 1).unwrap_or(1);
68
69        let mut metadata = old_value
70            .as_ref()
71            .map(|v| v.metadata.clone())
72            .unwrap_or_else(ConfigMetadata::new);
73        metadata.updated_at = chrono::Utc::now();
74
75        let config_value = ConfigValue {
76            key: key.to_string(),
77            value: json_value.clone(),
78            version,
79            metadata,
80        };
81
82        self.cache.set(key.to_string(), config_value);
83
84        let change = ConfigChange {
85            key: key.to_string(),
86            old_value: old_value.as_ref().map(|v| v.value.clone()),
87            new_value: Some(json_value),
88            change_type: if old_value.is_some() {
89                ChangeType::Updated
90            } else {
91                ChangeType::Created
92            },
93        };
94
95        self.notify_watchers(change);
96        info!("Config key '{}' updated", key);
97
98        Ok(())
99    }
100
101    async fn watch(&self, key: &str) -> Result<Box<dyn Stream<Item = String> + Send + Unpin>> {
102        let (tx, rx) = mpsc::unbounded_channel();
103        self.watchers.write().push(tx);
104
105        let key = key.to_string();
106        let stream = async_stream::stream! {
107            let mut rx = rx;
108            while let Some(change) = rx.recv().await {
109                if change.key == key {
110                    if let Some(new_value) = change.new_value {
111                        yield new_value.to_string();
112                    }
113                }
114            }
115        };
116
117        Ok(Box::new(UnpinStream::new(stream)) as Box<dyn Stream<Item = String> + Send + Unpin>)
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[tokio::test]
126    async fn test_memory_config() {
127        let provider = MemoryConfigProvider::new();
128
129        // Test set and get
130        provider.set("test.key", "test_value").await.unwrap();
131        let value = provider.get("test.key").await.unwrap();
132        assert_eq!(value, "\"test_value\"");
133
134        // Test missing key
135        let result = provider.get("missing.key").await;
136        assert!(result.is_err());
137    }
138
139    #[tokio::test]
140    async fn test_watch() {
141        use futures::StreamExt;
142
143        let provider = MemoryConfigProvider::new();
144
145        let mut stream = provider.watch("test.key").await.unwrap();
146
147        // Set value and check if watcher receives it
148        provider.set("test.key", "value1").await.unwrap();
149
150        if let Some(value) = stream.next().await {
151            assert_eq!(value, "\"value1\"");
152        } else {
153            panic!("Expected to receive a value");
154        }
155    }
156}