1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
use crate::{error::PskClientError, PskClient};

use std::io::Read;
use std::net::ToSocketAddrs;

/// Default ciphers that will be used for a connection.
const DEFAULT_PSK_CIPHERS: &[&str] = &[
    "RSA-PSK-AES256-GCM-SHA384",
    "DHE-PSK-AES256-GCM-SHA384",
    "RSA-PSK-CHACHA20-POLY1305",
    "DHE-PSK-CHACHA20-POLY1305",
    "DHE-PSK-AES256-CCM8",
    "DHE-PSK-AES256-CCM",
    "PSK-AES256-GCM-SHA384",
    "PSK-CHACHA20-POLY1305",
    "PSK-AES256-CCM8",
    "PSK-AES256-CCM",
    "RSA-PSK-AES128-GCM-SHA256",
    "DHE-PSK-AES128-GCM-SHA256",
];

/// The builder for a PSK client, somewhat simplifies creating a new PSK client
/// and makes it more ergonomic.
#[derive(Clone, Debug, PartialEq, Default)]
pub struct PskClientBuilder<'a, H: ToSocketAddrs> {
    host: H,
    cipher_list: Vec<&'a str>,
    identity: Option<String>,
    key: Option<Vec<u8>>,
    use_fips: bool,
    require_fips: bool,
}

impl<'a, H: ToSocketAddrs> PskClientBuilder<'a, H> {
    /// Create a new `PskClientBuilder` with the default cipher list, a `None` identity and an empty key.
    pub fn new(host: H) -> Self {
        PskClientBuilder {
            host,
            cipher_list: DEFAULT_PSK_CIPHERS.to_vec(),
            identity: None,
            key: None,
            use_fips: false,
            require_fips: false,
        }
    }

    /// Returns a new `PskClient` which can be used to
    /// ```rust
    /// use psk_client::PskClient;
    /// use std::io::Write;
    ///
    /// if let Ok(client) = PskClient::builder("127.0.0.1:4433")
    ///     .identity("some-client")
    ///     .key("1A2B3C4D")
    ///     .build()
    /// {
    ///     if let Ok(mut connection) = client.connect() {
    ///         connection.write_all(b"oing").unwrap();
    ///     }
    /// }
    /// ```
    pub fn build(self) -> Result<PskClient, PskClientError> {
        let host = match self.host.to_socket_addrs() {
            Ok(mut hosts) => match hosts.next() {
                Some(host) => host,
                None => {
                    unreachable!("Impossible to have valid hosts but have none in the interator.")
                }
            },
            Err(e) => return Err(PskClientError::NoValidHost(e)),
        };

        let identity = self
            .identity
            .ok_or(PskClientError::MissingIdentity)
            .map(|id| [&id, "\0"].join(""))?;

        let key = self
            .key
            .ok_or(PskClientError::MissingKey)
            .map(hex::decode)?
            .map_err(PskClientError::UnparseableKeyHex)?;

        if openssl::fips::enable(self.use_fips).is_err()
            || (self.require_fips && !openssl::fips::enabled())
        {
            return Err(PskClientError::FIPSError);
        }

        Ok(PskClient {
            host,
            ciphers: self.cipher_list.join(":"),
            identity,
            key,
        })
    }

    /// Tries to enable FIPS for this client session, requires that OpenSSL has been compiled with
    /// FIPS enabled. Will silently fail if FIPS cannot be enabled.
    /// ```rust
    /// use psk_client::PskClient;
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .use_fips();
    /// ```
    pub fn use_fips(mut self) -> Self {
        self.use_fips = true;
        self
    }

    /// Tries to enable FIPS for this client session, requires that OpenSSL has been compiled with
    /// FIPS enabled. Will return an error on client build if FIPS cannot be enabled.
    /// ```rust
    /// use psk_client::PskClient;
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .require_fips();
    /// ```
    pub fn require_fips(mut self) -> Self {
        self.require_fips = true;
        self.use_fips = true;
        self
    }

    /// Sets the identity to use for this session. This is required.
    /// ```rust
    /// use psk_client::PskClient;
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .identity("some-client");
    /// ```
    #[must_use]
    pub fn identity<S: Into<String>>(mut self, identity: S) -> Self {
        self.identity = Some(identity.into());
        self
    }

    /// Sets the identity to use for this session, taking the identity from an object implementing
    /// `Read`
    /// ```rust
    /// use psk_client::PskClient;
    /// use psk_client::error::PskClientError;
    /// use std::io::Cursor;
    ///
    /// // Create a dummy file
    /// let file = Cursor::new(b"some-identity");
    ///
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .identity_from(file)
    ///     .unwrap();
    /// ```
    pub fn identity_from<R: Read>(self, mut reader: R) -> Result<Self, PskClientError> {
        let mut buffer = String::new();

        reader
            .read_to_string(&mut buffer)
            .map_err(PskClientError::ReadError)?;

        Ok(self.identity(buffer))
    }

    /// Sets the key to use for this session. Must be valid hex. This is required.
    /// ```rust
    /// use psk_client::PskClient;
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .key("1A2B3C4D");
    /// ```
    #[must_use]
    pub fn key<S: Into<Vec<u8>>>(mut self, key: S) -> Self {
        self.key = Some(key.into());
        self
    }

    /// Sets the key to sue for this session, taking the value from an object which implements `Read`.
    /// Must be valid hex. It will also cleanup non alphanumeric
    /// characters to special sequences (like new lines, trailing whitespace) are not included.
    /// ```rust
    /// use psk_client::PskClient;
    /// use psk_client::error::PskClientError;
    /// use std::io::Cursor;
    ///
    /// // Create a dummy file
    /// let file = Cursor::new(b"a1b2c3d4\n");
    ///
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .key_from(file)
    ///     .unwrap();
    /// ```
    pub fn key_from<R: Read>(self, mut reader: R) -> Result<Self, PskClientError> {
        let mut buffer = Vec::new();

        reader
            .read_to_end(&mut buffer)
            .map_err(PskClientError::ReadError)?;

        buffer = buffer
            .into_iter()
            .filter(u8::is_ascii_alphanumeric)
            .collect();

        Ok(self.key(buffer))
    }

    /// Adds a given cipher to the list of ciphers to send to the server
    /// ```rust
    /// use psk_client::PskClient;
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .cipher("PSK-AES256-CBC-SHA");
    /// ```
    pub fn cipher(mut self, cipher: &'a str) -> Self {
        self.cipher_list.push(cipher);
        self
    }

    /// Clears the current list of ciphers, which are initialised with a default
    /// PSK set.
    /// ```rust
    /// use psk_client::PskClient;
    /// let builder = PskClient::builder("127.0.0.1:4433")
    ///     .reset_ciphers();
    /// ```
    pub fn reset_ciphers(mut self) -> Self {
        self.cipher_list.clear();
        self
    }
}