gremlin_client/
connection.rs

1use std::{net::TcpStream, time::Duration};
2
3use crate::{GraphSON, GremlinError, GremlinResult};
4use native_tls::TlsConnector;
5use tungstenite::{
6    client::{uri_mode, IntoClientRequest},
7    client_tls_with_config,
8    protocol::WebSocketConfig,
9    stream::{MaybeTlsStream, Mode, NoDelay},
10    Connector, Message, WebSocket,
11};
12
13struct ConnectionStream(WebSocket<MaybeTlsStream<TcpStream>>);
14
15impl std::fmt::Debug for ConnectionStream {
16    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
17        write!(f, "Connection")
18    }
19}
20
21impl ConnectionStream {
22    fn connect(options: ConnectionOptions) -> GremlinResult<Self> {
23        let connector = match options.tls_options.as_ref() {
24            Some(option) => Some(Connector::NativeTls(
25                option
26                    .tls_connector()
27                    .map_err(|e| GremlinError::Generic(e.to_string()))?,
28            )),
29            _ => None,
30        };
31
32        let request = options
33            .websocket_url()
34            .into_client_request()
35            .map_err(|e| GremlinError::Generic(e.to_string()))?;
36        let uri = request.uri();
37        let mode = uri_mode(uri).map_err(|e| GremlinError::Generic(e.to_string()))?;
38        let host = request
39            .uri()
40            .host()
41            .ok_or_else(|| GremlinError::Generic("No Hostname".into()))?;
42        let port = uri.port_u16().unwrap_or(match mode {
43            Mode::Plain => 80,
44            Mode::Tls => 443,
45        });
46        let mut stream = TcpStream::connect((host, port))
47            .map_err(|e| GremlinError::Generic(format!("Unable to connect {e:?}")))?;
48        NoDelay::set_nodelay(&mut stream, true)
49            .map_err(|e| GremlinError::Generic(e.to_string()))?;
50
51        let websocket_config = options
52            .websocket_options
53            .as_ref()
54            .map(WebSocketConfig::from);
55
56        let (client, _response) =
57            client_tls_with_config(options.websocket_url(), stream, websocket_config, connector)
58                .map_err(|e| GremlinError::Generic(e.to_string()))?;
59
60        Ok(ConnectionStream(client))
61    }
62
63    fn send(&mut self, payload: Vec<u8>) -> GremlinResult<()> {
64        self.0
65            .write_message(Message::Binary(payload))
66            .map_err(GremlinError::from)
67    }
68
69    fn recv(&mut self) -> GremlinResult<Vec<u8>> {
70        match self.0.read_message()? {
71            Message::Binary(binary) => Ok(binary),
72            _ => unimplemented!(),
73        }
74    }
75}
76
77#[derive(Debug)]
78pub(crate) struct Connection {
79    stream: ConnectionStream,
80    broken: bool,
81}
82
83impl Into<ConnectionOptions> for (&str, u16) {
84    fn into(self) -> ConnectionOptions {
85        ConnectionOptions {
86            host: String::from(self.0),
87            port: self.1,
88            ..Default::default()
89        }
90    }
91}
92
93impl Into<ConnectionOptions> for &str {
94    fn into(self) -> ConnectionOptions {
95        ConnectionOptions {
96            host: String::from(self),
97            ..Default::default()
98        }
99    }
100}
101
102pub struct ConnectionOptionsBuilder(ConnectionOptions);
103
104impl ConnectionOptionsBuilder {
105    pub fn host<T>(mut self, host: T) -> Self
106    where
107        T: Into<String>,
108    {
109        self.0.host = host.into();
110        self
111    }
112
113    pub fn port(mut self, port: u16) -> Self {
114        self.0.port = port;
115        self
116    }
117
118    pub fn pool_size(mut self, pool_size: u32) -> Self {
119        self.0.pool_size = pool_size;
120        self
121    }
122
123    /// Both the sync and async pool providers use a default of 30 seconds,
124    /// Async pool interprets `None` as no timeout. Sync pool maps `None` to the default value
125    pub fn pool_connection_timeout(mut self, pool_connection_timeout: Option<Duration>) -> Self {
126        self.0.pool_get_connection_timeout = pool_connection_timeout;
127        self
128    }
129
130    pub fn build(self) -> ConnectionOptions {
131        self.0
132    }
133
134    pub fn credentials(mut self, username: &str, password: &str) -> Self {
135        self.0.credentials = Some(Credentials {
136            username: String::from(username),
137            password: String::from(password),
138        });
139        self
140    }
141
142    pub fn ssl(mut self, ssl: bool) -> Self {
143        self.0.ssl = ssl;
144        self
145    }
146
147    pub fn tls_options(mut self, options: TlsOptions) -> Self {
148        self.0.tls_options = Some(options);
149        self
150    }
151
152    pub fn websocket_options(mut self, options: WebSocketOptions) -> Self {
153        self.0.websocket_options = Some(options);
154        self
155    }
156
157    pub fn serializer(mut self, serializer: GraphSON) -> Self {
158        self.0.serializer = serializer;
159        self
160    }
161
162    pub fn deserializer(mut self, deserializer: GraphSON) -> Self {
163        self.0.deserializer = deserializer;
164        self
165    }
166}
167
168#[derive(Clone, Debug)]
169pub struct ConnectionOptions {
170    pub(crate) host: String,
171    pub(crate) port: u16,
172    pub(crate) pool_size: u32,
173    pub(crate) pool_get_connection_timeout: Option<Duration>,
174    pub(crate) credentials: Option<Credentials>,
175    pub(crate) ssl: bool,
176    pub(crate) tls_options: Option<TlsOptions>,
177    pub(crate) serializer: GraphSON,
178    pub(crate) deserializer: GraphSON,
179    pub(crate) websocket_options: Option<WebSocketOptions>,
180}
181
182#[derive(Clone, Debug)]
183pub(crate) struct Credentials {
184    pub(crate) username: String,
185    pub(crate) password: String,
186}
187
188#[derive(Clone, Debug)]
189pub struct TlsOptions {
190    pub accept_invalid_certs: bool,
191}
192
193#[derive(Clone, Debug)]
194pub struct WebSocketOptions {
195    /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB.
196    pub(crate) max_message_size: Option<usize>,
197    /// The maximum size of a single message frame. `None` means no size limit. The limit is for
198    /// frame payload NOT including the frame header. The default value is 16 MiB.
199    pub(crate) max_frame_size: Option<usize>,
200}
201
202impl WebSocketOptions {
203    pub fn builder() -> WebSocketOptionsBuilder {
204        WebSocketOptionsBuilder(Self::default())
205    }
206}
207
208impl Default for WebSocketOptions {
209    fn default() -> Self {
210        Self {
211            max_message_size: Some(64 << 20),
212            max_frame_size: Some(16 << 20),
213        }
214    }
215}
216
217impl From<WebSocketOptions> for tungstenite::protocol::WebSocketConfig {
218    fn from(value: WebSocketOptions) -> Self {
219        (&value).into()
220    }
221}
222
223impl From<&WebSocketOptions> for tungstenite::protocol::WebSocketConfig {
224    fn from(value: &WebSocketOptions) -> Self {
225        let mut config = tungstenite::protocol::WebSocketConfig::default();
226        config.max_message_size = value.max_message_size;
227        config.max_frame_size = value.max_frame_size;
228        config
229    }
230}
231
232pub struct WebSocketOptionsBuilder(WebSocketOptions);
233
234impl WebSocketOptionsBuilder {
235    pub fn build(self) -> WebSocketOptions {
236        self.0
237    }
238
239    pub fn max_message_size(mut self, max_message_size: Option<usize>) -> Self {
240        self.0.max_message_size = max_message_size;
241        self
242    }
243
244    pub fn max_frame_size(mut self, max_frame_size: Option<usize>) -> Self {
245        self.0.max_frame_size = max_frame_size;
246        self
247    }
248}
249
250impl Default for ConnectionOptions {
251    fn default() -> ConnectionOptions {
252        ConnectionOptions {
253            host: String::from("localhost"),
254            port: 8182,
255            pool_size: 10,
256            pool_get_connection_timeout: Some(Duration::from_secs(30)),
257            credentials: None,
258            ssl: false,
259            tls_options: None,
260            serializer: GraphSON::V3,
261            deserializer: GraphSON::V3,
262            websocket_options: None,
263        }
264    }
265}
266
267impl ConnectionOptions {
268    pub fn builder() -> ConnectionOptionsBuilder {
269        ConnectionOptionsBuilder(ConnectionOptions::default())
270    }
271
272    pub fn websocket_url(&self) -> String {
273        let protocol = if self.ssl { "wss" } else { "ws" };
274        format!("{}://{}:{}/gremlin", protocol, self.host, self.port)
275    }
276}
277
278impl Connection {
279    pub fn connect<T>(options: T) -> GremlinResult<Connection>
280    where
281        T: Into<ConnectionOptions>,
282    {
283        Ok(Connection {
284            stream: ConnectionStream::connect(options.into())?,
285            broken: false,
286        })
287    }
288
289    pub fn send(&mut self, payload: Vec<u8>) -> GremlinResult<()> {
290        self.stream.send(payload).map_err(|e| {
291            if let GremlinError::WebSocket(_) = e {
292                self.broken = true;
293            }
294            e
295        })
296    }
297
298    pub fn recv(&mut self) -> GremlinResult<Vec<u8>> {
299        self.stream.recv().map_err(|e| {
300            if let GremlinError::WebSocket(_) = e {
301                self.broken = true
302            }
303            e
304        })
305    }
306
307    pub fn is_broken(&self) -> bool {
308        self.broken
309    }
310}
311
312impl TlsOptions {
313    pub(crate) fn tls_connector(&self) -> native_tls::Result<TlsConnector> {
314        TlsConnector::builder()
315            .danger_accept_invalid_certs(self.accept_invalid_certs)
316            .build()
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn it_should_connect() {
326        Connection::connect(("localhost", 8182)).unwrap();
327    }
328
329    #[test]
330    fn connection_option_build_url() {
331        let options = ConnectionOptions {
332            host: "localhost".into(),
333            port: 8182,
334            ssl: false,
335            ..Default::default()
336        };
337
338        assert_eq!(options.websocket_url(), "ws://localhost:8182/gremlin");
339
340        let options = ConnectionOptions {
341            host: "localhost".into(),
342            port: 8182,
343            ssl: true,
344            ..Default::default()
345        };
346
347        assert_eq!(options.websocket_url(), "wss://localhost:8182/gremlin");
348    }
349}