1use std::collections::HashMap;
7
8use tonic::Request;
9use tonic::transport::{Channel, Endpoint};
10
11use crate::client::{Column, Page};
12use crate::dsn::Dsn;
13use crate::error::{Error, Result};
14use crate::proto;
15use crate::proto::execution_response::Payload;
16use crate::proto::geode_service_client::GeodeServiceClient;
17use crate::types::Value;
18
19pub struct GrpcClient {
24 client: GeodeServiceClient<Channel>,
25 session_id: String,
26}
27
28impl GrpcClient {
29 pub async fn connect(dsn: &Dsn) -> Result<Self> {
48 let addr = if dsn.tls_enabled() {
49 format!("https://{}", dsn.address())
50 } else {
51 format!("http://{}", dsn.address())
52 };
53
54 let endpoint = Endpoint::from_shared(addr.clone())
55 .map_err(|e| Error::connection(format!("Invalid endpoint: {}", e)))?;
56
57 let endpoint = if dsn.tls_enabled() && dsn.skip_verify() {
59 endpoint
61 .tls_config(tonic::transport::ClientTlsConfig::new().with_enabled_roots())
62 .map_err(|e| Error::tls(format!("TLS config error: {}", e)))?
63 } else {
64 endpoint
65 };
66
67 let channel = endpoint
68 .connect()
69 .await
70 .map_err(|e| Error::connection(format!("gRPC connection failed to {}: {}", addr, e)))?;
71
72 let grpc_client = GeodeServiceClient::new(channel);
73
74 let mut client = Self {
75 client: grpc_client,
76 session_id: String::new(),
77 };
78
79 client.handshake(dsn.username(), dsn.password()).await?;
81
82 Ok(client)
83 }
84
85 async fn handshake(&mut self, username: Option<&str>, password: Option<&str>) -> Result<()> {
87 let request = proto::HelloRequest {
88 username: username.unwrap_or("").to_string(),
89 password: password.unwrap_or("").to_string(),
90 tenant_id: None,
91 client_name: "geode-rust".to_string(),
92 client_version: crate::VERSION.to_string(),
93 wanted_conformance: "minimum".to_string(),
94 };
95
96 let response = self
97 .client
98 .handshake(Request::new(request))
99 .await
100 .map_err(|e| Error::connection(format!("Handshake failed: {}", e)))?;
101
102 let resp = response.into_inner();
103 if !resp.success {
104 return Err(Error::auth(resp.error_message));
105 }
106
107 self.session_id = resp.session_id;
108 Ok(())
109 }
110
111 pub async fn query(&mut self, gql: &str) -> Result<(Page, Option<String>)> {
113 self.query_with_params(gql, &HashMap::new()).await
114 }
115
116 pub async fn query_with_params(
118 &mut self,
119 gql: &str,
120 params: &HashMap<String, Value>,
121 ) -> Result<(Page, Option<String>)> {
122 let proto_params: Vec<proto::Param> = params
123 .iter()
124 .map(|(k, v)| proto::Param {
125 name: k.clone(),
126 value: Some(v.to_proto_value()),
127 })
128 .collect();
129
130 let request = proto::ExecuteRequest {
131 session_id: self.session_id.clone(),
132 query: gql.to_string(),
133 params: proto_params,
134 };
135
136 let response = self
137 .client
138 .execute(Request::new(request))
139 .await
140 .map_err(|e| Error::query(format!("Query execution failed: {}", e)))?;
141
142 let mut stream = response.into_inner();
144 let mut columns = Vec::new();
145 let mut rows = Vec::new();
146 let mut final_page = true;
147 let mut ordered = false;
148 let mut order_keys = Vec::new();
149
150 while let Some(exec_resp) = stream
151 .message()
152 .await
153 .map_err(|e| Error::query(format!("Failed to read response: {}", e)))?
154 {
155 if let Some(payload) = exec_resp.payload {
156 match payload {
157 Payload::Schema(schema) => {
158 columns = schema
159 .columns
160 .into_iter()
161 .map(|c| Column {
162 name: c.name,
163 col_type: c.r#type,
164 })
165 .collect();
166 }
167 Payload::Page(page) => {
168 for row in page.rows {
169 let mut row_map = HashMap::new();
170 for (i, col) in columns.iter().enumerate() {
171 let value = if i < row.values.len() {
172 Self::convert_proto_value(&row.values[i])
173 } else {
174 Value::null()
175 };
176 row_map.insert(col.name.clone(), value);
177 }
178 rows.push(row_map);
179 }
180 final_page = page.r#final;
181 ordered = page.ordered;
182 order_keys = page.order_keys;
183 }
184 Payload::Error(err) => {
185 return Err(Error::Query {
186 code: err.code,
187 message: err.message,
188 });
189 }
190 Payload::Metrics(_) | Payload::Heartbeat(_) => {
191 }
193 Payload::Explain(_) | Payload::Profile(_) => {
194 }
196 }
197 }
198 }
199
200 Ok((
201 Page {
202 columns,
203 rows,
204 ordered,
205 order_keys,
206 final_page,
207 },
208 None,
209 ))
210 }
211
212 fn convert_proto_value(proto_val: &proto::Value) -> Value {
214 use crate::proto::value::Kind;
215 match &proto_val.kind {
216 Some(Kind::StringVal(s)) => Value::string(s.value.clone()),
217 Some(Kind::IntVal(i)) => Value::int(i.value),
218 Some(Kind::DoubleVal(d)) => {
219 Value::decimal(rust_decimal::Decimal::from_f64_retain(d.value).unwrap_or_default())
220 }
221 Some(Kind::BoolVal(b)) => Value::bool(*b),
222 Some(Kind::NullVal(_)) => Value::null(),
223 Some(Kind::ListVal(list)) => {
224 let values = list.values.iter().map(Self::convert_proto_value).collect();
225 Value::array(values)
226 }
227 Some(Kind::MapVal(map)) => {
228 let mut obj = HashMap::new();
229 for entry in &map.entries {
230 if let Some(ref val) = entry.value {
231 obj.insert(entry.key.clone(), Self::convert_proto_value(val));
232 }
233 }
234 Value::object(obj)
235 }
236 Some(Kind::DecimalVal(d)) => {
237 if let Ok(dec) = d.coeff.parse::<rust_decimal::Decimal>() {
238 Value::decimal(dec)
239 } else {
240 Value::string(d.orig_repr.clone())
241 }
242 }
243 Some(Kind::BytesVal(b)) => Value::string(format!("\\x{}", hex::encode(&b.value))),
244 _ => Value::null(),
245 }
246 }
247
248 pub async fn begin(&mut self) -> Result<()> {
250 let request = proto::BeginRequest {
251 read_only: false,
252 session_id: self.session_id.clone(),
253 };
254
255 self.client
256 .begin(Request::new(request))
257 .await
258 .map_err(|e| Error::connection(format!("Begin transaction failed: {}", e)))?;
259
260 Ok(())
261 }
262
263 pub async fn commit(&mut self) -> Result<()> {
265 let request = proto::CommitRequest {
266 session_id: self.session_id.clone(),
267 };
268
269 self.client
270 .commit(Request::new(request))
271 .await
272 .map_err(|e| Error::connection(format!("Commit failed: {}", e)))?;
273
274 Ok(())
275 }
276
277 pub async fn rollback(&mut self) -> Result<()> {
279 let request = proto::RollbackRequest {
280 session_id: self.session_id.clone(),
281 };
282
283 self.client
284 .rollback(Request::new(request))
285 .await
286 .map_err(|e| Error::connection(format!("Rollback failed: {}", e)))?;
287
288 Ok(())
289 }
290
291 pub async fn ping(&mut self) -> Result<bool> {
293 let response = self
294 .client
295 .ping(Request::new(proto::PingRequest {}))
296 .await
297 .map_err(|e| Error::connection(format!("Ping failed: {}", e)))?;
298
299 Ok(response.into_inner().ok)
300 }
301
302 pub fn close(&mut self) -> Result<()> {
304 Ok(())
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::proto;
313
314 #[test]
315 fn test_convert_proto_value_string() {
316 let proto_val = proto::Value {
317 kind: Some(proto::value::Kind::StringVal(proto::StringValue {
318 value: "hello".to_string(),
319 kind: 0,
320 })),
321 };
322 let val = GrpcClient::convert_proto_value(&proto_val);
323 assert_eq!(val.as_string().unwrap(), "hello");
324 }
325
326 #[test]
327 fn test_convert_proto_value_int() {
328 let proto_val = proto::Value {
329 kind: Some(proto::value::Kind::IntVal(proto::IntValue {
330 value: 42,
331 kind: 0,
332 })),
333 };
334 let val = GrpcClient::convert_proto_value(&proto_val);
335 assert_eq!(val.as_int().unwrap(), 42);
336 }
337
338 #[test]
339 fn test_convert_proto_value_bool() {
340 let proto_val = proto::Value {
341 kind: Some(proto::value::Kind::BoolVal(true)),
342 };
343 let val = GrpcClient::convert_proto_value(&proto_val);
344 assert!(val.as_bool().unwrap());
345 }
346
347 #[test]
348 fn test_convert_proto_value_null() {
349 let proto_val = proto::Value {
350 kind: Some(proto::value::Kind::NullVal(proto::NullValue {})),
351 };
352 let val = GrpcClient::convert_proto_value(&proto_val);
353 assert!(val.is_null());
354 }
355
356 #[test]
357 fn test_convert_proto_value_none() {
358 let proto_val = proto::Value { kind: None };
359 let val = GrpcClient::convert_proto_value(&proto_val);
360 assert!(val.is_null());
361 }
362
363 #[test]
364 fn test_convert_proto_value_double() {
365 let proto_val = proto::Value {
366 kind: Some(proto::value::Kind::DoubleVal(proto::DoubleValue {
367 value: 3.15,
368 kind: 0,
369 })),
370 };
371 let val = GrpcClient::convert_proto_value(&proto_val);
372 assert!(val.as_decimal().is_ok());
373 }
374
375 #[test]
376 fn test_convert_proto_value_list() {
377 let proto_val = proto::Value {
378 kind: Some(proto::value::Kind::ListVal(proto::ListValue {
379 values: vec![
380 proto::Value {
381 kind: Some(proto::value::Kind::IntVal(proto::IntValue {
382 value: 1,
383 kind: 0,
384 })),
385 },
386 proto::Value {
387 kind: Some(proto::value::Kind::IntVal(proto::IntValue {
388 value: 2,
389 kind: 0,
390 })),
391 },
392 ],
393 })),
394 };
395 let val = GrpcClient::convert_proto_value(&proto_val);
396 let arr = val.as_array().unwrap();
397 assert_eq!(arr.len(), 2);
398 assert_eq!(arr[0].as_int().unwrap(), 1);
399 assert_eq!(arr[1].as_int().unwrap(), 2);
400 }
401
402 #[test]
403 fn test_convert_proto_value_map() {
404 let proto_val = proto::Value {
405 kind: Some(proto::value::Kind::MapVal(proto::MapValue {
406 entries: vec![proto::MapEntry {
407 key: "name".to_string(),
408 value: Some(proto::Value {
409 kind: Some(proto::value::Kind::StringVal(proto::StringValue {
410 value: "Alice".to_string(),
411 kind: 0,
412 })),
413 }),
414 }],
415 })),
416 };
417 let val = GrpcClient::convert_proto_value(&proto_val);
418 let obj = val.as_object().unwrap();
419 assert_eq!(obj.get("name").unwrap().as_string().unwrap(), "Alice");
420 }
421}