1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use crate::cassandra::error::*;
use crate::cassandra::util::{Protected, ProtectedInner};

use crate::cassandra_sys::cass_ssl_add_trusted_cert_n;
use crate::cassandra_sys::cass_ssl_free;
use crate::cassandra_sys::cass_ssl_new;
use crate::cassandra_sys::cass_ssl_set_cert_n;
#[cfg(feature = "early_access_min_tls_version")]
use crate::cassandra_sys::cass_ssl_set_min_protocol_version;
use crate::cassandra_sys::cass_ssl_set_private_key_n;
use crate::cassandra_sys::cass_ssl_set_verify_flags;
use crate::cassandra_sys::CassSsl as _Ssl;
#[cfg(feature = "early_access_min_tls_version")]
pub use crate::cassandra_sys::CassSslTlsVersion as SslTlsVersion;
use crate::cassandra_sys::CassSslVerifyFlags;

use std::os::raw::c_char;

/// The individual SSL verification levels.
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
#[allow(missing_docs)] // Meanings are defined in CQL documentation.
#[allow(non_camel_case_types)] // Names are traditional.
pub enum SslVerifyFlag {
    NONE,
    PEER_CERT,
    PEER_IDENTITY,
    PEER_IDENTITY_DNS,
}

enhance_nullary_enum!(SslVerifyFlag, CassSslVerifyFlags, {
    (NONE, CASS_SSL_VERIFY_NONE, "NONE"),
    (PEER_CERT, CASS_SSL_VERIFY_PEER_CERT, "PEER_CERT"),
    (PEER_IDENTITY, CASS_SSL_VERIFY_PEER_IDENTITY, "PEER_IDENTITY"),
    (PEER_IDENTITY_DNS, CASS_SSL_VERIFY_PEER_IDENTITY_DNS, "PEER_IDENTITY_DNS"),
});

fn to_bitset(flags: &[SslVerifyFlag]) -> i32 {
    let mut res = 0;
    for f in flags.iter() {
        res |= f.inner() as u32;
    }
    res as i32
}

/// Describes the SSL configuration of a cluster.
#[derive(Debug)]
pub struct Ssl(*mut _Ssl);

// The underlying C type has no thread-local state, and forbids only concurrent
// mutation/free: https://datastax.github.io/cpp-driver/topics/#thread-safety
unsafe impl Send for Ssl {}
unsafe impl Sync for Ssl {}

impl ProtectedInner<*mut _Ssl> for Ssl {
    fn inner(&self) -> *mut _Ssl {
        self.0
    }
}

impl Protected<*mut _Ssl> for Ssl {
    fn build(inner: *mut _Ssl) -> Self {
        if inner.is_null() {
            panic!("Unexpected null pointer")
        };
        Ssl(inner)
    }
}

impl Drop for Ssl {
    /// Frees a SSL context instance.
    fn drop(&mut self) {
        unsafe { cass_ssl_free(self.0) }
    }
}

impl Default for Ssl {
    /// Creates a new SSL context.
    fn default() -> Ssl {
        unsafe { Ssl(cass_ssl_new()) }
    }
}

impl Ssl {
    /// Adds a trusted certificate. This is used to verify
    /// the peer's certificate.
    pub fn add_trusted_cert(&mut self, cert: impl AsRef<str>) -> Result<&mut Self> {
        let cert = cert.as_ref();
        unsafe {
            let cert_ptr = cert.as_ptr() as *const c_char;
            cass_ssl_add_trusted_cert_n(self.0, cert_ptr, cert.len()).to_result(self)
        }
    }

    /// Sets verification performed on the peer's certificate.
    ///
    /// CASS_SSL_VERIFY_NONE - No verification is performed
    ///
    /// CASS_SSL_VERIFY_PEER_CERT - Certificate is present and valid
    ///
    /// CASS_SSL_VERIFY_PEER_IDENTITY - IP address matches the certificate's
    /// common name or one of its subject alternative names. This implies the
    /// certificate is also present.
    ///
    /// <b>Default:</b> CASS_SSL_VERIFY_PEER_CERT
    pub fn set_verify_flags(&mut self, flags: &[SslVerifyFlag]) {
        unsafe { cass_ssl_set_verify_flags(self.0, to_bitset(flags)) }
    }

    /// Set client-side certificate chain. This is used to authenticate
    /// the client on the server-side. This should contain the entire
    /// Certificate chain starting with the certificate itself.
    pub fn set_cert(&mut self, cert: &str) -> Result<&mut Self> {
        unsafe {
            let cert_ptr = cert.as_ptr() as *const c_char;
            cass_ssl_set_cert_n(self.0, cert_ptr, cert.len()).to_result(self)
        }
    }

    /// Set client-side private key. This is used to authenticate
    /// the client on the server-side.
    pub fn set_private_key(&mut self, key: &str, password: &str) -> Result<&mut Self> {
        unsafe {
            let key_ptr = key.as_ptr() as *const c_char;
            let password_ptr = key.as_ptr() as *const c_char;
            cass_ssl_set_private_key_n(self.0, key_ptr, key.len(), password_ptr, password.len())
                .to_result(self)
        }
    }

    /// Set minimum TLS version. This helps avoid TLS downgrade attacks.
    #[cfg(feature = "early_access_min_tls_version")]
    pub fn set_min_protocol_version(
        &mut self,
        min_tls_version: SslTlsVersion,
    ) -> Result<&mut Self> {
        unsafe { cass_ssl_set_min_protocol_version(self.0, min_tls_version).to_result(self) }
    }
}