Skip to main content

geode_client/
grpc.rs

1//! gRPC transport implementation for Geode.
2//!
3//! This module provides gRPC client functionality using the tonic-generated
4//! `GeodeServiceClient` from `crate::proto::geode_service_client`.
5
6use std::collections::HashMap;
7
8use tonic::Request;
9use tonic::transport::{Channel, Endpoint};
10
11use crate::client::{Column, Page};
12use crate::dsn::Dsn;
13use crate::error::{Error, Result};
14use crate::proto;
15use crate::proto::execution_response::Payload;
16use crate::proto::geode_service_client::GeodeServiceClient;
17use crate::types::Value;
18
19/// gRPC client for Geode.
20///
21/// Provides gRPC-based connection to the Geode database server using the
22/// tonic-generated service client.
23pub struct GrpcClient {
24    client: GeodeServiceClient<Channel>,
25    session_id: String,
26}
27
28impl GrpcClient {
29    /// Connect to a Geode server using gRPC.
30    ///
31    /// # Arguments
32    ///
33    /// * `dsn` - Parsed DSN with gRPC transport
34    ///
35    /// # Example
36    ///
37    /// ```no_run
38    /// use geode_client::dsn::Dsn;
39    /// use geode_client::grpc::GrpcClient;
40    ///
41    /// # async fn example() -> geode_client::Result<()> {
42    /// let dsn = Dsn::parse("grpc://localhost:50051")?;
43    /// let client = GrpcClient::connect(&dsn).await?;
44    /// # Ok(())
45    /// # }
46    /// ```
47    pub async fn connect(dsn: &Dsn) -> Result<Self> {
48        let addr = if dsn.tls_enabled() {
49            format!("https://{}", dsn.address())
50        } else {
51            format!("http://{}", dsn.address())
52        };
53
54        let endpoint = Endpoint::from_shared(addr.clone())
55            .map_err(|e| Error::connection(format!("Invalid endpoint: {}", e)))?;
56
57        // Configure TLS if needed
58        let endpoint = if dsn.tls_enabled() && dsn.skip_verify() {
59            // Skip TLS verification (insecure - for development only)
60            endpoint
61                .tls_config(tonic::transport::ClientTlsConfig::new().with_enabled_roots())
62                .map_err(|e| Error::tls(format!("TLS config error: {}", e)))?
63        } else {
64            endpoint
65        };
66
67        let channel = endpoint
68            .connect()
69            .await
70            .map_err(|e| Error::connection(format!("gRPC connection failed to {}: {}", addr, e)))?;
71
72        let grpc_client = GeodeServiceClient::new(channel);
73
74        let mut client = Self {
75            client: grpc_client,
76            session_id: String::new(),
77        };
78
79        // Perform handshake
80        client.handshake(dsn.username(), dsn.password()).await?;
81
82        Ok(client)
83    }
84
85    /// Perform authentication handshake.
86    async fn handshake(&mut self, username: Option<&str>, password: Option<&str>) -> Result<()> {
87        let request = proto::HelloRequest {
88            username: username.unwrap_or("").to_string(),
89            password: password.unwrap_or("").to_string(),
90            tenant_id: None,
91            client_name: "geode-rust".to_string(),
92            client_version: crate::VERSION.to_string(),
93            wanted_conformance: "minimum".to_string(),
94        };
95
96        let response = self
97            .client
98            .handshake(Request::new(request))
99            .await
100            .map_err(|e| Error::connection(format!("Handshake failed: {}", e)))?;
101
102        let resp = response.into_inner();
103        if !resp.success {
104            return Err(Error::auth(resp.error_message));
105        }
106
107        self.session_id = resp.session_id;
108        Ok(())
109    }
110
111    /// Execute a GQL query.
112    pub async fn query(&mut self, gql: &str) -> Result<(Page, Option<String>)> {
113        self.query_with_params(gql, &HashMap::new()).await
114    }
115
116    /// Execute a GQL query with parameters.
117    pub async fn query_with_params(
118        &mut self,
119        gql: &str,
120        params: &HashMap<String, Value>,
121    ) -> Result<(Page, Option<String>)> {
122        let proto_params: Vec<proto::Param> = params
123            .iter()
124            .map(|(k, v)| proto::Param {
125                name: k.clone(),
126                value: Some(v.to_proto_value()),
127            })
128            .collect();
129
130        let request = proto::ExecuteRequest {
131            session_id: self.session_id.clone(),
132            query: gql.to_string(),
133            params: proto_params,
134        };
135
136        let response = self
137            .client
138            .execute(Request::new(request))
139            .await
140            .map_err(|e| Error::query(format!("Query execution failed: {}", e)))?;
141
142        // Process streaming response
143        let mut stream = response.into_inner();
144        let mut columns = Vec::new();
145        let mut rows = Vec::new();
146        let mut final_page = true;
147        let mut ordered = false;
148        let mut order_keys = Vec::new();
149
150        while let Some(exec_resp) = stream
151            .message()
152            .await
153            .map_err(|e| Error::query(format!("Failed to read response: {}", e)))?
154        {
155            if let Some(payload) = exec_resp.payload {
156                match payload {
157                    Payload::Schema(schema) => {
158                        columns = schema
159                            .columns
160                            .into_iter()
161                            .map(|c| Column {
162                                name: c.name,
163                                col_type: c.r#type,
164                            })
165                            .collect();
166                    }
167                    Payload::Page(page) => {
168                        for row in page.rows {
169                            let mut row_map = HashMap::new();
170                            for (i, col) in columns.iter().enumerate() {
171                                let value = if i < row.values.len() {
172                                    Self::convert_proto_value(&row.values[i])
173                                } else {
174                                    Value::null()
175                                };
176                                row_map.insert(col.name.clone(), value);
177                            }
178                            rows.push(row_map);
179                        }
180                        final_page = page.r#final;
181                        ordered = page.ordered;
182                        order_keys = page.order_keys;
183                    }
184                    Payload::Error(err) => {
185                        return Err(Error::Query {
186                            code: err.code,
187                            message: err.message,
188                        });
189                    }
190                    Payload::Metrics(_) | Payload::Heartbeat(_) => {
191                        // Informational payloads, continue
192                    }
193                    Payload::Explain(_) | Payload::Profile(_) => {
194                        // Plan/profile payloads, continue
195                    }
196                }
197            }
198        }
199
200        Ok((
201            Page {
202                columns,
203                rows,
204                ordered,
205                order_keys,
206                final_page,
207            },
208            None,
209        ))
210    }
211
212    /// Convert a proto Value to our Value type.
213    fn convert_proto_value(proto_val: &proto::Value) -> Value {
214        use crate::proto::value::Kind;
215        match &proto_val.kind {
216            Some(Kind::StringVal(s)) => Value::string(s.value.clone()),
217            Some(Kind::IntVal(i)) => Value::int(i.value),
218            Some(Kind::DoubleVal(d)) => {
219                Value::decimal(rust_decimal::Decimal::from_f64_retain(d.value).unwrap_or_default())
220            }
221            Some(Kind::BoolVal(b)) => Value::bool(*b),
222            Some(Kind::NullVal(_)) => Value::null(),
223            Some(Kind::ListVal(list)) => {
224                let values = list.values.iter().map(Self::convert_proto_value).collect();
225                Value::array(values)
226            }
227            Some(Kind::MapVal(map)) => {
228                let mut obj = HashMap::new();
229                for entry in &map.entries {
230                    if let Some(ref val) = entry.value {
231                        obj.insert(entry.key.clone(), Self::convert_proto_value(val));
232                    }
233                }
234                Value::object(obj)
235            }
236            Some(Kind::DecimalVal(d)) => {
237                if let Ok(dec) = d.coeff.parse::<rust_decimal::Decimal>() {
238                    Value::decimal(dec)
239                } else {
240                    Value::string(d.orig_repr.clone())
241                }
242            }
243            Some(Kind::BytesVal(b)) => Value::string(format!("\\x{}", hex::encode(&b.value))),
244            _ => Value::null(),
245        }
246    }
247
248    /// Begin a transaction.
249    pub async fn begin(&mut self) -> Result<()> {
250        let request = proto::BeginRequest {
251            read_only: false,
252            session_id: self.session_id.clone(),
253        };
254
255        self.client
256            .begin(Request::new(request))
257            .await
258            .map_err(|e| Error::connection(format!("Begin transaction failed: {}", e)))?;
259
260        Ok(())
261    }
262
263    /// Commit a transaction.
264    pub async fn commit(&mut self) -> Result<()> {
265        let request = proto::CommitRequest {
266            session_id: self.session_id.clone(),
267        };
268
269        self.client
270            .commit(Request::new(request))
271            .await
272            .map_err(|e| Error::connection(format!("Commit failed: {}", e)))?;
273
274        Ok(())
275    }
276
277    /// Rollback a transaction.
278    pub async fn rollback(&mut self) -> Result<()> {
279        let request = proto::RollbackRequest {
280            session_id: self.session_id.clone(),
281        };
282
283        self.client
284            .rollback(Request::new(request))
285            .await
286            .map_err(|e| Error::connection(format!("Rollback failed: {}", e)))?;
287
288        Ok(())
289    }
290
291    /// Send a ping request.
292    pub async fn ping(&mut self) -> Result<bool> {
293        let response = self
294            .client
295            .ping(Request::new(proto::PingRequest {}))
296            .await
297            .map_err(|e| Error::connection(format!("Ping failed: {}", e)))?;
298
299        Ok(response.into_inner().ok)
300    }
301
302    /// Close the connection.
303    pub fn close(&mut self) -> Result<()> {
304        // gRPC channels are automatically closed when dropped
305        Ok(())
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::proto;
313
314    #[test]
315    fn test_convert_proto_value_string() {
316        let proto_val = proto::Value {
317            kind: Some(proto::value::Kind::StringVal(proto::StringValue {
318                value: "hello".to_string(),
319                kind: 0,
320            })),
321        };
322        let val = GrpcClient::convert_proto_value(&proto_val);
323        assert_eq!(val.as_string().unwrap(), "hello");
324    }
325
326    #[test]
327    fn test_convert_proto_value_int() {
328        let proto_val = proto::Value {
329            kind: Some(proto::value::Kind::IntVal(proto::IntValue {
330                value: 42,
331                kind: 0,
332            })),
333        };
334        let val = GrpcClient::convert_proto_value(&proto_val);
335        assert_eq!(val.as_int().unwrap(), 42);
336    }
337
338    #[test]
339    fn test_convert_proto_value_bool() {
340        let proto_val = proto::Value {
341            kind: Some(proto::value::Kind::BoolVal(true)),
342        };
343        let val = GrpcClient::convert_proto_value(&proto_val);
344        assert!(val.as_bool().unwrap());
345    }
346
347    #[test]
348    fn test_convert_proto_value_null() {
349        let proto_val = proto::Value {
350            kind: Some(proto::value::Kind::NullVal(proto::NullValue {})),
351        };
352        let val = GrpcClient::convert_proto_value(&proto_val);
353        assert!(val.is_null());
354    }
355
356    #[test]
357    fn test_convert_proto_value_none() {
358        let proto_val = proto::Value { kind: None };
359        let val = GrpcClient::convert_proto_value(&proto_val);
360        assert!(val.is_null());
361    }
362
363    #[test]
364    fn test_convert_proto_value_double() {
365        let proto_val = proto::Value {
366            kind: Some(proto::value::Kind::DoubleVal(proto::DoubleValue {
367                value: 3.15,
368                kind: 0,
369            })),
370        };
371        let val = GrpcClient::convert_proto_value(&proto_val);
372        assert!(val.as_decimal().is_ok());
373    }
374
375    #[test]
376    fn test_convert_proto_value_list() {
377        let proto_val = proto::Value {
378            kind: Some(proto::value::Kind::ListVal(proto::ListValue {
379                values: vec![
380                    proto::Value {
381                        kind: Some(proto::value::Kind::IntVal(proto::IntValue {
382                            value: 1,
383                            kind: 0,
384                        })),
385                    },
386                    proto::Value {
387                        kind: Some(proto::value::Kind::IntVal(proto::IntValue {
388                            value: 2,
389                            kind: 0,
390                        })),
391                    },
392                ],
393            })),
394        };
395        let val = GrpcClient::convert_proto_value(&proto_val);
396        let arr = val.as_array().unwrap();
397        assert_eq!(arr.len(), 2);
398        assert_eq!(arr[0].as_int().unwrap(), 1);
399        assert_eq!(arr[1].as_int().unwrap(), 2);
400    }
401
402    #[test]
403    fn test_convert_proto_value_map() {
404        let proto_val = proto::Value {
405            kind: Some(proto::value::Kind::MapVal(proto::MapValue {
406                entries: vec![proto::MapEntry {
407                    key: "name".to_string(),
408                    value: Some(proto::Value {
409                        kind: Some(proto::value::Kind::StringVal(proto::StringValue {
410                            value: "Alice".to_string(),
411                            kind: 0,
412                        })),
413                    }),
414                }],
415            })),
416        };
417        let val = GrpcClient::convert_proto_value(&proto_val);
418        let obj = val.as_object().unwrap();
419        assert_eq!(obj.get("name").unwrap().as_string().unwrap(), "Alice");
420    }
421}