1use quinn::{ClientConfig, Endpoint};
5use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::net::{SocketAddr, ToSocketAddrs};
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, BufReader};
11use tokio::time::{timeout, Duration};
12
13use crate::error::{Error, Result};
14use crate::types::Value;
15
16const GEODE_ALPN: &[u8] = b"geode/1";
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Column {
23 pub name: String,
25 #[serde(rename = "type")]
27 pub col_type: String,
28}
29
30#[derive(Debug, Clone)]
51pub struct Page {
52 pub columns: Vec<Column>,
54 pub rows: Vec<HashMap<String, Value>>,
56 pub ordered: bool,
58 pub order_keys: Vec<String>,
60 pub final_page: bool,
62}
63
64#[derive(Debug, Clone)]
84pub struct Savepoint {
85 pub name: String,
87}
88
89#[derive(Debug, Clone)]
108pub struct PreparedStatement {
109 query: String,
111 param_names: Vec<String>,
113}
114
115impl PreparedStatement {
116 pub fn new(query: impl Into<String>) -> Self {
120 let query = query.into();
121 let param_names = Self::extract_param_names(&query);
122 Self { query, param_names }
123 }
124
125 fn extract_param_names(query: &str) -> Vec<String> {
127 let mut names = Vec::new();
128 let mut chars = query.chars().peekable();
129
130 while let Some(c) = chars.next() {
131 if c == '$' {
132 let mut name = String::new();
133 while let Some(&next) = chars.peek() {
134 if next.is_ascii_alphanumeric() || next == '_' {
135 name.push(chars.next().unwrap());
136 } else {
137 break;
138 }
139 }
140 if !name.is_empty() && !names.contains(&name) {
141 names.push(name);
142 }
143 }
144 }
145
146 names
147 }
148
149 pub fn query(&self) -> &str {
151 &self.query
152 }
153
154 pub fn param_names(&self) -> &[String] {
156 &self.param_names
157 }
158
159 pub async fn execute(
174 &self,
175 conn: &mut Connection,
176 params: &HashMap<String, crate::types::Value>,
177 ) -> crate::error::Result<(Page, Option<String>)> {
178 for name in &self.param_names {
180 if !params.contains_key(name) {
181 return Err(crate::error::Error::validation(format!(
182 "Missing required parameter: {}",
183 name
184 )));
185 }
186 }
187
188 conn.query_with_params(&self.query, params).await
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct PlanOperation {
195 pub op_type: String,
197 pub description: String,
199 pub estimated_rows: Option<u64>,
201 pub children: Vec<PlanOperation>,
203}
204
205#[derive(Debug, Clone)]
210pub struct QueryPlan {
211 pub operations: Vec<PlanOperation>,
213 pub estimated_rows: u64,
215 pub raw: serde_json::Value,
217}
218
219#[derive(Debug, Clone)]
223pub struct QueryProfile {
224 pub plan: QueryPlan,
226 pub actual_rows: u64,
228 pub execution_time_ms: f64,
230 pub raw: serde_json::Value,
232}
233
234#[derive(Clone)]
257pub struct Client {
258 host: String,
259 port: u16,
260 skip_verify: bool,
261 page_size: usize,
262 hello_name: String,
263 hello_ver: String,
264 conformance: String,
265 username: Option<String>,
266 password: Option<String>,
267}
268
269impl Client {
270 pub fn new(host: impl Into<String>, port: u16) -> Self {
287 Self {
288 host: host.into(),
289 port,
290 skip_verify: false,
291 page_size: 1000,
292 hello_name: "geode-rust-quinn".to_string(),
293 hello_ver: "0.1.0".to_string(),
294 conformance: "min".to_string(),
295 username: None,
296 password: None,
297 }
298 }
299
300 pub fn from_dsn(dsn: &str) -> Result<Self> {
330 let dsn = dsn.trim();
331 if dsn.is_empty() {
332 return Err(Error::invalid_dsn("DSN cannot be empty"));
333 }
334
335 if dsn.starts_with("geode://") {
337 return Self::parse_url_dsn(dsn);
338 }
339
340 Self::parse_simple_dsn(dsn)
342 }
343
344 fn parse_url_dsn(dsn: &str) -> Result<Self> {
345 use std::collections::HashMap;
346
347 let url = url::Url::parse(dsn)
349 .map_err(|e| Error::invalid_dsn(format!("Invalid URL format: {}", e)))?;
350
351 let host = url
352 .host_str()
353 .ok_or_else(|| Error::invalid_dsn("Host is required"))?
354 .to_string();
355
356 let port = url.port().unwrap_or(3141);
357
358 let username = if !url.username().is_empty() {
360 Some(
361 urlencoding::decode(url.username())
362 .map_err(|e| Error::invalid_dsn(format!("Invalid username encoding: {}", e)))?
363 .into_owned(),
364 )
365 } else {
366 None
367 };
368
369 let password = url.password().map(|p| {
370 urlencoding::decode(p)
371 .map(|s| s.into_owned())
372 .unwrap_or_else(|_| p.to_string())
373 });
374
375 let params: HashMap<String, String> = url.query_pairs().into_owned().collect();
377
378 let mut client = Self::new(host, port);
379 client.username = username;
380 client.password = password;
381
382 Self::apply_params(&mut client, ¶ms)?;
383
384 Ok(client)
385 }
386
387 fn parse_simple_dsn(dsn: &str) -> Result<Self> {
388 use std::collections::HashMap;
389
390 let (host_port, query_str) = if let Some(idx) = dsn.find('?') {
392 (&dsn[..idx], Some(&dsn[idx + 1..]))
393 } else {
394 (dsn, None)
395 };
396
397 let (host, port) = if let Some(idx) = host_port.rfind(':') {
399 let host = &host_port[..idx];
400 let port_str = &host_port[idx + 1..];
401 let port = port_str
402 .parse::<u16>()
403 .map_err(|_| Error::invalid_dsn(format!("Invalid port: {}", port_str)))?;
404 (host.to_string(), port)
405 } else {
406 (host_port.to_string(), 3141)
407 };
408
409 if host.is_empty() {
410 return Err(Error::invalid_dsn("Host is required"));
411 }
412
413 let mut client = Self::new(host, port);
414
415 if let Some(qs) = query_str {
417 let params: HashMap<String, String> = qs
418 .split('&')
419 .filter_map(|pair| {
420 let mut parts = pair.splitn(2, '=');
421 let key = parts.next()?;
422 let value = parts.next().unwrap_or("");
423 Some((key.to_string(), value.to_string()))
424 })
425 .collect();
426
427 Self::apply_params(&mut client, ¶ms)?;
428 }
429
430 Ok(client)
431 }
432
433 fn apply_params(
434 client: &mut Self,
435 params: &std::collections::HashMap<String, String>,
436 ) -> Result<()> {
437 for (key, value) in params {
438 match key.as_str() {
439 "page_size" => {
440 client.page_size = value
441 .parse()
442 .map_err(|_| Error::invalid_dsn(format!("Invalid page_size: {}", value)))?;
443 }
444 "hello_name" => {
445 client.hello_name = value.clone();
446 }
447 "hello_ver" => {
448 client.hello_ver = value.clone();
449 }
450 "conformance" => {
451 client.conformance = value.clone();
452 }
453 "insecure" => {
454 client.skip_verify = value == "true" || value == "1";
455 }
456 "username" | "user" => {
457 client.username = Some(value.clone());
458 }
459 "password" | "pass" => {
460 client.password = Some(value.clone());
461 }
462 _ => {
463 }
465 }
466 }
467 Ok(())
468 }
469
470 pub fn skip_verify(mut self, skip: bool) -> Self {
482 self.skip_verify = skip;
483 self
484 }
485
486 pub fn page_size(mut self, size: usize) -> Self {
495 self.page_size = size;
496 self
497 }
498
499 pub fn client_name(mut self, name: impl Into<String>) -> Self {
507 self.hello_name = name.into();
508 self
509 }
510
511 pub fn client_version(mut self, version: impl Into<String>) -> Self {
517 self.hello_ver = version.into();
518 self
519 }
520
521 pub fn conformance(mut self, level: impl Into<String>) -> Self {
527 self.conformance = level.into();
528 self
529 }
530
531 pub fn username(mut self, username: impl Into<String>) -> Self {
547 self.username = Some(username.into());
548 self
549 }
550
551 pub fn password(mut self, password: impl Into<String>) -> Self {
557 self.password = Some(password.into());
558 self
559 }
560
561 pub async fn connect(&self) -> Result<Connection> {
592 Connection::new(
593 &self.host,
594 self.port,
595 self.skip_verify,
596 self.page_size,
597 &self.hello_name,
598 &self.hello_ver,
599 &self.conformance,
600 self.username.as_deref(),
601 self.password.as_deref(),
602 )
603 .await
604 }
605}
606
607pub struct Connection {
647 conn: quinn::Connection,
648 page_size: usize,
649 send: quinn::SendStream,
650 reader: BufReader<quinn::RecvStream>,
651 next_request_id: u64,
652}
653
654impl Connection {
655 #[allow(clippy::too_many_arguments)]
656 async fn new(
657 host: &str,
658 port: u16,
659 skip_verify: bool,
660 page_size: usize,
661 hello_name: &str,
662 hello_ver: &str,
663 conformance: &str,
664 username: Option<&str>,
665 password: Option<&str>,
666 ) -> Result<Self> {
667 let mut last_err: Option<Error> = None;
668
669 for attempt in 1..=3 {
670 match Self::connect_once(
671 host,
672 port,
673 skip_verify,
674 page_size,
675 hello_name,
676 hello_ver,
677 conformance,
678 username,
679 password,
680 )
681 .await
682 {
683 Ok(conn) => return Ok(conn),
684 Err(e) => {
685 last_err = Some(e);
686 if attempt < 3 {
687 eprintln!("[QUINN] Connection attempt {} failed, retrying...", attempt);
688 tokio::time::sleep(Duration::from_millis(150)).await;
689 }
690 }
691 }
692 }
693
694 Err(last_err.unwrap_or_else(|| Error::connection("Failed to connect")))
695 }
696
697 #[allow(clippy::too_many_arguments)]
698 async fn connect_once(
699 host: &str,
700 port: u16,
701 skip_verify: bool,
702 page_size: usize,
703 hello_name: &str,
704 hello_ver: &str,
705 conformance: &str,
706 username: Option<&str>,
707 password: Option<&str>,
708 ) -> Result<Self> {
709 eprintln!("[QUINN] Creating connection to {}:{}", host, port);
710
711 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
713
714 let mut client_crypto = if skip_verify {
716 eprintln!("[QUINN] Skipping certificate verification (insecure)");
717 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
718 .dangerous()
719 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
720 .with_no_client_auth()
721 } else {
722 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
723 .with_root_certificates(rustls::RootCertStore::empty())
724 .with_no_client_auth()
725 };
726
727 client_crypto.alpn_protocols = vec![GEODE_ALPN.to_vec()];
729
730 let mut client_config = ClientConfig::new(Arc::new(
731 quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)
732 .map_err(|e| Error::connection(format!("Failed to create QUIC config: {}", e)))?,
733 ));
734
735 let mut transport = quinn::TransportConfig::default();
737 transport.max_idle_timeout(Some(Duration::from_secs(30).try_into().unwrap()));
738 transport.keep_alive_interval(Some(Duration::from_secs(5)));
739 client_config.transport_config(Arc::new(transport));
740
741 let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())
743 .map_err(|e| Error::connection(format!("Failed to create endpoint: {}", e)))?;
744 endpoint.set_default_client_config(client_config);
745
746 let mut resolved_addrs = format!("{}:{}", host, port)
748 .to_socket_addrs()
749 .map_err(|e| {
750 Error::connection(format!(
751 "Failed to resolve address {}:{} - {}",
752 host, port, e
753 ))
754 })?;
755
756 let server_addr: SocketAddr = resolved_addrs
757 .find(|addr| matches!(addr, SocketAddr::V4(_) | SocketAddr::V6(_)))
758 .ok_or_else(|| Error::connection("Invalid address: could not resolve host"))?;
759
760 eprintln!("[QUINN] Connecting to {}", server_addr);
761
762 let server_name = if skip_verify {
765 "localhost" } else {
767 host
768 };
769
770 eprintln!("[QUINN] Using server name: {}", server_name);
771
772 let conn = timeout(
773 Duration::from_secs(10),
774 endpoint
775 .connect(server_addr, server_name)
776 .map_err(|e| Error::connection(format!("Connection failed: {}", e)))?,
777 )
778 .await
779 .map_err(|_| Error::connection("Connection timeout"))?
780 .map_err(|e| Error::connection(format!("Failed to establish connection: {}", e)))?;
781
782 eprintln!("[QUINN] ✓ Connection established!");
783
784 let (mut send, recv) = conn
786 .open_bi()
787 .await
788 .map_err(|e| Error::connection(format!("Failed to open stream: {}", e)))?;
789
790 let mut hello_msg = serde_json::json!({
792 "type": "HELLO",
793 "client_name": hello_name,
794 "client_ver": hello_ver,
795 "wanted_conformance": conformance,
796 });
797
798 if let Some(user) = username {
800 hello_msg["username"] = serde_json::Value::String(user.to_string());
801 }
802 if let Some(pass) = password {
803 hello_msg["password"] = serde_json::Value::String(pass.to_string());
804 }
805
806 let mut hello_line = serde_json::to_string(&hello_msg)
807 .map_err(|e| Error::connection(format!("Failed to serialize HELLO: {}", e)))?;
808 hello_line.push('\n');
809
810 send.write_all(hello_line.as_bytes())
811 .await
812 .map_err(|e| Error::connection(format!("Failed to send HELLO: {}", e)))?;
813
814 let mut reader = BufReader::new(recv);
816 let mut response_line = String::new();
817 let read_len = timeout(Duration::from_secs(5), reader.read_line(&mut response_line))
818 .await
819 .map_err(|_| Error::connection("HELLO response timeout"))?
820 .map_err(|e| Error::connection(format!("Failed to read HELLO response: {}", e)))?;
821
822 if read_len == 0 {
823 return Err(Error::connection("Connection closed before HELLO response"));
824 }
825
826 let hello_response: serde_json::Value = serde_json::from_str(&response_line)
827 .map_err(|e| Error::connection(format!("Invalid HELLO response: {}", e)))?;
828
829 if hello_response.get("type").and_then(|t| t.as_str()) != Some("HELLO") {
830 eprintln!("[QUINN] Unexpected HELLO response: {}", hello_response);
831 }
832
833 eprintln!("[QUINN] ✓ HELLO handshake complete");
834
835 Ok(Self {
836 conn,
837 page_size,
838 send,
839 reader,
840 next_request_id: 1,
841 })
842 }
843
844 async fn send_json_line(&mut self, msg: &serde_json::Value) -> Result<()> {
845 let mut line = serde_json::to_string(msg)
846 .map_err(|e| Error::connection(format!("Failed to serialize message: {}", e)))?;
847 line.push('\n');
848 self.send
849 .write_all(line.as_bytes())
850 .await
851 .map_err(|e| Error::connection(format!("Failed to send message: {}", e)))?;
852 Ok(())
853 }
854
855 async fn read_json_line(&mut self, timeout_secs: u64) -> Result<serde_json::Value> {
856 let mut line = String::new();
857 let n = timeout(
858 Duration::from_secs(timeout_secs),
859 self.reader.read_line(&mut line),
860 )
861 .await
862 .map_err(|_| Error::timeout())?
863 .map_err(|e| Error::connection(format!("Failed to read response: {}", e)))?;
864
865 if n == 0 {
866 return Err(Error::connection("Connection closed"));
867 }
868
869 serde_json::from_str(&line)
870 .map_err(|e| Error::connection(format!("Invalid JSON response: {}", e)))
871 }
872
873 async fn try_read_json_line(&mut self) -> Result<Option<serde_json::Value>> {
876 let mut line = String::new();
877 let read_result = timeout(Duration::from_millis(1), self.reader.read_line(&mut line)).await;
878 match read_result {
879 Ok(res) => {
880 let n =
881 res.map_err(|e| Error::connection(format!("Failed to read response: {}", e)))?;
882 if n == 0 {
883 return Err(Error::connection("Connection closed"));
884 }
885 let value = serde_json::from_str(&line)
886 .map_err(|e| Error::connection(format!("Invalid JSON response: {}", e)))?;
887 Ok(Some(value))
888 }
889 Err(_) => Ok(None),
890 }
891 }
892
893 fn parse_rows(
894 &self,
895 rows_value: &serde_json::Value,
896 columns: &[Column],
897 ) -> Result<Vec<HashMap<String, Value>>> {
898 let rows_array = rows_value
899 .as_array()
900 .ok_or_else(|| Error::query("Response rows is not an array"))?;
901
902 let mut rows = Vec::new();
903 for row_json in rows_array {
904 let row_obj = row_json
905 .as_object()
906 .ok_or_else(|| Error::query("Row is not an object"))?;
907
908 let mut row = HashMap::new();
909 for (col_name, col_value) in row_obj {
910 let col_type = columns
911 .iter()
912 .find(|c| &c.name == col_name)
913 .map(|c| c.col_type.as_str())
914 .unwrap_or("");
915
916 let value = crate::types::decode_value(col_value, col_type)?;
917 row.insert(col_name.clone(), value);
918 }
919 rows.push(row);
920 }
921
922 Ok(rows)
923 }
924
925 async fn send_control(&mut self, msg_type: &str) -> Result<()> {
926 let msg = serde_json::json!({ "type": msg_type });
927 self.send_json_line(&msg).await?;
928
929 let resp = self.read_json_line(5).await?;
931 if let Some(result) = resp.get("result") {
932 if result.get("type").and_then(|t| t.as_str()) == Some("ERROR") {
933 let code = result
934 .get("code")
935 .and_then(|c| c.as_str())
936 .unwrap_or("UNKNOWN");
937 let message = result
938 .get("message")
939 .and_then(|m| m.as_str())
940 .unwrap_or("Command failed");
941 return Err(Error::Query {
942 code: code.to_string(),
943 message: message.to_string(),
944 });
945 }
946 }
947
948 Ok(())
949 }
950
951 pub async fn query(&mut self, gql: &str) -> Result<(Page, Option<String>)> {
981 self.query_with_params(gql, &HashMap::new()).await
982 }
983
984 pub async fn query_with_params(
1024 &mut self,
1025 gql: &str,
1026 params: &HashMap<String, Value>,
1027 ) -> Result<(Page, Option<String>)> {
1028 let params_json: serde_json::Map<String, serde_json::Value> = params
1029 .iter()
1030 .map(|(k, v)| (k.clone(), v.to_json()))
1031 .collect();
1032
1033 let run_msg = if params_json.is_empty() {
1034 serde_json::json!({
1035 "type": "RUN_GQL",
1036 "text": gql,
1037 })
1038 } else {
1039 serde_json::json!({
1040 "type": "RUN_GQL",
1041 "text": gql,
1042 "params": params_json,
1043 })
1044 };
1045
1046 self.send_json_line(&run_msg)
1047 .await
1048 .map_err(|e| Error::query(format!("{}", e)))?;
1049
1050 let schema_frame = self.read_json_line(10).await?;
1051 let result = schema_frame
1052 .get("result")
1053 .cloned()
1054 .unwrap_or_else(|| serde_json::json!({}));
1055
1056 let res_type = result.get("type").and_then(|t| t.as_str()).unwrap_or("");
1057
1058 if res_type == "ERROR" {
1059 let code = result
1060 .get("code")
1061 .and_then(|c| c.as_str())
1062 .unwrap_or("UNKNOWN");
1063 let msg = result
1064 .get("message")
1065 .and_then(|m| m.as_str())
1066 .unwrap_or("Query failed");
1067 return Err(Error::Query {
1068 code: code.to_string(),
1069 message: msg.to_string(),
1070 });
1071 }
1072
1073 if res_type != "SCHEMA" {
1074 return Err(Error::protocol(format!(
1075 "Unexpected first frame: {}",
1076 res_type
1077 )));
1078 }
1079
1080 let columns: Vec<Column> = serde_json::from_value(
1081 result
1082 .get("columns")
1083 .cloned()
1084 .unwrap_or_else(|| serde_json::Value::Array(vec![])),
1085 )
1086 .map_err(|e| Error::protocol(format!("Failed to parse columns: {}", e)))?;
1087
1088 if std::env::var("GEODE_RUST_DEBUG").is_ok() {
1089 eprintln!("[QUINN] Columns: {:?}", columns);
1090 }
1091
1092 if let Some(inline_frame) = self.try_read_json_line().await? {
1094 let inline_result = inline_frame
1095 .get("result")
1096 .cloned()
1097 .unwrap_or_else(|| serde_json::json!({}));
1098 let inline_type = inline_result
1099 .get("type")
1100 .and_then(|t| t.as_str())
1101 .unwrap_or("");
1102
1103 match inline_type {
1104 "BINDINGS" => {
1105 let ordered = inline_result
1106 .get("ordered")
1107 .and_then(|v| v.as_bool())
1108 .unwrap_or(false);
1109 let order_keys: Vec<String> = inline_result
1110 .get("order_keys")
1111 .and_then(|v| v.as_array())
1112 .map(|arr| {
1113 arr.iter()
1114 .filter_map(|v| v.as_str().map(String::from))
1115 .collect()
1116 })
1117 .unwrap_or_default();
1118 let rows = self.parse_rows(
1119 inline_result
1120 .get("rows")
1121 .unwrap_or(&serde_json::Value::Array(vec![])),
1122 &columns,
1123 )?;
1124 let final_page = inline_result
1125 .get("final")
1126 .and_then(|v| v.as_bool())
1127 .unwrap_or(true);
1128 let page = Page {
1129 columns,
1130 rows,
1131 ordered,
1132 order_keys,
1133 final_page,
1134 };
1135 return Ok((page, None));
1136 }
1137 "RESULT" | "STATUS" | "PROFILE" => {
1138 let page = Page {
1139 columns,
1140 rows: Vec::new(),
1141 ordered: false,
1142 order_keys: Vec::new(),
1143 final_page: true,
1144 };
1145 return Ok((page, None));
1146 }
1147 "ERROR" => {
1148 let code = inline_result
1149 .get("code")
1150 .and_then(|c| c.as_str())
1151 .unwrap_or("UNKNOWN");
1152 let msg = inline_result
1153 .get("message")
1154 .and_then(|m| m.as_str())
1155 .unwrap_or("Query failed");
1156 return Err(Error::Query {
1157 code: code.to_string(),
1158 message: msg.to_string(),
1159 });
1160 }
1161 other => {
1162 return Err(Error::protocol(format!(
1163 "Unexpected inline frame: {}",
1164 other
1165 )));
1166 }
1167 }
1168 }
1169
1170 let request_id = self.next_request_id;
1171 self.next_request_id = self.next_request_id.wrapping_add(1).max(1);
1172
1173 let pull_msg = serde_json::json!({
1174 "type": "PULL",
1175 "request_id": request_id,
1176 "page_size": self.page_size,
1177 });
1178 self.send_json_line(&pull_msg)
1179 .await
1180 .map_err(|e| Error::query(format!("{}", e)))?;
1181
1182 let mut all_rows: Vec<HashMap<String, Value>> = Vec::new();
1183 let mut ordered = false;
1184 let mut order_keys: Vec<String> = Vec::new();
1185 let mut final_page;
1186
1187 loop {
1188 let frame = self.read_json_line(30).await?;
1189 let result = frame
1190 .get("result")
1191 .cloned()
1192 .unwrap_or_else(|| serde_json::json!({}));
1193
1194 let r#type = result.get("type").and_then(|t| t.as_str()).unwrap_or("");
1195
1196 match r#type {
1197 "BINDINGS" => {
1198 ordered = result
1199 .get("ordered")
1200 .and_then(|v| v.as_bool())
1201 .unwrap_or(false);
1202 order_keys = result
1203 .get("order_keys")
1204 .and_then(|v| v.as_array())
1205 .map(|arr| {
1206 arr.iter()
1207 .filter_map(|v| v.as_str().map(String::from))
1208 .collect()
1209 })
1210 .unwrap_or_default();
1211
1212 let rows = self.parse_rows(
1213 result
1214 .get("rows")
1215 .unwrap_or(&serde_json::Value::Array(vec![])),
1216 &columns,
1217 )?;
1218 all_rows.extend(rows);
1219
1220 final_page = result
1221 .get("final")
1222 .and_then(|v| v.as_bool())
1223 .unwrap_or(false);
1224 if final_page {
1225 break;
1226 }
1227
1228 self.send_json_line(&pull_msg)
1229 .await
1230 .map_err(|e| Error::query(format!("{}", e)))?;
1231 }
1232 "RESULT" | "STATUS" | "PROFILE" => {
1233 final_page = true;
1234 break;
1235 }
1236 "ERROR" => {
1237 let msg = result
1238 .get("message")
1239 .and_then(|m| m.as_str())
1240 .unwrap_or("Query failed");
1241 let code = result
1242 .get("code")
1243 .and_then(|c| c.as_str())
1244 .unwrap_or("UNKNOWN");
1245 return Err(Error::Query {
1246 code: code.to_string(),
1247 message: msg.to_string(),
1248 });
1249 }
1250 other => {
1251 return Err(Error::protocol(format!("Unexpected frame type: {}", other)));
1252 }
1253 }
1254 }
1255
1256 let page = Page {
1257 columns,
1258 rows: all_rows,
1259 ordered,
1260 order_keys,
1261 final_page,
1262 };
1263
1264 Ok((page, None))
1265 }
1266
1267 pub fn query_sync(
1269 &mut self,
1270 gql: &str,
1271 params: Option<HashMap<String, serde_json::Value>>,
1272 ) -> Result<Page> {
1273 let params_map = params.unwrap_or_default();
1274 let params_typed: HashMap<String, Value> = params_map
1275 .into_iter()
1276 .map(|(k, v)| {
1277 let typed_val = crate::types::Value::from_json(v);
1278 (k, typed_val)
1279 })
1280 .collect();
1281
1282 match tokio::runtime::Handle::try_current() {
1283 Ok(handle) => {
1284 let (page, _cursor) =
1285 handle.block_on(self.query_with_params(gql, ¶ms_typed))?;
1286 Ok(page)
1287 }
1288 Err(_) => {
1289 let rt = tokio::runtime::Runtime::new()
1290 .map_err(|e| Error::query(format!("Failed to create runtime: {}", e)))?;
1291 let (page, _cursor) = rt.block_on(self.query_with_params(gql, ¶ms_typed))?;
1292 Ok(page)
1293 }
1294 }
1295 }
1296
1297 pub async fn begin(&mut self) -> Result<()> {
1322 self.send_control("BEGIN").await
1323 }
1324
1325 pub async fn commit(&mut self) -> Result<()> {
1348 self.send_control("COMMIT").await
1349 }
1350
1351 pub async fn rollback(&mut self) -> Result<()> {
1375 self.send_control("ROLLBACK").await
1376 }
1377
1378 pub fn prepare(&self, query: &str) -> Result<PreparedStatement> {
1412 Ok(PreparedStatement::new(query))
1413 }
1414
1415 pub async fn explain(&mut self, gql: &str) -> Result<QueryPlan> {
1448 let explain_msg = serde_json::json!({
1449 "type": "RUN_GQL",
1450 "text": gql,
1451 "explain": true,
1452 });
1453
1454 self.send_json_line(&explain_msg)
1455 .await
1456 .map_err(|e| Error::query(format!("{}", e)))?;
1457
1458 let response = self.read_json_line(10).await?;
1459 let result = response
1460 .get("result")
1461 .cloned()
1462 .unwrap_or_else(|| serde_json::json!({}));
1463
1464 let res_type = result.get("type").and_then(|t| t.as_str()).unwrap_or("");
1465
1466 if res_type == "ERROR" {
1467 let code = result
1468 .get("code")
1469 .and_then(|c| c.as_str())
1470 .unwrap_or("UNKNOWN");
1471 let msg = result
1472 .get("message")
1473 .and_then(|m| m.as_str())
1474 .unwrap_or("Explain failed");
1475 return Err(Error::Query {
1476 code: code.to_string(),
1477 message: msg.to_string(),
1478 });
1479 }
1480
1481 let operations = Self::parse_plan_operations(&result);
1483 let estimated_rows = result
1484 .get("estimated_rows")
1485 .and_then(|r| r.as_u64())
1486 .unwrap_or(0);
1487
1488 Ok(QueryPlan {
1489 operations,
1490 estimated_rows,
1491 raw: result,
1492 })
1493 }
1494
1495 pub async fn profile(&mut self, gql: &str) -> Result<QueryProfile> {
1526 let profile_msg = serde_json::json!({
1527 "type": "RUN_GQL",
1528 "text": gql,
1529 "profile": true,
1530 });
1531
1532 self.send_json_line(&profile_msg)
1533 .await
1534 .map_err(|e| Error::query(format!("{}", e)))?;
1535
1536 let response = self.read_json_line(30).await?;
1537 let result = response
1538 .get("result")
1539 .cloned()
1540 .unwrap_or_else(|| serde_json::json!({}));
1541
1542 let res_type = result.get("type").and_then(|t| t.as_str()).unwrap_or("");
1543
1544 if res_type == "ERROR" {
1545 let code = result
1546 .get("code")
1547 .and_then(|c| c.as_str())
1548 .unwrap_or("UNKNOWN");
1549 let msg = result
1550 .get("message")
1551 .and_then(|m| m.as_str())
1552 .unwrap_or("Profile failed");
1553 return Err(Error::Query {
1554 code: code.to_string(),
1555 message: msg.to_string(),
1556 });
1557 }
1558
1559 let operations = Self::parse_plan_operations(&result);
1561 let estimated_rows = result
1562 .get("estimated_rows")
1563 .and_then(|r| r.as_u64())
1564 .unwrap_or(0);
1565 let actual_rows = result
1566 .get("actual_rows")
1567 .and_then(|r| r.as_u64())
1568 .unwrap_or(0);
1569 let execution_time_ms = result
1570 .get("execution_time_ms")
1571 .and_then(|t| t.as_f64())
1572 .unwrap_or(0.0);
1573
1574 let plan = QueryPlan {
1575 operations,
1576 estimated_rows,
1577 raw: result.clone(),
1578 };
1579
1580 Ok(QueryProfile {
1581 plan,
1582 actual_rows,
1583 execution_time_ms,
1584 raw: result,
1585 })
1586 }
1587
1588 pub async fn batch(
1626 &mut self,
1627 queries: &[(&str, Option<&HashMap<String, Value>>)],
1628 ) -> Result<Vec<Page>> {
1629 let mut results = Vec::with_capacity(queries.len());
1630
1631 for (query, params) in queries {
1632 let (page, _) = match params {
1633 Some(p) => self.query_with_params(query, p).await?,
1634 None => self.query(query).await?,
1635 };
1636 results.push(page);
1637 }
1638
1639 Ok(results)
1640 }
1641
1642 fn parse_plan_operations(result: &serde_json::Value) -> Vec<PlanOperation> {
1644 let mut operations = Vec::new();
1645
1646 if let Some(ops) = result.get("operations").and_then(|o| o.as_array()) {
1647 for op in ops {
1648 operations.push(Self::parse_single_operation(op));
1649 }
1650 } else if let Some(plan) = result.get("plan") {
1651 operations.push(Self::parse_single_operation(plan));
1653 }
1654
1655 operations
1656 }
1657
1658 fn parse_single_operation(op: &serde_json::Value) -> PlanOperation {
1660 let op_type = op
1661 .get("type")
1662 .or_else(|| op.get("op_type"))
1663 .and_then(|t| t.as_str())
1664 .unwrap_or("Unknown")
1665 .to_string();
1666
1667 let description = op
1668 .get("description")
1669 .or_else(|| op.get("desc"))
1670 .and_then(|d| d.as_str())
1671 .unwrap_or("")
1672 .to_string();
1673
1674 let estimated_rows = op
1675 .get("estimated_rows")
1676 .or_else(|| op.get("rows"))
1677 .and_then(|r| r.as_u64());
1678
1679 let children = op
1680 .get("children")
1681 .and_then(|c| c.as_array())
1682 .map(|arr| arr.iter().map(Self::parse_single_operation).collect())
1683 .unwrap_or_default();
1684
1685 PlanOperation {
1686 op_type,
1687 description,
1688 estimated_rows,
1689 children,
1690 }
1691 }
1692
1693 pub fn close(&mut self) -> Result<()> {
1716 self.conn.close(0u32.into(), b"client closing");
1717 std::thread::sleep(std::time::Duration::from_millis(100));
1719 Ok(())
1720 }
1721}
1722
1723#[derive(Debug)]
1725struct SkipServerVerification;
1726
1727impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
1728 fn verify_server_cert(
1729 &self,
1730 _end_entity: &CertificateDer,
1731 _intermediates: &[CertificateDer],
1732 _server_name: &RustlsServerName,
1733 _ocsp_response: &[u8],
1734 _now: rustls::pki_types::UnixTime,
1735 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
1736 Ok(rustls::client::danger::ServerCertVerified::assertion())
1737 }
1738
1739 fn verify_tls12_signature(
1740 &self,
1741 _message: &[u8],
1742 _cert: &CertificateDer,
1743 _dss: &rustls::DigitallySignedStruct,
1744 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1745 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1746 }
1747
1748 fn verify_tls13_signature(
1749 &self,
1750 _message: &[u8],
1751 _cert: &CertificateDer,
1752 _dss: &rustls::DigitallySignedStruct,
1753 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1754 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1755 }
1756
1757 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
1758 vec![
1759 rustls::SignatureScheme::RSA_PKCS1_SHA256,
1760 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
1761 rustls::SignatureScheme::ED25519,
1762 ]
1763 }
1764}
1765
1766#[cfg(test)]
1767mod tests {
1768 use super::*;
1769
1770 #[test]
1773 fn test_prepared_statement_new() {
1774 let stmt = PreparedStatement::new("MATCH (n:Person {id: $id}) RETURN n");
1775 assert_eq!(stmt.query(), "MATCH (n:Person {id: $id}) RETURN n");
1776 assert_eq!(stmt.param_names(), &["id"]);
1777 }
1778
1779 #[test]
1780 fn test_prepared_statement_multiple_params() {
1781 let stmt = PreparedStatement::new(
1782 "MATCH (p:Person {name: $name}) WHERE p.age > $min_age AND p.city = $city RETURN p",
1783 );
1784 assert!(stmt.query().contains("$name"));
1785 let names = stmt.param_names();
1786 assert_eq!(names.len(), 3);
1787 assert!(names.contains(&"name".to_string()));
1788 assert!(names.contains(&"min_age".to_string()));
1789 assert!(names.contains(&"city".to_string()));
1790 }
1791
1792 #[test]
1793 fn test_prepared_statement_no_params() {
1794 let stmt = PreparedStatement::new("MATCH (n) RETURN n LIMIT 10");
1795 assert!(stmt.param_names().is_empty());
1796 }
1797
1798 #[test]
1799 fn test_prepared_statement_duplicate_params() {
1800 let stmt =
1801 PreparedStatement::new("MATCH (a {id: $id})-[:KNOWS]->(b {id: $id}) RETURN a, b");
1802 assert_eq!(stmt.param_names(), &["id"]);
1804 }
1805
1806 #[test]
1807 fn test_prepared_statement_underscore_params() {
1808 let stmt = PreparedStatement::new("MATCH (n {user_id: $user_id}) RETURN n");
1809 assert_eq!(stmt.param_names(), &["user_id"]);
1810 }
1811
1812 #[test]
1813 fn test_prepared_statement_numeric_params() {
1814 let stmt = PreparedStatement::new("RETURN $param1, $param2, $param123");
1815 let names = stmt.param_names();
1816 assert_eq!(names.len(), 3);
1817 assert!(names.contains(&"param1".to_string()));
1818 assert!(names.contains(&"param2".to_string()));
1819 assert!(names.contains(&"param123".to_string()));
1820 }
1821
1822 #[test]
1825 fn test_plan_operation_struct() {
1826 let op = PlanOperation {
1827 op_type: "NodeScan".to_string(),
1828 description: "Scan Person nodes".to_string(),
1829 estimated_rows: Some(100),
1830 children: vec![],
1831 };
1832 assert_eq!(op.op_type, "NodeScan");
1833 assert_eq!(op.description, "Scan Person nodes");
1834 assert_eq!(op.estimated_rows, Some(100));
1835 assert!(op.children.is_empty());
1836 }
1837
1838 #[test]
1839 fn test_plan_operation_with_children() {
1840 let child = PlanOperation {
1841 op_type: "Filter".to_string(),
1842 description: "Filter by age".to_string(),
1843 estimated_rows: Some(50),
1844 children: vec![],
1845 };
1846 let parent = PlanOperation {
1847 op_type: "Projection".to_string(),
1848 description: "Project name, age".to_string(),
1849 estimated_rows: Some(50),
1850 children: vec![child],
1851 };
1852 assert_eq!(parent.children.len(), 1);
1853 assert_eq!(parent.children[0].op_type, "Filter");
1854 }
1855
1856 #[test]
1859 fn test_query_plan_struct() {
1860 let plan = QueryPlan {
1861 operations: vec![PlanOperation {
1862 op_type: "NodeScan".to_string(),
1863 description: "Full scan".to_string(),
1864 estimated_rows: Some(1000),
1865 children: vec![],
1866 }],
1867 estimated_rows: 1000,
1868 raw: serde_json::json!({"type": "plan"}),
1869 };
1870 assert_eq!(plan.operations.len(), 1);
1871 assert_eq!(plan.estimated_rows, 1000);
1872 }
1873
1874 #[test]
1877 fn test_query_profile_struct() {
1878 let plan = QueryPlan {
1879 operations: vec![],
1880 estimated_rows: 100,
1881 raw: serde_json::json!({}),
1882 };
1883 let profile = QueryProfile {
1884 plan,
1885 actual_rows: 95,
1886 execution_time_ms: 12.5,
1887 raw: serde_json::json!({"type": "profile"}),
1888 };
1889 assert_eq!(profile.actual_rows, 95);
1890 assert!((profile.execution_time_ms - 12.5).abs() < 0.001);
1891 }
1892
1893 #[test]
1896 fn test_page_struct() {
1897 let page = Page {
1898 columns: vec![Column {
1899 name: "x".to_string(),
1900 col_type: "INT".to_string(),
1901 }],
1902 rows: vec![],
1903 ordered: false,
1904 order_keys: vec![],
1905 final_page: true,
1906 };
1907 assert_eq!(page.columns.len(), 1);
1908 assert!(page.rows.is_empty());
1909 assert!(page.final_page);
1910 }
1911
1912 #[test]
1915 fn test_column_struct() {
1916 let col = Column {
1917 name: "age".to_string(),
1918 col_type: "INT".to_string(),
1919 };
1920 assert_eq!(col.name, "age");
1921 assert_eq!(col.col_type, "INT");
1922 }
1923
1924 #[test]
1927 fn test_savepoint_struct() {
1928 let sp = Savepoint {
1929 name: "before_update".to_string(),
1930 };
1931 assert_eq!(sp.name, "before_update");
1932 }
1933
1934 #[test]
1937 fn test_client_builder_defaults() {
1938 let _client = Client::new("localhost", 3141);
1939 }
1941
1942 #[test]
1943 fn test_client_builder_chain() {
1944 let _client = Client::new("example.com", 8443)
1945 .skip_verify(true)
1946 .page_size(500)
1947 .client_name("test-app")
1948 .client_version("2.0.0")
1949 .conformance("full");
1950 }
1952
1953 #[test]
1954 fn test_client_clone() {
1955 let client = Client::new("localhost", 3141).skip_verify(true);
1956 let _cloned = client.clone();
1957 }
1959
1960 #[test]
1963 fn test_parse_plan_operations_empty() {
1964 let result = serde_json::json!({});
1965 let ops = Connection::parse_plan_operations(&result);
1966 assert!(ops.is_empty());
1967 }
1968
1969 #[test]
1970 fn test_parse_plan_operations_array() {
1971 let result = serde_json::json!({
1972 "operations": [
1973 {"type": "NodeScan", "description": "Scan nodes", "estimated_rows": 100},
1974 {"type": "Filter", "description": "Apply filter", "estimated_rows": 50}
1975 ]
1976 });
1977 let ops = Connection::parse_plan_operations(&result);
1978 assert_eq!(ops.len(), 2);
1979 assert_eq!(ops[0].op_type, "NodeScan");
1980 assert_eq!(ops[1].op_type, "Filter");
1981 }
1982
1983 #[test]
1984 fn test_parse_plan_operations_single_plan() {
1985 let result = serde_json::json!({
1986 "plan": {"op_type": "FullScan", "desc": "Full table scan"}
1987 });
1988 let ops = Connection::parse_plan_operations(&result);
1989 assert_eq!(ops.len(), 1);
1990 assert_eq!(ops[0].op_type, "FullScan");
1991 assert_eq!(ops[0].description, "Full table scan");
1992 }
1993
1994 #[test]
1995 fn test_parse_single_operation() {
1996 let op_json = serde_json::json!({
1997 "type": "IndexScan",
1998 "description": "Use index on Person(name)",
1999 "estimated_rows": 25,
2000 "children": [
2001 {"type": "Filter", "description": "Filter results"}
2002 ]
2003 });
2004 let op = Connection::parse_single_operation(&op_json);
2005 assert_eq!(op.op_type, "IndexScan");
2006 assert_eq!(op.description, "Use index on Person(name)");
2007 assert_eq!(op.estimated_rows, Some(25));
2008 assert_eq!(op.children.len(), 1);
2009 assert_eq!(op.children[0].op_type, "Filter");
2010 }
2011
2012 #[test]
2013 fn test_parse_single_operation_minimal() {
2014 let op_json = serde_json::json!({});
2015 let op = Connection::parse_single_operation(&op_json);
2016 assert_eq!(op.op_type, "Unknown");
2017 assert_eq!(op.description, "");
2018 assert_eq!(op.estimated_rows, None);
2019 assert!(op.children.is_empty());
2020 }
2021
2022 #[test]
2023 fn test_parse_single_operation_alt_fields() {
2024 let op_json = serde_json::json!({
2025 "op_type": "Sort",
2026 "desc": "Sort by name ASC",
2027 "rows": 100
2028 });
2029 let op = Connection::parse_single_operation(&op_json);
2030 assert_eq!(op.op_type, "Sort");
2031 assert_eq!(op.description, "Sort by name ASC");
2032 assert_eq!(op.estimated_rows, Some(100));
2033 }
2034}