1#![warn(rust_2018_idioms)]
2
3use std::{io, net::SocketAddr};
58
59use async_trait::async_trait;
60use bb8::ManageConnection;
61use tokio::{
62 io::BufStream,
63 net::{lookup_host, ToSocketAddrs},
64};
65use tokio_util::compat::*;
66
67use bolt_client::{
68 error::{CommunicationError, ConnectionError, Error as ClientError},
69 Client, Metadata, Stream,
70};
71use bolt_proto::{error::Error as ProtocolError, message, Message, ServerState};
72
73pub use bb8;
74pub use bolt_client;
75pub use bolt_client::bolt_proto;
76
77pub struct Manager {
78 addr: SocketAddr,
79 domain: Option<String>,
80 version_specifiers: [u32; 4],
81 metadata: Metadata,
82}
83
84impl Manager {
85 pub async fn new(
86 addr: impl ToSocketAddrs,
87 domain: Option<String>,
88 version_specifiers: [u32; 4],
89 metadata: Metadata,
90 ) -> io::Result<Self> {
91 Ok(Self {
92 addr: lookup_host(addr)
93 .await?
94 .next()
95 .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?,
96 domain,
97 version_specifiers,
98 metadata,
99 })
100 }
101}
102
103#[async_trait]
104impl ManageConnection for Manager {
105 type Connection = Client<Compat<BufStream<Stream>>>;
106 type Error = ClientError;
107
108 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
109 let mut client = Client::new(
110 BufStream::new(
111 Stream::connect(self.addr, self.domain.as_ref())
112 .await
113 .map_err(ConnectionError::from)?,
114 )
115 .compat(),
116 &self.version_specifiers,
117 )
118 .await?;
119
120 match client.hello(self.metadata.clone()).await? {
122 Message::Success(_) => Ok(client),
123 other => Err(CommunicationError::from(io::Error::new(
124 io::ErrorKind::ConnectionAborted,
125 format!("server responded with {:?}", other),
126 ))
127 .into()),
128 }
129 }
130
131 async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
132 message::Success::try_from(conn.reset().await?).map_err(ProtocolError::from)?;
133 Ok(())
134 }
135
136 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
137 conn.server_state() == ServerState::Defunct
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::env;
144
145 use bb8::*;
146 use bolt_proto::{version::*, Value};
147 use futures_util::{stream::FuturesUnordered, StreamExt};
148
149 use super::*;
150
151 async fn get_connection_manager(version_specifiers: [u32; 4], succeed: bool) -> Manager {
152 let credentials = if succeed {
153 env::var("BOLT_TEST_PASSWORD").unwrap()
154 } else {
155 String::from("invalid")
156 };
157
158 Manager::new(
159 env::var("BOLT_TEST_ADDR").unwrap(),
160 env::var("BOLT_TEST_DOMAIN").ok(),
161 version_specifiers,
162 Metadata::from_iter(vec![
163 ("user_agent", "bolt-client/X.Y.Z"),
164 ("scheme", "basic"),
165 ("principal", &env::var("BOLT_TEST_USERNAME").unwrap()),
166 ("credentials", &credentials),
167 ]),
168 )
169 .await
170 .unwrap()
171 }
172
173 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
174 async fn basic_pool() {
175 const POOL_SIZE: usize = 15;
176 const MAX_CONNS: usize = 50;
177
178 for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
179 let manager = get_connection_manager([bolt_version, 0, 0, 0], true).await;
180
181 match manager.connect().await {
183 Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
184 println!(
185 "skipping test: {}",
186 ConnectionError::HandshakeFailed(versions)
187 );
188 continue;
189 }
190 Err(other) => panic!("{}", other),
191 _ => {}
192 }
193
194 let pool = Pool::builder()
195 .max_size(POOL_SIZE as u32)
196 .build(manager)
197 .await
198 .unwrap();
199
200 (0..MAX_CONNS)
201 .map(|i| {
202 let pool = pool.clone();
203 async move {
204 let mut client = pool.get().await.unwrap();
205 let statement = format!("RETURN {} as num;", i);
206 client.run(statement, None, None).await.unwrap();
207 let (records, response) = client
208 .pull(Some(Metadata::from_iter(vec![("n", 1)])))
209 .await
210 .unwrap();
211 assert!(message::Success::try_from(response).is_ok());
212 assert_eq!(records[0].fields(), &[Value::from(i as i8)]);
213 }
214 })
215 .collect::<FuturesUnordered<_>>()
216 .collect::<Vec<_>>()
217 .await;
218 }
219 }
220
221 #[tokio::test]
222 async fn invalid_init_fails() {
223 for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
224 let manager = get_connection_manager([bolt_version, 0, 0, 0], false).await;
225 match manager.connect().await {
226 Ok(_) => panic!("initialization should have failed"),
227 Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
228 println!(
229 "skipping test: {}",
230 ConnectionError::HandshakeFailed(versions)
231 );
232 continue;
233 }
234 Err(ClientError::CommunicationError(comm_err)) => {
235 if let CommunicationError::IoError(io_err) = &*comm_err {
236 if io_err.kind() == io::ErrorKind::ConnectionAborted {
237 return;
241 }
242 }
243 panic!("{}", comm_err);
244 }
245 Err(other) => panic!("{}", other),
246 }
247 }
248 }
249}