Skip to main content

snap7_client/
tls.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use tokio::net::TcpStream;
4use tokio_rustls::rustls::pki_types::ServerName;
5use tokio_rustls::rustls::{ClientConfig, RootCertStore};
6use tokio_rustls::TlsConnector;
7
8use crate::error::Error;
9
10pub type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;
11
12/// Build a `rustls` `ClientConfig` with webpki system roots.
13/// If `extra_ca_der` is provided, it is added as a trusted CA certificate.
14pub fn make_tls_config(
15    extra_ca_der: Option<&[u8]>,
16) -> std::result::Result<Arc<ClientConfig>, Error> {
17    // Install the ring crypto provider if no process-level provider has been set yet.
18    // `install_default` returns an error when already installed; we ignore that case.
19    let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
20
21    let mut root_store = RootCertStore::empty();
22    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
23    if let Some(ca_bytes) = extra_ca_der {
24        let ca_cert = tokio_rustls::rustls::pki_types::CertificateDer::from(ca_bytes.to_vec());
25        root_store.add(ca_cert).map_err(|e| {
26            Error::Io(std::io::Error::new(
27                std::io::ErrorKind::InvalidData,
28                format!("invalid CA cert: {e}"),
29            ))
30        })?;
31    }
32    let config = ClientConfig::builder()
33        .with_root_certificates(root_store)
34        .with_no_client_auth();
35    Ok(Arc::new(config))
36}
37
38/// Connect a TLS stream to `addr` with SNI `server_name`.
39pub async fn tls_connect(
40    addr: SocketAddr,
41    server_name: &str,
42    extra_ca_der: Option<&[u8]>,
43) -> std::result::Result<TlsStream, Error> {
44    let config = make_tls_config(extra_ca_der)?;
45    let connector = TlsConnector::from(config);
46    let tcp = TcpStream::connect(addr).await.map_err(Error::Io)?;
47    let server_name = ServerName::try_from(server_name.to_string()).map_err(|e| {
48        Error::Io(std::io::Error::new(
49            std::io::ErrorKind::InvalidInput,
50            format!("invalid server name: {e}"),
51        ))
52    })?;
53    connector.connect(server_name, tcp).await.map_err(Error::Io)
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[test]
61    fn tls_config_builds_with_system_roots() {
62        let _cfg = make_tls_config(None).unwrap();
63    }
64
65    #[test]
66    fn tls_config_server_name_parses() {
67        let name = rustls::pki_types::ServerName::try_from("plc.example.com".to_string());
68        assert!(name.is_ok());
69    }
70}