1#![warn(rust_2018_idioms)]
2
3use std::{convert::Infallible, io, net::SocketAddr};
58
59use async_trait::async_trait;
60use deadpool::managed::RecycleResult;
61use tokio::{
62 io::BufStream,
63 net::{lookup_host, ToSocketAddrs},
64};
65use tokio_util::compat::*;
66
67use bolt_client::{
68 error::{CommunicationError, ConnectionError, Error as ClientError},
69 Client, Metadata, Stream,
70};
71use bolt_proto::{error::Error as ProtocolError, message, Message};
72
73pub use bolt_client;
74pub use bolt_client::bolt_proto;
75
76pub use deadpool::managed::reexports::*;
77deadpool::managed_reexports!(
78 "bolt_client",
79 Manager,
80 deadpool::managed::Object<Manager>,
81 ClientError,
82 Infallible
84);
85
86pub struct Manager {
87 addr: SocketAddr,
88 domain: Option<String>,
89 version_specifiers: [u32; 4],
90 metadata: Metadata,
91}
92
93impl Manager {
94 pub async fn new(
95 addr: impl ToSocketAddrs,
96 domain: Option<String>,
97 version_specifiers: [u32; 4],
98 metadata: Metadata,
99 ) -> io::Result<Self> {
100 Ok(Self {
101 addr: lookup_host(addr)
102 .await?
103 .next()
104 .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?,
105 domain,
106 version_specifiers,
107 metadata,
108 })
109 }
110}
111
112#[async_trait]
113impl deadpool::managed::Manager for Manager {
114 type Type = Client<Compat<BufStream<Stream>>>;
116 type Error = ClientError;
117
118 async fn create(&self) -> Result<Self::Type, Self::Error> {
119 let mut client = Client::new(
120 BufStream::new(
121 Stream::connect(self.addr, self.domain.as_ref())
122 .await
123 .map_err(ConnectionError::from)?,
124 )
125 .compat(),
126 &self.version_specifiers,
127 )
128 .await?;
129
130 match client.hello(self.metadata.clone()).await? {
131 Message::Success(_) => Ok(client),
132 other => Err(CommunicationError::from(io::Error::new(
133 io::ErrorKind::ConnectionAborted,
134 format!("server responded with {:?}", other),
135 ))
136 .into()),
137 }
138 }
139
140 async fn recycle(&self, conn: &mut Self::Type) -> RecycleResult<Self::Error> {
141 message::Success::try_from(conn.reset().await.map_err(Self::Error::from)?)
142 .map_err(ProtocolError::from)
143 .map_err(Self::Error::from)?;
144 Ok(())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::env;
151
152 use bolt_proto::{version::*, Value};
153 use deadpool::managed::Manager as DeadpoolManager;
154 use futures_util::{stream::FuturesUnordered, StreamExt};
155
156 use super::*;
157
158 async fn get_connection_manager(version_specifiers: [u32; 4], succeed: bool) -> Manager {
159 let credentials = if succeed {
160 env::var("BOLT_TEST_PASSWORD").unwrap()
161 } else {
162 String::from("invalid")
163 };
164
165 Manager::new(
166 env::var("BOLT_TEST_ADDR").unwrap(),
167 env::var("BOLT_TEST_DOMAIN").ok(),
168 version_specifiers,
169 Metadata::from_iter(vec![
170 ("user_agent", "bolt-client/X.Y.Z"),
171 ("scheme", "basic"),
172 ("principal", &env::var("BOLT_TEST_USERNAME").unwrap()),
173 ("credentials", &credentials),
174 ]),
175 )
176 .await
177 .unwrap()
178 }
179
180 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
181 async fn basic_pool() {
182 const POOL_SIZE: usize = 15;
183 const MAX_CONNS: usize = 50;
184
185 for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
186 let manager = get_connection_manager([bolt_version, 0, 0, 0], true).await;
187
188 match manager.create().await {
190 Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
191 println!(
192 "skipping test: {}",
193 ConnectionError::HandshakeFailed(versions)
194 );
195 continue;
196 }
197 Err(other) => panic!("{}", other),
198 _ => {}
199 }
200
201 let pool = Pool::builder(manager).max_size(POOL_SIZE).build().unwrap();
202
203 (0..MAX_CONNS)
204 .map(|i| {
205 let pool = pool.clone();
206 async move {
207 let mut client = pool.get().await.unwrap();
208 let statement = format!("RETURN {} as num;", i);
209 client.run(statement, None, None).await.unwrap();
210 let (records, response) = client
211 .pull(Some(Metadata::from_iter(vec![("n", 1)])))
212 .await
213 .unwrap();
214 assert!(message::Success::try_from(response).is_ok());
215 assert_eq!(records[0].fields(), &[Value::from(i as i8)]);
216 }
217 })
218 .collect::<FuturesUnordered<_>>()
219 .collect::<Vec<_>>()
220 .await;
221 }
222 }
223
224 #[tokio::test]
225 async fn invalid_init_fails() {
226 for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
227 let manager = get_connection_manager([bolt_version, 0, 0, 0], false).await;
228 match manager.create().await {
229 Ok(_) => panic!("initialization should have failed"),
230 Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
231 println!(
232 "skipping test: {}",
233 ConnectionError::HandshakeFailed(versions)
234 );
235 continue;
236 }
237 Err(ClientError::CommunicationError(comm_err)) => {
238 if let CommunicationError::IoError(io_err) = &*comm_err {
239 if io_err.kind() == io::ErrorKind::ConnectionAborted {
240 return;
244 }
245 }
246 panic!("{}", comm_err);
247 }
248 Err(other) => panic!("{}", other),
249 }
250 }
251 }
252}