pub mod error;
pub mod state;
pub mod tools;
pub use error::McpError;
pub use state::{AppState, CredentialSource};
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_credential_source_profile() {
let source = CredentialSource::Profile(Some("test".to_string()));
match source {
CredentialSource::Profile(Some(name)) => assert_eq!(name, "test"),
_ => panic!("Expected Profile variant"),
}
}
#[test]
fn test_credential_source_oauth() {
let source = CredentialSource::OAuth {
issuer: Some("https://example.com".to_string()),
audience: Some("my-api".to_string()),
};
match source {
CredentialSource::OAuth { issuer, audience } => {
assert_eq!(issuer, Some("https://example.com".to_string()));
assert_eq!(audience, Some("my-api".to_string()));
}
_ => panic!("Expected OAuth variant"),
}
}
#[test]
fn test_app_state_read_only() {
let state = AppState::new(
CredentialSource::Profile(None),
true, None,
)
.unwrap();
assert!(!state.is_write_allowed());
}
#[test]
fn test_app_state_write_allowed() {
let state = AppState::new(
CredentialSource::Profile(None),
false, None,
)
.unwrap();
assert!(state.is_write_allowed());
}
#[test]
fn test_app_state_database_url() {
let state = AppState::new(
CredentialSource::Profile(None),
true,
Some("redis://localhost:6379".to_string()),
)
.unwrap();
assert_eq!(
state.database_url,
Some("redis://localhost:6379".to_string())
);
}
#[test]
fn test_cloud_tools_build() {
let state = Arc::new(AppState::new(CredentialSource::Profile(None), true, None).unwrap());
let _ = tools::cloud::list_subscriptions(state.clone());
let _ = tools::cloud::get_subscription(state.clone());
let _ = tools::cloud::list_databases(state.clone());
let _ = tools::cloud::get_database(state.clone());
let _ = tools::cloud::get_backup_status(state.clone());
let _ = tools::cloud::get_slow_log(state.clone());
let _ = tools::cloud::get_tags(state.clone());
let _ = tools::cloud::get_account(state.clone());
let _ = tools::cloud::get_regions(state.clone());
let _ = tools::cloud::get_modules(state.clone());
let _ = tools::cloud::list_account_users(state.clone());
let _ = tools::cloud::list_acl_users(state.clone());
let _ = tools::cloud::list_acl_roles(state.clone());
let _ = tools::cloud::list_redis_rules(state.clone());
let _ = tools::cloud::list_tasks(state.clone());
let _ = tools::cloud::get_task(state.clone());
}
#[test]
fn test_enterprise_tools_build() {
let state = Arc::new(AppState::new(CredentialSource::Profile(None), true, None).unwrap());
let _ = tools::enterprise::get_cluster(state.clone());
let _ = tools::enterprise::get_cluster_stats(state.clone());
let _ = tools::enterprise::list_databases(state.clone());
let _ = tools::enterprise::get_database(state.clone());
let _ = tools::enterprise::get_database_stats(state.clone());
let _ = tools::enterprise::get_database_endpoints(state.clone());
let _ = tools::enterprise::list_database_alerts(state.clone());
let _ = tools::enterprise::list_nodes(state.clone());
let _ = tools::enterprise::get_node(state.clone());
let _ = tools::enterprise::get_node_stats(state.clone());
let _ = tools::enterprise::list_users(state.clone());
let _ = tools::enterprise::get_user(state.clone());
let _ = tools::enterprise::list_alerts(state.clone());
let _ = tools::enterprise::list_shards(state.clone());
}
#[test]
fn test_redis_tools_build() {
let state = Arc::new(AppState::new(CredentialSource::Profile(None), true, None).unwrap());
let _ = tools::redis::ping(state.clone());
let _ = tools::redis::info(state.clone());
let _ = tools::redis::dbsize(state.clone());
let _ = tools::redis::client_list(state.clone());
let _ = tools::redis::cluster_info(state.clone());
let _ = tools::redis::keys(state.clone());
let _ = tools::redis::get(state.clone());
let _ = tools::redis::key_type(state.clone());
let _ = tools::redis::ttl(state.clone());
let _ = tools::redis::exists(state.clone());
let _ = tools::redis::memory_usage(state.clone());
let _ = tools::redis::hgetall(state.clone());
let _ = tools::redis::lrange(state.clone());
let _ = tools::redis::smembers(state.clone());
let _ = tools::redis::zrange(state.clone());
}
#[test]
fn test_cloud_input_deserialization() {
let input: tools::cloud::ListSubscriptionsInput = serde_json::from_str("{}").unwrap();
let _ = input;
let input: tools::cloud::GetSubscriptionInput =
serde_json::from_str(r#"{"subscription_id": 123}"#).unwrap();
assert_eq!(input.subscription_id, 123);
let input: tools::cloud::ListDatabasesInput =
serde_json::from_str(r#"{"subscription_id": 456}"#).unwrap();
assert_eq!(input.subscription_id, 456);
let input: tools::cloud::GetDatabaseInput =
serde_json::from_str(r#"{"subscription_id": 789, "database_id": 101}"#).unwrap();
assert_eq!(input.subscription_id, 789);
assert_eq!(input.database_id, 101);
}
#[test]
fn test_enterprise_input_deserialization() {
let input: tools::enterprise::GetClusterInput = serde_json::from_str("{}").unwrap();
let _ = input;
let input: tools::enterprise::ListDatabasesInput =
serde_json::from_str(r#"{"name_filter": "test"}"#).unwrap();
assert_eq!(input.name_filter, Some("test".to_string()));
let input: tools::enterprise::ListDatabasesInput = serde_json::from_str("{}").unwrap();
assert_eq!(input.name_filter, None);
let input: tools::enterprise::GetDatabaseInput =
serde_json::from_str(r#"{"uid": 42}"#).unwrap();
assert_eq!(input.uid, 42);
let input: tools::enterprise::ListNodesInput = serde_json::from_str("{}").unwrap();
let _ = input;
}
#[test]
fn test_redis_input_deserialization() {
let input: tools::redis::PingInput =
serde_json::from_str(r#"{"url": "redis://localhost:6379"}"#).unwrap();
assert_eq!(input.url, Some("redis://localhost:6379".to_string()));
let input: tools::redis::PingInput = serde_json::from_str("{}").unwrap();
assert_eq!(input.url, None);
let input: tools::redis::InfoInput =
serde_json::from_str(r#"{"section": "memory"}"#).unwrap();
assert_eq!(input.section, Some("memory".to_string()));
let input: tools::redis::KeysInput =
serde_json::from_str(r#"{"pattern": "user:*", "limit": 50}"#).unwrap();
assert_eq!(input.pattern, "user:*");
assert_eq!(input.limit, 50);
let input: tools::redis::KeysInput = serde_json::from_str("{}").unwrap();
assert_eq!(input.pattern, "*");
assert_eq!(input.limit, 100);
}
#[test]
fn test_mcp_error_from_anyhow() {
let anyhow_err = anyhow::anyhow!("test error");
let mcp_err: McpError = anyhow_err.into();
assert!(matches!(mcp_err, McpError::ToolExecution(_)));
assert!(mcp_err.to_string().contains("test error"));
}
#[test]
fn test_mcp_error_variants() {
let err = McpError::Configuration("config issue".to_string());
assert!(err.to_string().contains("config issue"));
let err = McpError::CloudApi("cloud issue".to_string());
assert!(err.to_string().contains("cloud issue"));
let err = McpError::EnterpriseApi("enterprise issue".to_string());
assert!(err.to_string().contains("enterprise issue"));
let err = McpError::ReadOnlyMode;
assert!(err.to_string().contains("read-only"));
}
}