mobc-reql 0.6.4

RethinkDB support for the mobc connection pool
Documentation
use blocking::unblock;
use futures::lock::Mutex;
use futures::{Future, TryStreamExt};
use futures_timer::Delay;
use mobc::{async_trait, Manager};
use reql::cmd::connect::Options;
use reql::cmd::run::{self, Arg};
use reql::types::{Change, ServerStatus};
use reql::{r, Command, Connection, Driver, Error, Result};
use std::cmp::Ordering;
use std::io;
use std::net::{IpAddr, TcpStream};
use std::ops::Deref;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::trace;

pub type Pool = mobc::Pool<SessionManager>;

#[async_trait]
pub trait GetSession {
    async fn session(&self) -> Result<Session>;
}

#[async_trait]
impl GetSession for Pool {
    async fn session(&self) -> Result<Session> {
        Ok(Session {
            conn: self.get().await.map_err(to_reql)?,
        })
    }
}

pub struct Session {
    conn: mobc::Connection<SessionManager>,
}

impl Deref for Session {
    type Target = reql::Session;

    fn deref(&self) -> &Self::Target {
        &self.conn
    }
}

impl AsRef<reql::Session> for Session {
    fn as_ref(&self) -> &reql::Session {
        self.deref()
    }
}

impl Arg for &Session {
    fn into_run_opts(self) -> Result<(Connection, run::Options)> {
        self.deref().into_run_opts()
    }
}

#[derive(Debug, Clone, Eq)]
struct Server {
    name: String,
    addresses: Vec<IpAddr>,
    port: u16,
    latency: Duration,
}

impl Ord for Server {
    fn cmp(&self, other: &Server) -> Ordering {
        self.latency.cmp(&other.latency)
    }
}

impl PartialOrd for Server {
    fn partial_cmp(&self, other: &Server) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl PartialEq for Server {
    fn eq(&self, other: &Server) -> bool {
        self.name == other.name
    }
}

impl Server {
    fn from_status(status: ServerStatus) -> Self {
        let network = status.network;
        let addresses = network.canonical_addresses.into_iter().map(move |x| x.host);
        Self {
            name: status.name,
            addresses: addresses.collect(),
            port: network.reql_port,
            latency: Duration::from_millis(u64::MAX),
        }
    }
}

#[derive(Clone)]
pub struct SessionManager {
    opts: Options,
    servers: Arc<Mutex<Vec<Server>>>,
    pool: Option<Pool>,
}

impl SessionManager {
    pub fn new(opts: Options) -> Self {
        Self {
            opts,
            servers: Arc::new(Mutex::new(Vec::new())),
            pool: None,
        }
    }

    pub fn discover_hosts(&self) -> impl Future<Output = ()> {
        let mut manager = self.clone();
        manager.pool = Some(Pool::builder().max_open(2).build(self.clone()));
        async move {
            let mut wait = 0;
            loop {
                if let Err(error) = manager.listen_for_hosts(&mut wait).await {
                    trace!(
                        "listening for host changes; error: {}, wait: {}s",
                        error,
                        wait
                    );
                    Delay::new(Duration::from_secs(wait)).await;
                    wait = 300.min(wait + 1);
                }
            }
        }
    }

    async fn listen_for_hosts(&self, wait: &mut u64) -> Result<()> {
        let conn = self.pool.as_ref().unwrap().session().await?;
        let mut query = server_status()
            .changes(())
            .run::<_, Change<ServerStatus, ServerStatus>>(&conn);
        while query.try_next().await?.is_some() {
            let servers = self.get_servers().await?;
            *self.servers.lock().await = servers;
            *wait = 0;
        }
        Ok(())
    }

    async fn get_servers(&self) -> Result<Vec<Server>> {
        let mut servers = Vec::new();
        let conn = self.pool.as_ref().unwrap().session().await?;
        let mut query = server_status().run(&conn);
        while let Some(status) = query.try_next().await? {
            servers.push(Server::from_status(status));
        }
        set_latency(&mut servers).await;
        servers.sort();
        Ok(servers)
    }
}

async fn set_latency(servers: &mut Vec<Server>) {
    for server in servers {
        let port = server.port;
        for (i, host) in server.addresses.iter().enumerate() {
            let host = *host;
            let latency = unblock(move || {
                let start = Instant::now();
                if TcpStream::connect((host, port)).is_ok() {
                    return Some(start.elapsed());
                }
                None
            })
            .await;
            if let Some(latency) = latency {
                if latency > server.latency || i == 0 {
                    server.latency = latency;
                }
            }
        }
    }
}

fn server_status() -> Command {
    r.db("rethinkdb").table("server_status")
}

#[async_trait]
impl Manager for SessionManager {
    type Connection = reql::Session;
    type Error = Error;

    async fn connect(&self) -> Result<Self::Connection> {
        let opts = &self.opts;
        let servers = &self.servers.lock().await;
        if servers.is_empty() {
            trace!(
                "no discovered servers; host: {}, port: {}",
                opts.host,
                opts.port
            );
            return r.connect(opts.clone()).await;
        } else {
            for server in servers.iter() {
                for host in &server.addresses {
                    trace!(
                        "discovered server {}; host: {}, port: {}",
                        server.name,
                        host,
                        server.port
                    );
                    let addr = (*host, server.port);
                    if let Ok(conn) = r.connect(r.args((addr, opts.clone()))).await {
                        return Ok(conn);
                    }
                }
            }
        }
        Err(io::Error::new(
            io::ErrorKind::ConnectionRefused,
            "no RethinkDB servers available",
        )
        .into())
    }

    async fn check(&self, conn: Self::Connection) -> Result<Self::Connection> {
        let msg = 200;
        match r.expr(msg).run(&conn).try_next().await? {
            Some(res) => verify(res, msg)?,
            None => {
                return Err(Driver::ConnectionBroken.into());
            }
        }
        Ok(conn)
    }

    fn validate(&self, conn: &mut Self::Connection) -> bool {
        !conn.is_broken()
    }
}

fn verify(res: u32, msg: u32) -> Result<()> {
    if res != msg {
        return Err(Driver::ConnectionBroken.into());
    }
    Ok(())
}

fn to_reql(error: mobc::Error<Error>) -> Error {
    match error {
        mobc::Error::Inner(error) => error,
        mobc::Error::Timeout => io::Error::from(io::ErrorKind::TimedOut).into(),
        mobc::Error::BadConn => Driver::ConnectionBroken.into(),
    }
}