1use crate::types::{ChangeType, ConfigCache, ConfigChange, ConfigMetadata, ConfigValue};
2use crate::utils::UnpinStream;
3use async_trait::async_trait;
4use futures::Stream;
5use parking_lot::RwLock;
6use revoke_core::{ConfigProvider, Result, RevokeError};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tracing::info;
11
12pub struct MemoryConfigProvider {
13 cache: ConfigCache,
14 watchers: Arc<RwLock<Vec<mpsc::UnboundedSender<ConfigChange>>>>,
15}
16
17impl MemoryConfigProvider {
18 pub fn new() -> Self {
19 Self {
20 cache: ConfigCache::new(),
21 watchers: Arc::new(RwLock::new(Vec::new())),
22 }
23 }
24
25 pub fn with_initial_values(values: HashMap<String, serde_json::Value>) -> Self {
26 let provider = Self::new();
27
28 for (key, value) in values {
29 let config_value = ConfigValue {
30 key: key.clone(),
31 value,
32 version: 1,
33 metadata: ConfigMetadata::new(),
34 };
35 provider.cache.set(key, config_value);
36 }
37
38 provider
39 }
40
41 fn notify_watchers(&self, change: ConfigChange) {
42 let mut watchers = self.watchers.write();
43 watchers.retain(|tx| tx.send(change.clone()).is_ok());
44 }
45}
46
47impl Default for MemoryConfigProvider {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53#[async_trait]
54impl ConfigProvider for MemoryConfigProvider {
55 async fn get(&self, key: &str) -> Result<String> {
56 self.cache
57 .get(key)
58 .map(|v| v.value.to_string())
59 .ok_or_else(|| RevokeError::ConfigError(format!("Key not found: {}", key)))
60 }
61
62 async fn set(&self, key: &str, value: &str) -> Result<()> {
63 let json_value: serde_json::Value = serde_json::from_str(value)
64 .unwrap_or_else(|_| serde_json::Value::String(value.to_string()));
65
66 let old_value = self.cache.get(key);
67 let version = old_value.as_ref().map(|v| v.version + 1).unwrap_or(1);
68
69 let mut metadata = old_value
70 .as_ref()
71 .map(|v| v.metadata.clone())
72 .unwrap_or_else(ConfigMetadata::new);
73 metadata.updated_at = chrono::Utc::now();
74
75 let config_value = ConfigValue {
76 key: key.to_string(),
77 value: json_value.clone(),
78 version,
79 metadata,
80 };
81
82 self.cache.set(key.to_string(), config_value);
83
84 let change = ConfigChange {
85 key: key.to_string(),
86 old_value: old_value.as_ref().map(|v| v.value.clone()),
87 new_value: Some(json_value),
88 change_type: if old_value.is_some() {
89 ChangeType::Updated
90 } else {
91 ChangeType::Created
92 },
93 };
94
95 self.notify_watchers(change);
96 info!("Config key '{}' updated", key);
97
98 Ok(())
99 }
100
101 async fn watch(&self, key: &str) -> Result<Box<dyn Stream<Item = String> + Send + Unpin>> {
102 let (tx, rx) = mpsc::unbounded_channel();
103 self.watchers.write().push(tx);
104
105 let key = key.to_string();
106 let stream = async_stream::stream! {
107 let mut rx = rx;
108 while let Some(change) = rx.recv().await {
109 if change.key == key {
110 if let Some(new_value) = change.new_value {
111 yield new_value.to_string();
112 }
113 }
114 }
115 };
116
117 Ok(Box::new(UnpinStream::new(stream)) as Box<dyn Stream<Item = String> + Send + Unpin>)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[tokio::test]
126 async fn test_memory_config() {
127 let provider = MemoryConfigProvider::new();
128
129 provider.set("test.key", "test_value").await.unwrap();
131 let value = provider.get("test.key").await.unwrap();
132 assert_eq!(value, "\"test_value\"");
133
134 let result = provider.get("missing.key").await;
136 assert!(result.is_err());
137 }
138
139 #[tokio::test]
140 async fn test_watch() {
141 use futures::StreamExt;
142
143 let provider = MemoryConfigProvider::new();
144
145 let mut stream = provider.watch("test.key").await.unwrap();
146
147 provider.set("test.key", "value1").await.unwrap();
149
150 if let Some(value) = stream.next().await {
151 assert_eq!(value, "\"value1\"");
152 } else {
153 panic!("Expected to receive a value");
154 }
155 }
156}