1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*
 * Created on Wed May 05 2021
 *
 * Copyright (c) 2021 Sayan Nandan <nandansayan@outlook.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *    http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
*/

//! # Asynchronous database connections
//!
//! This module provides async interfaces for database connections. There are two versions:
//! - The [`Connection`]: a connection to the database over Skyhash/TCP
//! - The [`TlsConnection`]: a connection to the database over Skyhash/TLS
//!
//! All the [async actions][crate::actions::AsyncActions] can be used on both the connection types
//!

use crate::deserializer::{ParseError, Parser, RawResponse};
use crate::error::SkyhashError;
use crate::IoResult;
use crate::Query;
use crate::SkyResult;
use bytes::{Buf, BytesMut};
use std::io::{Error as IoError, ErrorKind};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;

/// 4 KB Read Buffer
const BUF_CAP: usize = 4096;

macro_rules! impl_async_methods {
    ($ty:ty) => {
        impl $ty {
            /// This function will write a [`Query`] to the stream and read the response from the
            /// server. It will then determine if the returned response is complete or incomplete
            /// or invalid and return an appropriate variant of [`Error`](crate::error::Error) wrapped in [`IoResult`]
            /// for any I/O errors that may occur
            ///
            /// ## Panics
            /// This method will panic if the [`Query`] supplied is empty (i.e has no arguments)
            pub async fn run_simple_query(&mut self, query: &Query) -> SkyResult {
                assert!(query.len() != 0, "A `Query` cannot be of zero length!");
                query.write_query_to(&mut self.stream).await?;
                self.stream.flush().await?;
                loop {
                    if 0usize == self.stream.read_buf(&mut self.buffer).await? {
                        return Err(IoError::from(ErrorKind::ConnectionReset).into());
                    }
                    match self.try_response() {
                        Ok((query, forward_by)) => {
                            self.buffer.advance(forward_by);
                            match query {
                                RawResponse::SimpleQuery(s) => return Ok(s),
                                RawResponse::PipelinedQuery(_) => {
                                    unimplemented!("Pipelined queries aren't implemented yet")
                                }
                            }
                        }
                        Err(e) => match e {
                            ParseError::NotEnough => (),
                            ParseError::BadPacket | ParseError::UnexpectedByte => {
                                self.buffer.clear();
                                return Err(SkyhashError::InvalidResponse.into());
                            }
                            ParseError::DataTypeError => {
                                return Err(SkyhashError::ParseError.into())
                            }
                            ParseError::Empty => {
                                return Err(IoError::from(ErrorKind::ConnectionReset).into())
                            }
                            ParseError::UnknownDatatype => {
                                return Err(SkyhashError::UnknownDataType.into())
                            }
                        },
                    }
                }
            }
            /// This function is a subroutine of `run_query` used to parse the response packet
            fn try_response(&mut self) -> Result<(RawResponse, usize), ParseError> {
                if self.buffer.is_empty() {
                    // The connection was possibly reset
                    return Err(ParseError::Empty);
                }
                Parser::new(&self.buffer).parse()
            }
        }
        impl crate::actions::AsyncSocket for $ty {
            fn run(&mut self, q: Query) -> crate::actions::AsyncResult<SkyResult> {
                Box::pin(async move { self.run_simple_query(&q).await })
            }
        }
    };
}

cfg_async!(
    /// An asynchronous database connection over Skyhash/TCP
    pub struct Connection {
        stream: BufWriter<TcpStream>,
        buffer: BytesMut,
    }

    impl Connection {
        /// Create a new connection to a Skytable instance hosted on `host` and running on `port`
        pub async fn new(host: &str, port: u16) -> IoResult<Self> {
            let stream = TcpStream::connect((host, port)).await?;
            Ok(Connection {
                stream: BufWriter::new(stream),
                buffer: BytesMut::with_capacity(BUF_CAP),
            })
        }
    }
    impl_async_methods!(Connection);
);

cfg_async_ssl_any!(
    use tokio_openssl::SslStream;
    use openssl::ssl::{SslContext, SslMethod, Ssl};
    use core::pin::Pin;
    use crate::error::SslError;

    /// An asynchronous database connection over Skyhash/TLS
    pub struct TlsConnection {
        stream: SslStream<TcpStream>,
        buffer: BytesMut
    }

    impl TlsConnection {
        /// Pass the `host` and `port` and the path to the CA certificate to use for TLS
        pub async fn new(host: &str, port: u16, sslcert: &str) -> Result<Self, SslError> {
            let mut ctx = SslContext::builder(SslMethod::tls_client())?;
            ctx.set_ca_file(sslcert)?;
            let ssl = Ssl::new(&ctx.build())?;
            let stream = TcpStream::connect((host, port)).await?;
            let mut stream = SslStream::new(ssl, stream)?;
            Pin::new(&mut stream).connect().await?;
            Ok(Self {
                stream,
                buffer: BytesMut::with_capacity(BUF_CAP),
            })
        }
    }
    impl_async_methods!(TlsConnection);
);