irc_connect/
lib.rs

1//! an abstraction over the kinds of connections useful for irc clients
2
3use pin_project_lite::pin_project;
4use std::{
5    fmt,
6    net::SocketAddr,
7    path::Path,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12use tokio::{
13    io::{AsyncRead, AsyncWrite, ReadBuf},
14    net::{TcpStream, UnixStream},
15};
16use tokio_rustls::{
17    client::TlsStream,
18    rustls::{
19        client::WebPkiServerVerifier,
20        pki_types::{CertificateDer, PrivateKeyDer, ServerName},
21        ClientConfig, RootCertStore,
22    },
23    TlsConnector,
24};
25use tokio_socks::{
26    tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
27    IntoTargetAddr, TargetAddr,
28};
29
30pub use tokio_rustls;
31
32mod danger;
33
34/// error type returned by `irc_connect`
35#[derive(Debug, foxerror::FoxError)]
36#[non_exhaustive]
37pub enum Error {
38    /// you specified a tls client cert without using tls
39    ClientCertNoTls,
40    /// failed to connect
41    #[err(from)]
42    Connect(std::io::Error),
43    /// could not sock
44    #[err(from)]
45    Socks(tokio_socks::Error),
46    /// could not rustls
47    #[err(from)]
48    Rustls(tokio_rustls::rustls::Error),
49}
50
51pin_project! {
52    /// an open connection
53    #[derive(Debug)]
54    pub struct Stream {
55        #[pin]
56        inner: MaybeTls,
57    }
58}
59
60impl Stream {
61    /// start building a new stream based on a tcp connection
62    ///
63    /// ```no_run
64    /// use irc_connect::Stream;
65    /// # #[tokio::main]
66    /// # async fn main() {
67    /// let stream = Stream::new_tcp(&"[::1]:6667".parse().unwrap()).connect().await.unwrap();
68    /// # }
69    /// ```
70    pub fn new_tcp(addr: &SocketAddr) -> StreamBuilder<'_> {
71        StreamBuilder::new(BaseParams::Tcp(addr))
72    }
73    /// start building a new stream based on a unix socket
74    ///
75    /// ```no_run
76    /// use std::path::Path;
77    /// use irc_connect::Stream;
78    /// # #[tokio::main]
79    /// # async fn main() {
80    /// let stream = Stream::new_unix(Path::new("./my-unix-socket")).connect().await.unwrap();
81    /// # }
82    /// ```
83    pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
84        StreamBuilder::new(BaseParams::Unix(path))
85    }
86}
87
88impl AsyncRead for Stream {
89    #[inline]
90    fn poll_read(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &mut ReadBuf<'_>,
94    ) -> Poll<std::io::Result<()>> {
95        self.project().inner.poll_read(cx, buf)
96    }
97}
98
99impl AsyncWrite for Stream {
100    #[inline]
101    fn poll_write(
102        self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &[u8],
105    ) -> Poll<Result<usize, std::io::Error>> {
106        self.project().inner.poll_write(cx, buf)
107    }
108    #[inline]
109    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
110        self.project().inner.poll_flush(cx)
111    }
112    #[inline]
113    fn poll_shutdown(
114        self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116    ) -> Poll<Result<(), std::io::Error>> {
117        self.project().inner.poll_shutdown(cx)
118    }
119}
120
121pin_project! {
122    #[project = MaybeTlsProj]
123    #[derive(Debug)]
124    #[allow(clippy::large_enum_variant)] // you should use tls most of the time
125    enum MaybeTls {
126        Plain {
127            #[pin]
128            inner: MaybeSocks,
129        },
130        Tls {
131            #[pin]
132            inner: TlsStream<MaybeSocks>,
133        },
134    }
135}
136
137impl AsyncRead for MaybeTls {
138    #[inline]
139    fn poll_read(
140        self: Pin<&mut Self>,
141        cx: &mut Context,
142        buf: &mut ReadBuf<'_>,
143    ) -> Poll<std::io::Result<()>> {
144        match self.project() {
145            MaybeTlsProj::Plain { inner } => inner.poll_read(cx, buf),
146            MaybeTlsProj::Tls { inner } => inner.poll_read(cx, buf),
147        }
148    }
149}
150
151impl AsyncWrite for MaybeTls {
152    #[inline]
153    fn poll_write(
154        self: Pin<&mut Self>,
155        cx: &mut Context<'_>,
156        buf: &[u8],
157    ) -> Poll<Result<usize, std::io::Error>> {
158        match self.project() {
159            MaybeTlsProj::Plain { inner } => inner.poll_write(cx, buf),
160            MaybeTlsProj::Tls { inner } => inner.poll_write(cx, buf),
161        }
162    }
163    #[inline]
164    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
165        match self.project() {
166            MaybeTlsProj::Plain { inner } => inner.poll_flush(cx),
167            MaybeTlsProj::Tls { inner } => inner.poll_flush(cx),
168        }
169    }
170    #[inline]
171    fn poll_shutdown(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174    ) -> Poll<Result<(), std::io::Error>> {
175        match self.project() {
176            MaybeTlsProj::Plain { inner } => inner.poll_shutdown(cx),
177            MaybeTlsProj::Tls { inner } => inner.poll_shutdown(cx),
178        }
179    }
180}
181
182pin_project! {
183    #[project = MaybeSocksProj]
184    #[derive(Debug)]
185    enum MaybeSocks {
186        Clear {
187            #[pin]
188            inner: BaseStream,
189        },
190        Socks4 {
191            #[pin]
192            inner: Socks4Stream<BaseStream>,
193        },
194        Socks5 {
195            #[pin]
196            inner: Socks5Stream<BaseStream>,
197        },
198    }
199}
200
201impl AsyncRead for MaybeSocks {
202    #[inline]
203    fn poll_read(
204        self: Pin<&mut Self>,
205        cx: &mut Context<'_>,
206        buf: &mut ReadBuf<'_>,
207    ) -> Poll<std::io::Result<()>> {
208        match self.project() {
209            MaybeSocksProj::Clear { inner } => inner.poll_read(cx, buf),
210            MaybeSocksProj::Socks4 { inner } => inner.poll_read(cx, buf),
211            MaybeSocksProj::Socks5 { inner } => inner.poll_read(cx, buf),
212        }
213    }
214}
215
216impl AsyncWrite for MaybeSocks {
217    #[inline]
218    fn poll_write(
219        self: Pin<&mut Self>,
220        cx: &mut Context<'_>,
221        buf: &[u8],
222    ) -> Poll<Result<usize, std::io::Error>> {
223        match self.project() {
224            MaybeSocksProj::Clear { inner } => inner.poll_write(cx, buf),
225            MaybeSocksProj::Socks4 { inner } => inner.poll_write(cx, buf),
226            MaybeSocksProj::Socks5 { inner } => inner.poll_write(cx, buf),
227        }
228    }
229    #[inline]
230    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
231        match self.project() {
232            MaybeSocksProj::Clear { inner } => inner.poll_flush(cx),
233            MaybeSocksProj::Socks4 { inner } => inner.poll_flush(cx),
234            MaybeSocksProj::Socks5 { inner } => inner.poll_flush(cx),
235        }
236    }
237    #[inline]
238    fn poll_shutdown(
239        self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241    ) -> Poll<Result<(), std::io::Error>> {
242        match self.project() {
243            MaybeSocksProj::Clear { inner } => inner.poll_shutdown(cx),
244            MaybeSocksProj::Socks4 { inner } => inner.poll_shutdown(cx),
245            MaybeSocksProj::Socks5 { inner } => inner.poll_shutdown(cx),
246        }
247    }
248}
249
250pin_project! {
251    #[project = BaseStreamProj]
252    #[derive(Debug)]
253    enum BaseStream {
254        Tcp {
255            #[pin]
256            inner: TcpStream,
257        },
258        Unix {
259            #[pin]
260            inner: UnixStream,
261        },
262    }
263}
264
265impl AsyncRead for BaseStream {
266    #[inline]
267    fn poll_read(
268        self: Pin<&mut Self>,
269        cx: &mut Context,
270        buf: &mut ReadBuf<'_>,
271    ) -> Poll<std::io::Result<()>> {
272        match self.project() {
273            BaseStreamProj::Tcp { inner } => inner.poll_read(cx, buf),
274            BaseStreamProj::Unix { inner } => inner.poll_read(cx, buf),
275        }
276    }
277}
278
279impl AsyncWrite for BaseStream {
280    #[inline]
281    fn poll_write(
282        self: Pin<&mut Self>,
283        cx: &mut Context<'_>,
284        buf: &[u8],
285    ) -> Poll<Result<usize, std::io::Error>> {
286        match self.project() {
287            BaseStreamProj::Tcp { inner } => inner.poll_write(cx, buf),
288            BaseStreamProj::Unix { inner } => inner.poll_write(cx, buf),
289        }
290    }
291    #[inline]
292    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
293        match self.project() {
294            BaseStreamProj::Tcp { inner } => inner.poll_flush(cx),
295            BaseStreamProj::Unix { inner } => inner.poll_flush(cx),
296        }
297    }
298    #[inline]
299    fn poll_shutdown(
300        self: Pin<&mut Self>,
301        cx: &mut Context<'_>,
302    ) -> Poll<Result<(), std::io::Error>> {
303        match self.project() {
304            BaseStreamProj::Tcp { inner } => inner.poll_shutdown(cx),
305            BaseStreamProj::Unix { inner } => inner.poll_shutdown(cx),
306        }
307    }
308}
309
310/// a builder for [`Stream`]
311#[derive(Debug)]
312pub struct StreamBuilder<'a> {
313    base: BaseParams<'a>,
314    socks: Option<SocksParams<'a>>,
315    tls: Option<TlsParams>,
316    client_cert: Option<ClientCert>,
317}
318
319impl<'a> StreamBuilder<'a> {
320    fn new(base: BaseParams<'a>) -> Self {
321        Self {
322            base,
323            socks: None,
324            tls: None,
325            client_cert: None,
326        }
327    }
328
329    fn socks(
330        mut self,
331        version: SocksVersion,
332        target: impl IntoTargetAddr<'a>,
333        auth: Option<SocksAuth<'a>>,
334    ) -> Self {
335        self.socks = Some(SocksParams {
336            version,
337            target: target.into_target_addr(),
338            auth,
339        });
340        self
341    }
342
343    /// enable socks4 proxying
344    ///
345    /// ```
346    /// # use irc_connect::Stream;
347    /// # #[tokio::main]
348    /// # async fn main() {
349    /// # let addr = "[::1]:6667".parse().unwrap();
350    /// # let builder = Stream::new_tcp(&addr);
351    /// let builder = builder.socks4("127.0.0.1:9050");
352    /// # }
353    /// ```
354    pub fn socks4(self, target: impl IntoTargetAddr<'a>) -> Self {
355        self.socks(SocksVersion::Socks4, target, None)
356    }
357
358    /// enable socks4 proxying with a userid
359    ///
360    /// ```
361    /// # use irc_connect::Stream;
362    /// # #[tokio::main]
363    /// # async fn main() {
364    /// # let addr = "[::1]:6667".parse().unwrap();
365    /// # let builder = Stream::new_tcp(&addr);
366    /// let builder = builder.socks4_with_userid("127.0.0.1:9050", "meow");
367    /// # }
368    /// ```
369    pub fn socks4_with_userid(self, target: impl IntoTargetAddr<'a>, userid: &'a str) -> Self {
370        self.socks(
371            SocksVersion::Socks4,
372            target,
373            Some(SocksAuth {
374                username: userid,
375                password: "h",
376            }),
377        )
378    }
379
380    /// enable socks5 proxying
381    ///
382    /// ```
383    /// # use irc_connect::Stream;
384    /// # #[tokio::main]
385    /// # async fn main() {
386    /// # let addr = "[::1]:6667".parse().unwrap();
387    /// # let builder = Stream::new_tcp(&addr);
388    /// let builder = builder.socks5("127.0.0.1:9050");
389    /// # }
390    /// ```
391    pub fn socks5(self, target: impl IntoTargetAddr<'a>) -> Self {
392        self.socks(SocksVersion::Socks5, target, None)
393    }
394
395    /// enable socks5 proxying with password authentication
396    ///
397    /// ```
398    /// # use irc_connect::Stream;
399    /// # #[tokio::main]
400    /// # async fn main() {
401    /// # let addr = "[::1]:6667".parse().unwrap();
402    /// # let builder = Stream::new_tcp(&addr);
403    /// let builder = builder.socks5_with_password("127.0.0.1:9050", "AzureDiamond", "hunter2");
404    /// # }
405    /// ```
406    pub fn socks5_with_password(
407        self,
408        target: impl IntoTargetAddr<'a>,
409        username: &'a str,
410        password: &'a str,
411    ) -> Self {
412        self.socks(
413            SocksVersion::Socks5,
414            target,
415            Some(SocksAuth { username, password }),
416        )
417    }
418
419    fn tls(mut self, domain: ServerName<'static>, verification: TlsVerify) -> Self {
420        self.tls = Some(TlsParams {
421            domain,
422            verification,
423        });
424        self
425    }
426
427    /// enable tls without any verification
428    ///
429    /// ```
430    /// use tokio_rustls::rustls::pki_types::ServerName;
431    /// # use irc_connect::Stream;
432    /// # #[tokio::main]
433    /// # async fn main() {
434    /// # let addr = "[::1]:6667".parse().unwrap();
435    /// # let builder = Stream::new_tcp(&addr);
436    /// let builder = builder.tls_danger_insecure(ServerName::try_from("google.com").unwrap());
437    /// # }
438    /// ```
439    pub fn tls_danger_insecure(self, domain: ServerName<'static>) -> Self {
440        self.tls(domain, TlsVerify::Insecure)
441    }
442
443    /// enable tls with root certificate verification
444    ///
445    /// can also be used to pin a self-signed cert as long as it has a `CA:FALSE` constraint
446    ///
447    /// ```
448    /// use tokio_rustls::rustls::RootCertStore;
449    /// use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
450    ///
451    /// # use irc_connect::Stream;
452    /// # #[tokio::main]
453    /// # async fn main() {
454    /// # let addr = "[::1]:6667".parse().unwrap();
455    /// # let builder = Stream::new_tcp(&addr);
456    /// let mut root = RootCertStore::empty();
457    /// root.add_parsable_certificates(
458    ///     CertificateDer::pem_file_iter("/etc/ssl/cert.pem")
459    ///         .unwrap()
460    ///         .flatten(),
461    /// );
462    /// let builder = builder.tls_with_root(ServerName::try_from("google.com").unwrap(), root);
463    /// # }
464    /// ```
465    pub fn tls_with_root(
466        self,
467        domain: ServerName<'static>,
468        root: impl Into<Arc<RootCertStore>>,
469    ) -> Self {
470        self.tls(domain, TlsVerify::CaStore(root.into()))
471    }
472
473    /// enable tls with a webpki verifier
474    pub fn tls_with_webpki(
475        self,
476        domain: ServerName<'static>,
477        webpki: Arc<WebPkiServerVerifier>,
478    ) -> Self {
479        self.tls(domain, TlsVerify::WebPki(webpki))
480    }
481
482    /// use a tls client certificate
483    ///
484    /// requires tls to be enabled
485    ///
486    /// ```no_run
487    /// use irc_connect::Stream;
488    /// use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject};
489    ///
490    /// # #[tokio::main]
491    /// # async fn main() {
492    /// let addr = "[::1]:6667".parse().unwrap();
493    /// let builder = Stream::new_tcp(&addr).tls_danger_insecure(ServerName::from(addr.ip()));
494    /// let cert = CertificateDer::pem_file_iter("cert.pem")
495    ///     .unwrap()
496    ///     .collect::<Result<Vec<_>, _>>()
497    ///     .unwrap();
498    /// let key = PrivateKeyDer::from_pem_file("cert.key").unwrap();
499    /// let builder = builder.client_cert(cert, key);
500    /// # }
501    /// ```
502    pub fn client_cert(
503        mut self,
504        cert_chain: Vec<CertificateDer<'static>>,
505        key_der: PrivateKeyDer<'static>,
506    ) -> Self {
507        self.client_cert = Some(ClientCert {
508            cert_chain,
509            key_der,
510        });
511        self
512    }
513
514    /// finish building and open the connection
515    ///
516    /// ```no_run
517    /// # use irc_connect::Stream;
518    /// # #[tokio::main]
519    /// # async fn main() {
520    /// # let addr = "[::1]:6667".parse().unwrap();
521    /// # let builder = Stream::new_tcp(&addr);
522    /// let stream = builder.connect().await.unwrap();
523    /// # }
524    /// ```
525    pub async fn connect(self) -> Result<Stream, Error> {
526        let stream = match self.base {
527            BaseParams::Tcp(addr) => BaseStream::Tcp {
528                inner: TcpStream::connect(addr).await?,
529            },
530            BaseParams::Unix(path) => BaseStream::Unix {
531                inner: UnixStream::connect(path).await?,
532            },
533        };
534        let stream = if let Some(params) = self.socks {
535            let target = params.target?;
536            match params.version {
537                SocksVersion::Socks4 => MaybeSocks::Socks4 {
538                    inner: if let Some(SocksAuth { username, .. }) = params.auth {
539                        Socks4Stream::connect_with_userid_and_socket(stream, target, username)
540                            .await?
541                    } else {
542                        Socks4Stream::connect_with_socket(stream, target).await?
543                    },
544                },
545                SocksVersion::Socks5 => MaybeSocks::Socks5 {
546                    inner: if let Some(SocksAuth { username, password }) = params.auth {
547                        Socks5Stream::connect_with_password_and_socket(
548                            stream, target, username, password,
549                        )
550                        .await?
551                    } else {
552                        Socks5Stream::connect_with_socket(stream, target).await?
553                    },
554                },
555            }
556        } else {
557            MaybeSocks::Clear { inner: stream }
558        };
559        let stream = if let Some(params) = self.tls {
560            let config = ClientConfig::builder();
561            let config = match params.verification {
562                TlsVerify::Insecure => {
563                    let provider = config.crypto_provider().clone();
564                    config
565                        .dangerous()
566                        .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
567                }
568                TlsVerify::CaStore(root) => config.with_root_certificates(root),
569                TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
570            };
571            let config = if let Some(ClientCert {
572                cert_chain,
573                key_der,
574            }) = self.client_cert
575            {
576                config.with_client_auth_cert(cert_chain, key_der)?
577            } else {
578                config.with_no_client_auth()
579            };
580            let connector = TlsConnector::from(Arc::new(config));
581            let inner = connector.connect(params.domain, stream).await?;
582            MaybeTls::Tls { inner }
583        } else {
584            if self.client_cert.is_some() {
585                return Err(Error::ClientCertNoTls);
586            }
587            MaybeTls::Plain { inner: stream }
588        };
589        Ok(Stream { inner: stream })
590    }
591}
592
593#[derive(Debug)]
594enum BaseParams<'a> {
595    // we cannot use [`tokio::net::ToSocketAddrs`] because they dont expose it :(
596    Tcp(&'a SocketAddr),
597    Unix(&'a Path),
598}
599
600struct SocksParams<'a> {
601    version: SocksVersion,
602    target: tokio_socks::Result<TargetAddr<'a>>,
603    auth: Option<SocksAuth<'a>>,
604}
605
606impl fmt::Debug for SocksParams<'_> {
607    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
608        fmt::Debug::fmt(&self.version, f)
609    }
610}
611
612struct SocksAuth<'a> {
613    username: &'a str,
614    password: &'a str,
615}
616
617#[derive(Debug)]
618enum SocksVersion {
619    Socks4,
620    Socks5,
621}
622
623#[derive(Debug)]
624struct TlsParams {
625    domain: ServerName<'static>,
626    verification: TlsVerify,
627}
628
629#[derive(Debug)]
630enum TlsVerify {
631    Insecure,
632    CaStore(Arc<RootCertStore>),
633    WebPki(Arc<WebPkiServerVerifier>),
634}
635
636#[derive(Debug)]
637struct ClientCert {
638    cert_chain: Vec<CertificateDer<'static>>,
639    key_der: PrivateKeyDer<'static>,
640}