irc_connect/
lib.rs

1//! an abstraction over the kinds of connections useful for irc clients
2
3use pin_project_lite::pin_project;
4use std::{
5    fmt,
6    net::SocketAddr,
7    path::Path,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12use tokio::{
13    io::{AsyncRead, AsyncWrite, ReadBuf},
14    net::{TcpStream, UnixStream},
15};
16use tokio_rustls::{
17    client::TlsStream,
18    rustls::{
19        client::WebPkiServerVerifier,
20        pki_types::{CertificateDer, PrivateKeyDer, ServerName},
21        ClientConfig, RootCertStore,
22    },
23    TlsConnector,
24};
25use tokio_socks::{
26    tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
27    IntoTargetAddr, TargetAddr,
28};
29
30pub use tokio_rustls;
31
32mod danger;
33
34/// error type returned by `irc_connect`
35#[derive(Debug, foxerror::FoxError)]
36#[non_exhaustive]
37pub enum Error {
38    /// you specified a tls client cert without using tls
39    ClientCertNoTls,
40    /// failed to connect
41    #[err(from)]
42    Connect(std::io::Error),
43    /// could not sock
44    #[err(from)]
45    Socks(tokio_socks::Error),
46    /// could not rustls
47    #[err(from)]
48    Rustls(tokio_rustls::rustls::Error),
49    /// socks cannot connect to unix sockets
50    SocksToUnsupported,
51    /// invalid target address
52    InvalidTarget(tokio_socks::Error),
53    /// no tls servername provided and failed to guess it
54    NoServerName,
55}
56
57pin_project! {
58    /// an open connection
59    #[derive(Debug)]
60    pub struct Stream {
61        #[pin]
62        inner: MaybeTls,
63    }
64}
65
66impl Stream {
67    /// start building a new stream based on a tcp connection
68    ///
69    /// ```no_run
70    /// use irc_connect::Stream;
71    /// # #[tokio::main]
72    /// # async fn main() {
73    /// let stream = Stream::new_tcp("[::1]:6667").connect().await.unwrap();
74    /// # }
75    /// ```
76    pub fn new_tcp<'a>(addr: impl IntoTargetAddr<'a>) -> StreamBuilder<'a> {
77        StreamBuilder::new(BaseParams::Tcp(addr.into_target_addr()))
78    }
79    /// start building a new stream based on a unix socket
80    ///
81    /// ```no_run
82    /// use std::path::Path;
83    /// use irc_connect::Stream;
84    /// # #[tokio::main]
85    /// # async fn main() {
86    /// let stream = Stream::new_unix(Path::new("./my-unix-socket")).connect().await.unwrap();
87    /// # }
88    /// ```
89    pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
90        StreamBuilder::new(BaseParams::Unix(path))
91    }
92}
93
94impl AsyncRead for Stream {
95    #[inline]
96    fn poll_read(
97        self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99        buf: &mut ReadBuf<'_>,
100    ) -> Poll<std::io::Result<()>> {
101        self.project().inner.poll_read(cx, buf)
102    }
103}
104
105impl AsyncWrite for Stream {
106    #[inline]
107    fn poll_write(
108        self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        buf: &[u8],
111    ) -> Poll<Result<usize, std::io::Error>> {
112        self.project().inner.poll_write(cx, buf)
113    }
114    #[inline]
115    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
116        self.project().inner.poll_flush(cx)
117    }
118    #[inline]
119    fn poll_shutdown(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122    ) -> Poll<Result<(), std::io::Error>> {
123        self.project().inner.poll_shutdown(cx)
124    }
125}
126
127pin_project! {
128    #[project = MaybeTlsProj]
129    #[derive(Debug)]
130    #[allow(clippy::large_enum_variant)] // you should use tls most of the time
131    enum MaybeTls {
132        Plain {
133            #[pin]
134            inner: MaybeSocks,
135        },
136        Tls {
137            #[pin]
138            inner: TlsStream<MaybeSocks>,
139        },
140    }
141}
142
143macro_rules! trivial_impl {
144    ($target:ty, ($($arm:path),*)) => {
145        impl AsyncRead for $target {
146            #[inline]
147            fn poll_read(
148                self: Pin<&mut Self>,
149                cx: &mut Context<'_>,
150                buf: &mut ReadBuf<'_>,
151            ) -> Poll<std::io::Result<()>> {
152                match self.project() {
153                    $($arm { inner } => inner.poll_read(cx, buf),)*
154                }
155            }
156        }
157
158        impl AsyncWrite for $target {
159            #[inline]
160            fn poll_write(
161                self: Pin<&mut Self>,
162                cx: &mut Context<'_>,
163                buf: &[u8],
164            ) -> Poll<Result<usize, std::io::Error>> {
165                match self.project() {
166                    $($arm { inner } => inner.poll_write(cx, buf),)*
167                }
168            }
169            #[inline]
170            fn poll_flush(
171                self: Pin<&mut Self>,
172                cx: &mut Context<'_>,
173            ) -> Poll<Result<(), std::io::Error>> {
174                match self.project() {
175                    $($arm { inner } => inner.poll_flush(cx),)*
176                }
177            }
178            #[inline]
179            fn poll_shutdown(
180                self: Pin<&mut Self>,
181                cx: &mut Context<'_>,
182            ) -> Poll<Result<(), std::io::Error>> {
183                match self.project() {
184                    $($arm { inner } => inner.poll_shutdown(cx),)*
185                }
186            }
187        }
188    };
189}
190
191trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
192
193pin_project! {
194    #[project = MaybeSocksProj]
195    #[derive(Debug)]
196    enum MaybeSocks {
197        Clear {
198            #[pin]
199            inner: BaseStream,
200        },
201        Socks4 {
202            #[pin]
203            inner: Socks4Stream<BaseStream>,
204        },
205        Socks5 {
206            #[pin]
207            inner: Socks5Stream<BaseStream>,
208        },
209    }
210}
211
212trivial_impl!(
213    MaybeSocks,
214    (
215        MaybeSocksProj::Clear,
216        MaybeSocksProj::Socks4,
217        MaybeSocksProj::Socks5
218    )
219);
220
221pin_project! {
222    #[project = BaseStreamProj]
223    #[derive(Debug)]
224    enum BaseStream {
225        Tcp {
226            #[pin]
227            inner: TcpStream,
228        },
229        Unix {
230            #[pin]
231            inner: UnixStream,
232        },
233    }
234}
235
236trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
237
238/// a builder for [`Stream`]
239#[derive(Debug)]
240#[must_use = "this does nothing unless you finish building"]
241pub struct StreamBuilder<'a> {
242    base: BaseParams<'a>,
243    socks: Option<SocksParams<'a>>,
244    tls: Option<TlsParams>,
245    client_cert: Option<ClientCert>,
246}
247
248impl<'a> StreamBuilder<'a> {
249    fn new(base: BaseParams<'a>) -> Self {
250        Self {
251            base,
252            socks: None,
253            tls: None,
254            client_cert: None,
255        }
256    }
257
258    fn socks(
259        mut self,
260        version: SocksVersion,
261        proxy: SocketAddr,
262        auth: Option<SocksAuth<'a>>,
263    ) -> Self {
264        self.socks = Some(SocksParams {
265            version,
266            proxy,
267            auth,
268        });
269        self
270    }
271
272    /// enable socks4 proxying
273    ///
274    /// ```
275    /// # use irc_connect::Stream;
276    /// # #[tokio::main]
277    /// # async fn main() {
278    /// # let builder = Stream::new_tcp("[::1]:6667");
279    /// let builder = builder.socks4("127.0.0.1:9050".parse().unwrap());
280    /// # }
281    /// ```
282    pub fn socks4(self, proxy: SocketAddr) -> Self {
283        self.socks(SocksVersion::Socks4, proxy, None)
284    }
285
286    /// enable socks4 proxying with a userid
287    ///
288    /// ```
289    /// # use irc_connect::Stream;
290    /// # #[tokio::main]
291    /// # async fn main() {
292    /// # let builder = Stream::new_tcp("[::1]:6667");
293    /// let builder = builder.socks4_with_userid("127.0.0.1:9050".parse().unwrap(), "meow");
294    /// # }
295    /// ```
296    pub fn socks4_with_userid(self, proxy: SocketAddr, userid: &'a str) -> Self {
297        self.socks(
298            SocksVersion::Socks4,
299            proxy,
300            Some(SocksAuth {
301                username: userid,
302                password: "h",
303            }),
304        )
305    }
306
307    /// enable socks5 proxying
308    ///
309    /// ```
310    /// # use irc_connect::Stream;
311    /// # #[tokio::main]
312    /// # async fn main() {
313    /// # let builder = Stream::new_tcp("[::1]:6667");
314    /// let builder = builder.socks5("127.0.0.1:9050".parse().unwrap());
315    /// # }
316    /// ```
317    pub fn socks5(self, proxy: SocketAddr) -> Self {
318        self.socks(SocksVersion::Socks5, proxy, None)
319    }
320
321    /// enable socks5 proxying with password authentication
322    ///
323    /// ```
324    /// # use irc_connect::Stream;
325    /// # #[tokio::main]
326    /// # async fn main() {
327    /// # let builder = Stream::new_tcp("[::1]:6667");
328    /// let builder =
329    ///     builder.socks5_with_password("127.0.0.1:9050".parse().unwrap(), "AzureDiamond", "hunter2");
330    /// # }
331    /// ```
332    pub fn socks5_with_password(
333        self,
334        proxy: SocketAddr,
335        username: &'a str,
336        password: &'a str,
337    ) -> Self {
338        self.socks(
339            SocksVersion::Socks5,
340            proxy,
341            Some(SocksAuth { username, password }),
342        )
343    }
344
345    fn tls(mut self, domain: Option<ServerName<'static>>, verification: TlsVerify) -> Self {
346        self.tls = Some(TlsParams {
347            domain,
348            verification,
349        });
350        self
351    }
352
353    /// enable tls without any verification
354    ///
355    /// ```
356    /// use tokio_rustls::rustls::pki_types::ServerName;
357    /// # use irc_connect::Stream;
358    /// # #[tokio::main]
359    /// # async fn main() {
360    /// # let builder = Stream::new_tcp("[::1]:6667");
361    /// let builder = builder.tls_danger_insecure(Some(ServerName::try_from("google.com").unwrap()));
362    /// # }
363    /// ```
364    pub fn tls_danger_insecure(self, domain: Option<ServerName<'static>>) -> Self {
365        self.tls(domain, TlsVerify::Insecure)
366    }
367
368    /// enable tls with root certificate verification
369    ///
370    /// can also be used to pin a self-signed cert as long as it has a `CA:FALSE` constraint
371    ///
372    /// ```
373    /// use tokio_rustls::rustls::RootCertStore;
374    /// use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
375    ///
376    /// # use irc_connect::Stream;
377    /// # #[tokio::main]
378    /// # async fn main() {
379    /// # let builder = Stream::new_tcp("[::1]:6667");
380    /// let mut root = RootCertStore::empty();
381    /// root.add_parsable_certificates(
382    ///     CertificateDer::pem_file_iter("/etc/ssl/cert.pem")
383    ///         .unwrap()
384    ///         .flatten(),
385    /// );
386    /// let builder = builder.tls_with_root(None, root);
387    /// # }
388    /// ```
389    pub fn tls_with_root(
390        self,
391        domain: Option<ServerName<'static>>,
392        root: impl Into<Arc<RootCertStore>>,
393    ) -> Self {
394        self.tls(domain, TlsVerify::CaStore(root.into()))
395    }
396
397    /// enable tls with a webpki verifier
398    pub fn tls_with_webpki(
399        self,
400        domain: Option<ServerName<'static>>,
401        webpki: Arc<WebPkiServerVerifier>,
402    ) -> Self {
403        self.tls(domain, TlsVerify::WebPki(webpki))
404    }
405
406    /// use a tls client certificate
407    ///
408    /// requires tls to be enabled
409    ///
410    /// ```no_run
411    /// use irc_connect::Stream;
412    /// use std::net::SocketAddr;
413    /// use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject};
414    ///
415    /// # #[tokio::main]
416    /// # async fn main() {
417    /// let builder = Stream::new_tcp("[::1]:6667").tls_danger_insecure(None);
418    /// let cert = CertificateDer::pem_file_iter("cert.pem")
419    ///     .unwrap()
420    ///     .collect::<Result<Vec<_>, _>>()
421    ///     .unwrap();
422    /// let key = PrivateKeyDer::from_pem_file("cert.key").unwrap();
423    /// let builder = builder.client_cert(cert, key);
424    /// # }
425    /// ```
426    pub fn client_cert(
427        mut self,
428        cert_chain: Vec<CertificateDer<'static>>,
429        key_der: PrivateKeyDer<'static>,
430    ) -> Self {
431        self.client_cert = Some(ClientCert {
432            cert_chain,
433            key_der,
434        });
435        self
436    }
437
438    /// finish building and open the connection
439    ///
440    /// ```no_run
441    /// # use irc_connect::Stream;
442    /// # #[tokio::main]
443    /// # async fn main() {
444    /// # let builder = Stream::new_tcp("[::1]:6667");
445    /// let stream = builder.connect().await.unwrap();
446    /// # }
447    /// ```
448    ///
449    /// # Errors
450    /// will return [`Error`] if an invalid combination of options has been
451    /// given to the builder, or if it is unable to connect
452    pub async fn connect(self) -> Result<Stream, Error> {
453        let tls = if let Some(mut params) = self.tls {
454            params.domain = params.domain.or_else(|| match &self.base {
455                BaseParams::Tcp(Ok(TargetAddr::Ip(addr))) => Some(ServerName::from(addr.ip())),
456                BaseParams::Tcp(Ok(TargetAddr::Domain(d, _))) => {
457                    ServerName::try_from(d.as_ref()).map(|s| s.to_owned()).ok()
458                }
459                _ => None,
460            });
461            Some(params)
462        } else {
463            None
464        };
465        let stream = if let Some(params) = self.socks {
466            let BaseParams::Tcp(target) = self.base else {
467                return Err(Error::SocksToUnsupported);
468            };
469            let target = target.map_err(Error::InvalidTarget)?;
470            let stream = BaseStream::Tcp {
471                inner: TcpStream::connect(params.proxy).await?,
472            };
473            match params.version {
474                SocksVersion::Socks4 => MaybeSocks::Socks4 {
475                    inner: if let Some(SocksAuth { username, .. }) = params.auth {
476                        Socks4Stream::connect_with_userid_and_socket(stream, target, username)
477                            .await?
478                    } else {
479                        Socks4Stream::connect_with_socket(stream, target).await?
480                    },
481                },
482                SocksVersion::Socks5 => MaybeSocks::Socks5 {
483                    inner: if let Some(SocksAuth { username, password }) = params.auth {
484                        Socks5Stream::connect_with_password_and_socket(
485                            stream, target, username, password,
486                        )
487                        .await?
488                    } else {
489                        Socks5Stream::connect_with_socket(stream, target).await?
490                    },
491                },
492            }
493        } else {
494            let stream = match self.base {
495                BaseParams::Tcp(addr) => {
496                    // FIXME: stick addr into connect directly, once tokio's ToSocketAddrs
497                    // stabilizes and TargetAddr implements it
498                    let inner = match addr.map_err(Error::InvalidTarget)? {
499                        TargetAddr::Ip(addr) => TcpStream::connect(addr).await?,
500                        TargetAddr::Domain(domain, port) => {
501                            TcpStream::connect((domain.as_ref(), port)).await?
502                        }
503                    };
504                    BaseStream::Tcp { inner }
505                }
506                BaseParams::Unix(path) => BaseStream::Unix {
507                    inner: UnixStream::connect(path).await?,
508                },
509            };
510            MaybeSocks::Clear { inner: stream }
511        };
512        let stream = if let Some(params) = tls {
513            let config = ClientConfig::builder();
514            let config = match params.verification {
515                TlsVerify::Insecure => {
516                    let provider = config.crypto_provider().clone();
517                    config
518                        .dangerous()
519                        .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
520                }
521                TlsVerify::CaStore(root) => config.with_root_certificates(root),
522                TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
523            };
524            let config = if let Some(ClientCert {
525                cert_chain,
526                key_der,
527            }) = self.client_cert
528            {
529                config.with_client_auth_cert(cert_chain, key_der)?
530            } else {
531                config.with_no_client_auth()
532            };
533            let connector = TlsConnector::from(Arc::new(config));
534            let domain = params.domain.ok_or(Error::NoServerName)?;
535            let inner = connector.connect(domain, stream).await?;
536            MaybeTls::Tls { inner }
537        } else {
538            if self.client_cert.is_some() {
539                return Err(Error::ClientCertNoTls);
540            }
541            MaybeTls::Plain { inner: stream }
542        };
543        Ok(Stream { inner: stream })
544    }
545}
546
547#[derive(Debug)]
548enum BaseParams<'a> {
549    Tcp(tokio_socks::Result<TargetAddr<'a>>),
550    Unix(&'a Path),
551}
552
553struct SocksParams<'a> {
554    version: SocksVersion,
555    proxy: SocketAddr,
556    auth: Option<SocksAuth<'a>>,
557}
558
559impl fmt::Debug for SocksParams<'_> {
560    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
561        fmt::Debug::fmt(&self.version, f)
562    }
563}
564
565struct SocksAuth<'a> {
566    username: &'a str,
567    password: &'a str,
568}
569
570#[derive(Debug)]
571enum SocksVersion {
572    Socks4,
573    Socks5,
574}
575
576#[derive(Debug)]
577struct TlsParams {
578    domain: Option<ServerName<'static>>,
579    verification: TlsVerify,
580}
581
582#[derive(Debug)]
583enum TlsVerify {
584    Insecure,
585    CaStore(Arc<RootCertStore>),
586    WebPki(Arc<WebPkiServerVerifier>),
587}
588
589#[derive(Debug)]
590struct ClientCert {
591    cert_chain: Vec<CertificateDer<'static>>,
592    key_der: PrivateKeyDer<'static>,
593}