1use std::net::ToSocketAddrs;
2use std::sync::Arc;
3use std::{convert::TryFrom, io};
4
5use async_trait::async_trait;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpStream;
8use tokio_rustls::{
9 client::TlsStream,
10 rustls::{
11 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
12 pki_types, SignatureScheme,
13 },
14};
15use tokio_rustls::{rustls, TlsConnector};
16
17use crate::client::Api;
18use crate::common::info;
19use crate::protocol::connection::Connection;
20use crate::protocol::message::{Authenticate, Delete, Get, Message, Ping, Set};
21use crate::protocol::{Key, Value};
22use crate::{KvsdError, Result};
23
24pub struct Client<T> {
26 connection: Connection<T>,
27}
28
29pub struct UnauthenticatedClient<T> {
32 client: Client<T>,
33}
34
35impl<T> UnauthenticatedClient<T>
36where
37 T: AsyncWrite + AsyncRead + Unpin,
38{
39 pub fn new(stream: T) -> Self {
41 Self {
42 client: Client::new(stream),
43 }
44 }
45
46 pub async fn authenticate<S1, S2>(mut self, username: S1, password: S2) -> Result<Client<T>>
48 where
49 S1: Into<String>,
50 S2: Into<String>,
51 {
52 let authenticate = Authenticate::new(username.into(), password.into());
53 self.client.connection.write_message(authenticate).await?;
54 match self.client.connection.read_message().await? {
55 Some(Message::Success(_)) => Ok(self.client),
56 Some(Message::Fail(_)) => Err(KvsdError::Unauthenticated),
57 msg => Err(KvsdError::Internal(Box::<
59 dyn std::error::Error + Send + Sync,
60 >::from(format!(
61 "unexpected message {:?}",
62 msg
63 )))),
64 }
65 }
66}
67
68impl UnauthenticatedClient<TcpStream> {
69 pub async fn insecure_from_addr(host: impl AsRef<str>, port: u16) -> Result<Self> {
71 let addr = (host.as_ref(), port)
72 .to_socket_addrs()?
73 .next()
74 .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
75
76 info!(%addr, "Connecting");
77
78 let stream = tokio::net::TcpStream::connect(addr).await?;
79
80 Ok(UnauthenticatedClient::new(stream))
81 }
82}
83
84impl UnauthenticatedClient<TlsStream<TcpStream>> {
85 pub async fn from_addr(host: impl Into<String>, port: u16) -> Result<Self> {
87 let host = host.into();
88 let addr = (host.as_str(), port)
89 .to_socket_addrs()?
90 .next()
91 .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
92 let tls_config = rustls::ClientConfig::builder()
93 .dangerous()
94 .with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier::new()))
95 .with_no_client_auth();
96
97 let connector = TlsConnector::from(Arc::new(tls_config));
98
99 let domain = pki_types::ServerName::try_from("localhost")
101 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid host"))?
102 .to_owned();
103
104 info!(%addr,?domain, "Connecting");
105
106 let stream = tokio::net::TcpStream::connect(addr).await?;
107
108 Ok(UnauthenticatedClient::new(
109 connector.connect(domain, stream).await?,
110 ))
111 }
112}
113
114impl<T> Client<T>
115where
116 T: AsyncWrite + AsyncRead + Unpin,
117{
118 fn new(stream: T) -> Self {
119 Self {
120 connection: Connection::new(stream, Some(1024 * 4)),
121 }
122 }
123}
124
125#[async_trait]
126impl<T> Api for Client<T>
127where
128 T: AsyncWrite + AsyncRead + Unpin + Send,
129{
130 async fn ping(&mut self) -> Result<chrono::Duration> {
132 let ping = Ping::new().record_client_time();
133 self.connection.write_message(ping).await?;
134 match self.connection.read_message().await? {
135 Some(Message::Ping(ping)) => Ok(ping.latency().unwrap()),
136 msg => Err(format!("unexpected message {:?}", msg).into()),
137 }
138 }
139
140 async fn set(&mut self, key: Key, value: Value) -> Result<()> {
141 let set = Set::new(key, value);
142 self.connection.write_message(set).await?;
143 match self.connection.read_message().await? {
144 Some(Message::Success(_)) => Ok(()),
145 msg => Err(KvsdError::Internal(Box::<
146 dyn std::error::Error + Send + Sync,
147 >::from(format!(
148 "unexpected message: {:?}",
149 msg
150 )))),
151 }
152 }
153
154 async fn get(&mut self, key: Key) -> Result<Option<Value>> {
155 let get = Get::new(key);
156 self.connection.write_message(get).await?;
157 match self.connection.read_message().await? {
158 Some(Message::Success(success)) => Ok(success.value()),
159 _ => unreachable!(),
160 }
161 }
162
163 async fn delete(&mut self, key: Key) -> Result<Option<Value>> {
164 let delete = Delete::new(key);
165 self.connection.write_message(delete).await?;
166 match self.connection.read_message().await? {
167 Some(Message::Success(success)) => Ok(success.value()),
168 _ => unreachable!(),
169 }
170 }
171}
172
173#[derive(Debug)]
174struct DangerousServerCertVerifier {}
175
176impl DangerousServerCertVerifier {
177 fn new() -> Self {
178 Self {}
179 }
180}
181
182impl ServerCertVerifier for DangerousServerCertVerifier {
183 fn verify_server_cert(
184 &self,
185 _end_entity: &pki_types::CertificateDer<'_>,
186 _intermediates: &[pki_types::CertificateDer<'_>],
187 _server_name: &pki_types::ServerName<'_>,
188 _ocsp_response: &[u8],
189 _now: pki_types::UnixTime,
190 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
191 Ok(ServerCertVerified::assertion())
192 }
193
194 fn verify_tls12_signature(
195 &self,
196 _message: &[u8],
197 _cert: &pki_types::CertificateDer<'_>,
198 _dss: &rustls::DigitallySignedStruct,
199 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
200 Ok(HandshakeSignatureValid::assertion())
201 }
202
203 fn verify_tls13_signature(
204 &self,
205 _message: &[u8],
206 _cert: &pki_types::CertificateDer<'_>,
207 _dss: &rustls::DigitallySignedStruct,
208 ) -> std::prelude::v1::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
209 {
210 Ok(HandshakeSignatureValid::assertion())
211 }
212
213 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
214 vec![
215 SignatureScheme::RSA_PKCS1_SHA1,
216 SignatureScheme::RSA_PKCS1_SHA256,
217 SignatureScheme::RSA_PKCS1_SHA384,
218 SignatureScheme::RSA_PKCS1_SHA512,
219 SignatureScheme::ECDSA_NISTP256_SHA256,
220 SignatureScheme::ECDSA_NISTP384_SHA384,
221 SignatureScheme::ECDSA_NISTP521_SHA512,
222 SignatureScheme::ED25519,
223 SignatureScheme::RSA_PSS_SHA256,
224 SignatureScheme::RSA_PSS_SHA384,
225 SignatureScheme::RSA_PSS_SHA512,
226 ]
227 }
228}