nostr_sdk_net/native/
mod.rs

1// Copyright (c) 2022-2023 Yuki Kishimoto
2// Copyright (c) 2023-2024 Rust Nostr Developers
3// Distributed under the MIT software license
4
5//! Native Network
6
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10
11use futures_util::stream::{SplitSink, SplitStream};
12use futures_util::StreamExt;
13use thiserror::Error;
14use tokio::net::TcpStream;
15use tokio_rustls::client::TlsStream;
16use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
17use tokio_rustls::TlsConnector;
18use tokio_tungstenite::tungstenite::Error as WsError;
19pub use tokio_tungstenite::tungstenite::Message;
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
21use url_fork::{ParseError, Url};
22
23type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
24type Sink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
25type Stream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
26
27mod socks;
28
29use self::socks::TpcSocks5Stream;
30
31#[derive(Debug, Error)]
32pub enum Error {
33    /// I/O error
34    #[error("io error: {0}")]
35    IO(#[from] std::io::Error),
36    /// Ws error
37    #[error("ws error: {0}")]
38    Ws(#[from] WsError),
39    #[error("socks error: {0}")]
40    Socks(#[from] tokio_socks::Error),
41    /// Timeout
42    #[error("timeout")]
43    Timeout,
44    /// Invalid DNS name
45    #[error("invalid DNS name")]
46    InvalidDNSName,
47    /// Url parse error
48    #[error("impossible to parse URL: {0}")]
49    Url(#[from] ParseError),
50}
51
52pub async fn connect(
53    url: &Url,
54    proxy: Option<SocketAddr>,
55    timeout: Option<Duration>,
56) -> Result<(Sink, Stream), Error> {
57    let stream = match proxy {
58        Some(proxy) => connect_proxy(url, proxy, timeout).await?,
59        None => connect_direct(url, timeout).await?,
60    };
61    Ok(stream.split())
62}
63
64async fn connect_direct(url: &Url, timeout: Option<Duration>) -> Result<WebSocket, Error> {
65    let timeout = timeout.unwrap_or(Duration::from_secs(60));
66    let (stream, _) =
67        tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url.to_string()))
68            .await
69            .map_err(|_| Error::Timeout)??;
70    Ok(stream)
71}
72
73async fn connect_proxy(
74    url: &Url,
75    proxy: SocketAddr,
76    timeout: Option<Duration>,
77) -> Result<WebSocket, Error> {
78    let timeout = timeout.unwrap_or(Duration::from_secs(60));
79    let addr: String = match url.host_str() {
80        Some(host) => match url.port_or_known_default() {
81            Some(port) => format!("{host}:{port}"),
82            None => return Err(Error::Url(ParseError::EmptyHost)),
83        },
84        None => return Err(Error::Url(ParseError::InvalidPort)),
85    };
86
87    let conn = TpcSocks5Stream::connect(proxy, addr.clone()).await?;
88    let conn = match connect_with_tls(conn, url).await {
89        Ok(stream) => MaybeTlsStream::Rustls(stream),
90        Err(_) => {
91            let conn = TpcSocks5Stream::connect(proxy, addr).await?;
92            MaybeTlsStream::Plain(conn)
93        }
94    };
95
96    let (stream, _) = tokio::time::timeout(
97        timeout,
98        tokio_tungstenite::client_async(url.to_string(), conn),
99    )
100    .await
101    .map_err(|_| Error::Timeout)??;
102    Ok(stream)
103}
104
105async fn connect_with_tls(stream: TcpStream, url: &Url) -> Result<TlsStream<TcpStream>, Error> {
106    let mut root_cert_store = RootCertStore::empty();
107    root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
108        OwnedTrustAnchor::from_subject_spki_name_constraints(
109            ta.subject,
110            ta.spki,
111            ta.name_constraints,
112        )
113    }));
114    let config = ClientConfig::builder()
115        .with_safe_defaults()
116        .with_root_certificates(root_cert_store)
117        .with_no_client_auth();
118    let connector = TlsConnector::from(Arc::new(config));
119    let domain = url.domain().ok_or(Error::InvalidDNSName)?;
120    let domain = ServerName::try_from(domain).map_err(|_| Error::InvalidDNSName)?;
121    Ok(connector.connect(domain, stream).await?)
122}