1use std::io;
2#[allow(unused_imports)]
3use std::ops::DerefMut;
4use wasmer_bus_ws::prelude::*;
5use tokio::io::AsyncRead;
6use tokio::io::AsyncWrite;
7#[allow(unused_imports, dead_code)]
8use tracing::{debug, error, info, trace, warn};
9use ate_crypto::KeySize;
10use ate_crypto::NodeId;
11
12use crate::HelloMetadata;
13use super::protocol::StreamRx;
14use super::protocol::StreamTx;
15use super::CertificateValidation;
16use super::certificate_validation::GLOBAL_CERTIFICATES;
17
18pub struct StreamClient
19{
20 rx: StreamRx,
21 tx: StreamTx,
22 hello: HelloMetadata,
23}
24
25pub use super::security::StreamSecurity;
26
27impl StreamClient
28{
29 pub async fn connect(connect_url: url::Url, path: &str, security: StreamSecurity, #[allow(unused)] dns_server: Option<String>, #[allow(unused)] dns_sec: bool) -> Result<Self, Box<dyn std::error::Error>>
30 {
31 let https = match connect_url.scheme() {
32 "https" => true,
33 "wss" => true,
34 _ => false,
35 };
36
37 let host = connect_url
38 .host()
39 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "URL does not have a host component"))?;
40 let domain = match &host {
41 url::Host::Domain(a) => Some(*a),
42 url::Host::Ipv4(ip) if ip.is_loopback() => Some("localhost"),
43 url::Host::Ipv6(ip) if ip.is_loopback() => Some("localhost"),
44 _ => None
45 }
46 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "URL does not have a domain component"))?;
47
48 #[allow(unused_variables)]
49 let mut validation = {
50 let mut certs = Vec::new();
51
52 #[cfg(feature = "dns")]
53 #[cfg(not(target_family = "wasm"))]
54 {
55 let dns_server = dns_server.as_ref().map(|a| a.as_ref()).unwrap_or("8.8.8.8");
56 let mut dns = crate::Dns::connect(dns_server, dns_sec).await;
57 for cert in dns.dns_certs(domain).await {
58 certs.push(cert);
59 }
60 }
61 for cert in GLOBAL_CERTIFICATES.read().unwrap().iter() {
62 if certs.contains(cert) == false {
63 certs.push(cert.clone());
64 }
65 }
66 if certs.len() > 0 {
67 CertificateValidation::AllowedCertificates(certs)
68 } else if domain == "localhost" {
69 CertificateValidation::AllowAll
70 } else {
71 CertificateValidation::DenyAll
72 }
73 };
74 #[allow(unused_assignments)]
75 if domain == "localhost" || security.quantum_encryption(https) == false {
76 validation = CertificateValidation::AllowAll;
77 }
78
79 let socket = SocketBuilder::new(connect_url.clone())
80 .open()
81 .await?;
82
83 let (tx, rx) = socket.split();
84 let tx: Box<dyn AsyncWrite + Send + Sync + Unpin + 'static> = Box::new(tx);
85 let rx: Box<dyn AsyncRead + Send + Sync + Unpin + 'static> = Box::new(rx);
86
87 let key_size = if security.quantum_encryption(https) == true {
90 Some(KeySize::Bit192)
91 } else {
92 None
93 };
94
95 let node_id = NodeId::generate_client_id();
97 let (mut proto, hello_metadata) = super::hello::mesh_hello_exchange_sender(
98 rx,
99 tx,
100 node_id,
101 path.to_string(),
102 domain.to_string(),
103 key_size,
104 )
105 .await?;
106
107 #[cfg(feature = "quantum")]
109 let ek = match hello_metadata.encryption {
110 Some(key_size) => Some(
111 super::key_exchange::mesh_key_exchange_sender(
112 proto.deref_mut(),
113 key_size,
114 validation,
115 )
116 .await?,
117 ),
118 None => None,
119 };
120 #[cfg(not(feature = "quantum"))]
121 let ek = None;
122
123 let (rx, tx) = proto.split(ek);
125 Ok(
126 Self {
127 rx,
128 tx,
129 hello: hello_metadata,
130 }
131 )
132 }
133
134 pub fn split(self) -> (StreamRx, StreamTx)
135 {
136 (
137 self.rx,
138 self.tx,
139 )
140 }
141
142 pub fn hello_metadata(&self) -> &HelloMetadata {
143 &self.hello
144 }
145}