deadpool_bolt/
lib.rs

1#![warn(rust_2018_idioms)]
2
3//! [`bolt_client`] support for the [`deadpool`] connection pool.
4//!
5//! # Example
6//! ```
7//! use std::env;
8//!
9//! # use deadpool::managed::Manager as DeadpoolManager;
10//! use deadpool_bolt::{
11//!     bolt_client::Metadata,
12//!     bolt_proto::{version::*, Value},
13//!     Manager,
14//!     Pool,
15//! #   bolt_client::error::{ConnectionError, Error as ClientError},
16//! #   bolt_proto::message::Success,
17//! };
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     // Configure a connection manager. We'll request Bolt v4.4 or v4.3.
22//!     let manager = Manager::new(
23//!         env::var("BOLT_TEST_ADDR")?,
24//!         env::var("BOLT_TEST_DOMAIN").ok(),
25//!         [V4_4, V4_3, 0, 0],
26//!         Metadata::from_iter(vec![
27//!             ("user_agent", "bolt-client/X.Y.Z"),
28//!             ("scheme", "basic"),
29//!             ("principal", &env::var("BOLT_TEST_USERNAME")?),
30//!             ("credentials", &env::var("BOLT_TEST_PASSWORD")?),
31//!         ]),
32//!     ).await?;
33//!
34//! #   match manager.create().await {
35//! #       Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
36//! #           println!("skipping test: {}", ConnectionError::HandshakeFailed(versions));
37//! #           return Ok(());
38//! #       }
39//! #       Err(other) => panic!("{}", other),
40//! #       _ => {}
41//! #   }
42//!     // Create a connection pool. This should be shared across your application.
43//!     let pool = Pool::builder(manager).build()?;
44//!
45//!     // Fetch and use a connection from the pool
46//!     let mut conn = pool.get().await?;
47//!     let response = conn.run("RETURN 1 as num;", None, None).await?;
48//! #   Success::try_from(response.clone()).unwrap();
49//!     let pull_meta = Metadata::from_iter(vec![("n", 1)]);
50//!     let (records, response) = conn.pull(Some(pull_meta)).await?;
51//! #   Success::try_from(response).unwrap();
52//!     assert_eq!(records[0].fields(), &[Value::from(1)]);
53//!
54//!     Ok(())
55//! }
56
57use std::{convert::Infallible, io, net::SocketAddr};
58
59use async_trait::async_trait;
60use deadpool::managed::RecycleResult;
61use tokio::{
62    io::BufStream,
63    net::{lookup_host, ToSocketAddrs},
64};
65use tokio_util::compat::*;
66
67use bolt_client::{
68    error::{CommunicationError, ConnectionError, Error as ClientError},
69    Client, Metadata, Stream,
70};
71use bolt_proto::{error::Error as ProtocolError, message, Message};
72
73pub use bolt_client;
74pub use bolt_client::bolt_proto;
75
76pub use deadpool::managed::reexports::*;
77deadpool::managed_reexports!(
78    "bolt_client",
79    Manager,
80    deadpool::managed::Object<Manager>,
81    ClientError,
82    // We're not validating configuration before requesting connections
83    Infallible
84);
85
86pub struct Manager {
87    addr: SocketAddr,
88    domain: Option<String>,
89    version_specifiers: [u32; 4],
90    metadata: Metadata,
91}
92
93impl Manager {
94    pub async fn new(
95        addr: impl ToSocketAddrs,
96        domain: Option<String>,
97        version_specifiers: [u32; 4],
98        metadata: Metadata,
99    ) -> io::Result<Self> {
100        Ok(Self {
101            addr: lookup_host(addr)
102                .await?
103                .next()
104                .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?,
105            domain,
106            version_specifiers,
107            metadata,
108        })
109    }
110}
111
112#[async_trait]
113impl deadpool::managed::Manager for Manager {
114    // TODO: Make a runtime-agnostic stream wrapper
115    type Type = Client<Compat<BufStream<Stream>>>;
116    type Error = ClientError;
117
118    async fn create(&self) -> Result<Self::Type, Self::Error> {
119        let mut client = Client::new(
120            BufStream::new(
121                Stream::connect(self.addr, self.domain.as_ref())
122                    .await
123                    .map_err(ConnectionError::from)?,
124            )
125            .compat(),
126            &self.version_specifiers,
127        )
128        .await?;
129
130        match client.hello(self.metadata.clone()).await? {
131            Message::Success(_) => Ok(client),
132            other => Err(CommunicationError::from(io::Error::new(
133                io::ErrorKind::ConnectionAborted,
134                format!("server responded with {:?}", other),
135            ))
136            .into()),
137        }
138    }
139
140    async fn recycle(&self, conn: &mut Self::Type) -> RecycleResult<Self::Error> {
141        message::Success::try_from(conn.reset().await.map_err(Self::Error::from)?)
142            .map_err(ProtocolError::from)
143            .map_err(Self::Error::from)?;
144        Ok(())
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::env;
151
152    use bolt_proto::{version::*, Value};
153    use deadpool::managed::Manager as DeadpoolManager;
154    use futures_util::{stream::FuturesUnordered, StreamExt};
155
156    use super::*;
157
158    async fn get_connection_manager(version_specifiers: [u32; 4], succeed: bool) -> Manager {
159        let credentials = if succeed {
160            env::var("BOLT_TEST_PASSWORD").unwrap()
161        } else {
162            String::from("invalid")
163        };
164
165        Manager::new(
166            env::var("BOLT_TEST_ADDR").unwrap(),
167            env::var("BOLT_TEST_DOMAIN").ok(),
168            version_specifiers,
169            Metadata::from_iter(vec![
170                ("user_agent", "bolt-client/X.Y.Z"),
171                ("scheme", "basic"),
172                ("principal", &env::var("BOLT_TEST_USERNAME").unwrap()),
173                ("credentials", &credentials),
174            ]),
175        )
176        .await
177        .unwrap()
178    }
179
180    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
181    async fn basic_pool() {
182        const POOL_SIZE: usize = 15;
183        const MAX_CONNS: usize = 50;
184
185        for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
186            let manager = get_connection_manager([bolt_version, 0, 0, 0], true).await;
187
188            // Don't even test connection pool if server doesn't support this Bolt version
189            match manager.create().await {
190                Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
191                    println!(
192                        "skipping test: {}",
193                        ConnectionError::HandshakeFailed(versions)
194                    );
195                    continue;
196                }
197                Err(other) => panic!("{}", other),
198                _ => {}
199            }
200
201            let pool = Pool::builder(manager).max_size(POOL_SIZE).build().unwrap();
202
203            (0..MAX_CONNS)
204                .map(|i| {
205                    let pool = pool.clone();
206                    async move {
207                        let mut client = pool.get().await.unwrap();
208                        let statement = format!("RETURN {} as num;", i);
209                        client.run(statement, None, None).await.unwrap();
210                        let (records, response) = client
211                            .pull(Some(Metadata::from_iter(vec![("n", 1)])))
212                            .await
213                            .unwrap();
214                        assert!(message::Success::try_from(response).is_ok());
215                        assert_eq!(records[0].fields(), &[Value::from(i as i8)]);
216                    }
217                })
218                .collect::<FuturesUnordered<_>>()
219                .collect::<Vec<_>>()
220                .await;
221        }
222    }
223
224    #[tokio::test]
225    async fn invalid_init_fails() {
226        for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
227            let manager = get_connection_manager([bolt_version, 0, 0, 0], false).await;
228            match manager.create().await {
229                Ok(_) => panic!("initialization should have failed"),
230                Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
231                    println!(
232                        "skipping test: {}",
233                        ConnectionError::HandshakeFailed(versions)
234                    );
235                    continue;
236                }
237                Err(ClientError::CommunicationError(comm_err)) => {
238                    if let CommunicationError::IoError(io_err) = &*comm_err {
239                        if io_err.kind() == io::ErrorKind::ConnectionAborted {
240                            // Test passed. We only check the first compatible version since
241                            // sending too many invalid credentials will cause us to get
242                            // rate-limited.
243                            return;
244                        }
245                    }
246                    panic!("{}", comm_err);
247                }
248                Err(other) => panic!("{}", other),
249            }
250        }
251    }
252}