irc_connect/
lib.rs

1//! an abstraction over the kinds of connections useful for irc clients
2
3use pin_project_lite::pin_project;
4use std::{
5    fmt,
6    net::SocketAddr,
7    path::Path,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12use tokio::{
13    io::{AsyncRead, AsyncWrite, ReadBuf},
14    net::{TcpStream, UnixStream},
15};
16use tokio_rustls::{
17    client::TlsStream,
18    rustls::{
19        client::WebPkiServerVerifier,
20        pki_types::{CertificateDer, PrivateKeyDer, ServerName},
21        ClientConfig, RootCertStore,
22    },
23    TlsConnector,
24};
25use tokio_socks::{
26    tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
27    IntoTargetAddr, TargetAddr,
28};
29
30pub use tokio_rustls;
31
32mod danger;
33
34/// error type returned by `irc_connect`
35#[derive(Debug, foxerror::FoxError)]
36#[non_exhaustive]
37pub enum Error {
38    /// you specified a tls client cert without using tls
39    ClientCertNoTls,
40    /// failed to connect
41    #[err(from)]
42    Connect(std::io::Error),
43    /// could not sock
44    #[err(from)]
45    Socks(tokio_socks::Error),
46    /// could not rustls
47    #[err(from)]
48    Rustls(tokio_rustls::rustls::Error),
49}
50
51pin_project! {
52    /// an open connection
53    #[derive(Debug)]
54    pub struct Stream {
55        #[pin]
56        inner: MaybeTls,
57    }
58}
59
60impl Stream {
61    /// start building a new stream based on a tcp connection
62    ///
63    /// ```no_run
64    /// use irc_connect::Stream;
65    /// # #[tokio::main]
66    /// # async fn main() {
67    /// let stream = Stream::new_tcp(&"[::1]:6667".parse().unwrap()).connect().await.unwrap();
68    /// # }
69    /// ```
70    pub fn new_tcp(addr: &SocketAddr) -> StreamBuilder<'_> {
71        StreamBuilder::new(BaseParams::Tcp(addr))
72    }
73    /// start building a new stream based on a unix socket
74    ///
75    /// ```no_run
76    /// use std::path::Path;
77    /// use irc_connect::Stream;
78    /// # #[tokio::main]
79    /// # async fn main() {
80    /// let stream = Stream::new_unix(Path::new("./my-unix-socket")).connect().await.unwrap();
81    /// # }
82    /// ```
83    pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
84        StreamBuilder::new(BaseParams::Unix(path))
85    }
86}
87
88impl AsyncRead for Stream {
89    #[inline]
90    fn poll_read(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &mut ReadBuf<'_>,
94    ) -> Poll<std::io::Result<()>> {
95        self.project().inner.poll_read(cx, buf)
96    }
97}
98
99impl AsyncWrite for Stream {
100    #[inline]
101    fn poll_write(
102        self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &[u8],
105    ) -> Poll<Result<usize, std::io::Error>> {
106        self.project().inner.poll_write(cx, buf)
107    }
108    #[inline]
109    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
110        self.project().inner.poll_flush(cx)
111    }
112    #[inline]
113    fn poll_shutdown(
114        self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116    ) -> Poll<Result<(), std::io::Error>> {
117        self.project().inner.poll_shutdown(cx)
118    }
119}
120
121pin_project! {
122    #[project = MaybeTlsProj]
123    #[derive(Debug)]
124    #[allow(clippy::large_enum_variant)] // you should use tls most of the time
125    enum MaybeTls {
126        Plain {
127            #[pin]
128            inner: MaybeSocks,
129        },
130        Tls {
131            #[pin]
132            inner: TlsStream<MaybeSocks>,
133        },
134    }
135}
136
137macro_rules! trivial_impl {
138    ($target:ty, ($($arm:path),*)) => {
139        impl AsyncRead for $target {
140            #[inline]
141            fn poll_read(
142                self: Pin<&mut Self>,
143                cx: &mut Context<'_>,
144                buf: &mut ReadBuf<'_>,
145            ) -> Poll<std::io::Result<()>> {
146                match self.project() {
147                    $($arm { inner } => inner.poll_read(cx, buf),)*
148                }
149            }
150        }
151
152        impl AsyncWrite for $target {
153            #[inline]
154            fn poll_write(
155                self: Pin<&mut Self>,
156                cx: &mut Context<'_>,
157                buf: &[u8],
158            ) -> Poll<Result<usize, std::io::Error>> {
159                match self.project() {
160                    $($arm { inner } => inner.poll_write(cx, buf),)*
161                }
162            }
163            #[inline]
164            fn poll_flush(
165                self: Pin<&mut Self>,
166                cx: &mut Context<'_>,
167            ) -> Poll<Result<(), std::io::Error>> {
168                match self.project() {
169                    $($arm { inner } => inner.poll_flush(cx),)*
170                }
171            }
172            #[inline]
173            fn poll_shutdown(
174                self: Pin<&mut Self>,
175                cx: &mut Context<'_>,
176            ) -> Poll<Result<(), std::io::Error>> {
177                match self.project() {
178                    $($arm { inner } => inner.poll_shutdown(cx),)*
179                }
180            }
181        }
182    };
183}
184
185trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
186
187pin_project! {
188    #[project = MaybeSocksProj]
189    #[derive(Debug)]
190    enum MaybeSocks {
191        Clear {
192            #[pin]
193            inner: BaseStream,
194        },
195        Socks4 {
196            #[pin]
197            inner: Socks4Stream<BaseStream>,
198        },
199        Socks5 {
200            #[pin]
201            inner: Socks5Stream<BaseStream>,
202        },
203    }
204}
205
206trivial_impl!(
207    MaybeSocks,
208    (
209        MaybeSocksProj::Clear,
210        MaybeSocksProj::Socks4,
211        MaybeSocksProj::Socks5
212    )
213);
214
215pin_project! {
216    #[project = BaseStreamProj]
217    #[derive(Debug)]
218    enum BaseStream {
219        Tcp {
220            #[pin]
221            inner: TcpStream,
222        },
223        Unix {
224            #[pin]
225            inner: UnixStream,
226        },
227    }
228}
229
230trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
231
232/// a builder for [`Stream`]
233#[derive(Debug)]
234pub struct StreamBuilder<'a> {
235    base: BaseParams<'a>,
236    socks: Option<SocksParams<'a>>,
237    tls: Option<TlsParams>,
238    client_cert: Option<ClientCert>,
239}
240
241impl<'a> StreamBuilder<'a> {
242    fn new(base: BaseParams<'a>) -> Self {
243        Self {
244            base,
245            socks: None,
246            tls: None,
247            client_cert: None,
248        }
249    }
250
251    fn socks(
252        mut self,
253        version: SocksVersion,
254        target: impl IntoTargetAddr<'a>,
255        auth: Option<SocksAuth<'a>>,
256    ) -> Self {
257        self.socks = Some(SocksParams {
258            version,
259            target: target.into_target_addr(),
260            auth,
261        });
262        self
263    }
264
265    /// enable socks4 proxying
266    ///
267    /// ```
268    /// # use irc_connect::Stream;
269    /// # #[tokio::main]
270    /// # async fn main() {
271    /// let addr = "127.0.0.1:9050".parse().unwrap();
272    /// let builder = Stream::new_tcp(&addr);
273    /// let builder = builder.socks4("irc.example.com:6667");
274    /// # }
275    /// ```
276    #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
277    pub fn socks4(self, target: impl IntoTargetAddr<'a>) -> Self {
278        self.socks(SocksVersion::Socks4, target, None)
279    }
280
281    /// enable socks4 proxying with a userid
282    ///
283    /// ```
284    /// # use irc_connect::Stream;
285    /// # #[tokio::main]
286    /// # async fn main() {
287    /// let addr = "127.0.0.1:9050".parse().unwrap();
288    /// let builder = Stream::new_tcp(&addr);
289    /// let builder = builder.socks4_with_userid("irc.example.com:6667", "meow");
290    /// # }
291    /// ```
292    #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
293    pub fn socks4_with_userid(self, target: impl IntoTargetAddr<'a>, userid: &'a str) -> Self {
294        self.socks(
295            SocksVersion::Socks4,
296            target,
297            Some(SocksAuth {
298                username: userid,
299                password: "h",
300            }),
301        )
302    }
303
304    /// enable socks5 proxying
305    ///
306    /// ```
307    /// # use irc_connect::Stream;
308    /// # #[tokio::main]
309    /// # async fn main() {
310    /// let addr = "127.0.0.1:9050".parse().unwrap();
311    /// let builder = Stream::new_tcp(&addr);
312    /// let builder = builder.socks5("irc.example.com:6667");
313    /// # }
314    /// ```
315    #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
316    pub fn socks5(self, target: impl IntoTargetAddr<'a>) -> Self {
317        self.socks(SocksVersion::Socks5, target, None)
318    }
319
320    /// enable socks5 proxying with password authentication
321    ///
322    /// ```
323    /// # use irc_connect::Stream;
324    /// # #[tokio::main]
325    /// # async fn main() {
326    /// let addr = "127.0.0.1:9050".parse().unwrap();
327    /// let builder = Stream::new_tcp(&addr);
328    /// let builder = builder.socks5_with_password("irc.example.com:6667", "AzureDiamond", "hunter2");
329    /// # }
330    /// ```
331    #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
332    pub fn socks5_with_password(
333        self,
334        target: impl IntoTargetAddr<'a>,
335        username: &'a str,
336        password: &'a str,
337    ) -> Self {
338        self.socks(
339            SocksVersion::Socks5,
340            target,
341            Some(SocksAuth { username, password }),
342        )
343    }
344
345    fn tls(mut self, domain: ServerName<'static>, verification: TlsVerify) -> Self {
346        self.tls = Some(TlsParams {
347            domain,
348            verification,
349        });
350        self
351    }
352
353    /// enable tls without any verification
354    ///
355    /// ```
356    /// use tokio_rustls::rustls::pki_types::ServerName;
357    /// # use irc_connect::Stream;
358    /// # #[tokio::main]
359    /// # async fn main() {
360    /// # let addr = "[::1]:6667".parse().unwrap();
361    /// # let builder = Stream::new_tcp(&addr);
362    /// let builder = builder.tls_danger_insecure(ServerName::try_from("google.com").unwrap());
363    /// # }
364    /// ```
365    pub fn tls_danger_insecure(self, domain: ServerName<'static>) -> Self {
366        self.tls(domain, TlsVerify::Insecure)
367    }
368
369    /// enable tls with root certificate verification
370    ///
371    /// can also be used to pin a self-signed cert as long as it has a `CA:FALSE` constraint
372    ///
373    /// ```
374    /// use tokio_rustls::rustls::RootCertStore;
375    /// use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
376    ///
377    /// # use irc_connect::Stream;
378    /// # #[tokio::main]
379    /// # async fn main() {
380    /// # let addr = "[::1]:6667".parse().unwrap();
381    /// # let builder = Stream::new_tcp(&addr);
382    /// let mut root = RootCertStore::empty();
383    /// root.add_parsable_certificates(
384    ///     CertificateDer::pem_file_iter("/etc/ssl/cert.pem")
385    ///         .unwrap()
386    ///         .flatten(),
387    /// );
388    /// let builder = builder.tls_with_root(ServerName::try_from("google.com").unwrap(), root);
389    /// # }
390    /// ```
391    pub fn tls_with_root(
392        self,
393        domain: ServerName<'static>,
394        root: impl Into<Arc<RootCertStore>>,
395    ) -> Self {
396        self.tls(domain, TlsVerify::CaStore(root.into()))
397    }
398
399    /// enable tls with a webpki verifier
400    pub fn tls_with_webpki(
401        self,
402        domain: ServerName<'static>,
403        webpki: Arc<WebPkiServerVerifier>,
404    ) -> Self {
405        self.tls(domain, TlsVerify::WebPki(webpki))
406    }
407
408    /// use a tls client certificate
409    ///
410    /// requires tls to be enabled
411    ///
412    /// ```no_run
413    /// use irc_connect::Stream;
414    /// use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject};
415    ///
416    /// # #[tokio::main]
417    /// # async fn main() {
418    /// let addr = "[::1]:6667".parse().unwrap();
419    /// let builder = Stream::new_tcp(&addr).tls_danger_insecure(ServerName::from(addr.ip()));
420    /// let cert = CertificateDer::pem_file_iter("cert.pem")
421    ///     .unwrap()
422    ///     .collect::<Result<Vec<_>, _>>()
423    ///     .unwrap();
424    /// let key = PrivateKeyDer::from_pem_file("cert.key").unwrap();
425    /// let builder = builder.client_cert(cert, key);
426    /// # }
427    /// ```
428    pub fn client_cert(
429        mut self,
430        cert_chain: Vec<CertificateDer<'static>>,
431        key_der: PrivateKeyDer<'static>,
432    ) -> Self {
433        self.client_cert = Some(ClientCert {
434            cert_chain,
435            key_der,
436        });
437        self
438    }
439
440    /// finish building and open the connection
441    ///
442    /// ```no_run
443    /// # use irc_connect::Stream;
444    /// # #[tokio::main]
445    /// # async fn main() {
446    /// # let addr = "[::1]:6667".parse().unwrap();
447    /// # let builder = Stream::new_tcp(&addr);
448    /// let stream = builder.connect().await.unwrap();
449    /// # }
450    /// ```
451    pub async fn connect(self) -> Result<Stream, Error> {
452        let stream = match self.base {
453            BaseParams::Tcp(addr) => BaseStream::Tcp {
454                inner: TcpStream::connect(addr).await?,
455            },
456            BaseParams::Unix(path) => BaseStream::Unix {
457                inner: UnixStream::connect(path).await?,
458            },
459        };
460        let stream = if let Some(params) = self.socks {
461            let target = params.target?;
462            match params.version {
463                SocksVersion::Socks4 => MaybeSocks::Socks4 {
464                    inner: if let Some(SocksAuth { username, .. }) = params.auth {
465                        Socks4Stream::connect_with_userid_and_socket(stream, target, username)
466                            .await?
467                    } else {
468                        Socks4Stream::connect_with_socket(stream, target).await?
469                    },
470                },
471                SocksVersion::Socks5 => MaybeSocks::Socks5 {
472                    inner: if let Some(SocksAuth { username, password }) = params.auth {
473                        Socks5Stream::connect_with_password_and_socket(
474                            stream, target, username, password,
475                        )
476                        .await?
477                    } else {
478                        Socks5Stream::connect_with_socket(stream, target).await?
479                    },
480                },
481            }
482        } else {
483            MaybeSocks::Clear { inner: stream }
484        };
485        let stream = if let Some(params) = self.tls {
486            let config = ClientConfig::builder();
487            let config = match params.verification {
488                TlsVerify::Insecure => {
489                    let provider = config.crypto_provider().clone();
490                    config
491                        .dangerous()
492                        .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
493                }
494                TlsVerify::CaStore(root) => config.with_root_certificates(root),
495                TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
496            };
497            let config = if let Some(ClientCert {
498                cert_chain,
499                key_der,
500            }) = self.client_cert
501            {
502                config.with_client_auth_cert(cert_chain, key_der)?
503            } else {
504                config.with_no_client_auth()
505            };
506            let connector = TlsConnector::from(Arc::new(config));
507            let inner = connector.connect(params.domain, stream).await?;
508            MaybeTls::Tls { inner }
509        } else {
510            if self.client_cert.is_some() {
511                return Err(Error::ClientCertNoTls);
512            }
513            MaybeTls::Plain { inner: stream }
514        };
515        Ok(Stream { inner: stream })
516    }
517}
518
519#[derive(Debug)]
520enum BaseParams<'a> {
521    // we cannot use [`tokio::net::ToSocketAddrs`] because they dont expose it :(
522    Tcp(&'a SocketAddr),
523    Unix(&'a Path),
524}
525
526struct SocksParams<'a> {
527    version: SocksVersion,
528    target: tokio_socks::Result<TargetAddr<'a>>,
529    auth: Option<SocksAuth<'a>>,
530}
531
532impl fmt::Debug for SocksParams<'_> {
533    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534        fmt::Debug::fmt(&self.version, f)
535    }
536}
537
538struct SocksAuth<'a> {
539    username: &'a str,
540    password: &'a str,
541}
542
543#[derive(Debug)]
544enum SocksVersion {
545    Socks4,
546    Socks5,
547}
548
549#[derive(Debug)]
550struct TlsParams {
551    domain: ServerName<'static>,
552    verification: TlsVerify,
553}
554
555#[derive(Debug)]
556enum TlsVerify {
557    Insecure,
558    CaStore(Arc<RootCertStore>),
559    WebPki(Arc<WebPkiServerVerifier>),
560}
561
562#[derive(Debug)]
563struct ClientCert {
564    cert_chain: Vec<CertificateDer<'static>>,
565    key_der: PrivateKeyDer<'static>,
566}