kvsd/client/
tcp.rs

1use std::net::ToSocketAddrs;
2use std::sync::Arc;
3use std::{convert::TryFrom, io};
4
5use async_trait::async_trait;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpStream;
8use tokio_rustls::{
9    client::TlsStream,
10    rustls::{
11        client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
12        pki_types, SignatureScheme,
13    },
14};
15use tokio_rustls::{rustls, TlsConnector};
16
17use crate::client::Api;
18use crate::common::info;
19use crate::protocol::connection::Connection;
20use crate::protocol::message::{Authenticate, Delete, Get, Message, Ping, Set};
21use crate::protocol::{Key, Value};
22use crate::{KvsdError, Result};
23
24/// Implementation of client api by tcp.
25pub struct Client<T> {
26    connection: Connection<T>,
27}
28
29/// A client that is not authenticated by the server.
30/// it provide processing allowed to clients that are not authenticate.
31pub struct UnauthenticatedClient<T> {
32    client: Client<T>,
33}
34
35impl<T> UnauthenticatedClient<T>
36where
37    T: AsyncWrite + AsyncRead + Unpin,
38{
39    /// Construct Client by given stream.
40    pub fn new(stream: T) -> Self {
41        Self {
42            client: Client::new(stream),
43        }
44    }
45
46    /// Try authenticate by given credential.
47    pub async fn authenticate<S1, S2>(mut self, username: S1, password: S2) -> Result<Client<T>>
48    where
49        S1: Into<String>,
50        S2: Into<String>,
51    {
52        let authenticate = Authenticate::new(username.into(), password.into());
53        self.client.connection.write_message(authenticate).await?;
54        match self.client.connection.read_message().await? {
55            Some(Message::Success(_)) => Ok(self.client),
56            Some(Message::Fail(_)) => Err(KvsdError::Unauthenticated),
57            // format!(..).into() does not work :(
58            msg => Err(KvsdError::Internal(Box::<
59                dyn std::error::Error + Send + Sync,
60            >::from(format!(
61                "unexpected message {:?}",
62                msg
63            )))),
64        }
65    }
66}
67
68impl UnauthenticatedClient<TcpStream> {
69    /// Return client that is not protected by TLS for tcp communication.
70    pub async fn insecure_from_addr(host: impl AsRef<str>, port: u16) -> Result<Self> {
71        let addr = (host.as_ref(), port)
72            .to_socket_addrs()?
73            .next()
74            .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
75
76        info!(%addr, "Connecting");
77
78        let stream = tokio::net::TcpStream::connect(addr).await?;
79
80        Ok(UnauthenticatedClient::new(stream))
81    }
82}
83
84impl UnauthenticatedClient<TlsStream<TcpStream>> {
85    /// Return the client with a TLS connection to the given address.
86    pub async fn from_addr(host: impl Into<String>, port: u16) -> Result<Self> {
87        let host = host.into();
88        let addr = (host.as_str(), port)
89            .to_socket_addrs()?
90            .next()
91            .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
92        let tls_config = rustls::ClientConfig::builder()
93            .dangerous()
94            .with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier::new()))
95            .with_no_client_auth();
96
97        let connector = TlsConnector::from(Arc::new(tls_config));
98
99        // TODO: remove hard code
100        let domain = pki_types::ServerName::try_from("localhost")
101            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid host"))?
102            .to_owned();
103
104        info!(%addr,?domain, "Connecting");
105
106        let stream = tokio::net::TcpStream::connect(addr).await?;
107
108        Ok(UnauthenticatedClient::new(
109            connector.connect(domain, stream).await?,
110        ))
111    }
112}
113
114impl<T> Client<T>
115where
116    T: AsyncWrite + AsyncRead + Unpin,
117{
118    fn new(stream: T) -> Self {
119        Self {
120            connection: Connection::new(stream, Some(1024 * 4)),
121        }
122    }
123}
124
125#[async_trait]
126impl<T> Api for Client<T>
127where
128    T: AsyncWrite + AsyncRead + Unpin + Send,
129{
130    // Return ping latency.
131    async fn ping(&mut self) -> Result<chrono::Duration> {
132        let ping = Ping::new().record_client_time();
133        self.connection.write_message(ping).await?;
134        match self.connection.read_message().await? {
135            Some(Message::Ping(ping)) => Ok(ping.latency().unwrap()),
136            msg => Err(format!("unexpected message {:?}", msg).into()),
137        }
138    }
139
140    async fn set(&mut self, key: Key, value: Value) -> Result<()> {
141        let set = Set::new(key, value);
142        self.connection.write_message(set).await?;
143        match self.connection.read_message().await? {
144            Some(Message::Success(_)) => Ok(()),
145            msg => Err(KvsdError::Internal(Box::<
146                dyn std::error::Error + Send + Sync,
147            >::from(format!(
148                "unexpected message: {:?}",
149                msg
150            )))),
151        }
152    }
153
154    async fn get(&mut self, key: Key) -> Result<Option<Value>> {
155        let get = Get::new(key);
156        self.connection.write_message(get).await?;
157        match self.connection.read_message().await? {
158            Some(Message::Success(success)) => Ok(success.value()),
159            _ => unreachable!(),
160        }
161    }
162
163    async fn delete(&mut self, key: Key) -> Result<Option<Value>> {
164        let delete = Delete::new(key);
165        self.connection.write_message(delete).await?;
166        match self.connection.read_message().await? {
167            Some(Message::Success(success)) => Ok(success.value()),
168            _ => unreachable!(),
169        }
170    }
171}
172
173#[derive(Debug)]
174struct DangerousServerCertVerifier {}
175
176impl DangerousServerCertVerifier {
177    fn new() -> Self {
178        Self {}
179    }
180}
181
182impl ServerCertVerifier for DangerousServerCertVerifier {
183    fn verify_server_cert(
184        &self,
185        _end_entity: &pki_types::CertificateDer<'_>,
186        _intermediates: &[pki_types::CertificateDer<'_>],
187        _server_name: &pki_types::ServerName<'_>,
188        _ocsp_response: &[u8],
189        _now: pki_types::UnixTime,
190    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
191        Ok(ServerCertVerified::assertion())
192    }
193
194    fn verify_tls12_signature(
195        &self,
196        _message: &[u8],
197        _cert: &pki_types::CertificateDer<'_>,
198        _dss: &rustls::DigitallySignedStruct,
199    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
200        Ok(HandshakeSignatureValid::assertion())
201    }
202
203    fn verify_tls13_signature(
204        &self,
205        _message: &[u8],
206        _cert: &pki_types::CertificateDer<'_>,
207        _dss: &rustls::DigitallySignedStruct,
208    ) -> std::prelude::v1::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
209    {
210        Ok(HandshakeSignatureValid::assertion())
211    }
212
213    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
214        vec![
215            SignatureScheme::RSA_PKCS1_SHA1,
216            SignatureScheme::RSA_PKCS1_SHA256,
217            SignatureScheme::RSA_PKCS1_SHA384,
218            SignatureScheme::RSA_PKCS1_SHA512,
219            SignatureScheme::ECDSA_NISTP256_SHA256,
220            SignatureScheme::ECDSA_NISTP384_SHA384,
221            SignatureScheme::ECDSA_NISTP521_SHA512,
222            SignatureScheme::ED25519,
223            SignatureScheme::RSA_PSS_SHA256,
224            SignatureScheme::RSA_PSS_SHA384,
225            SignatureScheme::RSA_PSS_SHA512,
226        ]
227    }
228}