1use std::collections::HashMap;
4use std::sync::Arc;
5
6#[cfg(any(feature = "cloud", feature = "enterprise", feature = "database"))]
7use anyhow::Context;
8use anyhow::Result;
9#[cfg(feature = "cloud")]
10use redis_cloud::CloudClient;
11#[cfg(feature = "enterprise")]
12use redis_enterprise::EnterpriseClient;
13use redisctl_core::Config;
14use tokio::sync::RwLock;
15
16use crate::policy::{Policy, SafetyTier};
17
18#[derive(Debug, Clone)]
20#[allow(dead_code)]
21pub enum CredentialSource {
22 Profiles(Vec<String>),
25 OAuth {
27 issuer: Option<String>,
28 audience: Option<String>,
29 },
30}
31
32pub struct CachedClients {
34 #[cfg(feature = "cloud")]
35 pub cloud: HashMap<String, CloudClient>,
36 #[cfg(feature = "enterprise")]
37 pub enterprise: HashMap<String, EnterpriseClient>,
38 #[cfg(feature = "database")]
39 pub database: HashMap<String, redis::aio::MultiplexedConnection>,
40}
41
42pub struct AppState {
44 pub credential_source: CredentialSource,
46 pub policy: Arc<Policy>,
48 pub database_url: Option<String>,
50 config: Option<Config>,
52 profiles: Vec<String>,
54 #[allow(dead_code)]
56 clients: RwLock<CachedClients>,
57 #[cfg(feature = "database")]
59 aliases: RwLock<HashMap<String, Vec<Vec<String>>>>,
60}
61
62impl AppState {
63 pub fn new(
65 credential_source: CredentialSource,
66 policy: Arc<Policy>,
67 database_url: Option<String>,
68 ) -> Result<Self> {
69 let profiles = match &credential_source {
71 CredentialSource::Profiles(p) => p.clone(),
72 CredentialSource::OAuth { .. } => vec![],
73 };
74
75 let config = match &credential_source {
77 CredentialSource::Profiles(_) => Config::load().ok(),
78 CredentialSource::OAuth { .. } => None,
79 };
80
81 Ok(Self {
82 credential_source,
83 policy,
84 database_url,
85 config,
86 profiles,
87 clients: RwLock::new(CachedClients {
88 #[cfg(feature = "cloud")]
89 cloud: HashMap::new(),
90 #[cfg(feature = "enterprise")]
91 enterprise: HashMap::new(),
92 #[cfg(feature = "database")]
93 database: HashMap::new(),
94 }),
95 #[cfg(feature = "database")]
96 aliases: RwLock::new(HashMap::new()),
97 })
98 }
99
100 #[allow(dead_code)]
102 pub fn available_profiles(&self) -> &[String] {
103 &self.profiles
104 }
105
106 #[cfg(feature = "cloud")]
110 pub async fn cloud_client_for_profile(&self, profile: Option<&str>) -> Result<CloudClient> {
111 let cache_key = profile.unwrap_or("_default").to_string();
112
113 {
115 let clients = self.clients.read().await;
116 if let Some(client) = clients.cloud.get(&cache_key) {
117 return Ok(client.clone());
118 }
119 }
120
121 let client = self.create_cloud_client(profile).await?;
123
124 {
126 let mut clients = self.clients.write().await;
127 clients.cloud.insert(cache_key, client.clone());
128 }
129
130 Ok(client)
131 }
132
133 #[cfg(feature = "cloud")]
135 #[allow(dead_code)]
136 pub async fn cloud_client(&self) -> Result<CloudClient> {
137 self.cloud_client_for_profile(None).await
138 }
139
140 #[cfg(feature = "enterprise")]
144 pub async fn enterprise_client_for_profile(
145 &self,
146 profile: Option<&str>,
147 ) -> Result<EnterpriseClient> {
148 let cache_key = profile.unwrap_or("_default").to_string();
149
150 {
152 let clients = self.clients.read().await;
153 if let Some(client) = clients.enterprise.get(&cache_key) {
154 return Ok(client.clone());
155 }
156 }
157
158 let client = self.create_enterprise_client(profile).await?;
160
161 {
163 let mut clients = self.clients.write().await;
164 clients.enterprise.insert(cache_key, client.clone());
165 }
166
167 Ok(client)
168 }
169
170 #[cfg(feature = "enterprise")]
172 #[allow(dead_code)]
173 pub async fn enterprise_client(&self) -> Result<EnterpriseClient> {
174 self.enterprise_client_for_profile(None).await
175 }
176
177 #[cfg(feature = "cloud")]
179 async fn create_cloud_client(&self, profile: Option<&str>) -> Result<CloudClient> {
180 match &self.credential_source {
181 CredentialSource::Profiles(profiles) => {
182 let config = self
183 .config
184 .as_ref()
185 .context("No redisctl config available")?;
186
187 let profile_to_use = profile
189 .map(|s| s.to_string())
190 .or_else(|| profiles.first().cloned());
191
192 let resolved_profile_name = config
194 .resolve_cloud_profile(profile_to_use.as_deref())
195 .context("Failed to resolve cloud profile")?;
196
197 let profile = config
199 .profiles
200 .get(&resolved_profile_name)
201 .with_context(|| format!("Profile '{}' not found", resolved_profile_name))?;
202
203 let (api_key, api_secret, _base_url) = profile
205 .resolve_cloud_credentials()
206 .context("Failed to resolve cloud credentials")?
207 .context("No cloud credentials in profile")?;
208
209 CloudClient::builder()
210 .api_key(api_key)
211 .api_secret(api_secret)
212 .build()
213 .context("Failed to build Cloud client")
214 }
215 CredentialSource::OAuth { .. } => {
216 let api_key =
218 std::env::var("REDIS_CLOUD_API_KEY").context("REDIS_CLOUD_API_KEY not set")?;
219 let api_secret = std::env::var("REDIS_CLOUD_API_SECRET")
220 .context("REDIS_CLOUD_API_SECRET not set")?;
221
222 CloudClient::builder()
223 .api_key(api_key)
224 .api_secret(api_secret)
225 .build()
226 .context("Failed to build Cloud client")
227 }
228 }
229 }
230
231 #[cfg(feature = "enterprise")]
233 async fn create_enterprise_client(&self, profile: Option<&str>) -> Result<EnterpriseClient> {
234 match &self.credential_source {
235 CredentialSource::Profiles(profiles) => {
236 let config = self
237 .config
238 .as_ref()
239 .context("No redisctl config available")?;
240
241 let profile_to_use = profile
243 .map(|s| s.to_string())
244 .or_else(|| profiles.first().cloned());
245
246 let resolved_profile_name = config
248 .resolve_enterprise_profile(profile_to_use.as_deref())
249 .context("Failed to resolve enterprise profile")?;
250
251 let profile_config = config
253 .profiles
254 .get(&resolved_profile_name)
255 .with_context(|| format!("Profile '{}' not found", resolved_profile_name))?;
256
257 let (url, username, password, insecure, ca_cert) = profile_config
259 .resolve_enterprise_credentials()
260 .context("Failed to resolve enterprise credentials")?
261 .context("No enterprise credentials in profile")?;
262
263 let mut builder = EnterpriseClient::builder()
264 .base_url(&url)
265 .username(&username)
266 .insecure(insecure);
267
268 if let Some(pwd) = password {
269 builder = builder.password(&pwd);
270 }
271
272 if let Some(cert_path) = ca_cert {
273 builder = builder.ca_cert(&cert_path);
274 }
275
276 builder.build().context("Failed to build Enterprise client")
277 }
278 CredentialSource::OAuth { .. } => {
279 let url = std::env::var("REDIS_ENTERPRISE_URL")
281 .context("REDIS_ENTERPRISE_URL not set")?;
282 let username = std::env::var("REDIS_ENTERPRISE_USER")
283 .context("REDIS_ENTERPRISE_USER not set")?;
284 let password = std::env::var("REDIS_ENTERPRISE_PASSWORD").ok();
285 let insecure = std::env::var("REDIS_ENTERPRISE_INSECURE")
286 .map(|v| v == "true" || v == "1")
287 .unwrap_or(false);
288
289 let mut builder = EnterpriseClient::builder()
290 .base_url(&url)
291 .username(&username)
292 .insecure(insecure);
293
294 if let Some(pwd) = password {
295 builder = builder.password(&pwd);
296 }
297
298 builder.build().context("Failed to build Enterprise client")
299 }
300 }
301 }
302
303 #[cfg(feature = "database")]
307 pub fn database_url_for_profile(&self, profile: Option<&str>) -> Result<String> {
308 let config = self
309 .config
310 .as_ref()
311 .context("No redisctl config available")?;
312
313 let profile_to_use = profile
314 .map(|s| s.to_string())
315 .or_else(|| self.profiles.first().cloned());
316
317 let resolved_name = config
318 .resolve_database_profile(profile_to_use.as_deref())
319 .context("Failed to resolve database profile")?;
320
321 let profile_config = config
322 .profiles
323 .get(&resolved_name)
324 .with_context(|| format!("Profile '{}' not found", resolved_name))?;
325
326 let (host, port, password, tls, username, database) = profile_config
327 .resolve_database_credentials()
328 .context("Failed to resolve database credentials")?
329 .context("No database credentials in profile")?;
330
331 let scheme = if tls { "rediss" } else { "redis" };
333 let auth = match (username.as_str(), password) {
334 ("", None) | ("default", None) => String::new(),
335 (user, Some(pass)) => format!(
336 "{}:{}@",
337 urlencoding::encode(user),
338 urlencoding::encode(&pass)
339 ),
340 (user, None) => format!("{}@", urlencoding::encode(user)),
341 };
342 let db_path = if database > 0 {
343 format!("/{}", database)
344 } else {
345 String::new()
346 };
347
348 Ok(format!("{}://{}{}:{}{}", scheme, auth, host, port, db_path))
349 }
350
351 #[cfg(feature = "database")]
356 pub async fn redis_connection_for_url(
357 &self,
358 url: &str,
359 ) -> Result<redis::aio::MultiplexedConnection> {
360 {
362 let clients = self.clients.read().await;
363 if let Some(conn) = clients.database.get(url) {
364 let mut test_conn = conn.clone();
366 if redis::cmd("PING")
367 .query_async::<String>(&mut test_conn)
368 .await
369 .is_ok()
370 {
371 return Ok(conn.clone());
372 }
373 }
375 }
376
377 let client = redis::Client::open(url).context("Failed to create Redis client")?;
379 let conn = client
380 .get_multiplexed_async_connection()
381 .await
382 .context("Failed to connect to Redis")?;
383
384 {
386 let mut clients = self.clients.write().await;
387 clients.database.insert(url.to_string(), conn.clone());
388 }
389
390 Ok(conn)
391 }
392
393 #[allow(dead_code)]
398 pub fn is_write_allowed(&self) -> bool {
399 matches!(
400 self.policy.global_tier(),
401 SafetyTier::ReadWrite | SafetyTier::Full
402 )
403 }
404
405 #[allow(dead_code)]
410 pub fn is_destructive_allowed(&self) -> bool {
411 matches!(self.policy.global_tier(), SafetyTier::Full)
412 }
413
414 #[cfg(feature = "database")]
416 pub async fn set_alias(&self, name: String, commands: Vec<Vec<String>>) {
417 let mut aliases = self.aliases.write().await;
418 aliases.insert(name, commands);
419 }
420
421 #[cfg(feature = "database")]
423 pub async fn get_alias(&self, name: &str) -> Option<Vec<Vec<String>>> {
424 let aliases = self.aliases.read().await;
425 aliases.get(name).cloned()
426 }
427
428 #[cfg(feature = "database")]
430 pub async fn list_aliases(&self) -> Vec<(String, usize)> {
431 let aliases = self.aliases.read().await;
432 let mut entries: Vec<_> = aliases.iter().map(|(k, v)| (k.clone(), v.len())).collect();
433 entries.sort_by(|a, b| a.0.cmp(&b.0));
434 entries
435 }
436
437 #[cfg(feature = "database")]
439 pub async fn delete_alias(&self, name: &str) -> bool {
440 let mut aliases = self.aliases.write().await;
441 aliases.remove(name).is_some()
442 }
443}
444
445impl Clone for AppState {
446 fn clone(&self) -> Self {
447 Self {
449 credential_source: self.credential_source.clone(),
450 policy: self.policy.clone(),
451 database_url: self.database_url.clone(),
452 config: self.config.clone(),
453 profiles: self.profiles.clone(),
454 clients: RwLock::new(CachedClients {
455 #[cfg(feature = "cloud")]
456 cloud: HashMap::new(),
457 #[cfg(feature = "enterprise")]
458 enterprise: HashMap::new(),
459 #[cfg(feature = "database")]
460 database: HashMap::new(),
461 }),
462 #[cfg(feature = "database")]
463 aliases: RwLock::new(HashMap::new()),
464 }
465 }
466}
467
468#[allow(dead_code)]
470impl AppState {
471 pub fn test_policy() -> Arc<Policy> {
473 Arc::new(Policy::new(
474 crate::policy::PolicyConfig::default(),
475 std::collections::HashMap::new(),
476 "test".to_string(),
477 ))
478 }
479
480 #[cfg(feature = "cloud")]
482 pub fn with_cloud_client(client: CloudClient) -> Self {
483 let mut cloud = HashMap::new();
484 cloud.insert("_default".to_string(), client);
485 Self {
486 credential_source: CredentialSource::Profiles(vec![]),
487 policy: Self::test_policy(),
488 database_url: None,
489 config: None,
490 profiles: vec![],
491 clients: RwLock::new(CachedClients {
492 cloud,
493 #[cfg(feature = "enterprise")]
494 enterprise: HashMap::new(),
495 #[cfg(feature = "database")]
496 database: HashMap::new(),
497 }),
498 #[cfg(feature = "database")]
499 aliases: RwLock::new(HashMap::new()),
500 }
501 }
502
503 #[cfg(feature = "enterprise")]
505 pub fn with_enterprise_client(client: EnterpriseClient) -> Self {
506 let mut enterprise = HashMap::new();
507 enterprise.insert("_default".to_string(), client);
508 Self {
509 credential_source: CredentialSource::Profiles(vec![]),
510 policy: Self::test_policy(),
511 database_url: None,
512 config: None,
513 profiles: vec![],
514 clients: RwLock::new(CachedClients {
515 #[cfg(feature = "cloud")]
516 cloud: HashMap::new(),
517 enterprise,
518 #[cfg(feature = "database")]
519 database: HashMap::new(),
520 }),
521 #[cfg(feature = "database")]
522 aliases: RwLock::new(HashMap::new()),
523 }
524 }
525
526 #[cfg(all(feature = "cloud", feature = "enterprise"))]
528 pub fn with_clients(cloud: CloudClient, enterprise: EnterpriseClient) -> Self {
529 let mut cloud_map = HashMap::new();
530 cloud_map.insert("_default".to_string(), cloud);
531 let mut enterprise_map = HashMap::new();
532 enterprise_map.insert("_default".to_string(), enterprise);
533 Self {
534 credential_source: CredentialSource::Profiles(vec![]),
535 policy: Self::test_policy(),
536 database_url: None,
537 config: None,
538 profiles: vec![],
539 clients: RwLock::new(CachedClients {
540 cloud: cloud_map,
541 enterprise: enterprise_map,
542 #[cfg(feature = "database")]
543 database: HashMap::new(),
544 }),
545 #[cfg(feature = "database")]
546 aliases: RwLock::new(HashMap::new()),
547 }
548 }
549}