mobc_bolt/
lib.rs

1#![warn(rust_2018_idioms)]
2
3//! [`bolt_client`] support for the [`mobc`] connection pool.
4//!
5//! # Example
6//! ```
7//! use std::env;
8//!
9//! # use mobc::Manager as MobcManager;
10//! use mobc_bolt::{
11//!     mobc::Pool,
12//!     bolt_client::Metadata,
13//!     bolt_proto::{version::*, Value},
14//!     Manager,
15//! #   bolt_client::error::{ConnectionError, Error as ClientError},
16//! #   bolt_proto::message::Success,
17//! };
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     // Configure a connection manager. We'll request Bolt v4.4 or v4.3.
22//!     let manager = Manager::new(
23//!         env::var("BOLT_TEST_ADDR")?,
24//!         env::var("BOLT_TEST_DOMAIN").ok(),
25//!         [V4_4, V4_3, 0, 0],
26//!         Metadata::from_iter(vec![
27//!             ("user_agent", "bolt-client/X.Y.Z"),
28//!             ("scheme", "basic"),
29//!             ("principal", &env::var("BOLT_TEST_USERNAME")?),
30//!             ("credentials", &env::var("BOLT_TEST_PASSWORD")?),
31//!         ]),
32//!     ).await?;
33//!
34//! #   match manager.connect().await {
35//! #       Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
36//! #           println!("skipping test: {}", ConnectionError::HandshakeFailed(versions));
37//! #           return Ok(());
38//! #       }
39//! #       Err(other) => panic!("{}", other),
40//! #       _ => {}
41//! #   }
42//!     // Create a connection pool. This should be shared across your application.
43//!     let pool = Pool::builder().build(manager);
44//!
45//!     // Fetch and use a connection from the pool
46//!     let mut conn = pool.get().await?;
47//!     let response = conn.run("RETURN 1 as num;", None, None).await?;
48//! #   Success::try_from(response.clone()).unwrap();
49//!     let pull_meta = Metadata::from_iter(vec![("n", 1)]);
50//!     let (records, response) = conn.pull(Some(pull_meta)).await?;
51//! #   Success::try_from(response).unwrap();
52//!     assert_eq!(records[0].fields(), &[Value::from(1)]);
53//!
54//!     Ok(())
55//! }
56
57use std::{io, net::SocketAddr};
58
59use async_trait::async_trait;
60use tokio::{
61    io::BufStream,
62    net::{lookup_host, ToSocketAddrs},
63};
64use tokio_util::compat::*;
65
66use bolt_client::{
67    error::{CommunicationError, ConnectionError, Error as ClientError},
68    Client, Metadata, Stream,
69};
70use bolt_proto::{error::Error as ProtocolError, message, Message, ServerState};
71
72pub use bolt_client;
73pub use bolt_client::bolt_proto;
74pub use mobc;
75
76pub struct Manager {
77    addr: SocketAddr,
78    domain: Option<String>,
79    version_specifiers: [u32; 4],
80    metadata: Metadata,
81}
82
83impl Manager {
84    pub async fn new(
85        addr: impl ToSocketAddrs,
86        domain: Option<String>,
87        version_specifiers: [u32; 4],
88        metadata: Metadata,
89    ) -> io::Result<Self> {
90        Ok(Self {
91            addr: lookup_host(addr)
92                .await?
93                .next()
94                .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?,
95            domain,
96            version_specifiers,
97            metadata,
98        })
99    }
100}
101
102#[async_trait]
103impl mobc::Manager for Manager {
104    // TODO: Make a runtime-agnostic stream wrapper
105    type Connection = Client<Compat<BufStream<Stream>>>;
106    type Error = ClientError;
107
108    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
109        let mut client = Client::new(
110            BufStream::new(
111                Stream::connect(self.addr, self.domain.as_ref())
112                    .await
113                    .map_err(ConnectionError::from)?,
114            )
115            .compat(),
116            &self.version_specifiers,
117        )
118        .await?;
119
120        match client.hello(self.metadata.clone()).await? {
121            Message::Success(_) => Ok(client),
122            other => Err(CommunicationError::from(io::Error::new(
123                io::ErrorKind::ConnectionAborted,
124                format!("server responded with {:?}", other),
125            ))
126            .into()),
127        }
128    }
129
130    async fn check(&self, mut conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
131        message::Success::try_from(conn.reset().await.map_err(Self::Error::from)?)
132            .map_err(ProtocolError::from)
133            .map_err(Self::Error::from)?;
134        Ok(conn)
135    }
136
137    fn validate(&self, conn: &mut Self::Connection) -> bool {
138        conn.server_state() != ServerState::Defunct
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use std::env;
145
146    use bolt_proto::{version::*, Value};
147    use futures_util::{stream::FuturesUnordered, StreamExt};
148    use mobc::{Manager as MobcManager, Pool};
149
150    use super::*;
151
152    async fn get_connection_manager(version_specifiers: [u32; 4], succeed: bool) -> Manager {
153        let credentials = if succeed {
154            env::var("BOLT_TEST_PASSWORD").unwrap()
155        } else {
156            String::from("invalid")
157        };
158
159        Manager::new(
160            env::var("BOLT_TEST_ADDR").unwrap(),
161            env::var("BOLT_TEST_DOMAIN").ok(),
162            version_specifiers,
163            Metadata::from_iter(vec![
164                ("user_agent", "bolt-client/X.Y.Z"),
165                ("scheme", "basic"),
166                ("principal", &env::var("BOLT_TEST_USERNAME").unwrap()),
167                ("credentials", &credentials),
168            ]),
169        )
170        .await
171        .unwrap()
172    }
173
174    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
175    async fn basic_pool() {
176        const POOL_SIZE: u64 = 15;
177        const MAX_CONNS: usize = 50;
178
179        for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
180            let manager = get_connection_manager([bolt_version, 0, 0, 0], true).await;
181
182            // Don't even test connection pool if server doesn't support this Bolt version
183            match manager.connect().await {
184                Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
185                    println!(
186                        "skipping test: {}",
187                        ConnectionError::HandshakeFailed(versions)
188                    );
189                    continue;
190                }
191                Err(other) => panic!("{}", other),
192                _ => {}
193            }
194
195            let pool = Pool::builder().max_open(POOL_SIZE).build(manager);
196
197            (0..MAX_CONNS)
198                .map(|i| {
199                    let pool = pool.clone();
200                    async move {
201                        let mut client = pool.get().await.unwrap();
202                        let statement = format!("RETURN {} as num;", i);
203                        client.run(statement, None, None).await.unwrap();
204                        let (records, response) = client
205                            .pull(Some(Metadata::from_iter(vec![("n", 1)])))
206                            .await
207                            .unwrap();
208                        assert!(message::Success::try_from(response).is_ok());
209                        assert_eq!(records[0].fields(), &[Value::from(i as i8)]);
210                    }
211                })
212                .collect::<FuturesUnordered<_>>()
213                .collect::<Vec<_>>()
214                .await;
215        }
216    }
217
218    #[tokio::test]
219    async fn invalid_init_fails() {
220        for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
221            let manager = get_connection_manager([bolt_version, 0, 0, 0], false).await;
222            match manager.connect().await {
223                Ok(_) => panic!("initialization should have failed"),
224                Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
225                    println!(
226                        "skipping test: {}",
227                        ConnectionError::HandshakeFailed(versions)
228                    );
229                    continue;
230                }
231                Err(ClientError::CommunicationError(comm_err)) => {
232                    if let CommunicationError::IoError(io_err) = &*comm_err {
233                        if io_err.kind() == io::ErrorKind::ConnectionAborted {
234                            // Test passed. We only check the first compatible version since
235                            // sending too many invalid credentials will cause us to get
236                            // rate-limited.
237                            return;
238                        }
239                    }
240                    panic!("{}", comm_err);
241                }
242                Err(other) => panic!("{}", other),
243            }
244        }
245    }
246}