Skip to main content

redisctl_mcp/
state.rs

1//! Application state and credential resolution
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[cfg(any(feature = "cloud", feature = "enterprise", feature = "database"))]
7use anyhow::Context;
8use anyhow::Result;
9#[cfg(feature = "cloud")]
10use redis_cloud::CloudClient;
11#[cfg(feature = "enterprise")]
12use redis_enterprise::EnterpriseClient;
13use redisctl_core::Config;
14use tokio::sync::RwLock;
15
16use crate::policy::{Policy, SafetyTier};
17
18/// How credentials are resolved
19#[derive(Debug, Clone)]
20#[allow(dead_code)]
21pub enum CredentialSource {
22    /// Resolve from redisctl profiles (local mode)
23    /// Empty vec means use default profiles from config
24    Profiles(Vec<String>),
25    /// Resolve from OAuth token claims (HTTP mode)
26    OAuth {
27        issuer: Option<String>,
28        audience: Option<String>,
29    },
30}
31
32/// Cached API clients and connections (per-profile for multi-cluster support)
33pub struct CachedClients {
34    #[cfg(feature = "cloud")]
35    pub cloud: HashMap<String, CloudClient>,
36    #[cfg(feature = "enterprise")]
37    pub enterprise: HashMap<String, EnterpriseClient>,
38    #[cfg(feature = "database")]
39    pub database: HashMap<String, redis::aio::MultiplexedConnection>,
40}
41
42/// Shared application state
43pub struct AppState {
44    /// Credential source configuration
45    pub credential_source: CredentialSource,
46    /// Resolved policy for granular tool access control
47    pub policy: Arc<Policy>,
48    /// Optional Redis database URL for direct connections
49    pub database_url: Option<String>,
50    /// redisctl config (for profile-based auth)
51    config: Option<Config>,
52    /// Configured profiles (for multi-cluster support)
53    profiles: Vec<String>,
54    /// Cached API clients (keyed by profile name, "_default" for default)
55    #[allow(dead_code)]
56    clients: RwLock<CachedClients>,
57    /// Session-scoped command aliases (name → list of command arg arrays)
58    #[cfg(feature = "database")]
59    aliases: RwLock<HashMap<String, Vec<Vec<String>>>>,
60}
61
62impl AppState {
63    /// Create new application state
64    pub fn new(
65        credential_source: CredentialSource,
66        policy: Arc<Policy>,
67        database_url: Option<String>,
68    ) -> Result<Self> {
69        // Extract profiles list
70        let profiles = match &credential_source {
71            CredentialSource::Profiles(p) => p.clone(),
72            CredentialSource::OAuth { .. } => vec![],
73        };
74
75        // Load config if using profile-based auth
76        let config = match &credential_source {
77            CredentialSource::Profiles(_) => Config::load().ok(),
78            CredentialSource::OAuth { .. } => None,
79        };
80
81        Ok(Self {
82            credential_source,
83            policy,
84            database_url,
85            config,
86            profiles,
87            clients: RwLock::new(CachedClients {
88                #[cfg(feature = "cloud")]
89                cloud: HashMap::new(),
90                #[cfg(feature = "enterprise")]
91                enterprise: HashMap::new(),
92                #[cfg(feature = "database")]
93                database: HashMap::new(),
94            }),
95            #[cfg(feature = "database")]
96            aliases: RwLock::new(HashMap::new()),
97        })
98    }
99
100    /// Get the list of configured profiles
101    #[allow(dead_code)]
102    pub fn available_profiles(&self) -> &[String] {
103        &self.profiles
104    }
105
106    /// Get or create Cloud API client for a specific profile
107    ///
108    /// If profile is None, uses the first configured profile or default from config
109    #[cfg(feature = "cloud")]
110    pub async fn cloud_client_for_profile(&self, profile: Option<&str>) -> Result<CloudClient> {
111        let cache_key = profile.unwrap_or("_default").to_string();
112
113        // Check cache first
114        {
115            let clients = self.clients.read().await;
116            if let Some(client) = clients.cloud.get(&cache_key) {
117                return Ok(client.clone());
118            }
119        }
120
121        // Create new client
122        let client = self.create_cloud_client(profile).await?;
123
124        // Cache it
125        {
126            let mut clients = self.clients.write().await;
127            clients.cloud.insert(cache_key, client.clone());
128        }
129
130        Ok(client)
131    }
132
133    /// Get or create Cloud API client (uses default profile)
134    #[cfg(feature = "cloud")]
135    #[allow(dead_code)]
136    pub async fn cloud_client(&self) -> Result<CloudClient> {
137        self.cloud_client_for_profile(None).await
138    }
139
140    /// Get or create Enterprise API client for a specific profile
141    ///
142    /// If profile is None, uses the first configured profile or default from config
143    #[cfg(feature = "enterprise")]
144    pub async fn enterprise_client_for_profile(
145        &self,
146        profile: Option<&str>,
147    ) -> Result<EnterpriseClient> {
148        let cache_key = profile.unwrap_or("_default").to_string();
149
150        // Check cache first
151        {
152            let clients = self.clients.read().await;
153            if let Some(client) = clients.enterprise.get(&cache_key) {
154                return Ok(client.clone());
155            }
156        }
157
158        // Create new client
159        let client = self.create_enterprise_client(profile).await?;
160
161        // Cache it
162        {
163            let mut clients = self.clients.write().await;
164            clients.enterprise.insert(cache_key, client.clone());
165        }
166
167        Ok(client)
168    }
169
170    /// Get or create Enterprise API client (uses default profile)
171    #[cfg(feature = "enterprise")]
172    #[allow(dead_code)]
173    pub async fn enterprise_client(&self) -> Result<EnterpriseClient> {
174        self.enterprise_client_for_profile(None).await
175    }
176
177    /// Create a new Cloud client from credentials
178    #[cfg(feature = "cloud")]
179    async fn create_cloud_client(&self, profile: Option<&str>) -> Result<CloudClient> {
180        match &self.credential_source {
181            CredentialSource::Profiles(profiles) => {
182                let config = self
183                    .config
184                    .as_ref()
185                    .context("No redisctl config available")?;
186
187                // Use specified profile, first configured profile, or let config resolve default
188                let profile_to_use = profile
189                    .map(|s| s.to_string())
190                    .or_else(|| profiles.first().cloned());
191
192                // Resolve the profile name
193                let resolved_profile_name = config
194                    .resolve_cloud_profile(profile_to_use.as_deref())
195                    .context("Failed to resolve cloud profile")?;
196
197                // Get the profile
198                let profile = config
199                    .profiles
200                    .get(&resolved_profile_name)
201                    .with_context(|| format!("Profile '{}' not found", resolved_profile_name))?;
202
203                // Get credentials
204                let (api_key, api_secret, _base_url) = profile
205                    .resolve_cloud_credentials()
206                    .context("Failed to resolve cloud credentials")?
207                    .context("No cloud credentials in profile")?;
208
209                CloudClient::builder()
210                    .api_key(api_key)
211                    .api_secret(api_secret)
212                    .build()
213                    .context("Failed to build Cloud client")
214            }
215            CredentialSource::OAuth { .. } => {
216                // In OAuth mode, credentials come from environment variables
217                let api_key =
218                    std::env::var("REDIS_CLOUD_API_KEY").context("REDIS_CLOUD_API_KEY not set")?;
219                let api_secret = std::env::var("REDIS_CLOUD_API_SECRET")
220                    .context("REDIS_CLOUD_API_SECRET not set")?;
221
222                CloudClient::builder()
223                    .api_key(api_key)
224                    .api_secret(api_secret)
225                    .build()
226                    .context("Failed to build Cloud client")
227            }
228        }
229    }
230
231    /// Create a new Enterprise client from credentials
232    #[cfg(feature = "enterprise")]
233    async fn create_enterprise_client(&self, profile: Option<&str>) -> Result<EnterpriseClient> {
234        match &self.credential_source {
235            CredentialSource::Profiles(profiles) => {
236                let config = self
237                    .config
238                    .as_ref()
239                    .context("No redisctl config available")?;
240
241                // Use specified profile, first configured profile, or let config resolve default
242                let profile_to_use = profile
243                    .map(|s| s.to_string())
244                    .or_else(|| profiles.first().cloned());
245
246                // Resolve the profile name
247                let resolved_profile_name = config
248                    .resolve_enterprise_profile(profile_to_use.as_deref())
249                    .context("Failed to resolve enterprise profile")?;
250
251                // Get the profile
252                let profile_config = config
253                    .profiles
254                    .get(&resolved_profile_name)
255                    .with_context(|| format!("Profile '{}' not found", resolved_profile_name))?;
256
257                // Get credentials
258                let (url, username, password, insecure, ca_cert) = profile_config
259                    .resolve_enterprise_credentials()
260                    .context("Failed to resolve enterprise credentials")?
261                    .context("No enterprise credentials in profile")?;
262
263                let mut builder = EnterpriseClient::builder()
264                    .base_url(&url)
265                    .username(&username)
266                    .insecure(insecure);
267
268                if let Some(pwd) = password {
269                    builder = builder.password(&pwd);
270                }
271
272                if let Some(cert_path) = ca_cert {
273                    builder = builder.ca_cert(&cert_path);
274                }
275
276                builder.build().context("Failed to build Enterprise client")
277            }
278            CredentialSource::OAuth { .. } => {
279                // In OAuth mode, credentials come from environment variables
280                let url = std::env::var("REDIS_ENTERPRISE_URL")
281                    .context("REDIS_ENTERPRISE_URL not set")?;
282                let username = std::env::var("REDIS_ENTERPRISE_USER")
283                    .context("REDIS_ENTERPRISE_USER not set")?;
284                let password = std::env::var("REDIS_ENTERPRISE_PASSWORD").ok();
285                let insecure = std::env::var("REDIS_ENTERPRISE_INSECURE")
286                    .map(|v| v == "true" || v == "1")
287                    .unwrap_or(false);
288
289                let mut builder = EnterpriseClient::builder()
290                    .base_url(&url)
291                    .username(&username)
292                    .insecure(insecure);
293
294                if let Some(pwd) = password {
295                    builder = builder.password(&pwd);
296                }
297
298                builder.build().context("Failed to build Enterprise client")
299            }
300        }
301    }
302
303    /// Resolve a Redis database URL from a profile
304    ///
305    /// If profile is `None`, uses the first configured profile or default from config
306    #[cfg(feature = "database")]
307    pub fn database_url_for_profile(&self, profile: Option<&str>) -> Result<String> {
308        let config = self
309            .config
310            .as_ref()
311            .context("No redisctl config available")?;
312
313        let profile_to_use = profile
314            .map(|s| s.to_string())
315            .or_else(|| self.profiles.first().cloned());
316
317        let resolved_name = config
318            .resolve_database_profile(profile_to_use.as_deref())
319            .context("Failed to resolve database profile")?;
320
321        let profile_config = config
322            .profiles
323            .get(&resolved_name)
324            .with_context(|| format!("Profile '{}' not found", resolved_name))?;
325
326        let (host, port, password, tls, username, database) = profile_config
327            .resolve_database_credentials()
328            .context("Failed to resolve database credentials")?
329            .context("No database credentials in profile")?;
330
331        // Build Redis URL: redis[s]://[username[:password]@]host:port[/database]
332        let scheme = if tls { "rediss" } else { "redis" };
333        let auth = match (username.as_str(), password) {
334            ("", None) | ("default", None) => String::new(),
335            (user, Some(pass)) => format!(
336                "{}:{}@",
337                urlencoding::encode(user),
338                urlencoding::encode(&pass)
339            ),
340            (user, None) => format!("{}@", urlencoding::encode(user)),
341        };
342        let db_path = if database > 0 {
343            format!("/{}", database)
344        } else {
345            String::new()
346        };
347
348        Ok(format!("{}://{}{}:{}{}", scheme, auth, host, port, db_path))
349    }
350
351    /// Get or create a cached Redis connection for a resolved URL.
352    ///
353    /// Connections are cached by URL. If a cached connection fails a PING
354    /// health check, it is evicted and a fresh connection is created.
355    #[cfg(feature = "database")]
356    pub async fn redis_connection_for_url(
357        &self,
358        url: &str,
359    ) -> Result<redis::aio::MultiplexedConnection> {
360        // Check cache first
361        {
362            let clients = self.clients.read().await;
363            if let Some(conn) = clients.database.get(url) {
364                // Quick health check -- if PING fails the connection is stale
365                let mut test_conn = conn.clone();
366                if redis::cmd("PING")
367                    .query_async::<String>(&mut test_conn)
368                    .await
369                    .is_ok()
370                {
371                    return Ok(conn.clone());
372                }
373                // Fall through to evict + reconnect
374            }
375        }
376
377        // Create new connection (or reconnect after eviction)
378        let client = redis::Client::open(url).context("Failed to create Redis client")?;
379        let conn = client
380            .get_multiplexed_async_connection()
381            .await
382            .context("Failed to connect to Redis")?;
383
384        // Cache it
385        {
386            let mut clients = self.clients.write().await;
387            clients.database.insert(url.to_string(), conn.clone());
388        }
389
390        Ok(conn)
391    }
392
393    /// Check if write operations are allowed by the global policy tier.
394    ///
395    /// Returns `true` for `ReadWrite` and `Full` tiers.
396    /// Used for defense-in-depth in non-destructive write tool handlers.
397    #[allow(dead_code)]
398    pub fn is_write_allowed(&self) -> bool {
399        matches!(
400            self.policy.global_tier(),
401            SafetyTier::ReadWrite | SafetyTier::Full
402        )
403    }
404
405    /// Check if destructive operations are allowed by the global policy tier.
406    ///
407    /// Returns `true` only for `Full` tier.
408    /// Used for defense-in-depth in destructive tool handlers.
409    #[allow(dead_code)]
410    pub fn is_destructive_allowed(&self) -> bool {
411        matches!(self.policy.global_tier(), SafetyTier::Full)
412    }
413
414    /// Store a named command alias (session-scoped, in-memory only).
415    #[cfg(feature = "database")]
416    pub async fn set_alias(&self, name: String, commands: Vec<Vec<String>>) {
417        let mut aliases = self.aliases.write().await;
418        aliases.insert(name, commands);
419    }
420
421    /// Retrieve a named command alias.
422    #[cfg(feature = "database")]
423    pub async fn get_alias(&self, name: &str) -> Option<Vec<Vec<String>>> {
424        let aliases = self.aliases.read().await;
425        aliases.get(name).cloned()
426    }
427
428    /// List all aliases with their command counts.
429    #[cfg(feature = "database")]
430    pub async fn list_aliases(&self) -> Vec<(String, usize)> {
431        let aliases = self.aliases.read().await;
432        let mut entries: Vec<_> = aliases.iter().map(|(k, v)| (k.clone(), v.len())).collect();
433        entries.sort_by(|a, b| a.0.cmp(&b.0));
434        entries
435    }
436
437    /// Delete a named alias. Returns true if it existed.
438    #[cfg(feature = "database")]
439    pub async fn delete_alias(&self, name: &str) -> bool {
440        let mut aliases = self.aliases.write().await;
441        aliases.remove(name).is_some()
442    }
443}
444
445impl Clone for AppState {
446    fn clone(&self) -> Self {
447        // Note: We don't clone the clients cache, each clone gets fresh cache
448        Self {
449            credential_source: self.credential_source.clone(),
450            policy: self.policy.clone(),
451            database_url: self.database_url.clone(),
452            config: self.config.clone(),
453            profiles: self.profiles.clone(),
454            clients: RwLock::new(CachedClients {
455                #[cfg(feature = "cloud")]
456                cloud: HashMap::new(),
457                #[cfg(feature = "enterprise")]
458                enterprise: HashMap::new(),
459                #[cfg(feature = "database")]
460                database: HashMap::new(),
461            }),
462            #[cfg(feature = "database")]
463            aliases: RwLock::new(HashMap::new()),
464        }
465    }
466}
467
468/// Test helpers for creating AppState with pre-configured clients
469#[allow(dead_code)]
470impl AppState {
471    /// Create a default read-only policy for tests
472    pub fn test_policy() -> Arc<Policy> {
473        Arc::new(Policy::new(
474            crate::policy::PolicyConfig::default(),
475            std::collections::HashMap::new(),
476            "test".to_string(),
477        ))
478    }
479
480    /// Create test state with a pre-configured Cloud client
481    #[cfg(feature = "cloud")]
482    pub fn with_cloud_client(client: CloudClient) -> Self {
483        let mut cloud = HashMap::new();
484        cloud.insert("_default".to_string(), client);
485        Self {
486            credential_source: CredentialSource::Profiles(vec![]),
487            policy: Self::test_policy(),
488            database_url: None,
489            config: None,
490            profiles: vec![],
491            clients: RwLock::new(CachedClients {
492                cloud,
493                #[cfg(feature = "enterprise")]
494                enterprise: HashMap::new(),
495                #[cfg(feature = "database")]
496                database: HashMap::new(),
497            }),
498            #[cfg(feature = "database")]
499            aliases: RwLock::new(HashMap::new()),
500        }
501    }
502
503    /// Create test state with a pre-configured Enterprise client
504    #[cfg(feature = "enterprise")]
505    pub fn with_enterprise_client(client: EnterpriseClient) -> Self {
506        let mut enterprise = HashMap::new();
507        enterprise.insert("_default".to_string(), client);
508        Self {
509            credential_source: CredentialSource::Profiles(vec![]),
510            policy: Self::test_policy(),
511            database_url: None,
512            config: None,
513            profiles: vec![],
514            clients: RwLock::new(CachedClients {
515                #[cfg(feature = "cloud")]
516                cloud: HashMap::new(),
517                enterprise,
518                #[cfg(feature = "database")]
519                database: HashMap::new(),
520            }),
521            #[cfg(feature = "database")]
522            aliases: RwLock::new(HashMap::new()),
523        }
524    }
525
526    /// Create test state with both Cloud and Enterprise clients
527    #[cfg(all(feature = "cloud", feature = "enterprise"))]
528    pub fn with_clients(cloud: CloudClient, enterprise: EnterpriseClient) -> Self {
529        let mut cloud_map = HashMap::new();
530        cloud_map.insert("_default".to_string(), cloud);
531        let mut enterprise_map = HashMap::new();
532        enterprise_map.insert("_default".to_string(), enterprise);
533        Self {
534            credential_source: CredentialSource::Profiles(vec![]),
535            policy: Self::test_policy(),
536            database_url: None,
537            config: None,
538            profiles: vec![],
539            clients: RwLock::new(CachedClients {
540                cloud: cloud_map,
541                enterprise: enterprise_map,
542                #[cfg(feature = "database")]
543                database: HashMap::new(),
544            }),
545            #[cfg(feature = "database")]
546            aliases: RwLock::new(HashMap::new()),
547        }
548    }
549}