1use std::{net::TcpStream, time::Duration};
2
3use crate::{GraphSON, GremlinError, GremlinResult};
4use native_tls::TlsConnector;
5use tungstenite::{
6 client::{uri_mode, IntoClientRequest},
7 client_tls_with_config,
8 protocol::WebSocketConfig,
9 stream::{MaybeTlsStream, Mode, NoDelay},
10 Connector, Message, WebSocket,
11};
12
13struct ConnectionStream(WebSocket<MaybeTlsStream<TcpStream>>);
14
15impl std::fmt::Debug for ConnectionStream {
16 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
17 write!(f, "Connection")
18 }
19}
20
21impl ConnectionStream {
22 fn connect(options: ConnectionOptions) -> GremlinResult<Self> {
23 let connector = match options.tls_options.as_ref() {
24 Some(option) => Some(Connector::NativeTls(
25 option
26 .tls_connector()
27 .map_err(|e| GremlinError::Generic(e.to_string()))?,
28 )),
29 _ => None,
30 };
31
32 let request = options
33 .websocket_url()
34 .into_client_request()
35 .map_err(|e| GremlinError::Generic(e.to_string()))?;
36 let uri = request.uri();
37 let mode = uri_mode(uri).map_err(|e| GremlinError::Generic(e.to_string()))?;
38 let host = request
39 .uri()
40 .host()
41 .ok_or_else(|| GremlinError::Generic("No Hostname".into()))?;
42 let port = uri.port_u16().unwrap_or(match mode {
43 Mode::Plain => 80,
44 Mode::Tls => 443,
45 });
46 let mut stream = TcpStream::connect((host, port))
47 .map_err(|e| GremlinError::Generic(format!("Unable to connect {e:?}")))?;
48 NoDelay::set_nodelay(&mut stream, true)
49 .map_err(|e| GremlinError::Generic(e.to_string()))?;
50
51 let websocket_config = options
52 .websocket_options
53 .as_ref()
54 .map(WebSocketConfig::from);
55
56 let (client, _response) =
57 client_tls_with_config(options.websocket_url(), stream, websocket_config, connector)
58 .map_err(|e| GremlinError::Generic(e.to_string()))?;
59
60 Ok(ConnectionStream(client))
61 }
62
63 fn send(&mut self, payload: Vec<u8>) -> GremlinResult<()> {
64 self.0
65 .write_message(Message::Binary(payload))
66 .map_err(GremlinError::from)
67 }
68
69 fn recv(&mut self) -> GremlinResult<Vec<u8>> {
70 match self.0.read_message()? {
71 Message::Binary(binary) => Ok(binary),
72 _ => unimplemented!(),
73 }
74 }
75}
76
77#[derive(Debug)]
78pub(crate) struct Connection {
79 stream: ConnectionStream,
80 broken: bool,
81}
82
83impl Into<ConnectionOptions> for (&str, u16) {
84 fn into(self) -> ConnectionOptions {
85 ConnectionOptions {
86 host: String::from(self.0),
87 port: self.1,
88 ..Default::default()
89 }
90 }
91}
92
93impl Into<ConnectionOptions> for &str {
94 fn into(self) -> ConnectionOptions {
95 ConnectionOptions {
96 host: String::from(self),
97 ..Default::default()
98 }
99 }
100}
101
102pub struct ConnectionOptionsBuilder(ConnectionOptions);
103
104impl ConnectionOptionsBuilder {
105 pub fn host<T>(mut self, host: T) -> Self
106 where
107 T: Into<String>,
108 {
109 self.0.host = host.into();
110 self
111 }
112
113 pub fn port(mut self, port: u16) -> Self {
114 self.0.port = port;
115 self
116 }
117
118 pub fn pool_size(mut self, pool_size: u32) -> Self {
119 self.0.pool_size = pool_size;
120 self
121 }
122
123 pub fn pool_connection_timeout(mut self, pool_connection_timeout: Option<Duration>) -> Self {
126 self.0.pool_get_connection_timeout = pool_connection_timeout;
127 self
128 }
129
130 pub fn build(self) -> ConnectionOptions {
131 self.0
132 }
133
134 pub fn credentials(mut self, username: &str, password: &str) -> Self {
135 self.0.credentials = Some(Credentials {
136 username: String::from(username),
137 password: String::from(password),
138 });
139 self
140 }
141
142 pub fn ssl(mut self, ssl: bool) -> Self {
143 self.0.ssl = ssl;
144 self
145 }
146
147 pub fn tls_options(mut self, options: TlsOptions) -> Self {
148 self.0.tls_options = Some(options);
149 self
150 }
151
152 pub fn websocket_options(mut self, options: WebSocketOptions) -> Self {
153 self.0.websocket_options = Some(options);
154 self
155 }
156
157 pub fn serializer(mut self, serializer: GraphSON) -> Self {
158 self.0.serializer = serializer;
159 self
160 }
161
162 pub fn deserializer(mut self, deserializer: GraphSON) -> Self {
163 self.0.deserializer = deserializer;
164 self
165 }
166}
167
168#[derive(Clone, Debug)]
169pub struct ConnectionOptions {
170 pub(crate) host: String,
171 pub(crate) port: u16,
172 pub(crate) pool_size: u32,
173 pub(crate) pool_get_connection_timeout: Option<Duration>,
174 pub(crate) credentials: Option<Credentials>,
175 pub(crate) ssl: bool,
176 pub(crate) tls_options: Option<TlsOptions>,
177 pub(crate) serializer: GraphSON,
178 pub(crate) deserializer: GraphSON,
179 pub(crate) websocket_options: Option<WebSocketOptions>,
180}
181
182#[derive(Clone, Debug)]
183pub(crate) struct Credentials {
184 pub(crate) username: String,
185 pub(crate) password: String,
186}
187
188#[derive(Clone, Debug)]
189pub struct TlsOptions {
190 pub accept_invalid_certs: bool,
191}
192
193#[derive(Clone, Debug)]
194pub struct WebSocketOptions {
195 pub(crate) max_message_size: Option<usize>,
197 pub(crate) max_frame_size: Option<usize>,
200}
201
202impl WebSocketOptions {
203 pub fn builder() -> WebSocketOptionsBuilder {
204 WebSocketOptionsBuilder(Self::default())
205 }
206}
207
208impl Default for WebSocketOptions {
209 fn default() -> Self {
210 Self {
211 max_message_size: Some(64 << 20),
212 max_frame_size: Some(16 << 20),
213 }
214 }
215}
216
217impl From<WebSocketOptions> for tungstenite::protocol::WebSocketConfig {
218 fn from(value: WebSocketOptions) -> Self {
219 (&value).into()
220 }
221}
222
223impl From<&WebSocketOptions> for tungstenite::protocol::WebSocketConfig {
224 fn from(value: &WebSocketOptions) -> Self {
225 let mut config = tungstenite::protocol::WebSocketConfig::default();
226 config.max_message_size = value.max_message_size;
227 config.max_frame_size = value.max_frame_size;
228 config
229 }
230}
231
232pub struct WebSocketOptionsBuilder(WebSocketOptions);
233
234impl WebSocketOptionsBuilder {
235 pub fn build(self) -> WebSocketOptions {
236 self.0
237 }
238
239 pub fn max_message_size(mut self, max_message_size: Option<usize>) -> Self {
240 self.0.max_message_size = max_message_size;
241 self
242 }
243
244 pub fn max_frame_size(mut self, max_frame_size: Option<usize>) -> Self {
245 self.0.max_frame_size = max_frame_size;
246 self
247 }
248}
249
250impl Default for ConnectionOptions {
251 fn default() -> ConnectionOptions {
252 ConnectionOptions {
253 host: String::from("localhost"),
254 port: 8182,
255 pool_size: 10,
256 pool_get_connection_timeout: Some(Duration::from_secs(30)),
257 credentials: None,
258 ssl: false,
259 tls_options: None,
260 serializer: GraphSON::V3,
261 deserializer: GraphSON::V3,
262 websocket_options: None,
263 }
264 }
265}
266
267impl ConnectionOptions {
268 pub fn builder() -> ConnectionOptionsBuilder {
269 ConnectionOptionsBuilder(ConnectionOptions::default())
270 }
271
272 pub fn websocket_url(&self) -> String {
273 let protocol = if self.ssl { "wss" } else { "ws" };
274 format!("{}://{}:{}/gremlin", protocol, self.host, self.port)
275 }
276}
277
278impl Connection {
279 pub fn connect<T>(options: T) -> GremlinResult<Connection>
280 where
281 T: Into<ConnectionOptions>,
282 {
283 Ok(Connection {
284 stream: ConnectionStream::connect(options.into())?,
285 broken: false,
286 })
287 }
288
289 pub fn send(&mut self, payload: Vec<u8>) -> GremlinResult<()> {
290 self.stream.send(payload).map_err(|e| {
291 if let GremlinError::WebSocket(_) = e {
292 self.broken = true;
293 }
294 e
295 })
296 }
297
298 pub fn recv(&mut self) -> GremlinResult<Vec<u8>> {
299 self.stream.recv().map_err(|e| {
300 if let GremlinError::WebSocket(_) = e {
301 self.broken = true
302 }
303 e
304 })
305 }
306
307 pub fn is_broken(&self) -> bool {
308 self.broken
309 }
310}
311
312impl TlsOptions {
313 pub(crate) fn tls_connector(&self) -> native_tls::Result<TlsConnector> {
314 TlsConnector::builder()
315 .danger_accept_invalid_certs(self.accept_invalid_certs)
316 .build()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn it_should_connect() {
326 Connection::connect(("localhost", 8182)).unwrap();
327 }
328
329 #[test]
330 fn connection_option_build_url() {
331 let options = ConnectionOptions {
332 host: "localhost".into(),
333 port: 8182,
334 ssl: false,
335 ..Default::default()
336 };
337
338 assert_eq!(options.websocket_url(), "ws://localhost:8182/gremlin");
339
340 let options = ConnectionOptions {
341 host: "localhost".into(),
342 port: 8182,
343 ssl: true,
344 ..Default::default()
345 };
346
347 assert_eq!(options.websocket_url(), "wss://localhost:8182/gremlin");
348 }
349}