irc3/client/
conn.rs

1use std::io::Error as IoError;
2use std::pin::Pin;
3use std::borrow::Cow;
4
5use futures::io::{AsyncRead, AsyncWrite};
6use futures::task::{Context, Poll};
7
8use async_std::net::{TcpStream};
9
10#[cfg(feature = "tls")]
11use async_tls::TlsConnector;
12#[cfg(feature = "tls")]
13use async_tls::client::TlsStream;
14
15fn normalize<'a, S: AsRef<str> + ?Sized>(dest: &'a S, port: &str) -> Cow<'a, str> {
16	let mut dest = Cow::from(dest.as_ref());
17
18	if dest.find(':').is_none() {
19		let dest = dest.to_mut();
20		dest.push(':');
21		dest.push_str(port);
22	}
23
24	dest
25}
26
27/// The underlying Connection of a client.
28pub enum Connection {
29	Unsecure(TcpStream),
30	#[cfg(feature = "tls")]
31	Secure(TlsStream<TcpStream>),
32}
33
34impl Connection {
35	/// Creates an unsecured (plaintext) connection to an IRC server.
36	///
37	/// The function expects a resolvable IP / DOMAIN in the form `<ip>[:<port>]` where `<port>` is
38	/// optional.
39	pub async fn unsecure<S: AsRef<str> + ?Sized>(dest: &S) -> Result<Connection, IoError> {
40		Ok(Connection::Unsecure(TcpStream::connect(normalize(dest, "6667").as_ref()).await?))
41	}
42
43	/// Creates a secured (TLS) connection to an IRC server.
44	///
45	/// The function expects a resolvable DOMAIN in the form `<domain>[:<port>]` where `<port>` is
46	/// optional.
47	#[cfg(feature = "tls")]
48	pub async fn secure<S: AsRef<str> + ?Sized>(dest: &S) -> Result<Connection, IoError> {
49		// instantiate a tlsconnector and dest
50		let dest = normalize(dest, "6697");
51		let connector = TlsConnector::default();
52		
53		// build tcp stream
54		let tcp = TcpStream::connect(dest.as_ref()).await?;
55
56		// attempt a TLS handshake
57		Ok(Connection::Secure(connector.connect(dest.as_ref().split(':').next().unwrap(), tcp)?.await?))
58	}
59}
60
61impl AsyncRead for Connection {
62	fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, IoError>> {
63		match *self {
64			Connection::Unsecure(ref mut st) => Pin::new(st).poll_read(cx, buf),
65			#[cfg(feature = "tls")]
66			Connection::Secure(ref mut sts) => Pin::new(sts).poll_read(cx, buf),
67		}
68	}
69}
70
71impl AsyncWrite for Connection {
72	fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, IoError>> {
73		match *self {
74			Connection::Unsecure(ref mut st) => Pin::new(st).poll_write(cx, buf),
75			#[cfg(feature = "tls")]
76			Connection::Secure(ref mut sts) => Pin::new(sts).poll_write(cx, buf),
77		}
78
79	}
80
81	fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
82		match *self {
83			Connection::Unsecure(ref mut st) => Pin::new(st).poll_flush(cx),
84			#[cfg(feature = "tls")]
85			Connection::Secure(ref mut sts) => Pin::new(sts).poll_flush(cx),
86		}
87	}
88
89	fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
90		match *self {
91			Connection::Unsecure(ref mut st) => Pin::new(st).poll_close(cx),
92			#[cfg(feature = "tls")]
93			Connection::Secure(ref mut sts) => Pin::new(sts).poll_close(cx),
94		}
95	}
96}