use std::collections::HashMap;
use std::sync::Arc;
#[cfg(any(feature = "cloud", feature = "enterprise", feature = "database"))]
use anyhow::Context;
use anyhow::Result;
#[cfg(feature = "cloud")]
use redis_cloud::CloudClient;
#[cfg(feature = "enterprise")]
use redis_enterprise::EnterpriseClient;
use redisctl_core::Config;
use tokio::sync::RwLock;
use crate::policy::{Policy, SafetyTier};
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum CredentialSource {
Profiles(Vec<String>),
OAuth {
issuer: Option<String>,
audience: Option<String>,
},
}
pub struct CachedClients {
#[cfg(feature = "cloud")]
pub cloud: HashMap<String, CloudClient>,
#[cfg(feature = "enterprise")]
pub enterprise: HashMap<String, EnterpriseClient>,
#[cfg(feature = "database")]
pub database: HashMap<String, redis::aio::MultiplexedConnection>,
}
pub struct AppState {
pub credential_source: CredentialSource,
pub policy: Arc<Policy>,
pub database_url: Option<String>,
config: Option<Config>,
profiles: Vec<String>,
#[allow(dead_code)]
clients: RwLock<CachedClients>,
#[cfg(feature = "database")]
aliases: RwLock<HashMap<String, Vec<Vec<String>>>>,
}
impl AppState {
pub fn new(
credential_source: CredentialSource,
policy: Arc<Policy>,
database_url: Option<String>,
) -> Result<Self> {
let profiles = match &credential_source {
CredentialSource::Profiles(p) => p.clone(),
CredentialSource::OAuth { .. } => vec![],
};
let config = match &credential_source {
CredentialSource::Profiles(_) => Config::load().ok(),
CredentialSource::OAuth { .. } => None,
};
Ok(Self {
credential_source,
policy,
database_url,
config,
profiles,
clients: RwLock::new(CachedClients {
#[cfg(feature = "cloud")]
cloud: HashMap::new(),
#[cfg(feature = "enterprise")]
enterprise: HashMap::new(),
#[cfg(feature = "database")]
database: HashMap::new(),
}),
#[cfg(feature = "database")]
aliases: RwLock::new(HashMap::new()),
})
}
#[allow(dead_code)]
pub fn available_profiles(&self) -> &[String] {
&self.profiles
}
#[cfg(feature = "cloud")]
pub async fn cloud_client_for_profile(&self, profile: Option<&str>) -> Result<CloudClient> {
let cache_key = profile.unwrap_or("_default").to_string();
{
let clients = self.clients.read().await;
if let Some(client) = clients.cloud.get(&cache_key) {
return Ok(client.clone());
}
}
let client = self.create_cloud_client(profile).await?;
{
let mut clients = self.clients.write().await;
clients.cloud.insert(cache_key, client.clone());
}
Ok(client)
}
#[cfg(feature = "cloud")]
#[allow(dead_code)]
pub async fn cloud_client(&self) -> Result<CloudClient> {
self.cloud_client_for_profile(None).await
}
#[cfg(feature = "enterprise")]
pub async fn enterprise_client_for_profile(
&self,
profile: Option<&str>,
) -> Result<EnterpriseClient> {
let cache_key = profile.unwrap_or("_default").to_string();
{
let clients = self.clients.read().await;
if let Some(client) = clients.enterprise.get(&cache_key) {
return Ok(client.clone());
}
}
let client = self.create_enterprise_client(profile).await?;
{
let mut clients = self.clients.write().await;
clients.enterprise.insert(cache_key, client.clone());
}
Ok(client)
}
#[cfg(feature = "enterprise")]
#[allow(dead_code)]
pub async fn enterprise_client(&self) -> Result<EnterpriseClient> {
self.enterprise_client_for_profile(None).await
}
#[cfg(feature = "cloud")]
async fn create_cloud_client(&self, profile: Option<&str>) -> Result<CloudClient> {
match &self.credential_source {
CredentialSource::Profiles(profiles) => {
let config = self
.config
.as_ref()
.context("No redisctl config available")?;
let profile_to_use = profile
.map(|s| s.to_string())
.or_else(|| profiles.first().cloned());
let resolved_profile_name = config
.resolve_cloud_profile(profile_to_use.as_deref())
.context("Failed to resolve cloud profile")?;
let profile = config
.profiles
.get(&resolved_profile_name)
.with_context(|| format!("Profile '{}' not found", resolved_profile_name))?;
let (api_key, api_secret, _base_url) = profile
.resolve_cloud_credentials()
.context("Failed to resolve cloud credentials")?
.context("No cloud credentials in profile")?;
CloudClient::builder()
.api_key(api_key)
.api_secret(api_secret)
.build()
.context("Failed to build Cloud client")
}
CredentialSource::OAuth { .. } => {
let api_key =
std::env::var("REDIS_CLOUD_API_KEY").context("REDIS_CLOUD_API_KEY not set")?;
let api_secret = std::env::var("REDIS_CLOUD_API_SECRET")
.context("REDIS_CLOUD_API_SECRET not set")?;
CloudClient::builder()
.api_key(api_key)
.api_secret(api_secret)
.build()
.context("Failed to build Cloud client")
}
}
}
#[cfg(feature = "enterprise")]
async fn create_enterprise_client(&self, profile: Option<&str>) -> Result<EnterpriseClient> {
match &self.credential_source {
CredentialSource::Profiles(profiles) => {
let config = self
.config
.as_ref()
.context("No redisctl config available")?;
let profile_to_use = profile
.map(|s| s.to_string())
.or_else(|| profiles.first().cloned());
let resolved_profile_name = config
.resolve_enterprise_profile(profile_to_use.as_deref())
.context("Failed to resolve enterprise profile")?;
let profile_config = config
.profiles
.get(&resolved_profile_name)
.with_context(|| format!("Profile '{}' not found", resolved_profile_name))?;
let (url, username, password, insecure, ca_cert) = profile_config
.resolve_enterprise_credentials()
.context("Failed to resolve enterprise credentials")?
.context("No enterprise credentials in profile")?;
let mut builder = EnterpriseClient::builder()
.base_url(&url)
.username(&username)
.insecure(insecure);
if let Some(pwd) = password {
builder = builder.password(&pwd);
}
if let Some(cert_path) = ca_cert {
builder = builder.ca_cert(&cert_path);
}
builder.build().context("Failed to build Enterprise client")
}
CredentialSource::OAuth { .. } => {
let url = std::env::var("REDIS_ENTERPRISE_URL")
.context("REDIS_ENTERPRISE_URL not set")?;
let username = std::env::var("REDIS_ENTERPRISE_USER")
.context("REDIS_ENTERPRISE_USER not set")?;
let password = std::env::var("REDIS_ENTERPRISE_PASSWORD").ok();
let insecure = std::env::var("REDIS_ENTERPRISE_INSECURE")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
let mut builder = EnterpriseClient::builder()
.base_url(&url)
.username(&username)
.insecure(insecure);
if let Some(pwd) = password {
builder = builder.password(&pwd);
}
builder.build().context("Failed to build Enterprise client")
}
}
}
#[cfg(feature = "database")]
pub fn database_url_for_profile(&self, profile: Option<&str>) -> Result<String> {
let config = self
.config
.as_ref()
.context("No redisctl config available")?;
let profile_to_use = profile
.map(|s| s.to_string())
.or_else(|| self.profiles.first().cloned());
let resolved_name = config
.resolve_database_profile(profile_to_use.as_deref())
.context("Failed to resolve database profile")?;
let profile_config = config
.profiles
.get(&resolved_name)
.with_context(|| format!("Profile '{}' not found", resolved_name))?;
let (host, port, password, tls, username, database) = profile_config
.resolve_database_credentials()
.context("Failed to resolve database credentials")?
.context("No database credentials in profile")?;
let scheme = if tls { "rediss" } else { "redis" };
let auth = match (username.as_str(), password) {
("", None) | ("default", None) => String::new(),
(user, Some(pass)) => format!(
"{}:{}@",
urlencoding::encode(user),
urlencoding::encode(&pass)
),
(user, None) => format!("{}@", urlencoding::encode(user)),
};
let db_path = if database > 0 {
format!("/{}", database)
} else {
String::new()
};
Ok(format!("{}://{}{}:{}{}", scheme, auth, host, port, db_path))
}
#[cfg(feature = "database")]
pub async fn redis_connection_for_url(
&self,
url: &str,
) -> Result<redis::aio::MultiplexedConnection> {
{
let clients = self.clients.read().await;
if let Some(conn) = clients.database.get(url) {
let mut test_conn = conn.clone();
if redis::cmd("PING")
.query_async::<String>(&mut test_conn)
.await
.is_ok()
{
return Ok(conn.clone());
}
}
}
let client = redis::Client::open(url).context("Failed to create Redis client")?;
let conn = client
.get_multiplexed_async_connection()
.await
.context("Failed to connect to Redis")?;
{
let mut clients = self.clients.write().await;
clients.database.insert(url.to_string(), conn.clone());
}
Ok(conn)
}
#[allow(dead_code)]
pub fn is_write_allowed(&self) -> bool {
matches!(
self.policy.global_tier(),
SafetyTier::ReadWrite | SafetyTier::Full
)
}
#[allow(dead_code)]
pub fn is_destructive_allowed(&self) -> bool {
matches!(self.policy.global_tier(), SafetyTier::Full)
}
#[cfg(feature = "database")]
pub async fn set_alias(&self, name: String, commands: Vec<Vec<String>>) {
let mut aliases = self.aliases.write().await;
aliases.insert(name, commands);
}
#[cfg(feature = "database")]
pub async fn get_alias(&self, name: &str) -> Option<Vec<Vec<String>>> {
let aliases = self.aliases.read().await;
aliases.get(name).cloned()
}
#[cfg(feature = "database")]
pub async fn list_aliases(&self) -> Vec<(String, usize)> {
let aliases = self.aliases.read().await;
let mut entries: Vec<_> = aliases.iter().map(|(k, v)| (k.clone(), v.len())).collect();
entries.sort_by(|a, b| a.0.cmp(&b.0));
entries
}
#[cfg(feature = "database")]
pub async fn delete_alias(&self, name: &str) -> bool {
let mut aliases = self.aliases.write().await;
aliases.remove(name).is_some()
}
}
impl Clone for AppState {
fn clone(&self) -> Self {
Self {
credential_source: self.credential_source.clone(),
policy: self.policy.clone(),
database_url: self.database_url.clone(),
config: self.config.clone(),
profiles: self.profiles.clone(),
clients: RwLock::new(CachedClients {
#[cfg(feature = "cloud")]
cloud: HashMap::new(),
#[cfg(feature = "enterprise")]
enterprise: HashMap::new(),
#[cfg(feature = "database")]
database: HashMap::new(),
}),
#[cfg(feature = "database")]
aliases: RwLock::new(HashMap::new()),
}
}
}
#[allow(dead_code)]
impl AppState {
pub fn test_policy() -> Arc<Policy> {
Arc::new(Policy::new(
crate::policy::PolicyConfig::default(),
std::collections::HashMap::new(),
"test".to_string(),
))
}
#[cfg(feature = "cloud")]
pub fn with_cloud_client(client: CloudClient) -> Self {
let mut cloud = HashMap::new();
cloud.insert("_default".to_string(), client);
Self {
credential_source: CredentialSource::Profiles(vec![]),
policy: Self::test_policy(),
database_url: None,
config: None,
profiles: vec![],
clients: RwLock::new(CachedClients {
cloud,
#[cfg(feature = "enterprise")]
enterprise: HashMap::new(),
#[cfg(feature = "database")]
database: HashMap::new(),
}),
#[cfg(feature = "database")]
aliases: RwLock::new(HashMap::new()),
}
}
#[cfg(feature = "enterprise")]
pub fn with_enterprise_client(client: EnterpriseClient) -> Self {
let mut enterprise = HashMap::new();
enterprise.insert("_default".to_string(), client);
Self {
credential_source: CredentialSource::Profiles(vec![]),
policy: Self::test_policy(),
database_url: None,
config: None,
profiles: vec![],
clients: RwLock::new(CachedClients {
#[cfg(feature = "cloud")]
cloud: HashMap::new(),
enterprise,
#[cfg(feature = "database")]
database: HashMap::new(),
}),
#[cfg(feature = "database")]
aliases: RwLock::new(HashMap::new()),
}
}
#[cfg(all(feature = "cloud", feature = "enterprise"))]
pub fn with_clients(cloud: CloudClient, enterprise: EnterpriseClient) -> Self {
let mut cloud_map = HashMap::new();
cloud_map.insert("_default".to_string(), cloud);
let mut enterprise_map = HashMap::new();
enterprise_map.insert("_default".to_string(), enterprise);
Self {
credential_source: CredentialSource::Profiles(vec![]),
policy: Self::test_policy(),
database_url: None,
config: None,
profiles: vec![],
clients: RwLock::new(CachedClients {
cloud: cloud_map,
enterprise: enterprise_map,
#[cfg(feature = "database")]
database: HashMap::new(),
}),
#[cfg(feature = "database")]
aliases: RwLock::new(HashMap::new()),
}
}
}