use buffer;
use buffer::{sync, Buffer};
use futures;
use futures::sync::{mpsc, oneshot};
use std;
use std::borrow::Cow;
use std::collections::HashMap;
use tokio::net::TcpStream;
use tokio_io::io::{flush, read_exact, write_all, Flush, ReadExact, WriteAll};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_timer::Delay;
use {
Connecting, Connection, Error, HandleRow, Parameters, Request, Rows, SendRequest,
SendRequestErr, SendState, WriteFuture,
};
use bytes::buf::Buf;
use futures::{Async, AsyncSink, Future, Poll, Sink, Stream};
use tokio;
use tokio_openssl::{ConnectAsync, ConnectConfigurationExt, SslConnectorExt, SslStream};
struct DatabaseConnection<R: Request + HandleRow, P: AsRef<Parameters>> {
state: Option<DbState<R, P>>,
parameters: P,
sender: mpsc::UnboundedSender<Option<(R, oneshot::Sender<(R, Option<Error>)>)>>,
receiver: mpsc::UnboundedReceiver<Option<(R, oneshot::Sender<(R, Option<Error>)>)>>,
query_buffer: Vec<u8>,
parse_buffer: Vec<u8>,
leftover_oneshot: Option<(oneshot::Sender<(R, Option<Error>)>, R)>,
}
enum OptionSSL {
Plain(tokio::net::TcpStream),
SSL(SslStream<tokio::net::TcpStream>),
}
impl AsyncWrite for OptionSSL {
fn shutdown(&mut self) -> Poll<(), std::io::Error> {
match *self {
OptionSSL::Plain(ref mut p) => p.shutdown(),
OptionSSL::SSL(ref mut p) => p.shutdown(),
}
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, std::io::Error> {
match *self {
OptionSSL::Plain(ref mut p) => p.write_buf(buf),
OptionSSL::SSL(ref mut p) => p.write_buf(buf),
}
}
}
use std::io::{Read, Write};
impl Read for OptionSSL {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match *self {
OptionSSL::Plain(ref mut p) => p.read(buf),
OptionSSL::SSL(ref mut p) => p.read(buf),
}
}
}
impl AsyncRead for OptionSSL {}
impl Write for OptionSSL {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
match *self {
OptionSSL::Plain(ref mut p) => p.write(buf),
OptionSSL::SSL(ref mut p) => p.write(buf),
}
}
fn flush(&mut self) -> Result<(), std::io::Error> {
match *self {
OptionSSL::Plain(ref mut p) => p.flush(),
OptionSSL::SSL(ref mut p) => p.flush(),
}
}
}
enum DbState<R: HandleRow, P: AsRef<Parameters>> {
Reconnecting(HashMap<Cow<'static, str>, usize>),
Connect(
tokio::net::tcp::ConnectFuture,
HashMap<Cow<'static, str>, usize>,
),
SSLInit(
WriteAll<TcpStream, &'static [u8]>,
HashMap<Cow<'static, str>, usize>,
),
SSLFlush(Flush<TcpStream>, HashMap<Cow<'static, str>, usize>),
SSLConfirm(
ReadExact<TcpStream, [u8; 1]>,
HashMap<Cow<'static, str>, usize>,
),
SSLConnect(ConnectAsync<TcpStream>, HashMap<Cow<'static, str>, usize>),
Init(Connecting<OptionSSL, P>, HashMap<Cow<'static, str>, usize>),
InitWrite(WriteFuture<OptionSSL, P>),
Connected {
connection: Connection<OptionSSL, P>,
timeout: Option<Delay>,
},
Write {
connection: WriteFuture<OptionSSL, P>,
query: R,
reply: oneshot::Sender<(R, Option<Error>)>,
},
Query {
rows: Rows<OptionSSL, R, P>,
reply: oneshot::Sender<(R, Option<Error>)>,
},
}
impl<R: Request + HandleRow, P: AsRef<Parameters> + Clone + Send> Future
for DatabaseConnection<R, P>
{
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
debug!("polling databaseconnection");
match self.state.take() {
Some(DbState::Reconnecting(parsed_queries)) => {
debug!("reconnecting");
self.state = Some(DbState::Connect(
TcpStream::connect(&self.parameters.as_ref().addr),
parsed_queries,
));
}
Some(DbState::Connect(mut stream, parsed_queries)) => {
debug!("connect");
match stream.poll() {
Ok(Async::Ready(stream)) => {
stream.set_keepalive(self.parameters.as_ref().tcp_keepalive)?;
if self.parameters.as_ref().ssl.is_some() {
debug!("sending {:?}", [0, 0, 0, 8, 4, 210, 22, 47]);
self.state = Some(DbState::SSLInit(
write_all(stream, &[0, 0, 0, 8, 4, 210, 22, 47]),
parsed_queries,
))
} else {
self.state = Some(DbState::Init(
Connection::new(
OptionSSL::Plain(stream),
self.parameters.clone(),
HashMap::new(),
),
parsed_queries,
))
}
}
Ok(Async::NotReady) => {
debug!("not ready");
self.state = Some(DbState::Connect(stream, parsed_queries));
return Ok(Async::NotReady);
}
Err(e) => {
error!("{:?}", e);
self.state = Some(DbState::Reconnecting(parsed_queries))
}
}
}
Some(DbState::SSLInit(mut w, parsed_queries)) => {
debug!("SSLInit");
if let Async::Ready((w, _)) = w.poll()? {
debug!("flush");
self.state = Some(DbState::SSLFlush(flush(w), parsed_queries))
} else {
debug!("still init");
self.state = Some(DbState::SSLInit(w, parsed_queries));
return Ok(Async::NotReady);
}
}
Some(DbState::SSLFlush(mut w, parsed_queries)) => {
debug!("SSLFlush");
if let Async::Ready(w) = w.poll()? {
debug!("confirm");
self.state =
Some(DbState::SSLConfirm(read_exact(w, [0; 1]), parsed_queries))
} else {
debug!("still flush");
self.state = Some(DbState::SSLFlush(w, parsed_queries));
return Ok(Async::NotReady);
}
}
Some(DbState::SSLConfirm(mut r, parsed_queries)) => {
debug!("SSLConfirm");
if let Async::Ready((r, buf)) = r.poll()? {
if buf[0] == b'S' {
let connect = if let Some(ref ssl) = self.parameters.as_ref().ssl {
if let Some(ref hostname) = ssl.hostname {
debug!("hostname {:?}", hostname);
ssl.config.connect_async(hostname, r)
} else {
debug!("no host name");
ssl.config
.configure()?
.verify_hostname(false)
.connect_async("", r)
}
} else {
unreachable!()
};
self.state = Some(DbState::SSLConnect(connect, parsed_queries))
} else {
debug!("buf = {:?}", buf);
return Err(Error::SSLDenied);
}
} else {
debug!("confirm not ready");
self.state = Some(DbState::SSLConfirm(r, parsed_queries));
return Ok(Async::NotReady);
}
}
Some(DbState::SSLConnect(mut w, parsed_queries)) => {
debug!("SSLConnect");
if let Async::Ready(w) = w.poll()? {
debug!("Connected");
self.state = Some(DbState::Init(
Connection::new(
OptionSSL::SSL(w),
self.parameters.clone(),
HashMap::new(),
),
parsed_queries,
))
} else {
debug!("SSLConnect not ready");
self.state = Some(DbState::SSLConnect(w, parsed_queries));
return Ok(Async::NotReady);
}
}
Some(DbState::Init(mut connected, parsed_queries)) => {
debug!("init");
match connected.poll() {
Ok(Async::Ready(mut connected)) => {
{
let mut buf = connected.buffer();
for (a, b) in parsed_queries.iter() {
buf.parse(&buffer::name(*b), a, &[]);
}
}
connected.parsed_queries = parsed_queries;
self.state = Some(DbState::InitWrite(connected.write()));
}
Ok(Async::NotReady) => {
self.state = Some(DbState::Init(connected, parsed_queries));
return Ok(Async::NotReady);
}
Err(e) => {
error!("{:?}", e);
self.state = Some(DbState::Reconnecting(parsed_queries))
}
}
}
Some(DbState::InitWrite(mut connected)) => {
debug!("initwrite");
match connected.poll() {
Ok(Async::Ready(connected)) => {
self.state = Some(DbState::Connected {
connection: connected,
timeout: if let Some(t) = self.parameters.as_ref().idle_timeout {
Some(Delay::new(std::time::Instant::now() + t))
} else {
None
},
})
}
Ok(Async::NotReady) => {
self.state = Some(DbState::InitWrite(connected));
return Ok(Async::NotReady);
}
Err(e) => {
error!("{:?}", e);
self.state = Some(DbState::Reconnecting(
connected.0.take().unwrap().parsed_queries,
))
}
}
}
Some(DbState::Connected {
mut connection,
mut timeout,
}) => {
debug!("connected");
if let Some((reply, query)) = self.leftover_oneshot.take() {
debug!("self.query_buffer = {:?}", self.query_buffer);
connection.buffer().clone_from_buf(&self.query_buffer);
self.state = Some(DbState::Write {
connection: connection.write(),
reply: reply,
query: query,
});
} else if connection.ready_for_query {
debug!("receiver polled");
match self.receiver.poll() {
Ok(Async::Ready(Some(Some((mut query, reply))))) => {
self.query_buffer.clear();
self.parse_buffer.clear();
debug!("ready ? {:?}", connection.ready_for_query);
query.request(&mut Buffer::buf(
&mut self.query_buffer,
&mut self.parse_buffer,
&mut connection.parsed_queries,
));
sync(&mut self.query_buffer);
{
let mut b = connection.buffer();
b.clone_from_buf(&self.parse_buffer);
b.extend(&self.query_buffer);
}
connection.ready_for_query = false;
self.state = Some(DbState::Write {
connection: connection.write(),
reply: reply,
query: query,
});
}
Ok(Async::Ready(_)) => return Ok(Async::Ready(())),
Ok(Async::NotReady) => {
let timeout_is_ready = if let Some(ref mut timeout) = timeout {
match timeout.poll() {
Ok(Async::Ready(_)) => true,
_ => false,
}
} else {
false
};
if timeout_is_ready {
self.state =
Some(DbState::Reconnecting(connection.parsed_queries));
} else {
self.state = Some(DbState::Connected {
connection,
timeout,
});
return Ok(Async::NotReady);
}
}
Err(e) => {
error!("{:?}", e);
self.state = Some(DbState::Reconnecting(connection.parsed_queries))
}
}
} else {
match connection.poll_connecting() {
Ok(Async::Ready(())) => {
self.state = Some(DbState::Connected {
connection,
timeout,
});
}
Ok(Async::NotReady) => {
self.state = Some(DbState::Connected {
connection,
timeout,
});
return Ok(Async::NotReady);
}
Err(e) => {
error!("error: {:?}", e)
}
}
}
}
Some(DbState::Write {
mut connection,
reply,
query,
}) => {
debug!("write");
match connection.poll() {
Ok(Async::Ready(connection)) => {
let timeout = connection.timeout();
self.state = Some(DbState::Query {
rows: connection.rows(query, Some(timeout))?,
reply: reply,
})
}
Ok(Async::NotReady) => {
self.state = Some(DbState::Write {
connection: connection,
query: query,
reply: reply,
});
return Ok(Async::NotReady);
}
Err(e) => {
error!("{:?}", e);
self.state = Some(DbState::Reconnecting(
connection.0.take().unwrap().parsed_queries,
))
}
}
}
Some(DbState::Query { reply, mut rows }) => {
debug!("polling rows");
match rows.poll() {
Ok(Async::Ready(((connection, query), err))) => {
if let Some(Error::Timeout) = err {
self.leftover_oneshot = Some((reply, query));
self.state = Some(DbState::Reconnecting(connection.parsed_queries));
} else {
if reply.send((query, err)).is_err() {
error!("Could not send result")
}
self.state = Some(DbState::Connected {
connection,
timeout: if let Some(t) = self.parameters.as_ref().idle_timeout
{
Some(Delay::new(std::time::Instant::now() + t))
} else {
None
},
});
}
continue;
}
Ok(Async::NotReady) => {}
Err(_) => {
unreachable!()
}
}
self.state = Some(DbState::Query {
rows: rows,
reply: reply,
});
return Ok(Async::NotReady);
}
None => panic!("None"),
}
}
}
}
impl<R: Request + HandleRow, P: AsRef<Parameters> + Send> DatabaseConnection<R, P> {
pub fn new(p: P) -> Self {
let (se, re) = mpsc::unbounded();
DatabaseConnection {
parameters: p,
state: Some(DbState::Reconnecting(HashMap::new())),
receiver: re,
sender: se,
query_buffer: Vec::new(),
parse_buffer: Vec::new(),
leftover_oneshot: None,
}
}
pub fn handle(&self) -> Handle<R> {
Handle {
sender: self.sender.clone(),
}
}
}
pub struct Handle<R> {
sender: mpsc::UnboundedSender<Option<(R, oneshot::Sender<(R, Option<Error>)>)>>,
}
impl<R> Clone for Handle<R> {
fn clone(&self) -> Self {
Handle {
sender: self.sender.clone(),
}
}
}
impl<R: Request + HandleRow> Handle<R> {
pub fn send_request<F>(&self, r: F) -> SendRequest<R>
where
R: std::convert::From<F>,
{
SendRequest(self.send_request_err(r))
}
pub fn send_request_err<F>(&self, r: F) -> SendRequestErr<R>
where
R: std::convert::From<F>,
{
let (sender, receiver) = futures::sync::oneshot::channel();
SendRequestErr {
sender: self.sender.clone(),
receiver: receiver,
state: Some(SendState::Init(r.into(), sender)),
}
}
pub fn shutdown(&self) -> ShutdownFuture<R> {
debug!("shutting down send ?");
ShutdownFuture {
sender: self.sender.clone(),
shutdown: false,
}
}
}
pub struct ShutdownFuture<R> {
sender: futures::sync::mpsc::UnboundedSender<
Option<(R, futures::sync::oneshot::Sender<(R, Option<Error>)>)>,
>,
shutdown: bool,
}
impl<R> Future for ShutdownFuture<R> {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
debug!("shutting down sender");
loop {
if self.shutdown {
return Ok(self.sender.close()?);
} else {
match self.sender.start_send(None)? {
AsyncSink::NotReady(None) => return Ok(Async::NotReady),
AsyncSink::Ready => self.shutdown = true,
AsyncSink::NotReady(_) => unreachable!(),
}
}
}
}
}
pub fn spawn_connection<
R: HandleRow + Request + Send + 'static,
P: AsRef<Parameters> + Clone + Send + 'static,
>(
p: P,
) -> Result<Handle<R>, Error> {
let db = DatabaseConnection::new(p);
let db_handle = db.handle();
tokio::spawn(db.map_err(|err| println!("err {:?}", err)));
Ok(db_handle)
}