use std::collections::HashMap;
use std::future::IntoFuture;
use futures::future::join_all;
use futures::Future;
use crate::{AsyncDB, Connection as ConnectionName, DBOutput};
pub trait MakeConnection {
type Conn: AsyncDB;
type MakeFuture: Future<Output = Result<Self::Conn, <Self::Conn as AsyncDB>::Error>>;
fn make(&mut self) -> Self::MakeFuture;
}
impl<D: AsyncDB, F, Fut> MakeConnection for F
where
F: FnMut() -> Fut,
Fut: IntoFuture<Output = Result<D, D::Error>>,
{
type Conn = D;
type MakeFuture = Fut::IntoFuture;
fn make(&mut self) -> Self::MakeFuture {
self().into_future()
}
}
pub(crate) struct Connections<D, M> {
make_conn: M,
conns: HashMap<ConnectionName, D>,
}
impl<D: AsyncDB, M: MakeConnection<Conn = D>> Connections<D, M> {
pub fn new(make_conn: M) -> Self {
Connections {
make_conn,
conns: HashMap::new(),
}
}
pub async fn get(&mut self, name: ConnectionName) -> Result<&mut D, D::Error> {
use std::collections::hash_map::Entry;
let conn = match self.conns.entry(name) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => {
let conn = self.make_conn.make().await?;
v.insert(conn)
}
};
Ok(conn)
}
pub async fn run_default(&mut self, sql: &str) -> Result<DBOutput<D::ColumnType>, D::Error> {
self.get(ConnectionName::Default).await?.run(sql).await
}
pub async fn shutdown_all(&mut self) {
join_all(self.conns.values_mut().map(|conn| conn.shutdown())).await;
}
}