#![deny(trivial_casts, unstable_features, unused_import_braces)]
#![recursion_limit = "128"]
extern crate futures;
extern crate tokio;
#[macro_use]
extern crate log;
extern crate byteorder;
extern crate chrono;
extern crate digest;
extern crate md5;
extern crate uuid;
#[macro_use]
extern crate serde_derive;
extern crate bytes;
#[cfg(test)]
extern crate env_logger;
extern crate openssl;
extern crate tokio_io;
extern crate tokio_openssl;
extern crate tokio_timer;
extern crate pleingres_macros;
pub use pleingres_macros::{HandleRowJoin, Request, sql};
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use std::collections::HashMap;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
use tokio_io::io::{read_exact, ReadExact, ReadHalf, WriteHalf};
use tokio_io::{AsyncRead, AsyncWrite};
use futures::{Async, Future, Sink};
const MAJOR_VERSION: u16 = 3;
const MINOR_VERSION: u16 = 0;
mod statement;
pub use statement::*;
mod rows;
pub use rows::*;
#[allow(missing_docs)]
mod msg;
use msg::*;
mod buffer;
pub use buffer::*;
mod spawned;
pub use spawned::*;
#[derive(Debug)]
pub enum Error {
Io(::std::io::Error),
Utf8(::std::str::Utf8Error),
Parse(::std::num::ParseIntError),
Openssl(::openssl::ssl::Error),
OpensslStack(::openssl::error::ErrorStack),
Canceled(::futures::Canceled),
Protocol,
NoAddr,
Postgres(PostgresError),
Timeout,
Unit,
SSLDenied,
}
#[derive(Debug)]
pub enum PostgresError {
IntegrityConstraintViolation(String),
RestrictViolation(String),
NotNullViolation(String),
ForeignKeyViolation(String),
UniqueViolation(String),
CheckViolation(String),
ExclusionViolation(String),
Other { message: String },
}
impl std::error::Error for Error {
fn description(&self) -> &str {
use Error::*;
match *self {
Io(ref e) => e.description(),
Utf8(ref e) => e.description(),
Parse(ref e) => e.description(),
Openssl(ref e) => e.description(),
OpensslStack(ref e) => e.description(),
Canceled(ref e) => e.description(),
Protocol => "Postgres protocol error",
NoAddr => "The supplied network address is invalid",
Postgres(ref e) => e.description(),
Timeout => "The connection timed out",
Unit => "Other error",
SSLDenied => "SSL connection denied",
}
}
fn cause(&self) -> Option<&std::error::Error> {
use Error::*;
match *self {
Io(ref e) => Some(e),
Utf8(ref e) => Some(e),
Parse(ref e) => Some(e),
Openssl(ref e) => Some(e),
OpensslStack(ref e) => Some(e),
Canceled(ref e) => Some(e),
_ => None,
}
}
}
impl std::error::Error for PostgresError {
fn description(&self) -> &str {
match *self {
PostgresError::IntegrityConstraintViolation(ref a) => a,
PostgresError::RestrictViolation(ref a) => a,
PostgresError::NotNullViolation(ref a) => a,
PostgresError::ForeignKeyViolation(ref a) => a,
PostgresError::UniqueViolation(ref a) => a,
PostgresError::CheckViolation(ref a) => a,
PostgresError::ExclusionViolation(ref a) => a,
PostgresError::Other { ref message } => message,
}
}
fn cause(&self) -> Option<&std::error::Error> {
None
}
}
impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
use Error::*;
match *self {
Io(ref e) => e.fmt(fmt),
Utf8(ref e) => e.fmt(fmt),
Parse(ref e) => e.fmt(fmt),
Openssl(ref e) => e.fmt(fmt),
OpensslStack(ref e) => e.fmt(fmt),
Canceled(ref e) => e.fmt(fmt),
Protocol => write!(fmt, "Postgres protocol error"),
NoAddr => write!(fmt, "The supplied network address is invalid"),
Postgres(ref e) => e.fmt(fmt),
Timeout => write!(fmt, "The connection timed out"),
Unit => write!(fmt, "Other error"),
SSLDenied => write!(fmt, "SSL connection denied"),
}
}
}
impl std::fmt::Display for PostgresError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
use std::error::Error;
self.description().fmt(fmt)
}
}
impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::Io(e)
}
}
impl From<openssl::error::ErrorStack> for Error {
fn from(e: openssl::error::ErrorStack) -> Self {
Error::OpensslStack(e)
}
}
impl From<openssl::ssl::Error> for Error {
fn from(e: openssl::ssl::Error) -> Self {
Error::Openssl(e)
}
}
impl From<std::num::ParseIntError> for Error {
fn from(e: std::num::ParseIntError) -> Self {
Error::Parse(e)
}
}
impl From<std::str::Utf8Error> for Error {
fn from(e: std::str::Utf8Error) -> Self {
Error::Utf8(e)
}
}
impl From<futures::Canceled> for Error {
fn from(e: futures::Canceled) -> Self {
Error::Canceled(e)
}
}
unsafe impl Send for Error {}
impl std::convert::From<()> for Error {
fn from(_: ()) -> Self {
Error::Unit
}
}
impl<E> std::convert::From<futures::sync::mpsc::SendError<E>> for Error {
fn from(_: futures::sync::mpsc::SendError<E>) -> Self {
Error::Unit.into()
}
}
trait WritePostgresExt {
fn write_string(&mut self, s: &[u8]) -> std::io::Result<()>;
}
impl<W: Write> WritePostgresExt for W {
fn write_string(&mut self, s: &[u8]) -> std::io::Result<()> {
assert!(s.iter().all(|&c| c != 0));
try!(self.write(s));
try!(self.write(&[0]));
Ok(())
}
}
#[derive(Default, Serialize, Deserialize)]
pub struct ConfigFile {
pub user: String,
pub password: String,
pub db: String,
pub addr: String,
pub tcp_keepalive: Option<u32>,
pub idle_timeout: Option<u64>,
pub ssl: Option<SSLConfigFile>,
}
#[derive(Default, Serialize, Deserialize)]
pub struct SSLConfigFile {
pub ca_root: Option<String>,
pub private_key: Option<String>,
pub hostname: Option<String>,
}
impl ConfigFile {
pub fn to_parameters(&self) -> Result<Parameters, Error> {
use std::net::ToSocketAddrs;
let addr = self
.addr
.to_socket_addrs()
.expect("invalid pleingres socket address")
.next()
.expect("no pleingres socket address");
Ok(Parameters {
addr,
user: self.user.to_string(),
password: self.password.to_string(),
database: Some(self.db.to_string()),
tcp_keepalive: self
.tcp_keepalive
.map(|x| std::time::Duration::from_millis(x as u64)),
idle_timeout: self.idle_timeout.map(Duration::from_millis),
ssl: if let Some(ref ssl) = self.ssl {
use openssl::ssl::{SslConnector, SslMethod};
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
if let Some(ref ca_root) = ssl.ca_root {
connector.set_ca_file(ca_root)?
}
if let Some(ref client_pk) = ssl.private_key {
connector.set_private_key_file(client_pk, openssl::ssl::SslFiletype::PEM)?
}
let connector = connector.build();
Some(SSLParameters {
config: connector,
hostname: ssl.hostname.clone(),
})
} else {
None
},
})
}
}
pub struct Parameters {
pub addr: SocketAddr,
pub user: String,
pub password: String,
pub database: Option<String>,
pub tcp_keepalive: Option<std::time::Duration>,
pub idle_timeout: Option<Duration>,
pub ssl: Option<SSLParameters>,
}
impl Default for Parameters {
fn default() -> Self {
use std::net::ToSocketAddrs;
Parameters {
addr: "localhost:5432".to_socket_addrs().unwrap().next().unwrap(),
user: String::new(),
password: String::new(),
database: None,
idle_timeout: None,
tcp_keepalive: None,
ssl: None,
}
}
}
pub struct SSLParameters {
pub config: openssl::ssl::SslConnector,
pub hostname: Option<String>,
}
enum ReadState<W: AsyncRead> {
ReadLength(ReadExact<ReadHalf<W>, Vec<u8>>),
Read {
msg: u8,
body: ReadExact<ReadHalf<W>, Vec<u8>>,
},
}
enum WriteState<W: AsyncWrite> {
Idle {
w: WriteHalf<W>,
buf: Vec<u8>,
},
Write {
w: tokio_io::io::WriteAll<WriteHalf<W>, Vec<u8>>,
},
Flush {
flush: tokio_io::io::Flush<WriteHalf<W>>,
buf: Vec<u8>,
},
}
use std::borrow::Cow;
pub struct Connection<W: AsyncWrite + AsyncRead, P: AsRef<Parameters>> {
parameters: P,
write: Option<WriteState<W>>,
read: Option<ReadState<W>>,
ready_for_query: bool,
process_id: Option<i32>,
secret_key: Option<i32>,
parsed_queries: HashMap<Cow<'static, str>, usize>,
}
struct Connecting<W: AsyncWrite + AsyncRead, P: AsRef<Parameters>>(Option<Connection<W, P>>);
impl<W: AsyncWrite + AsyncRead, P: AsRef<Parameters> + Send> Future for Connecting<W, P> {
type Item = Connection<W, P>;
type Error = Error;
fn poll(&mut self) -> Result<Async<Self::Item>, Self::Error> {
loop {
debug!("connecting poll");
if let Some(ref mut c) = self.0 {
if let Async::NotReady = c.poll_connecting()? {
return Ok(Async::NotReady);
}
}
return Ok(Async::Ready(self.0.take().unwrap()));
}
}
}
const AUTH_MD5_PASSWORD: u32 = 5;
const AUTH_OK: u32 = 0;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(u8)]
pub enum Wait {
CommandComplete = b'C',
Parse = b'1',
Bind = b'2',
Ready = b'Z',
CloseComplete = b'3',
}
fn read_length<A: AsyncRead>(stream: ReadHalf<A>, mut buf: Vec<u8>) -> ReadState<A> {
buf.resize(5, 0);
ReadState::ReadLength(read_exact(stream, buf))
}
fn read<A: AsyncRead>(stream: ReadHalf<A>, mut buf: Vec<u8>) -> ReadState<A> {
let msg = buf[0];
let n = BigEndian::read_u32(&buf[1..]) as usize;
debug!("n = {:?}", n);
buf.resize(n - 4, 0);
ReadState::Read {
msg: msg,
body: read_exact(stream, buf),
}
}
pub struct WriteFuture<W: AsyncWrite + AsyncRead, P: AsRef<Parameters>>(Option<Connection<W, P>>);
impl<W: AsyncWrite + AsyncRead, P: AsRef<Parameters> + Send> Connection<W, P> {
fn new(
stream: W,
parameters: P,
parsed_queries: HashMap<Cow<'static, str>, usize>,
) -> Connecting<W, P> {
let mut v = Vec::new();
v.extend(&[0, 0, 0, 0]);
v.write_u16::<BigEndian>(MAJOR_VERSION).unwrap();
v.write_u16::<BigEndian>(MINOR_VERSION).unwrap();
v.write_string(b"user").unwrap();
v.write_string(parameters.as_ref().user.as_bytes()).unwrap();
if let Some(ref database) = parameters.as_ref().database {
v.write_string(b"database").unwrap();
v.write_string(database.as_bytes()).unwrap();
}
v.write(&[0]).unwrap();
let len = v.len() as u32;
BigEndian::write_u32(&mut v, len);
let (read_half, write_half) = stream.split();
Connecting(Some(Connection {
write: Some(WriteState::Write {
w: tokio_io::io::write_all(write_half, v),
}),
read: Some(read_length(read_half, Vec::new())),
ready_for_query: false,
parameters: parameters,
process_id: None,
secret_key: None,
parsed_queries: parsed_queries,
}))
}
fn rows<F: HandleRow>(
self,
f: F,
timeout: Option<tokio_timer::Delay>,
) -> Result<Rows<W, F, P>, Error> {
Ok(Rows {
connection: Some((self, f)),
timeout,
last_row_was_halted: false,
})
}
fn timeout(&self) -> tokio_timer::Delay {
tokio_timer::Delay::new(
std::time::Instant::now()
+ self
.parameters
.as_ref()
.idle_timeout
.unwrap_or(std::time::Duration::from_secs(10)),
)
}
pub fn query<F: HandleRow>(
mut self,
q: &'static str,
args: &[&ToSql],
f: F,
) -> Result<Rows<W, F, P>, Error> {
{
let mut buf = self.buffer();
buf.bind(q, args).execute(0);
}
match self.write.take() {
Some(WriteState::Idle { w, buf }) => {
self.write = Some(WriteState::Write {
w: tokio_io::io::write_all(w, buf),
});
}
_ => unreachable!(),
}
let timeout = self.timeout();
Ok(self.rows(f, Some(timeout))?)
}
pub fn write(mut self) -> WriteFuture<W, P> {
match self.write.take() {
Some(WriteState::Idle { w, buf }) => {
self.write = Some(WriteState::Write {
w: tokio_io::io::write_all(w, buf),
});
WriteFuture(Some(self))
}
_ => unreachable!(),
}
}
fn poll_connecting(&mut self) -> Result<Async<()>, Error> {
loop {
match self.write.take() {
Some(WriteState::Write { mut w }) => {
debug!("write {}: {}", file!(), line!());
let (stream, buf) = match w.poll()? {
Async::Ready(s) => s,
Async::NotReady => {
self.write = Some(WriteState::Write { w });
return Ok(Async::NotReady);
}
};
self.write = Some(WriteState::Flush {
flush: tokio_io::io::flush(stream),
buf,
});
continue;
}
Some(WriteState::Flush { mut flush, buf }) => {
debug!("flush {}: {}", file!(), line!());
let w = match flush.poll()? {
Async::Ready(s) => s,
Async::NotReady => {
self.write = Some(WriteState::Flush { flush, buf });
return Ok(Async::NotReady);
}
};
self.write = Some(WriteState::Idle { w, buf });
}
Some(idle) => self.write = Some(idle),
None => unreachable!(),
}
match self.read.take() {
Some(ReadState::ReadLength(mut r)) => {
debug!("readlength: {} {}", file!(), line!());
match r.poll()? {
Async::Ready((connection, buffer)) => {
self.read = Some(read(connection, buffer));
}
Async::NotReady => {
self.read = Some(ReadState::ReadLength(r));
return Ok(Async::NotReady);
}
}
}
Some(ReadState::Read { msg, mut body }) => {
debug!("read {} {}", line!(), msg as char);
let (stream, buf) = match body.poll()? {
Async::Ready(s) => s,
Async::NotReady => {
self.read = Some(ReadState::Read {
msg: msg,
body: body,
});
return Ok(Async::NotReady);
}
};
debug!("poll connecting, read, {:?}", buf);
match msg {
MSG_AUTH => {
self.auth_request(&buf);
self.read = Some(read_length(stream, buf));
}
MSG_PARAMETER_STATUS => {
try!(parameter_status(&buf));
self.read = Some(read_length(stream, buf));
}
MSG_BACKEND_KEY_DATA => {
self.backend_key_data(&buf);
self.read = Some(read_length(stream, buf));
}
MSG_READY_FOR_QUERY => {
debug!("ready for query");
self.ready_for_query = true;
self.read = Some(read_length(stream, buf));
return Ok(Async::Ready(()));
}
MSG_ERROR => {
let err = error(&buf)?;
return Err(Error::Postgres(err));
}
msg => {
debug!("unknown message {:?}", msg);
self.read = Some(read_length(stream, buf));
}
}
}
None => unreachable!(),
}
}
}
fn auth_request(&mut self, read_buf: &[u8]) {
match self.write.take() {
Some(WriteState::Idle { w, mut buf }) => {
debug!("auth request: buf: {:?}", buf);
let auth_type = BigEndian::read_u32(&read_buf[..]);
match auth_type {
AUTH_OK => {
debug!("auth ok");
self.write = Some(WriteState::Idle { w, buf })
}
AUTH_MD5_PASSWORD => {
debug!("md5 password");
let md5 = md5(
&read_buf[4..],
self.parameters.as_ref().user.as_bytes(),
self.parameters.as_ref().password.as_bytes(),
);
buf.clear();
buf.push(MSG_PASSWORD);
buf.extend(&[0, 0, 0, 0]);
buf.extend(b"md5");
for i in md5.iter() {
let i = *i as usize;
buf.push(HEX[i >> 4]);
buf.push(HEX[i & 0xf]);
}
buf.push(0);
let len = buf.len() - 1;
BigEndian::write_u32(&mut buf[1..], len as u32);
self.write = Some(WriteState::Write {
w: tokio_io::io::write_all(w, buf),
})
}
t => panic!("auth {:?} not implemented", t),
}
}
_ => {}
}
}
fn backend_key_data(&mut self, buf: &[u8]) {
self.process_id = Some(BigEndian::read_i32(&buf[..]));
self.secret_key = Some(BigEndian::read_i32(&buf[4..]));
}
}
fn parameter_status(buf: &[u8]) -> Result<(), Error> {
if let Some(i) = (&buf[..]).iter().position(|&x| x == 0) {
if let Some(j) = (&buf[i + 1..]).iter().position(|&x| x == 0) {
debug!("parameter status: {:?}", std::str::from_utf8(&buf[..i]));
debug!(
"parameter status: {:?}",
std::str::from_utf8(&buf[i + 1..i + 1 + j])
);
return Ok(());
}
}
Err(Error::Protocol)
}
fn error(buf: &[u8]) -> Result<PostgresError, Error> {
let mut i = 0;
let mut line_count = 0;
let mut code = &[][..];
let mut into = String::new();
debug!("{:?}", buf);
while i < buf.len() && buf[i] != 0 {
if let Some(j) = (&buf[i + 1..]).iter().position(|&x| x == 0) {
if line_count == 2 {
code = &buf[i + 1..i + 1 + j];
let s = std::str::from_utf8(&buf[i + 1..i + 1 + j])?;
into.push_str(s);
into.push('\n');
} else {
let s = std::str::from_utf8(&buf[i + 1..i + 1 + j])?;
error!("Error: {:?}", s);
into.push_str(s);
into.push('\n');
}
line_count += 1;
i += j + 2
} else {
return Err(Error::Protocol);
}
}
match code {
b"23000" => Ok(PostgresError::IntegrityConstraintViolation(into)),
b"23001" => Ok(PostgresError::RestrictViolation(into)),
b"23502" => Ok(PostgresError::NotNullViolation(into)),
b"23503" => Ok(PostgresError::ForeignKeyViolation(into)),
b"23505" => Ok(PostgresError::UniqueViolation(into)),
b"23514" => Ok(PostgresError::CheckViolation(into)),
b"23P01" => Ok(PostgresError::ExclusionViolation(into)),
_ => Ok(PostgresError::Other { message: into }),
}
}
impl<W: AsyncWrite + AsyncRead, P: AsRef<Parameters>> Future for WriteFuture<W, P> {
type Item = Connection<W, P>;
type Error = Error;
fn poll(&mut self) -> Result<Async<Self::Item>, Self::Error> {
loop {
debug!("writefuture poll");
let is_flushed = if let Some(ref mut c) = self.0 {
match c.write.take() {
Some(WriteState::Write { mut w }) => match try!(w.poll()) {
Async::Ready((stream, buf)) => {
debug!("written buf = {:?}", buf);
c.write = Some(WriteState::Flush {
flush: tokio_io::io::flush(stream),
buf,
});
false
}
Async::NotReady => {
c.write = Some(WriteState::Write { w });
return Ok(Async::NotReady);
}
},
Some(WriteState::Flush { mut flush, buf }) => match try!(flush.poll()) {
Async::Ready(w) => {
c.write = Some(WriteState::Idle { w, buf });
true
}
Async::NotReady => {
c.write = Some(WriteState::Flush { flush, buf });
return Ok(Async::NotReady);
}
},
_ => unreachable!(),
}
} else {
false
};
if is_flushed {
return Ok(Async::Ready(self.0.take().unwrap()));
}
}
}
}
impl<W: AsyncWrite + AsyncRead, P: AsRef<Parameters>> Future for Connection<W, P> {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Result<Async<()>, Error> {
loop {
match self.write.take() {
Some(WriteState::Write { mut w }) => match try!(w.poll()) {
Async::Ready((stream, buf)) => {
self.write = Some(WriteState::Flush {
flush: tokio_io::io::flush(stream),
buf,
});
continue;
}
Async::NotReady => {
self.write = Some(WriteState::Write { w });
return Ok(Async::NotReady);
}
},
Some(WriteState::Flush { mut flush, buf }) => match try!(flush.poll()) {
Async::Ready(w) => self.write = Some(WriteState::Idle { w, buf }),
Async::NotReady => {
self.write = Some(WriteState::Flush { flush, buf });
return Ok(Async::NotReady);
}
},
Some(state) => self.write = Some(state),
None => unreachable!(),
}
debug!("connection poll");
match self.read.take() {
Some(ReadState::ReadLength(mut r)) => match r.poll()? {
Async::Ready((connection, buffer)) => {
self.read = Some(read(connection, buffer));
}
Async::NotReady => {
self.read = Some(ReadState::ReadLength(r));
return Ok(Async::NotReady);
}
},
Some(ReadState::Read { msg, mut body }) => {
let (stream, buf) = match try!(body.poll()) {
Async::Ready(s) => s,
Async::NotReady => {
self.read = Some(ReadState::Read {
msg: msg,
body: body,
});
return Ok(Async::NotReady);
}
};
match msg {
MSG_ERROR => {
let _err = error(&buf)?;
self.read = Some(read_length(stream, buf));
}
msg => {
debug!("unknown message {:?}", msg as char);
self.read = Some(read_length(stream, buf));
}
}
}
None => unreachable!(),
}
}
}
}
const HEX: &'static [u8] = b"0123456789abcdef";
fn md5(salt: &[u8], user: &[u8], password: &[u8]) -> [u8; 16] {
let mut md5 = md5::Context::new();
md5.consume(password);
md5.consume(user);
let output = md5.compute();
let mut md5 = md5::Context::new();
for i in output.iter() {
let i = *i as usize;
md5.consume(&[HEX[i >> 4], HEX[i & 0xf]])
}
md5.consume(salt);
md5.compute().0
}
#[test]
fn test_md5() {
let salt = [0xc4, 0x4e, 0x74, 0x12];
let md5 = md5(&salt[..], b"pe", b"password");
let mut v = Vec::new();
for i in md5.iter() {
let i = *i as usize;
v.push(HEX[i >> 4]);
v.push(HEX[i & 0xf]);
}
assert_eq!(&v[..], b"66008dc01e22d8e920e0458c43584d9e")
}
pub trait Request {
fn request(&mut self, buf: &mut Buffer);
}
pub trait SubRequest<R>: Sized {
fn from(self) -> R;
fn try_into(r: R) -> Result<Self, R>;
}
pub trait SuperRequest<R>: Sized {
fn from(r: R) -> Self;
fn try_into(self) -> Result<R, Self>;
}
impl Request for String {
fn request(&mut self, buf: &mut Buffer) {
buf.bind_unnamed(self, &[]).execute(0)
}
}
impl HandleRow for String {}
use futures::sync::oneshot::Sender;
enum SendState<R> {
Init(R, Sender<(R, Option<Error>)>),
Sending,
Receiving,
}
pub struct SendRequestErr<R> {
sender: futures::sync::mpsc::UnboundedSender<
Option<(R, futures::sync::oneshot::Sender<(R, Option<Error>)>)>,
>,
receiver: futures::sync::oneshot::Receiver<(R, Option<Error>)>,
state: Option<SendState<R>>,
}
pub struct SendRequest<R>(SendRequestErr<R>);
use futures::AsyncSink;
impl<R> Future for SendRequestErr<R> {
type Item = (R, Option<Error>);
type Error = Error;
fn poll(&mut self) -> Result<Async<Self::Item>, Self::Error> {
loop {
debug!("send request poll");
match self.state.take() {
Some(SendState::Init(r, send)) => {
debug!("state init");
match self.sender.start_send(Some((r, send)))? {
AsyncSink::NotReady(Some((r, send))) => {
self.state = Some(SendState::Init(r, send));
return Ok(Async::NotReady);
}
AsyncSink::Ready => self.state = Some(SendState::Sending),
AsyncSink::NotReady(None) => unreachable!(),
}
}
Some(SendState::Sending) => {
debug!("state sending");
match self.sender.poll_complete()? {
Async::NotReady => {
self.state = Some(SendState::Sending);
return Ok(Async::NotReady);
}
Async::Ready(()) => {
self.state = Some(SendState::Receiving);
}
}
}
Some(SendState::Receiving) => {
debug!("state receiving");
match self.receiver.poll()? {
Async::NotReady => {
debug!("not ready");
self.state = Some(SendState::Receiving);
return Ok(Async::NotReady);
}
Async::Ready(r) => return Ok(Async::Ready(r)),
}
}
None => panic!("sendstate polled after completion"),
}
}
}
}
impl<R> Future for SendRequest<R> {
type Item = R;
type Error = Error;
fn poll(&mut self) -> Result<Async<Self::Item>, Self::Error> {
match self.0.poll()? {
Async::Ready((_, Some(e))) => Err(e),
Async::Ready((req, None)) => Ok(Async::Ready(req)),
Async::NotReady => Ok(Async::NotReady),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use env_logger;
use futures;
use tokio;
use uuid::Uuid;
#[test]
fn basic_connection() {
use futures::Future;
use std::net::ToSocketAddrs;
use std::sync::Arc;
#[derive(Debug)]
struct Request {
login: String,
id: Option<Uuid>,
}
impl super::Request for Request {
fn request(&mut self, buf: &mut Buffer) {
buf.bind("SELECT id FROM users WHERE login=$1", &[&self.login])
.execute(0);
}
}
impl HandleRow for Request {
fn row(&mut self, mut row: Row) -> bool {
if let Some(id_) = row.next() {
debug!("{:?}", std::str::from_utf8(id_));
let id_ = std::str::from_utf8(id_).unwrap().parse().unwrap();
self.id = Some(id_)
}
true
}
}
env_logger::try_init().unwrap_or(());
let p = Arc::new(Parameters {
addr: "127.0.0.1:5432".to_socket_addrs().unwrap().next().unwrap(),
user: "pe".to_string(),
password: "password".to_string(),
database: Some("pijul".to_string()),
idle_timeout: Some(std::time::Duration::from_millis(20_000)),
tcp_keepalive: Some(std::time::Duration::from_millis(10000)),
ssl: None,
});
tokio::run(futures::lazy(move || {
let db: Handle<Request> = spawn_connection(p.clone()).unwrap();
db.send_request(Request {
login: "me".to_string(),
id: None,
})
.and_then(move |dbreq| {
debug!("{:?}", dbreq);
if let Some(ref id) = dbreq.id {
debug!("id: {:?}", id)
}
db.shutdown()
})
.map_err(|_| ())
}));
println!("done");
}
#[test]
fn ssl_connection() {
use futures::Future;
use std::sync::Arc;
struct Request {
login: String,
id: Option<Uuid>,
}
impl super::Request for Request {
fn request(&mut self, buf: &mut Buffer) {
buf.bind("SELECT id FROM users WHERE login=$1", &[&self.login])
.execute(0);
}
}
impl HandleRow for Request {
fn row(&mut self, mut row: Row) -> bool {
if let Some(id_) = row.next() {
debug!("{:?}", std::str::from_utf8(id_));
let id_ = std::str::from_utf8(id_).unwrap().parse().unwrap();
self.id = Some(id_)
}
true
}
}
env_logger::try_init().unwrap_or(());
let p = Arc::new(
ConfigFile {
addr: "35.195.27.52:5432".to_string(),
user: "pijul".to_string(),
password: "AVRWx3vDYR1V4W4HxnKn".to_string(),
db: "db".to_string(),
idle_timeout: Some(20_000),
tcp_keepalive: Some(10_000),
ssl: Some(SSLConfigFile {
ca_root: Some("../sql-server-ca.pem".to_string()),
private_key: Some("../sql-client-key.pem".to_string()),
hostname: Some("pijul-163713:nest".to_string()),
}),
}
.to_parameters()
.unwrap(),
);
tokio::run(futures::lazy(move || {
let db: Handle<Request> = spawn_connection(p.clone()).unwrap();
debug!("spawned");
db.send_request(Request {
login: "me".to_string(),
id: None,
})
.and_then(move |dbreq| {
if let Some(ref id) = dbreq.id {
debug!("id: {:?}", id)
}
db.shutdown()
})
.map_err(|_| ())
}))
}
}