gremlin_client/
client.rs

1use crate::io::GraphSON;
2use crate::message::{
3    message_with_args, message_with_args_and_uuid, message_with_args_v2, Message, Response,
4};
5use crate::pool::GremlinConnectionManager;
6use crate::process::traversal::Bytecode;
7use crate::ToGValue;
8use crate::{ConnectionOptions, GremlinError, GremlinResult};
9use crate::{GResultSet, GValue};
10use base64::encode;
11use r2d2::Pool;
12use serde::Serialize;
13use std::collections::{HashMap, VecDeque};
14
15type SessionedClient = GremlinClient;
16
17impl SessionedClient {
18    pub fn close_session(&mut self) -> GremlinResult<GResultSet> {
19        if let Some(session_name) = self.session.take() {
20            let mut args = HashMap::new();
21            args.insert(String::from("session"), GValue::from(session_name.clone()));
22            let args = self.options.serializer.write(&GValue::from(args))?;
23
24            let processor = "session".to_string();
25
26            let message = match self.options.serializer {
27                GraphSON::V2 => message_with_args_v2(String::from("close"), processor, args),
28                GraphSON::V3 => message_with_args(String::from("close"), processor, args),
29            };
30
31            let conn = self.pool.get()?;
32
33            self.send_message(conn, message)
34        } else {
35            Err(GremlinError::Generic("No session to close".to_string()))
36        }
37    }
38}
39
40#[derive(Clone, Debug)]
41pub struct GremlinClient {
42    pool: Pool<GremlinConnectionManager>,
43    session: Option<String>,
44    alias: Option<String>,
45    options: ConnectionOptions,
46}
47
48impl GremlinClient {
49    pub fn connect<T>(options: T) -> GremlinResult<GremlinClient>
50    where
51        T: Into<ConnectionOptions>,
52    {
53        let opts = options.into();
54        let pool_size = opts.pool_size;
55        let manager = GremlinConnectionManager::new(opts.clone());
56
57        let mut pool_builder = Pool::builder().max_size(pool_size);
58
59        if let Some(get_connection_timeout) = opts.pool_get_connection_timeout {
60            pool_builder = pool_builder.connection_timeout(get_connection_timeout);
61        }
62
63        Ok(GremlinClient {
64            pool: pool_builder.build(manager)?,
65            session: None,
66            alias: None,
67            options: opts,
68        })
69    }
70
71    pub fn create_session(&mut self, name: String) -> GremlinResult<SessionedClient> {
72        let manager = GremlinConnectionManager::new(self.options.clone());
73        Ok(SessionedClient {
74            pool: Pool::builder().max_size(1).build(manager)?,
75            session: Some(name),
76            alias: None,
77            options: self.options.clone(),
78        })
79    }
80
81    /// Return a cloned client with the provided alias
82    pub fn alias<T>(&self, alias: T) -> GremlinClient
83    where
84        T: Into<String>,
85    {
86        let mut cloned = self.clone();
87        cloned.alias = Some(alias.into());
88        cloned
89    }
90
91    pub fn execute<T>(
92        &self,
93        script: T,
94        params: &[(&str, &dyn ToGValue)],
95    ) -> GremlinResult<GResultSet>
96    where
97        T: Into<String>,
98    {
99        let mut args = HashMap::new();
100
101        args.insert(String::from("gremlin"), GValue::String(script.into()));
102        args.insert(
103            String::from("language"),
104            GValue::String(String::from("gremlin-groovy")),
105        );
106
107        let aliases = self
108            .alias
109            .clone()
110            .map(|s| {
111                let mut map = HashMap::new();
112                map.insert(String::from("g"), GValue::String(s));
113                map
114            })
115            .unwrap_or_else(HashMap::new);
116
117        args.insert(String::from("aliases"), GValue::from(aliases));
118
119        let bindings: HashMap<String, GValue> = params
120            .iter()
121            .map(|(k, v)| (String::from(*k), v.to_gvalue()))
122            .collect();
123
124        args.insert(String::from("bindings"), GValue::from(bindings));
125
126        if let Some(session_name) = &self.session {
127            args.insert(String::from("session"), GValue::from(session_name.clone()));
128        }
129
130        let args = self.options.serializer.write(&GValue::from(args))?;
131
132        let processor = if self.session.is_some() {
133            "session".to_string()
134        } else {
135            String::default()
136        };
137
138        let message = match self.options.serializer {
139            GraphSON::V2 => message_with_args_v2(String::from("eval"), processor, args),
140            GraphSON::V3 => message_with_args(String::from("eval"), processor, args),
141        };
142
143        let conn = self.pool.get()?;
144
145        self.send_message(conn, message)
146    }
147
148    pub(crate) fn write_message<T: Serialize>(
149        &self,
150        conn: &mut r2d2::PooledConnection<GremlinConnectionManager>,
151        msg: Message<T>,
152    ) -> GremlinResult<()> {
153        let message = self.build_message(msg)?;
154
155        let content_type = self.options.serializer.content_type();
156        let payload = String::from("") + content_type + &message;
157
158        let mut binary = payload.into_bytes();
159        binary.insert(0, content_type.len() as u8);
160
161        conn.send(binary)?;
162
163        Ok(())
164    }
165
166    pub(crate) fn send_message<T: Serialize>(
167        &self,
168        mut conn: r2d2::PooledConnection<GremlinConnectionManager>,
169        msg: Message<T>,
170    ) -> GremlinResult<GResultSet> {
171        self.write_message(&mut conn, msg)?;
172
173        let (response, results) = self.read_response(&mut conn)?;
174
175        Ok(GResultSet::new(self.clone(), results, response, conn))
176    }
177
178    pub fn generate_message(
179        &self,
180        bytecode: &Bytecode,
181    ) -> GremlinResult<Message<serde_json::Value>> {
182        let mut args = HashMap::new();
183
184        args.insert(String::from("gremlin"), GValue::Bytecode(bytecode.clone()));
185
186        let aliases = self
187            .alias
188            .clone()
189            .or_else(|| Some(String::from("g")))
190            .map(|s| {
191                let mut map = HashMap::new();
192                map.insert(String::from("g"), GValue::String(s));
193                map
194            })
195            .unwrap_or_else(HashMap::new);
196
197        args.insert(String::from("aliases"), GValue::from(aliases));
198
199        let args = self.options.serializer.write(&GValue::from(args))?;
200
201        Ok(message_with_args(
202            String::from("bytecode"),
203            String::from("traversal"),
204            args,
205        ))
206    }
207
208    pub(crate) fn submit_traversal(&self, bytecode: &Bytecode) -> GremlinResult<GResultSet> {
209        let message = self.generate_message(bytecode)?;
210
211        let conn = self.pool.get()?;
212
213        self.send_message(conn, message)
214    }
215
216    pub(crate) fn read_response(
217        &self,
218        conn: &mut r2d2::PooledConnection<GremlinConnectionManager>,
219    ) -> GremlinResult<(Response, VecDeque<GValue>)> {
220        let result = conn.recv()?;
221        let response: Response = serde_json::from_slice(&result)?;
222
223        match response.status.code {
224            200 | 206 => {
225                let results: VecDeque<GValue> = self
226                    .options
227                    .deserializer
228                    .read(&response.result.data)?
229                    .map(|v| v.into())
230                    .unwrap_or_else(VecDeque::new);
231
232                Ok((response, results))
233            }
234            204 => Ok((response, VecDeque::new())),
235            407 => match &self.options.credentials {
236                Some(c) => {
237                    let mut args = HashMap::new();
238
239                    args.insert(
240                        String::from("sasl"),
241                        GValue::String(encode(&format!("\0{}\0{}", c.username, c.password))),
242                    );
243
244                    let args = self.options.serializer.write(&GValue::from(args))?;
245                    let message = message_with_args_and_uuid(
246                        String::from("authentication"),
247                        String::from("traversal"),
248                        response.request_id,
249                        args,
250                    );
251
252                    self.write_message(conn, message)?;
253
254                    self.read_response(conn)
255                }
256                None => Err(GremlinError::Request((
257                    response.status.code,
258                    response.status.message,
259                ))),
260            },
261            _ => Err(GremlinError::Request((
262                response.status.code,
263                response.status.message,
264            ))),
265        }
266    }
267    fn build_message<T: Serialize>(&self, msg: Message<T>) -> GremlinResult<String> {
268        serde_json::to_string(&msg).map_err(GremlinError::from)
269    }
270}