1#![warn(rust_2018_idioms)]
2
3use std::{io, net::SocketAddr};
58
59use async_trait::async_trait;
60use tokio::{
61 io::BufStream,
62 net::{lookup_host, ToSocketAddrs},
63};
64use tokio_util::compat::*;
65
66use bolt_client::{
67 error::{CommunicationError, ConnectionError, Error as ClientError},
68 Client, Metadata, Stream,
69};
70use bolt_proto::{error::Error as ProtocolError, message, Message, ServerState};
71
72pub use bolt_client;
73pub use bolt_client::bolt_proto;
74pub use mobc;
75
76pub struct Manager {
77 addr: SocketAddr,
78 domain: Option<String>,
79 version_specifiers: [u32; 4],
80 metadata: Metadata,
81}
82
83impl Manager {
84 pub async fn new(
85 addr: impl ToSocketAddrs,
86 domain: Option<String>,
87 version_specifiers: [u32; 4],
88 metadata: Metadata,
89 ) -> io::Result<Self> {
90 Ok(Self {
91 addr: lookup_host(addr)
92 .await?
93 .next()
94 .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?,
95 domain,
96 version_specifiers,
97 metadata,
98 })
99 }
100}
101
102#[async_trait]
103impl mobc::Manager for Manager {
104 type Connection = Client<Compat<BufStream<Stream>>>;
106 type Error = ClientError;
107
108 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
109 let mut client = Client::new(
110 BufStream::new(
111 Stream::connect(self.addr, self.domain.as_ref())
112 .await
113 .map_err(ConnectionError::from)?,
114 )
115 .compat(),
116 &self.version_specifiers,
117 )
118 .await?;
119
120 match client.hello(self.metadata.clone()).await? {
121 Message::Success(_) => Ok(client),
122 other => Err(CommunicationError::from(io::Error::new(
123 io::ErrorKind::ConnectionAborted,
124 format!("server responded with {:?}", other),
125 ))
126 .into()),
127 }
128 }
129
130 async fn check(&self, mut conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
131 message::Success::try_from(conn.reset().await.map_err(Self::Error::from)?)
132 .map_err(ProtocolError::from)
133 .map_err(Self::Error::from)?;
134 Ok(conn)
135 }
136
137 fn validate(&self, conn: &mut Self::Connection) -> bool {
138 conn.server_state() != ServerState::Defunct
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use std::env;
145
146 use bolt_proto::{version::*, Value};
147 use futures_util::{stream::FuturesUnordered, StreamExt};
148 use mobc::{Manager as MobcManager, Pool};
149
150 use super::*;
151
152 async fn get_connection_manager(version_specifiers: [u32; 4], succeed: bool) -> Manager {
153 let credentials = if succeed {
154 env::var("BOLT_TEST_PASSWORD").unwrap()
155 } else {
156 String::from("invalid")
157 };
158
159 Manager::new(
160 env::var("BOLT_TEST_ADDR").unwrap(),
161 env::var("BOLT_TEST_DOMAIN").ok(),
162 version_specifiers,
163 Metadata::from_iter(vec![
164 ("user_agent", "bolt-client/X.Y.Z"),
165 ("scheme", "basic"),
166 ("principal", &env::var("BOLT_TEST_USERNAME").unwrap()),
167 ("credentials", &credentials),
168 ]),
169 )
170 .await
171 .unwrap()
172 }
173
174 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
175 async fn basic_pool() {
176 const POOL_SIZE: u64 = 15;
177 const MAX_CONNS: usize = 50;
178
179 for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
180 let manager = get_connection_manager([bolt_version, 0, 0, 0], true).await;
181
182 match manager.connect().await {
184 Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
185 println!(
186 "skipping test: {}",
187 ConnectionError::HandshakeFailed(versions)
188 );
189 continue;
190 }
191 Err(other) => panic!("{}", other),
192 _ => {}
193 }
194
195 let pool = Pool::builder().max_open(POOL_SIZE).build(manager);
196
197 (0..MAX_CONNS)
198 .map(|i| {
199 let pool = pool.clone();
200 async move {
201 let mut client = pool.get().await.unwrap();
202 let statement = format!("RETURN {} as num;", i);
203 client.run(statement, None, None).await.unwrap();
204 let (records, response) = client
205 .pull(Some(Metadata::from_iter(vec![("n", 1)])))
206 .await
207 .unwrap();
208 assert!(message::Success::try_from(response).is_ok());
209 assert_eq!(records[0].fields(), &[Value::from(i as i8)]);
210 }
211 })
212 .collect::<FuturesUnordered<_>>()
213 .collect::<Vec<_>>()
214 .await;
215 }
216 }
217
218 #[tokio::test]
219 async fn invalid_init_fails() {
220 for &bolt_version in &[V1_0, V2_0, V3_0, V4_0, V4_1, V4_2, V4_3, V4_4, V4] {
221 let manager = get_connection_manager([bolt_version, 0, 0, 0], false).await;
222 match manager.connect().await {
223 Ok(_) => panic!("initialization should have failed"),
224 Err(ClientError::ConnectionError(ConnectionError::HandshakeFailed(versions))) => {
225 println!(
226 "skipping test: {}",
227 ConnectionError::HandshakeFailed(versions)
228 );
229 continue;
230 }
231 Err(ClientError::CommunicationError(comm_err)) => {
232 if let CommunicationError::IoError(io_err) = &*comm_err {
233 if io_err.kind() == io::ErrorKind::ConnectionAborted {
234 return;
238 }
239 }
240 panic!("{}", comm_err);
241 }
242 Err(other) => panic!("{}", other),
243 }
244 }
245 }
246}