nostr_sdk_net/native/
mod.rs1use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10
11use futures_util::stream::{SplitSink, SplitStream};
12use futures_util::StreamExt;
13use thiserror::Error;
14use tokio::net::TcpStream;
15use tokio_rustls::client::TlsStream;
16use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
17use tokio_rustls::TlsConnector;
18use tokio_tungstenite::tungstenite::Error as WsError;
19pub use tokio_tungstenite::tungstenite::Message;
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
21use url_fork::{ParseError, Url};
22
23type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
24type Sink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
25type Stream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
26
27mod socks;
28
29use self::socks::TpcSocks5Stream;
30
31#[derive(Debug, Error)]
32pub enum Error {
33 #[error("io error: {0}")]
35 IO(#[from] std::io::Error),
36 #[error("ws error: {0}")]
38 Ws(#[from] WsError),
39 #[error("socks error: {0}")]
40 Socks(#[from] tokio_socks::Error),
41 #[error("timeout")]
43 Timeout,
44 #[error("invalid DNS name")]
46 InvalidDNSName,
47 #[error("impossible to parse URL: {0}")]
49 Url(#[from] ParseError),
50}
51
52pub async fn connect(
53 url: &Url,
54 proxy: Option<SocketAddr>,
55 timeout: Option<Duration>,
56) -> Result<(Sink, Stream), Error> {
57 let stream = match proxy {
58 Some(proxy) => connect_proxy(url, proxy, timeout).await?,
59 None => connect_direct(url, timeout).await?,
60 };
61 Ok(stream.split())
62}
63
64async fn connect_direct(url: &Url, timeout: Option<Duration>) -> Result<WebSocket, Error> {
65 let timeout = timeout.unwrap_or(Duration::from_secs(60));
66 let (stream, _) =
67 tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url.to_string()))
68 .await
69 .map_err(|_| Error::Timeout)??;
70 Ok(stream)
71}
72
73async fn connect_proxy(
74 url: &Url,
75 proxy: SocketAddr,
76 timeout: Option<Duration>,
77) -> Result<WebSocket, Error> {
78 let timeout = timeout.unwrap_or(Duration::from_secs(60));
79 let addr: String = match url.host_str() {
80 Some(host) => match url.port_or_known_default() {
81 Some(port) => format!("{host}:{port}"),
82 None => return Err(Error::Url(ParseError::EmptyHost)),
83 },
84 None => return Err(Error::Url(ParseError::InvalidPort)),
85 };
86
87 let conn = TpcSocks5Stream::connect(proxy, addr.clone()).await?;
88 let conn = match connect_with_tls(conn, url).await {
89 Ok(stream) => MaybeTlsStream::Rustls(stream),
90 Err(_) => {
91 let conn = TpcSocks5Stream::connect(proxy, addr).await?;
92 MaybeTlsStream::Plain(conn)
93 }
94 };
95
96 let (stream, _) = tokio::time::timeout(
97 timeout,
98 tokio_tungstenite::client_async(url.to_string(), conn),
99 )
100 .await
101 .map_err(|_| Error::Timeout)??;
102 Ok(stream)
103}
104
105async fn connect_with_tls(stream: TcpStream, url: &Url) -> Result<TlsStream<TcpStream>, Error> {
106 let mut root_cert_store = RootCertStore::empty();
107 root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
108 OwnedTrustAnchor::from_subject_spki_name_constraints(
109 ta.subject,
110 ta.spki,
111 ta.name_constraints,
112 )
113 }));
114 let config = ClientConfig::builder()
115 .with_safe_defaults()
116 .with_root_certificates(root_cert_store)
117 .with_no_client_auth();
118 let connector = TlsConnector::from(Arc::new(config));
119 let domain = url.domain().ok_or(Error::InvalidDNSName)?;
120 let domain = ServerName::try_from(domain).map_err(|_| Error::InvalidDNSName)?;
121 Ok(connector.connect(domain, stream).await?)
122}