1use futures_util::stream::StreamExt;
6use redis::aio::ConnectionManager;
7use redis::Client;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tracing::{error, info};
12
13use crate::error::ConfigError;
14use crate::types::Config;
15
16pub const CONFIG_CHANNEL: &str = "hyperinfer:config_updates";
17pub const CONFIG_KEY: &str = "hyperinfer:config";
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ConfigUpdate {
21 pub config: Config,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PolicyUpdate {
26 pub key: String,
27 pub action: PolicyAction,
28 pub reason: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "lowercase")]
33pub enum PolicyAction {
34 Revoke,
35 Update,
36}
37
38#[derive(Clone)]
39pub struct ConfigManager {
40 client: Arc<Client>,
41 manager: ConnectionManager,
42}
43
44impl ConfigManager {
45 pub async fn new(redis_url: &str) -> Result<Self, ConfigError> {
46 let client = Client::open(redis_url)?;
47 let manager = ConnectionManager::new(client.clone()).await?;
48 Ok(Self {
49 client: Arc::new(client),
50 manager,
51 })
52 }
53
54 pub async fn subscribe_to_config_updates(
55 &self,
56 config: Arc<RwLock<Config>>,
57 ) -> Result<tokio::task::JoinHandle<()>, ConfigError> {
58 let client = Arc::clone(&self.client);
59
60 let handle = tokio::spawn(async move {
61 let mut backoff = 1u64;
62
63 loop {
64 let result = async {
65 let mut pubsub = client.get_async_pubsub().await?;
66 pubsub.subscribe(CONFIG_CHANNEL).await?;
67
68 info!(
69 "Subscribed to Redis config updates channel: {}",
70 CONFIG_CHANNEL
71 );
72
73 let mut stream = pubsub.on_message();
74
75 while let Some(msg) = stream.next().await {
76 let payload_str = match msg.get_payload::<String>() {
77 Ok(p) => p,
78 Err(e) => {
79 error!("Failed to get message payload: {}", e);
80 continue;
81 }
82 };
83
84 let new_config = match serde_json::from_str::<ConfigUpdate>(&payload_str) {
85 Ok(update) => update.config,
86 Err(e) => {
87 error!("Failed to parse config update: {}", e);
88 continue;
89 }
90 };
91
92 {
93 let mut cfg = config.write().await;
94 *cfg = new_config;
95 info!("Config updated via Pub/Sub");
96 }
97 }
98 Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
99 }
100 .await;
101
102 match result {
103 Err(e) => {
104 error!(
105 "Config subscription error: {}, reconnecting in {}s",
106 e, backoff
107 );
108 tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
109 backoff = (backoff * 2).min(60);
110 }
111 Ok(()) => {
112 backoff = 1;
113 info!("Config updates stream ended, reconnecting in {}s", backoff);
114 tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
115 }
116 }
117 }
118 });
119
120 Ok(handle)
121 }
122
123 pub async fn subscribe_to_policy_updates(
124 &self,
125 callback: impl Fn(PolicyUpdate) + Send + Sync + 'static,
126 ) -> Result<tokio::task::JoinHandle<()>, ConfigError> {
127 let client = Arc::clone(&self.client);
128
129 let handle = tokio::spawn(async move {
130 let mut backoff = 1u64;
131
132 loop {
133 let result = async {
134 let mut pubsub = client.get_async_pubsub().await?;
135 pubsub.subscribe("hyperinfer:policy_updates").await?;
136
137 info!("Subscribed to Redis policy updates channel");
138
139 let mut stream = pubsub.on_message();
140
141 while let Some(msg) = stream.next().await {
142 let payload = match msg.get_payload::<String>() {
143 Ok(p) => p,
144 Err(e) => {
145 error!("Failed to get policy message payload: {}", e);
146 continue;
147 }
148 };
149
150 match serde_json::from_str::<PolicyUpdate>(&payload) {
151 Ok(update) => callback(update),
152 Err(e) => error!("Failed to parse policy update: {}", e),
153 }
154 }
155 Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
156 }
157 .await;
158
159 match result {
160 Err(e) => {
161 error!(
162 "Policy subscription error: {}, reconnecting in {}s",
163 e, backoff
164 );
165 tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
166 backoff = (backoff * 2).min(60);
167 }
168 Ok(()) => {
169 backoff = 1;
170 info!("Policy updates stream ended, reconnecting in {}s", backoff);
171 tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
172 }
173 }
174 }
175 });
176
177 Ok(handle)
178 }
179
180 pub async fn ping(&self) -> Result<(), ConfigError> {
181 let mut conn = self.manager.clone();
182 redis::cmd("PING").query_async::<String>(&mut conn).await?;
183 Ok(())
184 }
185
186 pub async fn fetch_config(&self) -> Result<Config, ConfigError> {
187 let mut conn = self.manager.clone();
188
189 let data: Option<Vec<u8>> = redis::cmd("GET")
190 .arg(CONFIG_KEY)
191 .query_async(&mut conn)
192 .await?;
193
194 match data {
195 Some(bytes) => {
196 let config: Config = serde_json::from_slice(&bytes)?;
197 Ok(config)
198 }
199 None => Ok(Config {
200 api_keys: std::collections::HashMap::new(),
201 routing_rules: Vec::new(),
202 quotas: std::collections::HashMap::new(),
203 model_aliases: std::collections::HashMap::new(),
204 default_provider: None,
205 }),
206 }
207 }
208
209 pub async fn publish_config_update(&self, config: &Config) -> Result<(), ConfigError> {
210 let mut conn = self.manager.clone();
211
212 let config_bytes = serde_json::to_vec(config)?;
214
215 redis::cmd("SET")
216 .arg(CONFIG_KEY)
217 .arg(config_bytes)
218 .query_async::<()>(&mut conn)
219 .await?;
220
221 let update = ConfigUpdate {
222 config: config.clone(),
223 };
224
225 let payload = serde_json::to_string(&update)?;
226
227 redis::cmd("PUBLISH")
228 .arg(CONFIG_CHANNEL)
229 .arg(&payload)
230 .query_async::<()>(&mut conn)
231 .await?;
232
233 info!("Published config update to channel: {}", CONFIG_CHANNEL);
234
235 Ok(())
236 }
237
238 pub async fn publish_policy_update(&self, update: &PolicyUpdate) -> Result<(), ConfigError> {
239 let mut conn = self.manager.clone();
240
241 let payload = serde_json::to_string(update)?;
242
243 redis::cmd("PUBLISH")
244 .arg("hyperinfer:policy_updates")
245 .arg(&payload)
246 .query_async::<()>(&mut conn)
247 .await?;
248
249 info!("Published policy update: {:?}", update.action);
250
251 Ok(())
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::Provider;
259
260 #[test]
261 fn test_config_update_serialization() {
262 let config = Config {
263 api_keys: std::collections::HashMap::new(),
264 routing_rules: vec![],
265 quotas: std::collections::HashMap::new(),
266 model_aliases: std::collections::HashMap::new(),
267 default_provider: Some(Provider::OpenAI),
268 };
269
270 let update = ConfigUpdate {
271 config: config.clone(),
272 };
273
274 let json = serde_json::to_string(&update).unwrap();
275 let deserialized: ConfigUpdate = serde_json::from_str(&json).unwrap();
276
277 assert_eq!(deserialized.config.default_provider, Some(Provider::OpenAI));
278 }
279
280 #[test]
281 fn test_policy_update_serialization() {
282 let update = PolicyUpdate {
283 key: "test-key".to_string(),
284 action: PolicyAction::Revoke,
285 reason: Some("Testing".to_string()),
286 };
287
288 let json = serde_json::to_string(&update).unwrap();
289 let deserialized: PolicyUpdate = serde_json::from_str(&json).unwrap();
290
291 assert_eq!(deserialized.key, "test-key");
292 assert_eq!(deserialized.reason, Some("Testing".to_string()));
293 }
294
295 #[test]
296 fn test_policy_action_revoke() {
297 let action = PolicyAction::Revoke;
298 let json = serde_json::to_string(&action).unwrap();
299 assert_eq!(json, "\"revoke\"");
300 }
301
302 #[test]
303 fn test_policy_action_update() {
304 let action = PolicyAction::Update;
305 let json = serde_json::to_string(&action).unwrap();
306 assert_eq!(json, "\"update\"");
307 }
308
309 #[test]
310 fn test_policy_update_without_reason() {
311 let update = PolicyUpdate {
312 key: "key123".to_string(),
313 action: PolicyAction::Update,
314 reason: None,
315 };
316
317 let json = serde_json::to_string(&update).unwrap();
318 let deserialized: PolicyUpdate = serde_json::from_str(&json).unwrap();
319
320 assert_eq!(deserialized.key, "key123");
321 assert_eq!(deserialized.reason, None);
322 }
323
324 #[test]
325 fn test_policy_update_clone() {
326 let update = PolicyUpdate {
327 key: "clone-key".to_string(),
328 action: PolicyAction::Revoke,
329 reason: Some("Clone test".to_string()),
330 };
331
332 let cloned = update.clone();
333 assert_eq!(update.key, cloned.key);
334 assert_eq!(update.reason, cloned.reason);
335 }
336
337 #[test]
338 fn test_config_update_clone() {
339 let config = Config {
340 api_keys: std::collections::HashMap::new(),
341 routing_rules: vec![],
342 quotas: std::collections::HashMap::new(),
343 model_aliases: std::collections::HashMap::new(),
344 default_provider: None,
345 };
346
347 let update = ConfigUpdate { config };
348 let cloned = update.clone();
349
350 assert_eq!(
351 update.config.routing_rules.len(),
352 cloned.config.routing_rules.len()
353 );
354 }
355
356 #[test]
357 fn test_config_channel_constant() {
358 assert_eq!(CONFIG_CHANNEL, "hyperinfer:config_updates");
359 }
360
361 #[test]
362 fn test_config_key_constant() {
363 assert_eq!(CONFIG_KEY, "hyperinfer:config");
364 }
365
366 #[test]
367 fn test_policy_action_deserialization_revoke() {
368 let json = "\"revoke\"";
369 let action: PolicyAction = serde_json::from_str(json).unwrap();
370 assert!(matches!(action, PolicyAction::Revoke));
371 }
372
373 #[test]
374 fn test_policy_action_deserialization_update() {
375 let json = "\"update\"";
376 let action: PolicyAction = serde_json::from_str(json).unwrap();
377 assert!(matches!(action, PolicyAction::Update));
378 }
379
380 #[test]
381 fn test_config_update_with_routing_rules() {
382 use crate::types::RoutingRule;
383
384 let rule = RoutingRule {
385 name: "test-rule".to_string(),
386 priority: 1,
387 fallback_models: vec!["model1".to_string(), "model2".to_string()],
388 };
389
390 let config = Config {
391 api_keys: std::collections::HashMap::new(),
392 routing_rules: vec![rule],
393 quotas: std::collections::HashMap::new(),
394 model_aliases: std::collections::HashMap::new(),
395 default_provider: None,
396 };
397
398 let update = ConfigUpdate { config };
399 let json = serde_json::to_string(&update).unwrap();
400 let deserialized: ConfigUpdate = serde_json::from_str(&json).unwrap();
401
402 assert_eq!(deserialized.config.routing_rules.len(), 1);
403 assert_eq!(deserialized.config.routing_rules[0].name, "test-rule");
404 }
405
406 #[test]
407 fn test_config_update_with_model_aliases() {
408 let mut aliases = std::collections::HashMap::new();
409 aliases.insert("alias1".to_string(), "model1".to_string());
410 aliases.insert("alias2".to_string(), "model2".to_string());
411
412 let config = Config {
413 api_keys: std::collections::HashMap::new(),
414 routing_rules: vec![],
415 quotas: std::collections::HashMap::new(),
416 model_aliases: aliases,
417 default_provider: None,
418 };
419
420 let update = ConfigUpdate { config };
421 let json = serde_json::to_string(&update).unwrap();
422 let deserialized: ConfigUpdate = serde_json::from_str(&json).unwrap();
423
424 assert_eq!(deserialized.config.model_aliases.len(), 2);
425 }
426}