ate_comms/
client.rs

1use std::io;
2#[allow(unused_imports)]
3use std::ops::DerefMut;
4use wasmer_bus_ws::prelude::*;
5use tokio::io::AsyncRead;
6use tokio::io::AsyncWrite;
7#[allow(unused_imports, dead_code)]
8use tracing::{debug, error, info, trace, warn};
9use ate_crypto::KeySize;
10use ate_crypto::NodeId;
11
12use crate::HelloMetadata;
13use super::protocol::StreamRx;
14use super::protocol::StreamTx;
15use super::CertificateValidation;
16use super::certificate_validation::GLOBAL_CERTIFICATES;
17
18pub struct StreamClient
19{
20    rx: StreamRx,
21    tx: StreamTx,
22    hello: HelloMetadata,
23}
24
25pub use super::security::StreamSecurity;
26
27impl StreamClient
28{
29    pub async fn connect(connect_url: url::Url, path: &str, security: StreamSecurity, #[allow(unused)] dns_server: Option<String>, #[allow(unused)] dns_sec: bool) -> Result<Self, Box<dyn std::error::Error>>
30    {
31        let https = match connect_url.scheme() {
32            "https" => true,
33            "wss" => true,
34            _ => false,
35        };
36
37        let host = connect_url
38            .host()
39            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "URL does not have a host component"))?;
40        let domain = match &host {
41                url::Host::Domain(a) => Some(*a),
42                url::Host::Ipv4(ip) if ip.is_loopback() => Some("localhost"),
43                url::Host::Ipv6(ip) if ip.is_loopback() => Some("localhost"),
44                _ => None
45            }
46            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "URL does not have a domain component"))?;
47
48        #[allow(unused_variables)]
49        let mut validation = {
50            let mut certs = Vec::new();
51
52            #[cfg(feature = "dns")]
53            #[cfg(not(target_family = "wasm"))]
54            {
55                let dns_server = dns_server.as_ref().map(|a| a.as_ref()).unwrap_or("8.8.8.8");
56                let mut dns = crate::Dns::connect(dns_server, dns_sec).await;
57                for cert in dns.dns_certs(domain).await {
58                    certs.push(cert);
59                }
60            }
61            for cert in GLOBAL_CERTIFICATES.read().unwrap().iter() {
62                if certs.contains(cert) == false {
63                    certs.push(cert.clone());
64                }
65            }
66            if certs.len() > 0 {
67                CertificateValidation::AllowedCertificates(certs)
68            } else if domain == "localhost" {
69                CertificateValidation::AllowAll
70            } else {
71                CertificateValidation::DenyAll
72            }
73        };
74        #[allow(unused_assignments)]
75        if domain == "localhost" || security.quantum_encryption(https) == false {
76            validation = CertificateValidation::AllowAll;
77        }
78
79        let socket = SocketBuilder::new(connect_url.clone())
80            .open()
81            .await?;
82            
83        let (tx, rx) = socket.split(); 
84        let tx: Box<dyn AsyncWrite + Send + Sync + Unpin + 'static> = Box::new(tx);
85        let rx: Box<dyn AsyncRead + Send + Sync + Unpin + 'static> = Box::new(rx);
86
87        // We only encrypt if it actually has a certificate (otherwise
88        // a simple man-in-the-middle could intercept anyway)
89        let key_size = if security.quantum_encryption(https) == true {
90            Some(KeySize::Bit192)
91        } else {
92            None
93        };
94
95        // Say hello
96        let node_id = NodeId::generate_client_id();
97        let (mut proto, hello_metadata) = super::hello::mesh_hello_exchange_sender(
98            rx,
99            tx,
100            node_id,
101            path.to_string(),
102            domain.to_string(),
103            key_size,
104        )
105        .await?;
106
107        // If we are using wire encryption then exchange secrets
108        #[cfg(feature = "quantum")]
109        let ek = match hello_metadata.encryption {
110            Some(key_size) => Some(
111                super::key_exchange::mesh_key_exchange_sender(
112                    proto.deref_mut(),
113                    key_size,
114                    validation,
115                )
116                .await?,
117            ),
118            None => None,
119        };
120        #[cfg(not(feature = "quantum"))]
121        let ek = None;
122
123        // Create the rx and tx message streams
124        let (rx, tx) = proto.split(ek);
125        Ok(
126            Self {
127                rx,
128                tx,
129                hello: hello_metadata,
130            }
131        )
132    }
133
134    pub fn split(self) -> (StreamRx, StreamTx)
135    {
136        (
137            self.rx,
138            self.tx,
139        )
140    }
141
142    pub fn hello_metadata(&self) -> &HelloMetadata {
143        &self.hello
144    }
145}