pooly 0.2.1

A protobuf to Postgres adapter + connection pooling middleware.
Documentation
use std::collections::HashSet;

use serde::{Deserialize, Serialize};

use crate::models::utils::wildcards::WildcardPattern;
use crate::models::versioning::updatable::{StringSetCommand, Updatable, WildcardPatternSetCommand};

pub trait ConnectionIdAccessEntry {

    fn is_allowed(&self,
                  client_id: &str,
                  connection_id: &str) -> bool;

}

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct LiteralConnectionIdAccessEntry {

    client_id: String,
    connection_ids: HashSet<String>

}

impl LiteralConnectionIdAccessEntry {

    pub fn new(client_id: &str,
               connection_ids: HashSet<String>) -> LiteralConnectionIdAccessEntry {
        LiteralConnectionIdAccessEntry {
            client_id: client_id.into(),
            connection_ids
        }
    }

    pub fn one(client_id: &str,
               connection_id: &str) -> LiteralConnectionIdAccessEntry {
        let mut connection_ids = HashSet::new();

        connection_ids.insert(connection_id.into());

        LiteralConnectionIdAccessEntry::new(client_id, connection_ids)
    }

}

impl ConnectionIdAccessEntry for LiteralConnectionIdAccessEntry {
    fn is_allowed(&self, client_id: &str, connection_id: &str) -> bool {
        if !client_id.eq(&self.client_id) {
            return false;
        }

        !self.connection_ids.is_empty() && self.connection_ids.contains(connection_id)
    }
}

impl Updatable<StringSetCommand> for LiteralConnectionIdAccessEntry {
    fn get_id(&self) -> &str {
        &self.client_id
    }

    fn accept(&self, update: StringSetCommand) -> Self {
        LiteralConnectionIdAccessEntry {
            client_id: self.client_id.clone(),
            connection_ids: update.apply(&self.connection_ids)
        }
    }
}

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct WildcardPatternConnectionIdAccessEntry {

    client_id: String,
    patterns: HashSet<WildcardPattern>

}

impl WildcardPatternConnectionIdAccessEntry {

    pub fn new(client_id: &str,
               patterns: HashSet<WildcardPattern>) -> WildcardPatternConnectionIdAccessEntry {
        WildcardPatternConnectionIdAccessEntry {
            client_id: client_id.into(),
            patterns
        }
    }

}

impl ConnectionIdAccessEntry for WildcardPatternConnectionIdAccessEntry {
    fn is_allowed(&self, client_id: &str, connection_id: &str) -> bool {
        if !client_id.eq(&self.client_id) {
            return false;
        }

        for pattern in &self.patterns {
            if pattern.matches(connection_id) {
                return true;
            }
        }

        false
    }
}

impl Updatable<WildcardPatternSetCommand> for WildcardPatternConnectionIdAccessEntry {
    fn get_id(&self) -> &str {
        &self.client_id
    }

    fn accept(&self, update: WildcardPatternSetCommand) -> Self {
        WildcardPatternConnectionIdAccessEntry {
            client_id: self.client_id.clone(),
            patterns: update.apply(&self.patterns)
        }
    }
}


#[cfg(test)]
mod tests {
    use std::collections::HashSet;

    use crate::models::auth::access::{ConnectionIdAccessEntry, LiteralConnectionIdAccessEntry, WildcardPatternConnectionIdAccessEntry};
    use crate::models::utils::wildcards::WildcardPattern;
    use crate::models::versioning::versioned::Versioned;

    const CLIENT_ID: &str = "client-id-1";
    const NOT_CLIENT_ID: &str = "not-client-id-1";

    const FIRST_CONNECTION_ID: &str = "first-connection-id";
    const SECOND_CONNECTION_ID: &str = "second-connection-id";
    const THIRD_CONNECTION_ID: &str = "third___connection___id";
    const FOURTH_CONNECTION_ID: &str = "fourth-connection-id";

    const FIRST_CONN_ID: &str = "first_conn_id";

    #[test]
    fn test_matches_correctly_exact() {
        let mut should_match = get_should_match();

        let ace = LiteralConnectionIdAccessEntry::new(
            CLIENT_ID,
            should_match.clone());

        for connection_id in &should_match {
            assert!(ace.is_allowed(CLIENT_ID, &connection_id));
        }

        assert!(should_match.remove(FIRST_CONNECTION_ID));

        let ace = LiteralConnectionIdAccessEntry::new(
            CLIENT_ID, should_match.clone());

        for connection_id in should_match {
            assert_eq!(ace.is_allowed(CLIENT_ID, &connection_id),
                       !connection_id.eq(FIRST_CONNECTION_ID));
        }
    }

    #[test]
    fn test_matches_by_patterns() {
        let mut patterns = HashSet::new();

        patterns.insert(WildcardPattern::parse("*connection-id").unwrap());
        patterns.insert(WildcardPattern::parse("*connection*").unwrap());
        patterns.insert(WildcardPattern::parse("first*").unwrap());

        let ace = WildcardPatternConnectionIdAccessEntry::new(
            CLIENT_ID, patterns);

        for connection_id in get_should_match() {
            assert!(ace.is_allowed(CLIENT_ID, &connection_id));
        }
    }

    #[test]
    fn test_does_not_match_on_client_id_mismatch() {
        let mut patterns = HashSet::new();

        patterns.insert(WildcardPattern::parse("*connection-id").unwrap());
        patterns.insert(WildcardPattern::parse("*connection*").unwrap());
        patterns.insert(WildcardPattern::parse("first*").unwrap());


        let should_match = get_should_match();

        let wildcard_ace = WildcardPatternConnectionIdAccessEntry::new(
            NOT_CLIENT_ID, patterns);
        let literal_ace = LiteralConnectionIdAccessEntry::new(
            NOT_CLIENT_ID, should_match.clone());

        for connection_id in &should_match {
            assert!(!wildcard_ace.is_allowed(CLIENT_ID, &connection_id));
            assert!(!literal_ace.is_allowed(CLIENT_ID, &connection_id));
        }
    }

    #[test]
    fn empty_never_matches() {
        let wildcard_ace = WildcardPatternConnectionIdAccessEntry::new(
            CLIENT_ID, HashSet::new());

        let literal_ace = LiteralConnectionIdAccessEntry::new(
            CLIENT_ID,
            HashSet::new());

        for connection_id in get_should_match() {
            assert!(!wildcard_ace.is_allowed(CLIENT_ID, &connection_id));
            assert!(!literal_ace.is_allowed(CLIENT_ID, &connection_id));
        }
    }

    fn get_should_match() -> HashSet<String> {
        let mut response = HashSet::new();

        response.insert(FIRST_CONNECTION_ID.to_string());
        response.insert(SECOND_CONNECTION_ID.to_string());
        response.insert(THIRD_CONNECTION_ID.to_string());
        response.insert(FOURTH_CONNECTION_ID.to_string());
        response.insert(FIRST_CONN_ID.to_string());

        response
    }

}