ice_rs/
ssl.rs

1
2use tokio::{io::{AsyncRead, AsyncWrite}, net::TcpStream};
3use tokio_openssl::SslStream;
4use openssl::{ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslVerifyMode}};
5use openssl::x509::*;
6use std::path::Path;
7
8use crate::transport::Transport;
9use crate::errors::*;
10use crate::properties::Properties;
11
12use crate::ssltools::*;
13
14pub struct SslTransport {
15    stream: SslStream<TcpStream>
16}
17
18impl SslTransport {
19    pub async fn new(address: &str, properties: &Properties) -> Result<SslTransport, Box<dyn std::error::Error + Sync + Send>>
20    {
21        let mut builder = SslConnector::builder(SslMethod::tls())?;
22        let ssl_dir = Path::new(properties.get("IceSSL.DefaultDir").ok_or(Box::new(PropertyError::new("IceSSL.DefaultDir")))?);
23
24        configure_ca(&ssl_dir, properties, &mut builder)?;
25        configure_client_certs(&ssl_dir, properties, &mut builder)?;
26        configure_peer_verification(properties, &mut builder)?;
27        configure_ciphers(properties, &mut builder)?;
28        configure_protocol_versions(properties, &mut builder)?;
29
30        // connect
31        let connector = builder.build();
32        let stream = TcpStream::connect(address).await?;
33        let mut stream = SslStream::new(connector.configure()?.into_ssl(address)?, stream)?;
34        std::pin::Pin::new(&mut stream).connect().await.unwrap();
35        Ok(SslTransport {
36            stream
37        })
38    }
39}
40impl AsyncWrite for SslTransport {
41    fn poll_write(
42        self: std::pin::Pin<&mut Self>,
43        cx: &mut std::task::Context<'_>,
44        buf: &[u8],
45    ) -> std::task::Poll<Result<usize, std::io::Error>> {
46        std::pin::Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
47    }
48
49    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
50        std::pin::Pin::new(&mut self.get_mut().stream).poll_flush(cx)
51    }
52
53    fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
54        std::pin::Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
55    }
56}
57
58impl AsyncRead for SslTransport {
59    fn poll_read(
60        self: std::pin::Pin<&mut Self>,
61        cx: &mut std::task::Context<'_>,
62        buf: &mut tokio::io::ReadBuf<'_>,
63    ) -> std::task::Poll<std::io::Result<()>> {
64        std::pin::Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
65    }
66}
67
68impl Transport for SslTransport {
69    fn transport_type(&self) -> String {
70        return String::from("ssl");
71    }
72}
73
74
75fn configure_ca(ssl_dir: &Path, properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
76    let ca = match properties.get("IceSSL.CAs") {
77        Some(ca_file) => {
78            // PEM
79            read_pem(ca_file, ssl_dir)?.0.unwrap()
80        }
81        _ => {
82            if let Some(ca_file) = properties.get("IceSSL.CertAuthFile") {
83                // PEM [DEPRECATED]
84                println!("[SSL] Use of deprecated property IceSSL.CertAuthFile");
85                read_pem(ca_file, ssl_dir)?.0.unwrap()
86            } else {
87                return Ok(());
88            }
89        }
90    };
91    let mut store_builder = store::X509StoreBuilder::new()?;
92    store_builder.add_cert(ca)?;
93    let store = store_builder.build();
94    builder.set_verify_cert_store(store)?;
95
96    Ok(())
97}
98
99fn configure_client_certs(ssl_dir: &Path, properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
100{
101    let (cert, pkey) = match properties.get("IceSSL.CertFile") {
102        Some(cert_file) => {
103            if let Some(key_file) = properties.get("IceSSL.KeyFile") {
104                // PEM [DEPRECATED]
105                println!("[SSL] Use of deprecated property IceSSL.KeyFile");
106                let (cert, _) = read_pem(cert_file, ssl_dir)?;
107                let (_, pkey) = read_pem(key_file, ssl_dir)?;
108                (cert.unwrap(), pkey.unwrap())
109            } else {
110                // PKCS12
111                let password = properties.get("IceSSL.Password").ok_or(Box::new(PropertyError::new("Use of IceSSL.CertFile requires IceSSL.Password to be set")))?;
112                let pkcs12 = read_pkcs12(Path::new(cert_file), password, ssl_dir)?;
113                (pkcs12.cert, pkcs12.pkey)
114            }
115        }
116        _ => {
117            return Ok(());
118        }
119    };
120    builder.set_certificate(&cert)?;
121    builder.set_private_key(&pkey)?;
122    Ok(())
123}
124
125fn configure_peer_verification(properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
126{
127    match properties.get("IceSSL.VerifyPeer").unwrap_or(&String::from("1")).parse::<u8>()? {
128        0 => builder.set_verify(SslVerifyMode::NONE),
129        _ => builder.set_verify(SslVerifyMode::PEER)
130    }
131    Ok(())
132}
133
134fn configure_ciphers(properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
135{
136    if let Some(ciphers) =  properties.get("IceSSL.Ciphers") {
137        builder.set_cipher_list(ciphers)?;
138    }
139    Ok(())
140}
141
142fn configure_protocol_versions(properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
143{
144    let mut min_proto = None;
145    let mut max_proto = None;
146    if let Some(protocols) = properties.get("IceSSL.Protocols") {
147        for protocol in protocols.split(",").collect::<Vec<&str>>() {
148            if let Some(protocol) = parse_protocol(protocol) {
149                max_proto = Some(max_protocol(protocol, max_proto));
150                min_proto = Some(min_protocol(protocol, min_proto));
151            }
152        }
153    }
154
155    builder.set_min_proto_version(min_proto)?;
156    builder.set_max_proto_version(max_proto)?;
157    Ok(())
158}