generic-db-observer 0.3.1

A generic observer that monitors a database for changes and triggers a change in some Event Subject.
Documentation
//! This module allows us to set a connection to Redis as a Listener.

pub use redis::RedisError;

use std::collections::HashSet;

use crate::observer::MsgListener;
use futures_util::StreamExt;
use redis::{Commands, Msg, ToRedisArgs};
use secrecy::SecretString;
use std::fmt;
use tracing::{debug, info, instrument};

/// Configuration to access and authenticate to the Redis message broker.
#[derive(Clone, Debug)]
pub struct RedisConfig {
    /// Host under which the Redis install can be accessed.
    pub redis_host: String,
    /// Password to authenticate with redis.
    pub redis_password: SecretString,
    /// The name of the set that contains the data that is being tracked.
    pub redis_set_name: String,
}

/// Represents a connection to Redis. We use `PubSub` to receive information about whether there are updates to the monitored set, but then have to get the members of the set using a SMEMBERS command.
/// Since keeping the `PubSub` channel open locks the redis Connection, we need two separate connections, one to listen to `PubSub` notifications, and one to make requests to Redis.
#[allow(missing_debug_implementations)]
pub struct RedisConnection {
    /// Connection used to make requests to Redis.
    pub request_connection: redis::Connection,
    /// Connection used to listen to PubSub notifications.
    pub pubsub_connection: redis::aio::PubSub,
}

/// The main Redis message broker structure.
pub struct Redis {
    /// Connection to the Redis Database.
    pub connection: RedisConnection,
    /// Configuration to authorize and reach the Redis data.
    pub config: RedisConfig,
}

impl RedisConfig {
    /// Encode the special characters within a password to be valid URL characters
    fn url_encode(s: impl AsRef<str>) -> String {
        url::form_urlencoded::byte_serialize(s.as_ref().as_bytes()).collect::<String>()
    }

    #[must_use]
    /// Create a redis address with authorization info, based on the configuration information.
    pub fn redis_auth_address(&self) -> SecretString {
        use secrecy::ExposeSecret;
        let encoded_password = Self::url_encode(self.redis_password.expose_secret().clone());
        let auth_address = format!("redis://:{encoded_password}@{}", self.redis_host);

        SecretString::new(auth_address)
    }
}

impl MsgListener for Redis {
    type Config = RedisConfig;
    type Connection = RedisConnection;
    type Message = redis::Msg;
    type Error = redis::RedisError;

    #[instrument]
    async fn connect(config: &Self::Config) -> Result<Self, Self::Error> {
        use secrecy::ExposeSecret;
        let client = redis::Client::open(config.redis_auth_address().expose_secret().clone())?;
        let request_connection = client.get_connection()?;
        info!("Redis Request connection established");

        let pubsub_client =
            redis::Client::open(config.redis_auth_address().expose_secret().clone())?;
        let pubsub_connection = pubsub_client.get_async_connection().await?.into_pubsub();
        info!("Redis PubSub connection established");

        let connection = RedisConnection {
            request_connection,
            pubsub_connection,
        };

        Ok(Self {
            config: config.clone(),
            connection,
        })
    }

    /// Subscribe to the appropriate Redis PubSub channel.
    async fn subscribe_to_notifications(&mut self) -> Result<(), Self::Error> {
        let pubsub_channel_name = format!("__keyspace@0__:{}", self.config.redis_set_name);
        self.connection
            .pubsub_connection
            .subscribe(pubsub_channel_name)
            .await
    }

    /// This function subscribes to the appropriate PubSub channel and then waits until
    async fn notification_stream(&mut self) -> Option<Msg> {
        self.connection.pubsub_connection.on_message().next().await
    }

    /// Get all members of the referenced Redis set.
    #[allow(clippy::cognitive_complexity)]
    async fn get_current_state(&mut self) -> Result<HashSet<String>, <Self as MsgListener>::Error> {
        let set_name = self.config.redis_set_name.clone();

        let set_members = redis::cmd("SMEMBERS")
            .arg(set_name)
            .query::<HashSet<String>>(&mut self.connection.request_connection)
            .unwrap_or_default();

        debug!("Set members retrieved from Redis: {set_members:?}");
        Ok(set_members)
    }
}

impl fmt::Debug for Redis {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Redis")
            .field("config", &self.config)
            .finish_non_exhaustive()
    }
}

impl Redis {
    /// Deletes a key from Redis.
    pub fn del(&mut self, key: &str) -> redis::RedisResult<()> {
        self.connection.request_connection.del(key)
    }

    /// Adds a member to a set.
    pub fn sadd<S: ToRedisArgs, T: ToRedisArgs>(
        &mut self,
        key: S,
        member: T,
    ) -> redis::RedisResult<()> {
        self.connection.request_connection.sadd(key, member)
    }

    /// Removes a member from a set
    pub fn srem<S: ToRedisArgs, T: ToRedisArgs>(
        &mut self,
        key: S,
        member: T,
    ) -> redis::RedisResult<()> {
        self.connection.request_connection.srem(key, member)
    }
}

#[cfg(test)]
mod tests {
    use kube::{Client, Config};
    use kube_discovery::LabelSelector;
    use redis::Commands;

    use super::*;
    use std::time::Duration;

    const REDIS_LABELS: &str = "app=redis-twitch-observer,environment=ci";
    const REDIS_PASSWORD_ENV_VAR: &str = "REDIS_PASSWORD";
    const REDIS_TESTING_SET: &str = "tracked_channels";

    #[allow(clippy::significant_drop_tightening)] // false positive
    /// Loads the Redis connection configuration from the Kubernetes cluster.
    async fn load_redis_conn_config() -> Result<(String, SecretString), Box<dyn std::error::Error>>
    {
        let kube_config = Config::infer().await.unwrap();

        let kube_client = Client::try_from(kube_config.clone())?;

        let password = LabelSelector(REDIS_LABELS)
            .load_secret_value_through_workload(&kube_client, REDIS_PASSWORD_ENV_VAR)
            .await?;

        // Host is of the format "{cluster_ip}::{nodeport}"
        let redis_host = LabelSelector(REDIS_LABELS)
            .load_service_host(&kube_config, &kube_client)
            .await?;

        Ok((redis_host, password))
    }

    #[rstest::rstest]
    #[test_log::test(tokio::test)]
    #[timeout(Duration::from_secs(15))]
    async fn receive_notification() {
        let (redis_host, redis_password) = load_redis_conn_config().await.unwrap();
        let redis_config = RedisConfig {
            redis_host,
            redis_password,
            redis_set_name: REDIS_TESTING_SET.to_string(),
        };
        let mut redis_broker = Redis::connect(&redis_config).await.unwrap();
        info!("Established Redis Connection");

        // A notification is only actually received if the added data is new, so we clear the set first
        let _: () = redis_broker
            .connection
            .request_connection
            .del(REDIS_TESTING_SET)
            .unwrap();

        redis_broker.subscribe_to_notifications().await.unwrap();
        info!("Subscribed to notifications");

        // Add an element to the redis set
        let _: () = redis_broker
            .connection
            .request_connection
            .sadd(REDIS_TESTING_SET, "testdata")
            .unwrap();
        info!("Added sample data to set");

        // Test if a notification is received
        redis_broker.notification_stream().await;
        info!("Notification received");
    }
}