1use std::future::Future;
5
6#[allow(unused)]
8pub use bb8::PooledConnection;
9pub use bb8::{ManageConnection, Pool};
10
11use crate::cluster::storage::MembershipStorage;
12use crate::protocol::ClientError;
13
14use super::Client;
15use super::ClientBuilder;
16use super::DEFAULT_TIMEOUT_MILLIS;
17
18pub struct ClientConnectionManager<S> {
27 pub(crate) members_storage: S,
28 pub(crate) timeout_millis: u64,
29}
30
31impl<S: MembershipStorage + 'static> ClientConnectionManager<S> {
32 pub fn new(members_storage: S) -> Self {
33 ClientConnectionManager {
34 members_storage,
35 timeout_millis: DEFAULT_TIMEOUT_MILLIS,
36 }
37 }
38
39 pub fn pool() -> bb8::Builder<Self> {
40 Pool::builder()
41 }
42}
43
44impl<S: MembershipStorage + 'static> ManageConnection for ClientConnectionManager<S> {
45 type Connection = Client<S>;
46 type Error = ClientError;
47 fn connect(&self) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send {
48 futures::future::ready(
49 ClientBuilder::new()
50 .members_storage(self.members_storage.clone())
51 .timeout_millis(self.timeout_millis)
52 .build()
53 .map_err(|err| ClientError::Unknown(err.to_string())),
54 )
55 }
56
57 fn is_valid(
58 &self,
59 _conn: &mut Self::Connection,
60 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
61 futures::future::ok(())
62 }
63
64 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
65 false
66 }
67}
68
69#[cfg(test)]
70mod test {
71 use bb8::Pool;
72
73 use crate::cluster::storage::local::LocalStorage;
74 use crate::cluster::storage::Member;
75
76 use super::*;
77
78 #[tokio::test]
79 async fn basic_usage() {
80 let local_members_storage = LocalStorage::default();
81 local_members_storage
82 .push(Member::new("0.0.0.0".to_string(), "9999".to_string()))
83 .await
84 .unwrap();
85
86 let manager = ClientConnectionManager::new(local_members_storage);
87 let client = Pool::builder().build(manager).await.unwrap();
88
89 let mut conn_1 = client.get().await.unwrap();
90 let conn_2 = client.get().await.unwrap();
91
92 conn_1.fetch_active_servers().await.unwrap();
93
94 assert_eq!(conn_1.membership_storage.members().await.unwrap().len(), 1);
95 assert_eq!(conn_2.membership_storage.members().await.unwrap().len(), 1);
96 }
97}