1
2use tokio::{io::{AsyncRead, AsyncWrite}, net::TcpStream};
3use tokio_openssl::SslStream;
4use openssl::{ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslVerifyMode}};
5use openssl::x509::*;
6use std::path::Path;
7
8use crate::transport::Transport;
9use crate::errors::*;
10use crate::properties::Properties;
11
12use crate::ssltools::*;
13
14pub struct SslTransport {
15 stream: SslStream<TcpStream>
16}
17
18impl SslTransport {
19 pub async fn new(address: &str, properties: &Properties) -> Result<SslTransport, Box<dyn std::error::Error + Sync + Send>>
20 {
21 let mut builder = SslConnector::builder(SslMethod::tls())?;
22 let ssl_dir = Path::new(properties.get("IceSSL.DefaultDir").ok_or(Box::new(PropertyError::new("IceSSL.DefaultDir")))?);
23
24 configure_ca(&ssl_dir, properties, &mut builder)?;
25 configure_client_certs(&ssl_dir, properties, &mut builder)?;
26 configure_peer_verification(properties, &mut builder)?;
27 configure_ciphers(properties, &mut builder)?;
28 configure_protocol_versions(properties, &mut builder)?;
29
30 let connector = builder.build();
32 let stream = TcpStream::connect(address).await?;
33 let mut stream = SslStream::new(connector.configure()?.into_ssl(address)?, stream)?;
34 std::pin::Pin::new(&mut stream).connect().await.unwrap();
35 Ok(SslTransport {
36 stream
37 })
38 }
39}
40impl AsyncWrite for SslTransport {
41 fn poll_write(
42 self: std::pin::Pin<&mut Self>,
43 cx: &mut std::task::Context<'_>,
44 buf: &[u8],
45 ) -> std::task::Poll<Result<usize, std::io::Error>> {
46 std::pin::Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
47 }
48
49 fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
50 std::pin::Pin::new(&mut self.get_mut().stream).poll_flush(cx)
51 }
52
53 fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
54 std::pin::Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
55 }
56}
57
58impl AsyncRead for SslTransport {
59 fn poll_read(
60 self: std::pin::Pin<&mut Self>,
61 cx: &mut std::task::Context<'_>,
62 buf: &mut tokio::io::ReadBuf<'_>,
63 ) -> std::task::Poll<std::io::Result<()>> {
64 std::pin::Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
65 }
66}
67
68impl Transport for SslTransport {
69 fn transport_type(&self) -> String {
70 return String::from("ssl");
71 }
72}
73
74
75fn configure_ca(ssl_dir: &Path, properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
76 let ca = match properties.get("IceSSL.CAs") {
77 Some(ca_file) => {
78 read_pem(ca_file, ssl_dir)?.0.unwrap()
80 }
81 _ => {
82 if let Some(ca_file) = properties.get("IceSSL.CertAuthFile") {
83 println!("[SSL] Use of deprecated property IceSSL.CertAuthFile");
85 read_pem(ca_file, ssl_dir)?.0.unwrap()
86 } else {
87 return Ok(());
88 }
89 }
90 };
91 let mut store_builder = store::X509StoreBuilder::new()?;
92 store_builder.add_cert(ca)?;
93 let store = store_builder.build();
94 builder.set_verify_cert_store(store)?;
95
96 Ok(())
97}
98
99fn configure_client_certs(ssl_dir: &Path, properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
100{
101 let (cert, pkey) = match properties.get("IceSSL.CertFile") {
102 Some(cert_file) => {
103 if let Some(key_file) = properties.get("IceSSL.KeyFile") {
104 println!("[SSL] Use of deprecated property IceSSL.KeyFile");
106 let (cert, _) = read_pem(cert_file, ssl_dir)?;
107 let (_, pkey) = read_pem(key_file, ssl_dir)?;
108 (cert.unwrap(), pkey.unwrap())
109 } else {
110 let password = properties.get("IceSSL.Password").ok_or(Box::new(PropertyError::new("Use of IceSSL.CertFile requires IceSSL.Password to be set")))?;
112 let pkcs12 = read_pkcs12(Path::new(cert_file), password, ssl_dir)?;
113 (pkcs12.cert, pkcs12.pkey)
114 }
115 }
116 _ => {
117 return Ok(());
118 }
119 };
120 builder.set_certificate(&cert)?;
121 builder.set_private_key(&pkey)?;
122 Ok(())
123}
124
125fn configure_peer_verification(properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
126{
127 match properties.get("IceSSL.VerifyPeer").unwrap_or(&String::from("1")).parse::<u8>()? {
128 0 => builder.set_verify(SslVerifyMode::NONE),
129 _ => builder.set_verify(SslVerifyMode::PEER)
130 }
131 Ok(())
132}
133
134fn configure_ciphers(properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
135{
136 if let Some(ciphers) = properties.get("IceSSL.Ciphers") {
137 builder.set_cipher_list(ciphers)?;
138 }
139 Ok(())
140}
141
142fn configure_protocol_versions(properties: &Properties, builder: &mut SslConnectorBuilder) -> Result<(), Box<dyn std::error::Error + Sync + Send>>
143{
144 let mut min_proto = None;
145 let mut max_proto = None;
146 if let Some(protocols) = properties.get("IceSSL.Protocols") {
147 for protocol in protocols.split(",").collect::<Vec<&str>>() {
148 if let Some(protocol) = parse_protocol(protocol) {
149 max_proto = Some(max_protocol(protocol, max_proto));
150 min_proto = Some(min_protocol(protocol, min_proto));
151 }
152 }
153 }
154
155 builder.set_min_proto_version(min_proto)?;
156 builder.set_max_proto_version(max_proto)?;
157 Ok(())
158}