Skip to main content

hyperinfer_core/
redis.rs

1//! Redis utilities for HyperInfer
2//!
3//! Provides functionality for Redis-based configuration and policy updates.
4
5use futures_util::stream::StreamExt;
6use redis::aio::ConnectionManager;
7use redis::Client;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tracing::{error, info};
12
13use crate::error::ConfigError;
14use crate::types::Config;
15
16pub const CONFIG_CHANNEL: &str = "hyperinfer:config_updates";
17pub const CONFIG_KEY: &str = "hyperinfer:config";
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ConfigUpdate {
21    pub config: Config,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PolicyUpdate {
26    pub key: String,
27    pub action: PolicyAction,
28    pub reason: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "lowercase")]
33pub enum PolicyAction {
34    Revoke,
35    Update,
36}
37
38#[derive(Clone)]
39pub struct ConfigManager {
40    client: Arc<Client>,
41    manager: ConnectionManager,
42}
43
44impl ConfigManager {
45    pub async fn new(redis_url: &str) -> Result<Self, ConfigError> {
46        let client = Client::open(redis_url)?;
47        let manager = ConnectionManager::new(client.clone()).await?;
48        Ok(Self {
49            client: Arc::new(client),
50            manager,
51        })
52    }
53
54    pub async fn subscribe_to_config_updates(
55        &self,
56        config: Arc<RwLock<Config>>,
57    ) -> Result<tokio::task::JoinHandle<()>, ConfigError> {
58        let client = Arc::clone(&self.client);
59
60        let handle = tokio::spawn(async move {
61            let mut backoff = 1u64;
62
63            loop {
64                let result = async {
65                    let mut pubsub = client.get_async_pubsub().await?;
66                    pubsub.subscribe(CONFIG_CHANNEL).await?;
67
68                    info!(
69                        "Subscribed to Redis config updates channel: {}",
70                        CONFIG_CHANNEL
71                    );
72
73                    let mut stream = pubsub.on_message();
74
75                    while let Some(msg) = stream.next().await {
76                        let payload_str = match msg.get_payload::<String>() {
77                            Ok(p) => p,
78                            Err(e) => {
79                                error!("Failed to get message payload: {}", e);
80                                continue;
81                            }
82                        };
83
84                        let new_config = match serde_json::from_str::<ConfigUpdate>(&payload_str) {
85                            Ok(update) => update.config,
86                            Err(e) => {
87                                error!("Failed to parse config update: {}", e);
88                                continue;
89                            }
90                        };
91
92                        {
93                            let mut cfg = config.write().await;
94                            *cfg = new_config;
95                            info!("Config updated via Pub/Sub");
96                        }
97                    }
98                    Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
99                }
100                .await;
101
102                match result {
103                    Err(e) => {
104                        error!(
105                            "Config subscription error: {}, reconnecting in {}s",
106                            e, backoff
107                        );
108                        tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
109                        backoff = (backoff * 2).min(60);
110                    }
111                    Ok(()) => {
112                        backoff = 1;
113                        info!("Config updates stream ended, reconnecting in {}s", backoff);
114                        tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
115                    }
116                }
117            }
118        });
119
120        Ok(handle)
121    }
122
123    pub async fn subscribe_to_policy_updates(
124        &self,
125        callback: impl Fn(PolicyUpdate) + Send + Sync + 'static,
126    ) -> Result<tokio::task::JoinHandle<()>, ConfigError> {
127        let client = Arc::clone(&self.client);
128
129        let handle = tokio::spawn(async move {
130            let mut backoff = 1u64;
131
132            loop {
133                let result = async {
134                    let mut pubsub = client.get_async_pubsub().await?;
135                    pubsub.subscribe("hyperinfer:policy_updates").await?;
136
137                    info!("Subscribed to Redis policy updates channel");
138
139                    let mut stream = pubsub.on_message();
140
141                    while let Some(msg) = stream.next().await {
142                        let payload = match msg.get_payload::<String>() {
143                            Ok(p) => p,
144                            Err(e) => {
145                                error!("Failed to get policy message payload: {}", e);
146                                continue;
147                            }
148                        };
149
150                        match serde_json::from_str::<PolicyUpdate>(&payload) {
151                            Ok(update) => callback(update),
152                            Err(e) => error!("Failed to parse policy update: {}", e),
153                        }
154                    }
155                    Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
156                }
157                .await;
158
159                match result {
160                    Err(e) => {
161                        error!(
162                            "Policy subscription error: {}, reconnecting in {}s",
163                            e, backoff
164                        );
165                        tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
166                        backoff = (backoff * 2).min(60);
167                    }
168                    Ok(()) => {
169                        backoff = 1;
170                        info!("Policy updates stream ended, reconnecting in {}s", backoff);
171                        tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
172                    }
173                }
174            }
175        });
176
177        Ok(handle)
178    }
179
180    pub async fn ping(&self) -> Result<(), ConfigError> {
181        let mut conn = self.manager.clone();
182        redis::cmd("PING").query_async::<String>(&mut conn).await?;
183        Ok(())
184    }
185
186    pub async fn fetch_config(&self) -> Result<Config, ConfigError> {
187        let mut conn = self.manager.clone();
188
189        let data: Option<Vec<u8>> = redis::cmd("GET")
190            .arg(CONFIG_KEY)
191            .query_async(&mut conn)
192            .await?;
193
194        match data {
195            Some(bytes) => {
196                let config: Config = serde_json::from_slice(&bytes)?;
197                Ok(config)
198            }
199            None => Ok(Config {
200                api_keys: std::collections::HashMap::new(),
201                routing_rules: Vec::new(),
202                quotas: std::collections::HashMap::new(),
203                model_aliases: std::collections::HashMap::new(),
204                default_provider: None,
205            }),
206        }
207    }
208
209    pub async fn publish_config_update(&self, config: &Config) -> Result<(), ConfigError> {
210        let mut conn = self.manager.clone();
211
212        // Store config first so it's available when subscribers receive notification
213        let config_bytes = serde_json::to_vec(config)?;
214
215        redis::cmd("SET")
216            .arg(CONFIG_KEY)
217            .arg(config_bytes)
218            .query_async::<()>(&mut conn)
219            .await?;
220
221        let update = ConfigUpdate {
222            config: config.clone(),
223        };
224
225        let payload = serde_json::to_string(&update)?;
226
227        redis::cmd("PUBLISH")
228            .arg(CONFIG_CHANNEL)
229            .arg(&payload)
230            .query_async::<()>(&mut conn)
231            .await?;
232
233        info!("Published config update to channel: {}", CONFIG_CHANNEL);
234
235        Ok(())
236    }
237
238    pub async fn publish_policy_update(&self, update: &PolicyUpdate) -> Result<(), ConfigError> {
239        let mut conn = self.manager.clone();
240
241        let payload = serde_json::to_string(update)?;
242
243        redis::cmd("PUBLISH")
244            .arg("hyperinfer:policy_updates")
245            .arg(&payload)
246            .query_async::<()>(&mut conn)
247            .await?;
248
249        info!("Published policy update: {:?}", update.action);
250
251        Ok(())
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::Provider;
259
260    #[test]
261    fn test_config_update_serialization() {
262        let config = Config {
263            api_keys: std::collections::HashMap::new(),
264            routing_rules: vec![],
265            quotas: std::collections::HashMap::new(),
266            model_aliases: std::collections::HashMap::new(),
267            default_provider: Some(Provider::OpenAI),
268        };
269
270        let update = ConfigUpdate {
271            config: config.clone(),
272        };
273
274        let json = serde_json::to_string(&update).unwrap();
275        let deserialized: ConfigUpdate = serde_json::from_str(&json).unwrap();
276
277        assert_eq!(deserialized.config.default_provider, Some(Provider::OpenAI));
278    }
279
280    #[test]
281    fn test_policy_update_serialization() {
282        let update = PolicyUpdate {
283            key: "test-key".to_string(),
284            action: PolicyAction::Revoke,
285            reason: Some("Testing".to_string()),
286        };
287
288        let json = serde_json::to_string(&update).unwrap();
289        let deserialized: PolicyUpdate = serde_json::from_str(&json).unwrap();
290
291        assert_eq!(deserialized.key, "test-key");
292        assert_eq!(deserialized.reason, Some("Testing".to_string()));
293    }
294
295    #[test]
296    fn test_policy_action_revoke() {
297        let action = PolicyAction::Revoke;
298        let json = serde_json::to_string(&action).unwrap();
299        assert_eq!(json, "\"revoke\"");
300    }
301
302    #[test]
303    fn test_policy_action_update() {
304        let action = PolicyAction::Update;
305        let json = serde_json::to_string(&action).unwrap();
306        assert_eq!(json, "\"update\"");
307    }
308
309    #[test]
310    fn test_policy_update_without_reason() {
311        let update = PolicyUpdate {
312            key: "key123".to_string(),
313            action: PolicyAction::Update,
314            reason: None,
315        };
316
317        let json = serde_json::to_string(&update).unwrap();
318        let deserialized: PolicyUpdate = serde_json::from_str(&json).unwrap();
319
320        assert_eq!(deserialized.key, "key123");
321        assert_eq!(deserialized.reason, None);
322    }
323
324    #[test]
325    fn test_policy_update_clone() {
326        let update = PolicyUpdate {
327            key: "clone-key".to_string(),
328            action: PolicyAction::Revoke,
329            reason: Some("Clone test".to_string()),
330        };
331
332        let cloned = update.clone();
333        assert_eq!(update.key, cloned.key);
334        assert_eq!(update.reason, cloned.reason);
335    }
336
337    #[test]
338    fn test_config_update_clone() {
339        let config = Config {
340            api_keys: std::collections::HashMap::new(),
341            routing_rules: vec![],
342            quotas: std::collections::HashMap::new(),
343            model_aliases: std::collections::HashMap::new(),
344            default_provider: None,
345        };
346
347        let update = ConfigUpdate { config };
348        let cloned = update.clone();
349
350        assert_eq!(
351            update.config.routing_rules.len(),
352            cloned.config.routing_rules.len()
353        );
354    }
355
356    #[test]
357    fn test_config_channel_constant() {
358        assert_eq!(CONFIG_CHANNEL, "hyperinfer:config_updates");
359    }
360
361    #[test]
362    fn test_config_key_constant() {
363        assert_eq!(CONFIG_KEY, "hyperinfer:config");
364    }
365
366    #[test]
367    fn test_policy_action_deserialization_revoke() {
368        let json = "\"revoke\"";
369        let action: PolicyAction = serde_json::from_str(json).unwrap();
370        assert!(matches!(action, PolicyAction::Revoke));
371    }
372
373    #[test]
374    fn test_policy_action_deserialization_update() {
375        let json = "\"update\"";
376        let action: PolicyAction = serde_json::from_str(json).unwrap();
377        assert!(matches!(action, PolicyAction::Update));
378    }
379
380    #[test]
381    fn test_config_update_with_routing_rules() {
382        use crate::types::RoutingRule;
383
384        let rule = RoutingRule {
385            name: "test-rule".to_string(),
386            priority: 1,
387            fallback_models: vec!["model1".to_string(), "model2".to_string()],
388        };
389
390        let config = Config {
391            api_keys: std::collections::HashMap::new(),
392            routing_rules: vec![rule],
393            quotas: std::collections::HashMap::new(),
394            model_aliases: std::collections::HashMap::new(),
395            default_provider: None,
396        };
397
398        let update = ConfigUpdate { config };
399        let json = serde_json::to_string(&update).unwrap();
400        let deserialized: ConfigUpdate = serde_json::from_str(&json).unwrap();
401
402        assert_eq!(deserialized.config.routing_rules.len(), 1);
403        assert_eq!(deserialized.config.routing_rules[0].name, "test-rule");
404    }
405
406    #[test]
407    fn test_config_update_with_model_aliases() {
408        let mut aliases = std::collections::HashMap::new();
409        aliases.insert("alias1".to_string(), "model1".to_string());
410        aliases.insert("alias2".to_string(), "model2".to_string());
411
412        let config = Config {
413            api_keys: std::collections::HashMap::new(),
414            routing_rules: vec![],
415            quotas: std::collections::HashMap::new(),
416            model_aliases: aliases,
417            default_provider: None,
418        };
419
420        let update = ConfigUpdate { config };
421        let json = serde_json::to_string(&update).unwrap();
422        let deserialized: ConfigUpdate = serde_json::from_str(&json).unwrap();
423
424        assert_eq!(deserialized.config.model_aliases.len(), 2);
425    }
426}