bb8_bolt/
lib.rs

1#![warn(rust_2018_idioms)]
2
3//! [`bolt_client`] support for the [`bb8`] connection pool.
4//!
5//! # Example
6//! ```
7//! use std::env;
8//!
9//! # use bb8::ManageConnection;
10//! use bb8_bolt::{
11//!     bb8::Pool,
12//!     bolt_client::Metadata,
13//!     bolt_proto::{version::*, Value},
14//!     Manager,
15//! #   bolt_client::error::{ConnectionError, Error as ClientError},
16//! #   bolt_proto::message::Success,
17//! };
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     // Configure a connection manager. We'll request Bolt v4.4 or v4.3.
22//!     let manager = Manager::new(
23//!         env::var("BOLT_TEST_ADDR")?,
24//!         env::var("BOLT_TEST_DOMAIN").ok(),
25//!         [V4_4, V4_3, 0, 0],
26//!         Metadata::from_iter(vec![
27//!             ("user_agent", "bolt-client/X.Y.Z"),
28//!             ("scheme", "basic"),
29//!             ("principal", &env::var("BOLT_TEST_USERNAME")?),
30//!             ("credentials", &env::var("BOLT_TEST_PASSWORD")?),
31//!         ]),
32//!     ).await?;
33//!
34//! #   match manager.connect().await {
35//! #       Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
36//! #           println!("skipping test: {}", ConnectionError::HandshakeFailed(versions));
37//! #           return Ok(());
38//! #       }
39//! #       Err(other) => panic!("{}", other),
40//! #       _ => {}
41//! #   }
42//!     // Create a connection pool. This should be shared across your application.
43//!     let pool = Pool::builder().build(manager).await?;
44//!
45//!     // Fetch and use a connection from the pool
46//!     let mut conn = pool.get().await?;
47//!     let response = conn.run("RETURN 1 as num;", None, None).await?;
48//! #   Success::try_from(response.clone()).unwrap();
49//!     let pull_meta = Metadata::from_iter(vec![("n", 1)]);
50//!     let (records, response) = conn.pull(Some(pull_meta)).await?;
51//! #   Success::try_from(response).unwrap();
52//!     assert_eq!(records[0].fields(), &[Value::from(1)]);
53//!
54//!     Ok(())
55//! }
56
57use std::{io, net::SocketAddr};
58
59use async_trait::async_trait;
60use bb8::ManageConnection;
61use tokio::{
62    io::BufStream,
63    net::{lookup_host, ToSocketAddrs},
64};
65use tokio_util::compat::*;
66
67use bolt_client::{
68    error::{CommunicationError, ConnectionError, Error as ClientError},
69    Client, Metadata, Stream,
70};
71use bolt_proto::{error::Error as ProtocolError, message, Message, ServerState};
72
73pub use bb8;
74pub use bolt_client;
75pub use bolt_client::bolt_proto;
76
77pub struct Manager {
78    addr: SocketAddr,
79    domain: Option<String>,
80    version_specifiers: [u32; 4],
81    metadata: Metadata,
82}
83
84impl Manager {
85    pub async fn new(
86        addr: impl ToSocketAddrs,
87        domain: Option<String>,
88        version_specifiers: [u32; 4],
89        metadata: Metadata,
90    ) -> io::Result<Self> {
91        Ok(Self {
92            addr: lookup_host(addr)
93                .await?
94                .next()
95                .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?,
96            domain,
97            version_specifiers,
98            metadata,
99        })
100    }
101}
102
103#[async_trait]
104impl ManageConnection for Manager {
105    type Connection = Client<Compat<BufStream<Stream>>>;
106    type Error = ClientError;
107
108    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
109        let mut client = Client::new(
110            BufStream::new(
111                Stream::connect(self.addr, self.domain.as_ref())
112                    .await
113                    .map_err(ConnectionError::from)?,
114            )
115            .compat(),
116            &self.version_specifiers,
117        )
118        .await?;
119
120        // TODO: Should we send HELLO now, or let the user do it later?
121        match client.hello(self.metadata.clone()).await? {
122            Message::Success(_) => Ok(client),
123            other => Err(CommunicationError::from(io::Error::new(
124                io::ErrorKind::ConnectionAborted,
125                format!("server responded with {:?}", other),
126            ))
127            .into()),
128        }
129    }
130
131    async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
132        message::Success::try_from(conn.reset().await?).map_err(ProtocolError::from)?;
133        Ok(())
134    }
135
136    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
137        conn.server_state() == ServerState::Defunct
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::env;
144
145    use bb8::*;
146    use bolt_proto::{version::*, Value};
147    use futures_util::{stream::FuturesUnordered, StreamExt};
148
149    use super::*;
150
151    async fn get_connection_manager(version_specifiers: [u32; 4], succeed: bool) -> Manager {
152        let credentials = if succeed {
153            env::var("BOLT_TEST_PASSWORD").unwrap()
154        } else {
155            String::from("invalid")
156        };
157
158        Manager::new(
159            env::var("BOLT_TEST_ADDR").unwrap(),
160            env::var("BOLT_TEST_DOMAIN").ok(),
161            version_specifiers,
162            Metadata::from_iter(vec![
163                ("user_agent", "bolt-client/X.Y.Z"),
164                ("scheme", "basic"),
165                ("principal", &env::var("BOLT_TEST_USERNAME").unwrap()),
166                ("credentials", &credentials),
167            ]),
168        )
169        .await
170        .unwrap()
171    }
172
173    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
174    async fn basic_pool() {
175        const POOL_SIZE: usize = 15;
176        const MAX_CONNS: usize = 50;
177
178        for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
179            let manager = get_connection_manager([bolt_version, 0, 0, 0], true).await;
180
181            // Don't even test connection pool if server doesn't support this Bolt version
182            match manager.connect().await {
183                Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
184                    println!(
185                        "skipping test: {}",
186                        ConnectionError::HandshakeFailed(versions)
187                    );
188                    continue;
189                }
190                Err(other) => panic!("{}", other),
191                _ => {}
192            }
193
194            let pool = Pool::builder()
195                .max_size(POOL_SIZE as u32)
196                .build(manager)
197                .await
198                .unwrap();
199
200            (0..MAX_CONNS)
201                .map(|i| {
202                    let pool = pool.clone();
203                    async move {
204                        let mut client = pool.get().await.unwrap();
205                        let statement = format!("RETURN {} as num;", i);
206                        client.run(statement, None, None).await.unwrap();
207                        let (records, response) = client
208                            .pull(Some(Metadata::from_iter(vec![("n", 1)])))
209                            .await
210                            .unwrap();
211                        assert!(message::Success::try_from(response).is_ok());
212                        assert_eq!(records[0].fields(), &[Value::from(i as i8)]);
213                    }
214                })
215                .collect::<FuturesUnordered<_>>()
216                .collect::<Vec<_>>()
217                .await;
218        }
219    }
220
221    #[tokio::test]
222    async fn invalid_init_fails() {
223        for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
224            let manager = get_connection_manager([bolt_version, 0, 0, 0], false).await;
225            match manager.connect().await {
226                Ok(_) => panic!("initialization should have failed"),
227                Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
228                    println!(
229                        "skipping test: {}",
230                        ConnectionError::HandshakeFailed(versions)
231                    );
232                    continue;
233                }
234                Err(ClientError::CommunicationError(comm_err)) => {
235                    if let CommunicationError::IoError(io_err) = &*comm_err {
236                        if io_err.kind() == io::ErrorKind::ConnectionAborted {
237                            // Test passed. We only check the first compatible version since
238                            // sending too many invalid credentials will cause us to get
239                            // rate-limited.
240                            return;
241                        }
242                    }
243                    panic!("{}", comm_err);
244                }
245                Err(other) => panic!("{}", other),
246            }
247        }
248    }
249}