Skip to main content

irc_connect/
lib.rs

1// SPDX-FileCopyrightText: 2025 xfnw
2//
3// SPDX-License-Identifier: MIT
4
5//! an abstraction over the kinds of connections useful for irc clients
6
7use pin_project_lite::pin_project;
8use std::{
9    fmt,
10    net::SocketAddr,
11    path::Path,
12    pin::Pin,
13    sync::Arc,
14    task::{Context, Poll},
15};
16use tokio::{
17    io::{AsyncRead, AsyncWrite, ReadBuf},
18    net::{TcpStream, UnixStream},
19};
20use tokio_rustls::{
21    client::TlsStream,
22    rustls::{
23        client::WebPkiServerVerifier,
24        pki_types::{CertificateDer, PrivateKeyDer, ServerName},
25        ClientConfig, RootCertStore,
26    },
27    TlsConnector,
28};
29use tokio_socks::{
30    tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
31    IntoTargetAddr, TargetAddr,
32};
33
34pub use tokio_rustls;
35
36mod danger;
37
38#[deprecated(since = "0.2.1", note = "Stream was renamed to Connection")]
39pub type Stream = Connection;
40#[deprecated(
41    since = "0.2.1",
42    note = "StreamBuilder was renamed to ConnectionBuilder"
43)]
44pub type StreamBuilder<'a> = ConnectionBuilder<'a>;
45
46/// error type returned by `irc_connect`
47#[derive(Debug)]
48#[non_exhaustive]
49pub enum Error {
50    /// you specified a tls client cert without using tls
51    ClientCertNoTls,
52    /// failed to connect
53    Connect(std::io::Error),
54    /// could not sock
55    Socks(tokio_socks::Error),
56    /// could not rustls
57    Rustls(tokio_rustls::rustls::Error),
58    /// socks cannot connect to unix sockets
59    SocksToUnsupported,
60    /// invalid target address
61    InvalidTarget(tokio_socks::Error),
62    /// no tls servername provided and failed to guess it
63    NoServerName,
64}
65
66impl fmt::Display for Error {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self {
69            Self::ClientCertNoTls => write!(f, "you specified a client cert without using tls"),
70            Self::Connect(e) => write!(f, "failed to connect: {e}"),
71            Self::Socks(e) => write!(f, "could not sock: {e}"),
72            Self::Rustls(e) => write!(f, "could not rustls: {e}"),
73            Self::SocksToUnsupported => write!(f, "socks cannot connect to unix sockets"),
74            Self::InvalidTarget(e) => write!(f, "invalid target address: {e}"),
75            Self::NoServerName => write!(f, "no tls servername provided and failed to guess it"),
76        }
77    }
78}
79
80impl std::error::Error for Error {}
81
82impl From<std::io::Error> for Error {
83    fn from(value: std::io::Error) -> Self {
84        Self::Connect(value)
85    }
86}
87
88impl From<tokio_socks::Error> for Error {
89    fn from(value: tokio_socks::Error) -> Self {
90        Self::Socks(value)
91    }
92}
93
94impl From<tokio_rustls::rustls::Error> for Error {
95    fn from(value: tokio_rustls::rustls::Error) -> Self {
96        Self::Rustls(value)
97    }
98}
99
100pin_project! {
101    /// an open connection
102    #[derive(Debug)]
103    pub struct Connection {
104        #[pin]
105        inner: MaybeTls,
106    }
107}
108
109impl Connection {
110    /// start building a new stream based on a tcp connection
111    ///
112    /// ```no_run
113    /// use irc_connect::Connection;
114    /// # #[tokio::main]
115    /// # async fn main() {
116    /// let stream = Connection::new_tcp("[::1]:6667").connect().await.unwrap();
117    /// # }
118    /// ```
119    pub fn new_tcp<'a>(addr: impl IntoTargetAddr<'a>) -> ConnectionBuilder<'a> {
120        ConnectionBuilder::new(BaseParams::Tcp(addr.into_target_addr()))
121    }
122    /// start building a new stream based on a unix socket
123    ///
124    /// ```no_run
125    /// use std::path::Path;
126    /// use irc_connect::Connection;
127    /// # #[tokio::main]
128    /// # async fn main() {
129    /// let stream = Connection::new_unix(Path::new("./my-unix-socket")).connect().await.unwrap();
130    /// # }
131    /// ```
132    pub fn new_unix(path: &Path) -> ConnectionBuilder<'_> {
133        ConnectionBuilder::new(BaseParams::Unix(path))
134    }
135}
136
137impl AsyncRead for Connection {
138    #[inline]
139    fn poll_read(
140        self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142        buf: &mut ReadBuf<'_>,
143    ) -> Poll<std::io::Result<()>> {
144        self.project().inner.poll_read(cx, buf)
145    }
146}
147
148impl AsyncWrite for Connection {
149    #[inline]
150    fn poll_write(
151        self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153        buf: &[u8],
154    ) -> Poll<Result<usize, std::io::Error>> {
155        self.project().inner.poll_write(cx, buf)
156    }
157    #[inline]
158    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
159        self.project().inner.poll_flush(cx)
160    }
161    #[inline]
162    fn poll_shutdown(
163        self: Pin<&mut Self>,
164        cx: &mut Context<'_>,
165    ) -> Poll<Result<(), std::io::Error>> {
166        self.project().inner.poll_shutdown(cx)
167    }
168}
169
170pin_project! {
171    #[project = MaybeTlsProj]
172    #[derive(Debug)]
173    #[allow(clippy::large_enum_variant)] // you should use tls most of the time
174    enum MaybeTls {
175        Plain {
176            #[pin]
177            inner: MaybeSocks,
178        },
179        Tls {
180            #[pin]
181            inner: TlsStream<MaybeSocks>,
182        },
183    }
184}
185
186macro_rules! trivial_impl {
187    ($target:ty, ($($arm:path),*)) => {
188        impl AsyncRead for $target {
189            #[inline]
190            fn poll_read(
191                self: Pin<&mut Self>,
192                cx: &mut Context<'_>,
193                buf: &mut ReadBuf<'_>,
194            ) -> Poll<std::io::Result<()>> {
195                match self.project() {
196                    $($arm { inner } => inner.poll_read(cx, buf),)*
197                }
198            }
199        }
200
201        impl AsyncWrite for $target {
202            #[inline]
203            fn poll_write(
204                self: Pin<&mut Self>,
205                cx: &mut Context<'_>,
206                buf: &[u8],
207            ) -> Poll<Result<usize, std::io::Error>> {
208                match self.project() {
209                    $($arm { inner } => inner.poll_write(cx, buf),)*
210                }
211            }
212            #[inline]
213            fn poll_flush(
214                self: Pin<&mut Self>,
215                cx: &mut Context<'_>,
216            ) -> Poll<Result<(), std::io::Error>> {
217                match self.project() {
218                    $($arm { inner } => inner.poll_flush(cx),)*
219                }
220            }
221            #[inline]
222            fn poll_shutdown(
223                self: Pin<&mut Self>,
224                cx: &mut Context<'_>,
225            ) -> Poll<Result<(), std::io::Error>> {
226                match self.project() {
227                    $($arm { inner } => inner.poll_shutdown(cx),)*
228                }
229            }
230        }
231    };
232}
233
234trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
235
236pin_project! {
237    #[project = MaybeSocksProj]
238    #[derive(Debug)]
239    enum MaybeSocks {
240        Clear {
241            #[pin]
242            inner: BaseStream,
243        },
244        Socks4 {
245            #[pin]
246            inner: Socks4Stream<BaseStream>,
247        },
248        Socks5 {
249            #[pin]
250            inner: Socks5Stream<BaseStream>,
251        },
252    }
253}
254
255trivial_impl!(
256    MaybeSocks,
257    (
258        MaybeSocksProj::Clear,
259        MaybeSocksProj::Socks4,
260        MaybeSocksProj::Socks5
261    )
262);
263
264pin_project! {
265    #[project = BaseStreamProj]
266    #[derive(Debug)]
267    enum BaseStream {
268        Tcp {
269            #[pin]
270            inner: TcpStream,
271        },
272        Unix {
273            #[pin]
274            inner: UnixStream,
275        },
276    }
277}
278
279trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
280
281/// a builder for [`Connection`]
282#[derive(Debug)]
283#[must_use = "this does nothing unless you finish building"]
284pub struct ConnectionBuilder<'a> {
285    base: BaseParams<'a>,
286    socks: Option<SocksParams<'a>>,
287    tls: Option<TlsParams>,
288    client_cert: Option<ClientCert>,
289}
290
291impl<'a> ConnectionBuilder<'a> {
292    fn new(base: BaseParams<'a>) -> Self {
293        Self {
294            base,
295            socks: None,
296            tls: None,
297            client_cert: None,
298        }
299    }
300
301    fn socks(
302        mut self,
303        version: SocksVersion,
304        proxy: SocketAddr,
305        auth: Option<SocksAuth<'a>>,
306    ) -> Self {
307        self.socks = Some(SocksParams {
308            version,
309            proxy,
310            auth,
311        });
312        self
313    }
314
315    /// enable socks4 proxying
316    ///
317    /// ```
318    /// # use irc_connect::Connection;
319    /// # #[tokio::main]
320    /// # async fn main() {
321    /// # let builder = Connection::new_tcp("[::1]:6667");
322    /// let builder = builder.socks4("127.0.0.1:9050".parse().unwrap());
323    /// # }
324    /// ```
325    pub fn socks4(self, proxy: SocketAddr) -> Self {
326        self.socks(SocksVersion::Socks4, proxy, None)
327    }
328
329    /// enable socks4 proxying with a userid
330    ///
331    /// ```
332    /// # use irc_connect::Connection;
333    /// # #[tokio::main]
334    /// # async fn main() {
335    /// # let builder = Connection::new_tcp("[::1]:6667");
336    /// let builder = builder.socks4_with_userid("127.0.0.1:9050".parse().unwrap(), "meow");
337    /// # }
338    /// ```
339    pub fn socks4_with_userid(self, proxy: SocketAddr, userid: &'a str) -> Self {
340        self.socks(
341            SocksVersion::Socks4,
342            proxy,
343            Some(SocksAuth {
344                username: userid,
345                password: "h",
346            }),
347        )
348    }
349
350    /// enable socks5 proxying
351    ///
352    /// ```
353    /// # use irc_connect::Connection;
354    /// # #[tokio::main]
355    /// # async fn main() {
356    /// # let builder = Connection::new_tcp("[::1]:6667");
357    /// let builder = builder.socks5("127.0.0.1:9050".parse().unwrap());
358    /// # }
359    /// ```
360    pub fn socks5(self, proxy: SocketAddr) -> Self {
361        self.socks(SocksVersion::Socks5, proxy, None)
362    }
363
364    /// enable socks5 proxying with password authentication
365    ///
366    /// ```
367    /// # use irc_connect::Connection;
368    /// # #[tokio::main]
369    /// # async fn main() {
370    /// # let builder = Connection::new_tcp("[::1]:6667");
371    /// let builder =
372    ///     builder.socks5_with_password("127.0.0.1:9050".parse().unwrap(), "AzureDiamond", "hunter2");
373    /// # }
374    /// ```
375    pub fn socks5_with_password(
376        self,
377        proxy: SocketAddr,
378        username: &'a str,
379        password: &'a str,
380    ) -> Self {
381        self.socks(
382            SocksVersion::Socks5,
383            proxy,
384            Some(SocksAuth { username, password }),
385        )
386    }
387
388    fn tls(mut self, domain: Option<ServerName<'static>>, verification: TlsVerify) -> Self {
389        self.tls = Some(TlsParams {
390            domain,
391            verification,
392        });
393        self
394    }
395
396    /// enable tls without any verification
397    ///
398    /// ```
399    /// use tokio_rustls::rustls::pki_types::ServerName;
400    /// # use irc_connect::Connection;
401    /// # #[tokio::main]
402    /// # async fn main() {
403    /// # let builder = Connection::new_tcp("[::1]:6667");
404    /// let builder = builder.tls_danger_insecure(Some(ServerName::try_from("google.com").unwrap()));
405    /// # }
406    /// ```
407    pub fn tls_danger_insecure(self, domain: Option<ServerName<'static>>) -> Self {
408        self.tls(domain, TlsVerify::Insecure)
409    }
410
411    /// enable tls with root certificate verification
412    ///
413    /// can also be used to pin a self-signed cert as long as it has a `CA:FALSE` constraint
414    ///
415    /// ```no_run
416    /// use tokio_rustls::rustls::RootCertStore;
417    /// use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
418    ///
419    /// # use irc_connect::Connection;
420    /// # #[tokio::main]
421    /// # async fn main() {
422    /// # let builder = Connection::new_tcp("[::1]:6667");
423    /// let mut root = RootCertStore::empty();
424    /// root.add_parsable_certificates(
425    ///     CertificateDer::pem_file_iter("/etc/ssl/certs/ca-bundle.crt")
426    ///         .unwrap()
427    ///         .flatten(),
428    /// );
429    /// let builder = builder.tls_with_root(None, root);
430    /// # }
431    /// ```
432    pub fn tls_with_root(
433        self,
434        domain: Option<ServerName<'static>>,
435        root: impl Into<Arc<RootCertStore>>,
436    ) -> Self {
437        self.tls(domain, TlsVerify::CaStore(root.into()))
438    }
439
440    /// enable tls with a webpki verifier
441    pub fn tls_with_webpki(
442        self,
443        domain: Option<ServerName<'static>>,
444        webpki: Arc<WebPkiServerVerifier>,
445    ) -> Self {
446        self.tls(domain, TlsVerify::WebPki(webpki))
447    }
448
449    /// use a tls client certificate
450    ///
451    /// requires tls to be enabled
452    ///
453    /// ```no_run
454    /// use irc_connect::Connection;
455    /// use std::net::SocketAddr;
456    /// use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject};
457    ///
458    /// # #[tokio::main]
459    /// # async fn main() {
460    /// let builder = Connection::new_tcp("[::1]:6667").tls_danger_insecure(None);
461    /// let cert = CertificateDer::pem_file_iter("cert.pem")
462    ///     .unwrap()
463    ///     .collect::<Result<Vec<_>, _>>()
464    ///     .unwrap();
465    /// let key = PrivateKeyDer::from_pem_file("cert.key").unwrap();
466    /// let builder = builder.client_cert(cert, key);
467    /// # }
468    /// ```
469    pub fn client_cert(
470        mut self,
471        cert_chain: Vec<CertificateDer<'static>>,
472        key_der: PrivateKeyDer<'static>,
473    ) -> Self {
474        self.client_cert = Some(ClientCert {
475            cert_chain,
476            key_der,
477        });
478        self
479    }
480
481    /// finish building and open the connection
482    ///
483    /// ```no_run
484    /// # use irc_connect::Connection;
485    /// # #[tokio::main]
486    /// # async fn main() {
487    /// # let builder = Connection::new_tcp("[::1]:6667");
488    /// let stream = builder.connect().await.unwrap();
489    /// # }
490    /// ```
491    ///
492    /// # Errors
493    /// will return [`Error`] if an invalid combination of options has been
494    /// given to the builder, or if it is unable to connect
495    pub async fn connect(self) -> Result<Connection, Error> {
496        let tls = if let Some(mut params) = self.tls {
497            params.domain = params.domain.or_else(|| match &self.base {
498                BaseParams::Tcp(Ok(TargetAddr::Ip(addr))) => Some(ServerName::from(addr.ip())),
499                BaseParams::Tcp(Ok(TargetAddr::Domain(d, _))) => {
500                    ServerName::try_from(d.as_ref()).map(|s| s.to_owned()).ok()
501                }
502                _ => None,
503            });
504            Some(params)
505        } else {
506            None
507        };
508        let stream = if let Some(params) = self.socks {
509            let BaseParams::Tcp(target) = self.base else {
510                return Err(Error::SocksToUnsupported);
511            };
512            let target = target.map_err(Error::InvalidTarget)?;
513            let stream = BaseStream::Tcp {
514                inner: TcpStream::connect(params.proxy).await?,
515            };
516            match params.version {
517                SocksVersion::Socks4 => MaybeSocks::Socks4 {
518                    inner: if let Some(SocksAuth { username, .. }) = params.auth {
519                        Socks4Stream::connect_with_userid_and_socket(stream, target, username)
520                            .await?
521                    } else {
522                        Socks4Stream::connect_with_socket(stream, target).await?
523                    },
524                },
525                SocksVersion::Socks5 => MaybeSocks::Socks5 {
526                    inner: if let Some(SocksAuth { username, password }) = params.auth {
527                        Socks5Stream::connect_with_password_and_socket(
528                            stream, target, username, password,
529                        )
530                        .await?
531                    } else {
532                        Socks5Stream::connect_with_socket(stream, target).await?
533                    },
534                },
535            }
536        } else {
537            let stream = match self.base {
538                BaseParams::Tcp(addr) => {
539                    // FIXME: stick addr into connect directly, once tokio's ToSocketAddrs
540                    // stabilizes and TargetAddr implements it
541                    let inner = match addr.map_err(Error::InvalidTarget)? {
542                        TargetAddr::Ip(addr) => TcpStream::connect(addr).await?,
543                        TargetAddr::Domain(domain, port) => {
544                            TcpStream::connect((domain.as_ref(), port)).await?
545                        }
546                    };
547                    BaseStream::Tcp { inner }
548                }
549                BaseParams::Unix(path) => BaseStream::Unix {
550                    inner: UnixStream::connect(path).await?,
551                },
552            };
553            MaybeSocks::Clear { inner: stream }
554        };
555        let stream = if let Some(params) = tls {
556            let config = ClientConfig::builder();
557            let config = match params.verification {
558                TlsVerify::Insecure => {
559                    let provider = config.crypto_provider().clone();
560                    config
561                        .dangerous()
562                        .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
563                }
564                TlsVerify::CaStore(root) => config.with_root_certificates(root),
565                TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
566            };
567            let config = if let Some(ClientCert {
568                cert_chain,
569                key_der,
570            }) = self.client_cert
571            {
572                config.with_client_auth_cert(cert_chain, key_der)?
573            } else {
574                config.with_no_client_auth()
575            };
576            let connector = TlsConnector::from(Arc::new(config));
577            let domain = params.domain.ok_or(Error::NoServerName)?;
578            let inner = connector.connect(domain, stream).await?;
579            MaybeTls::Tls { inner }
580        } else {
581            if self.client_cert.is_some() {
582                return Err(Error::ClientCertNoTls);
583            }
584            MaybeTls::Plain { inner: stream }
585        };
586        Ok(Connection { inner: stream })
587    }
588}
589
590#[derive(Debug)]
591enum BaseParams<'a> {
592    Tcp(tokio_socks::Result<TargetAddr<'a>>),
593    Unix(&'a Path),
594}
595
596struct SocksParams<'a> {
597    version: SocksVersion,
598    proxy: SocketAddr,
599    auth: Option<SocksAuth<'a>>,
600}
601
602impl fmt::Debug for SocksParams<'_> {
603    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604        fmt::Debug::fmt(&self.version, f)
605    }
606}
607
608struct SocksAuth<'a> {
609    username: &'a str,
610    password: &'a str,
611}
612
613#[derive(Debug)]
614enum SocksVersion {
615    Socks4,
616    Socks5,
617}
618
619#[derive(Debug)]
620struct TlsParams {
621    domain: Option<ServerName<'static>>,
622    verification: TlsVerify,
623}
624
625#[derive(Debug)]
626enum TlsVerify {
627    Insecure,
628    CaStore(Arc<RootCertStore>),
629    WebPki(Arc<WebPkiServerVerifier>),
630}
631
632#[derive(Debug)]
633struct ClientCert {
634    cert_chain: Vec<CertificateDer<'static>>,
635    key_der: PrivateKeyDer<'static>,
636}