1use log::{debug, trace, warn};
5use quinn::{ClientConfig, Endpoint};
6use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName};
7use secrecy::{ExposeSecret, SecretString};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::net::{SocketAddr, ToSocketAddrs};
11use std::sync::Arc;
12use tokio::time::{Duration, timeout};
13
14use crate::dsn::{Dsn, Transport};
15use crate::error::{Error, Result};
16use crate::proto;
17use crate::types::Value;
18use crate::validate;
19
20const GEODE_ALPN: &[u8] = b"geode/1";
21
22#[allow(dead_code)] fn redact_dsn(dsn: &str) -> String {
26 let mut result = dsn.to_string();
27
28 if let Some(scheme_end) = result.find("://") {
31 let after_scheme = scheme_end + 3;
32 if let Some(at_pos) = result[after_scheme..].find('@') {
33 let auth_section = &result[after_scheme..after_scheme + at_pos];
34 if let Some(colon_pos) = auth_section.find(':') {
35 let user = &auth_section[..colon_pos];
37 let rest_start = after_scheme + at_pos;
38 result = format!(
39 "{}{}:{}{}",
40 &result[..after_scheme],
41 user,
42 "[REDACTED]",
43 &result[rest_start..]
44 );
45 }
46 }
47 }
48
49 let patterns = ["password=", "pass="];
52 for pattern in patterns {
53 let lower = result.to_lowercase();
54 if let Some(start) = lower.find(pattern) {
55 let value_start = start + pattern.len();
56 let value_end = result[value_start..]
58 .find('&')
59 .map(|i| value_start + i)
60 .unwrap_or(result.len());
61
62 result = format!(
63 "{}[REDACTED]{}",
64 &result[..value_start],
65 &result[value_end..]
66 );
67 }
68 }
69
70 result
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct Column {
78 pub name: String,
80 #[serde(rename = "type")]
82 pub col_type: String,
83}
84
85#[derive(Debug, Clone)]
106pub struct Page {
107 pub columns: Vec<Column>,
109 pub rows: Vec<HashMap<String, Value>>,
111 pub ordered: bool,
113 pub order_keys: Vec<String>,
115 pub final_page: bool,
117}
118
119#[derive(Debug, Clone)]
139pub struct Savepoint {
140 pub name: String,
142}
143
144#[derive(Debug, Clone)]
163pub struct PreparedStatement {
164 query: String,
166 param_names: Vec<String>,
168}
169
170impl PreparedStatement {
171 pub fn new(query: impl Into<String>) -> Self {
175 let query = query.into();
176 let param_names = Self::extract_param_names(&query);
177 Self { query, param_names }
178 }
179
180 fn extract_param_names(query: &str) -> Vec<String> {
182 let mut names = Vec::new();
183 let mut chars = query.chars().peekable();
184
185 while let Some(c) = chars.next() {
186 if c == '$' {
187 let mut name = String::new();
188 while let Some(&next) = chars.peek() {
189 if next.is_ascii_alphanumeric() || next == '_' {
190 name.push(chars.next().unwrap());
191 } else {
192 break;
193 }
194 }
195 if !name.is_empty() && !names.contains(&name) {
196 names.push(name);
197 }
198 }
199 }
200
201 names
202 }
203
204 pub fn query(&self) -> &str {
206 &self.query
207 }
208
209 pub fn param_names(&self) -> &[String] {
211 &self.param_names
212 }
213
214 pub async fn execute(
229 &self,
230 conn: &mut Connection,
231 params: &HashMap<String, crate::types::Value>,
232 ) -> crate::error::Result<(Page, Option<String>)> {
233 for name in &self.param_names {
235 if !params.contains_key(name) {
236 return Err(crate::error::Error::validation(format!(
237 "Missing required parameter: {}",
238 name
239 )));
240 }
241 }
242
243 conn.query_with_params(&self.query, params).await
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct PlanOperation {
250 pub op_type: String,
252 pub description: String,
254 pub estimated_rows: Option<u64>,
256 pub children: Vec<PlanOperation>,
258}
259
260#[derive(Debug, Clone)]
265pub struct QueryPlan {
266 pub operations: Vec<PlanOperation>,
268 pub estimated_rows: u64,
270 pub raw: serde_json::Value,
272}
273
274#[derive(Debug, Clone)]
278pub struct QueryProfile {
279 pub plan: QueryPlan,
281 pub actual_rows: u64,
283 pub execution_time_ms: f64,
285 pub raw: serde_json::Value,
287}
288
289#[derive(Clone)]
328pub struct Client {
329 transport: Transport,
330 host: String,
331 port: u16,
332 tls_enabled: bool,
333 skip_verify: bool,
334 page_size: usize,
335 hello_name: String,
336 hello_ver: String,
337 conformance: String,
338 username: Option<String>,
339 password: Option<SecretString>,
342 connect_timeout_secs: u64,
344 hello_timeout_secs: u64,
346 idle_timeout_secs: u64,
348}
349
350impl Client {
351 pub fn new(host: impl Into<String>, port: u16) -> Self {
371 Self {
372 transport: Transport::Quic,
373 host: host.into(),
374 port,
375 tls_enabled: true,
376 skip_verify: false,
377 page_size: 1000,
378 hello_name: "geode-rust".to_string(),
379 hello_ver: env!("CARGO_PKG_VERSION").to_string(),
380 conformance: "min".to_string(),
381 username: None,
382 password: None,
383 connect_timeout_secs: 10,
384 hello_timeout_secs: 5,
385 idle_timeout_secs: 30,
386 }
387 }
388
389 pub fn from_dsn(dsn_str: &str) -> Result<Self> {
437 let dsn = Dsn::parse(dsn_str)?;
438
439 Ok(Self {
440 transport: dsn.transport(),
441 host: dsn.host().to_string(),
442 port: dsn.port(),
443 tls_enabled: dsn.tls_enabled(),
444 skip_verify: dsn.skip_verify(),
445 page_size: dsn.page_size(),
446 hello_name: dsn.client_name().to_string(),
447 hello_ver: dsn.client_version().to_string(),
448 conformance: dsn.conformance().to_string(),
449 username: dsn.username().map(String::from),
450 password: dsn.password().map(|p| SecretString::from(p.to_string())),
451 connect_timeout_secs: 10,
452 hello_timeout_secs: 5,
453 idle_timeout_secs: 30,
454 })
455 }
456
457 pub fn transport(&self) -> Transport {
459 self.transport
460 }
461
462 pub fn skip_verify(mut self, skip: bool) -> Self {
474 self.skip_verify = skip;
475 self
476 }
477
478 pub fn page_size(mut self, size: usize) -> Self {
487 self.page_size = size;
488 self
489 }
490
491 pub fn client_name(mut self, name: impl Into<String>) -> Self {
499 self.hello_name = name.into();
500 self
501 }
502
503 pub fn client_version(mut self, version: impl Into<String>) -> Self {
509 self.hello_ver = version.into();
510 self
511 }
512
513 pub fn conformance(mut self, level: impl Into<String>) -> Self {
519 self.conformance = level.into();
520 self
521 }
522
523 pub fn username(mut self, username: impl Into<String>) -> Self {
539 self.username = Some(username.into());
540 self
541 }
542
543 pub fn password(mut self, password: impl Into<String>) -> Self {
553 self.password = Some(SecretString::from(password.into()));
554 self
555 }
556
557 pub fn connect_timeout(mut self, seconds: u64) -> Self {
566 self.connect_timeout_secs = seconds.max(1);
567 self
568 }
569
570 pub fn hello_timeout(mut self, seconds: u64) -> Self {
579 self.hello_timeout_secs = seconds.max(1);
580 self
581 }
582
583 pub fn idle_timeout(mut self, seconds: u64) -> Self {
592 self.idle_timeout_secs = seconds.max(1);
593 self
594 }
595
596 pub fn validate(&self) -> Result<()> {
624 validate::hostname(&self.host)?;
626
627 validate::port(self.port)?;
629
630 validate::page_size(self.page_size)?;
632
633 Ok(())
634 }
635
636 pub async fn connect(&self) -> Result<Connection> {
667 self.validate()?;
669
670 let password_ref = self.password.as_ref().map(|s| s.expose_secret());
672
673 match self.transport {
674 Transport::Quic => {
675 Connection::new_quic(
676 &self.host,
677 self.port,
678 self.skip_verify,
679 self.page_size,
680 &self.hello_name,
681 &self.hello_ver,
682 &self.conformance,
683 self.username.as_deref(),
684 password_ref,
685 self.connect_timeout_secs,
686 self.hello_timeout_secs,
687 self.idle_timeout_secs,
688 )
689 .await
690 }
691 Transport::Grpc => {
692 #[cfg(feature = "grpc")]
693 {
694 Connection::new_grpc(
695 &self.host,
696 self.port,
697 self.tls_enabled,
698 self.skip_verify,
699 self.page_size,
700 self.username.as_deref(),
701 password_ref,
702 )
703 .await
704 }
705 #[cfg(not(feature = "grpc"))]
706 {
707 Err(Error::connection(
708 "gRPC transport requires the 'grpc' feature to be enabled",
709 ))
710 }
711 }
712 }
713 }
714}
715
716#[allow(dead_code)]
718enum ConnectionKind {
719 Quic {
721 conn: quinn::Connection,
722 send: quinn::SendStream,
723 recv: quinn::RecvStream,
724 buffer: Vec<u8>,
726 next_request_id: u64,
728 session_id: String,
730 },
731 #[cfg(feature = "grpc")]
733 Grpc { client: crate::grpc::GrpcClient },
734}
735
736pub struct Connection {
781 kind: ConnectionKind,
782 #[allow(dead_code)]
784 page_size: usize,
785}
786
787impl Connection {
788 #[allow(clippy::too_many_arguments)]
790 async fn new_quic(
791 host: &str,
792 port: u16,
793 skip_verify: bool,
794 page_size: usize,
795 hello_name: &str,
796 hello_ver: &str,
797 conformance: &str,
798 username: Option<&str>,
799 password: Option<&str>,
800 connect_timeout_secs: u64,
801 hello_timeout_secs: u64,
802 idle_timeout_secs: u64,
803 ) -> Result<Self> {
804 let mut last_err: Option<Error> = None;
805
806 for attempt in 1..=3 {
807 match Self::connect_quic_once(
808 host,
809 port,
810 skip_verify,
811 page_size,
812 hello_name,
813 hello_ver,
814 conformance,
815 username,
816 password,
817 connect_timeout_secs,
818 hello_timeout_secs,
819 idle_timeout_secs,
820 )
821 .await
822 {
823 Ok(conn) => return Ok(conn),
824 Err(e) => {
825 last_err = Some(e);
826 if attempt < 3 {
827 debug!("Connection attempt {} failed, retrying...", attempt);
828 tokio::time::sleep(Duration::from_millis(150)).await;
829 }
830 }
831 }
832 }
833
834 Err(last_err.unwrap_or_else(|| Error::connection("Failed to connect")))
835 }
836
837 #[cfg(feature = "grpc")]
839 #[allow(clippy::too_many_arguments)]
840 async fn new_grpc(
841 host: &str,
842 port: u16,
843 tls_enabled: bool,
844 skip_verify: bool,
845 page_size: usize,
846 username: Option<&str>,
847 password: Option<&str>,
848 ) -> Result<Self> {
849 use crate::dsn::Dsn;
850
851 let tls_val = if tls_enabled { "1" } else { "0" };
853 let dsn_str = if let (Some(user), Some(pass)) = (username, password) {
854 format!(
855 "grpc://{}:{}@{}:{}?tls={}&insecure={}",
856 user, pass, host, port, tls_val, skip_verify
857 )
858 } else {
859 format!(
860 "grpc://{}:{}?tls={}&insecure={}",
861 host, port, tls_val, skip_verify
862 )
863 };
864
865 let dsn = Dsn::parse(&dsn_str)?;
866 let client = crate::grpc::GrpcClient::connect(&dsn).await?;
867
868 Ok(Self {
869 kind: ConnectionKind::Grpc { client },
870 page_size,
871 })
872 }
873
874 #[allow(clippy::too_many_arguments)]
875 async fn connect_quic_once(
876 host: &str,
877 port: u16,
878 skip_verify: bool,
879 page_size: usize,
880 _hello_name: &str,
881 _hello_ver: &str,
882 _conformance: &str,
883 username: Option<&str>,
884 password: Option<&str>,
885 connect_timeout_secs: u64,
886 _hello_timeout_secs: u64,
887 idle_timeout_secs: u64,
888 ) -> Result<Self> {
889 debug!("Creating connection to {}:{}", host, port);
890
891 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
893
894 let mut client_crypto = if skip_verify {
896 warn!(
899 "TLS certificate verification DISABLED - connection to {}:{} is vulnerable to MITM attacks. \
900 Do NOT use skip_verify in production!",
901 host, port
902 );
903 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
904 .dangerous()
905 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
906 .with_no_client_auth()
907 } else {
908 let mut root_store = rustls::RootCertStore::empty();
910
911 let cert_result = rustls_native_certs::load_native_certs();
912
913 for err in &cert_result.errors {
915 warn!("Error loading native certificate: {:?}", err);
916 }
917
918 let mut certs_loaded = 0;
919 let mut certs_failed = 0;
920
921 for cert in cert_result.certs {
922 match root_store.add(cert) {
923 Ok(()) => certs_loaded += 1,
924 Err(_) => certs_failed += 1,
925 }
926 }
927
928 if certs_loaded == 0 {
929 return Err(Error::tls(
930 "No system root certificates found. TLS verification cannot proceed. \
931 Either install system CA certificates or use skip_verify(true) for development only.",
932 ));
933 }
934
935 debug!(
936 "Loaded {} system root certificates ({} failed to parse)",
937 certs_loaded, certs_failed
938 );
939
940 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
941 .with_root_certificates(root_store)
942 .with_no_client_auth()
943 };
944
945 client_crypto.alpn_protocols = vec![GEODE_ALPN.to_vec()];
947
948 let mut client_config = ClientConfig::new(Arc::new(
949 quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)
950 .map_err(|e| Error::connection(format!("Failed to create QUIC config: {}", e)))?,
951 ));
952
953 let mut transport = quinn::TransportConfig::default();
955 let idle_timeout = Duration::from_secs(idle_timeout_secs.min(146_000 * 365 * 24 * 3600));
958 transport.max_idle_timeout(Some(idle_timeout.try_into().map_err(|_| {
959 Error::connection("Idle timeout value too large for QUIC protocol")
960 })?));
961 transport.keep_alive_interval(Some(Duration::from_secs(5)));
962 client_config.transport_config(Arc::new(transport));
963
964 let mut endpoint = Endpoint::client(
967 "0.0.0.0:0"
968 .parse()
969 .expect("0.0.0.0:0 is a valid socket address"),
970 )
971 .map_err(|e| Error::connection(format!("Failed to create endpoint: {}", e)))?;
972 endpoint.set_default_client_config(client_config);
973
974 let mut resolved_addrs = format!("{}:{}", host, port)
976 .to_socket_addrs()
977 .map_err(|e| {
978 Error::connection(format!(
979 "Failed to resolve address {}:{} - {}",
980 host, port, e
981 ))
982 })?;
983
984 let server_addr: SocketAddr = resolved_addrs
985 .find(|addr| matches!(addr, SocketAddr::V4(_) | SocketAddr::V6(_)))
986 .ok_or_else(|| Error::connection("Invalid address: could not resolve host"))?;
987
988 debug!("Connecting to {}", server_addr);
989
990 let server_name = if skip_verify {
993 "localhost" } else {
995 host
996 };
997
998 trace!("Using server name for SNI: {}", server_name);
999
1000 let conn = timeout(
1001 Duration::from_secs(connect_timeout_secs),
1002 endpoint
1003 .connect(server_addr, server_name)
1004 .map_err(|e| Error::connection(format!("Connection failed: {}", e)))?,
1005 )
1006 .await
1007 .map_err(|_| Error::connection("Connection timeout"))?
1008 .map_err(|e| Error::connection(format!("Failed to establish connection: {}", e)))?;
1009
1010 debug!("Connection established to {}:{}", host, port);
1011
1012 let (mut send, mut recv) = conn
1014 .open_bi()
1015 .await
1016 .map_err(|e| Error::connection(format!("Failed to open stream: {}", e)))?;
1017
1018 let hello_req = proto::HelloRequest {
1020 username: username.unwrap_or("").to_string(),
1021 password: password.unwrap_or("").to_string(),
1022 tenant_id: None,
1023 client_name: String::new(),
1024 client_version: String::new(),
1025 wanted_conformance: String::new(),
1026 };
1027 let msg = proto::QuicClientMessage {
1028 msg: Some(proto::quic_client_message::Msg::Hello(hello_req)),
1029 };
1030 let data = proto::encode_with_length_prefix(&msg);
1031
1032 send.write_all(&data)
1033 .await
1034 .map_err(|e| Error::connection(format!("Failed to send HELLO: {}", e)))?;
1035
1036 let mut length_buf = [0u8; 4];
1038 timeout(Duration::from_secs(5), recv.read_exact(&mut length_buf))
1039 .await
1040 .map_err(|_| Error::connection("HELLO response timeout"))?
1041 .map_err(|e| {
1042 Error::connection(format!("Failed to read HELLO response length: {}", e))
1043 })?;
1044
1045 let msg_len = u32::from_be_bytes(length_buf) as usize;
1046 let mut msg_buf = vec![0u8; msg_len];
1047 recv.read_exact(&mut msg_buf)
1048 .await
1049 .map_err(|e| Error::connection(format!("Failed to read HELLO response body: {}", e)))?;
1050
1051 let hello_response = proto::decode_quic_server_message(&msg_buf)?;
1052
1053 let session_id = match hello_response.msg {
1054 Some(proto::quic_server_message::Msg::Hello(ref hello_resp)) => {
1055 if !hello_resp.success {
1056 return Err(Error::connection(format!(
1057 "Authentication failed: {}",
1058 hello_resp.error_message
1059 )));
1060 }
1061 hello_resp.session_id.clone()
1062 }
1063 _ => {
1064 return Err(Error::connection("Expected HELLO response"));
1065 }
1066 };
1067
1068 debug!("HELLO handshake complete, session_id={}", session_id);
1069
1070 Ok(Self {
1071 kind: ConnectionKind::Quic {
1072 conn,
1073 send,
1074 recv,
1075 buffer: Vec::new(),
1076 next_request_id: 1,
1077 session_id,
1078 },
1079 page_size,
1080 })
1081 }
1082
1083 async fn send_proto_quic(
1085 send: &mut quinn::SendStream,
1086 msg: &proto::QuicClientMessage,
1087 ) -> Result<()> {
1088 let data = proto::encode_with_length_prefix(msg);
1089 send.write_all(&data)
1090 .await
1091 .map_err(|e| Error::connection(format!("Failed to send message: {}", e)))?;
1092 Ok(())
1093 }
1094
1095 async fn read_proto_quic(
1099 recv: &mut quinn::RecvStream,
1100 timeout_secs: u64,
1101 ) -> Result<proto::QuicServerMessage> {
1102 timeout(Duration::from_secs(timeout_secs), async {
1103 let mut length_buf = [0u8; 4];
1105 recv.read_exact(&mut length_buf)
1106 .await
1107 .map_err(|e| Error::connection(format!("Failed to read response length: {}", e)))?;
1108
1109 let msg_len = u32::from_be_bytes(length_buf) as usize;
1110 let mut msg_buf = vec![0u8; msg_len];
1111 recv.read_exact(&mut msg_buf)
1112 .await
1113 .map_err(|e| Error::connection(format!("Failed to read response body: {}", e)))?;
1114
1115 proto::decode_quic_server_message(&msg_buf)
1116 })
1117 .await
1118 .map_err(|_| Error::timeout())?
1119 }
1120
1121 async fn try_read_proto_quic(
1125 recv: &mut quinn::RecvStream,
1126 ) -> Result<Option<proto::QuicServerMessage>> {
1127 let read_result = timeout(Duration::from_millis(5000), async {
1128 let mut length_buf = [0u8; 4];
1129 recv.read_exact(&mut length_buf)
1130 .await
1131 .map_err(|e| Error::connection(format!("Failed to read response: {}", e)))?;
1132
1133 let msg_len = u32::from_be_bytes(length_buf) as usize;
1134 let mut msg_buf = vec![0u8; msg_len];
1135 recv.read_exact(&mut msg_buf)
1136 .await
1137 .map_err(|e| Error::connection(format!("Failed to read response body: {}", e)))?;
1138
1139 proto::decode_quic_server_message(&msg_buf)
1140 })
1141 .await;
1142
1143 match read_result {
1144 Ok(Ok(msg)) => Ok(Some(msg)),
1145 Ok(Err(e)) => Err(e),
1146 Err(_) => Ok(None), }
1148 }
1149
1150 fn parse_proto_rows_static(
1152 proto_rows: &[proto::Row],
1153 columns: &[Column],
1154 ) -> Result<Vec<HashMap<String, Value>>> {
1155 let mut rows = Vec::new();
1156 for proto_row in proto_rows {
1157 let mut row = HashMap::new();
1158 for (i, col) in columns.iter().enumerate() {
1159 let value = if i < proto_row.values.len() {
1160 Self::convert_proto_value_static(&proto_row.values[i])
1161 } else {
1162 Value::null()
1163 };
1164 row.insert(col.name.clone(), value);
1165 }
1166 rows.push(row);
1167 }
1168 Ok(rows)
1169 }
1170
1171 fn convert_proto_value_static(proto_val: &proto::Value) -> Value {
1173 match &proto_val.kind {
1174 Some(proto::value::Kind::NullVal(_)) => Value::null(),
1175 Some(proto::value::Kind::StringVal(s)) => Value::string(s.value.clone()),
1176 Some(proto::value::Kind::IntVal(i)) => Value::int(i.value),
1177 Some(proto::value::Kind::DoubleVal(d)) => {
1178 Value::decimal(rust_decimal::Decimal::from_f64_retain(d.value).unwrap_or_default())
1179 }
1180 Some(proto::value::Kind::BoolVal(b)) => Value::bool(*b),
1181 Some(proto::value::Kind::ListVal(list)) => {
1182 let values: Vec<Value> = list
1183 .values
1184 .iter()
1185 .map(Self::convert_proto_value_static)
1186 .collect();
1187 Value::array(values)
1188 }
1189 Some(proto::value::Kind::MapVal(map)) => {
1190 let mut obj = std::collections::HashMap::new();
1191 for entry in &map.entries {
1192 let val = entry
1193 .value
1194 .as_ref()
1195 .map(Self::convert_proto_value_static)
1196 .unwrap_or_else(Value::null);
1197 obj.insert(entry.key.clone(), val);
1198 }
1199 Value::object(obj)
1200 }
1201 Some(proto::value::Kind::NodeVal(node)) => {
1202 let mut obj = std::collections::HashMap::new();
1203 obj.insert("id".to_string(), Value::int(node.id as i64));
1204 let labels: Vec<Value> = node
1205 .labels
1206 .iter()
1207 .map(|l| Value::string(l.clone()))
1208 .collect();
1209 obj.insert("labels".to_string(), Value::array(labels));
1210 let mut props = std::collections::HashMap::new();
1211 for entry in &node.properties {
1212 let val = entry
1213 .value
1214 .as_ref()
1215 .map(Self::convert_proto_value_static)
1216 .unwrap_or_else(Value::null);
1217 props.insert(entry.key.clone(), val);
1218 }
1219 obj.insert("properties".to_string(), Value::object(props));
1220 Value::object(obj)
1221 }
1222 Some(proto::value::Kind::EdgeVal(edge)) => {
1223 let mut obj = std::collections::HashMap::new();
1224 obj.insert("id".to_string(), Value::int(edge.id as i64));
1225 obj.insert("start_node".to_string(), Value::int(edge.from_id as i64));
1226 obj.insert("end_node".to_string(), Value::int(edge.to_id as i64));
1227 obj.insert("type".to_string(), Value::string(edge.label.clone()));
1228 let mut props = std::collections::HashMap::new();
1229 for entry in &edge.properties {
1230 let val = entry
1231 .value
1232 .as_ref()
1233 .map(Self::convert_proto_value_static)
1234 .unwrap_or_else(Value::null);
1235 props.insert(entry.key.clone(), val);
1236 }
1237 obj.insert("properties".to_string(), Value::object(props));
1238 Value::object(obj)
1239 }
1240 Some(proto::value::Kind::DecimalVal(d)) => {
1241 if let Ok(dec) = d.coeff.parse::<rust_decimal::Decimal>() {
1243 Value::decimal(dec)
1244 } else {
1245 Value::string(d.orig_repr.clone())
1246 }
1247 }
1248 Some(proto::value::Kind::BytesVal(b)) => {
1249 Value::string(format!("\\x{}", hex::encode(&b.value)))
1250 }
1251 _ => Value::null(),
1252 }
1253 }
1254
1255 async fn send_begin_quic(
1257 send: &mut quinn::SendStream,
1258 recv: &mut quinn::RecvStream,
1259 session_id: &str,
1260 ) -> Result<()> {
1261 let msg = proto::QuicClientMessage {
1262 msg: Some(proto::quic_client_message::Msg::Begin(
1263 proto::BeginRequest {
1264 session_id: session_id.to_string(),
1265 ..Default::default()
1266 },
1267 )),
1268 };
1269 Self::send_proto_quic(send, &msg).await?;
1270
1271 let resp = Self::read_proto_quic(recv, 5).await?;
1272 if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Begin(_))) {
1273 return Err(Error::protocol("Expected BEGIN response"));
1274 }
1275 Ok(())
1276 }
1277
1278 async fn send_commit_quic(
1280 send: &mut quinn::SendStream,
1281 recv: &mut quinn::RecvStream,
1282 session_id: &str,
1283 ) -> Result<()> {
1284 let msg = proto::QuicClientMessage {
1285 msg: Some(proto::quic_client_message::Msg::Commit(
1286 proto::CommitRequest {
1287 session_id: session_id.to_string(),
1288 },
1289 )),
1290 };
1291 Self::send_proto_quic(send, &msg).await?;
1292
1293 let resp = Self::read_proto_quic(recv, 5).await?;
1294 if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Commit(_))) {
1295 return Err(Error::protocol("Expected COMMIT response"));
1296 }
1297 Ok(())
1298 }
1299
1300 async fn send_rollback_quic(
1302 send: &mut quinn::SendStream,
1303 recv: &mut quinn::RecvStream,
1304 session_id: &str,
1305 ) -> Result<()> {
1306 let msg = proto::QuicClientMessage {
1307 msg: Some(proto::quic_client_message::Msg::Rollback(
1308 proto::RollbackRequest {
1309 session_id: session_id.to_string(),
1310 },
1311 )),
1312 };
1313 Self::send_proto_quic(send, &msg).await?;
1314
1315 let resp = Self::read_proto_quic(recv, 5).await?;
1316 if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Rollback(_))) {
1317 return Err(Error::protocol("Expected ROLLBACK response"));
1318 }
1319 Ok(())
1320 }
1321
1322 pub async fn query(&mut self, gql: &str) -> Result<(Page, Option<String>)> {
1352 self.query_with_params(gql, &HashMap::new()).await
1353 }
1354
1355 pub async fn query_with_params(
1395 &mut self,
1396 gql: &str,
1397 params: &HashMap<String, Value>,
1398 ) -> Result<(Page, Option<String>)> {
1399 match &mut self.kind {
1400 ConnectionKind::Quic {
1401 send,
1402 recv,
1403 session_id,
1404 ..
1405 } => Self::query_with_params_quic(send, recv, gql, params, session_id).await,
1406 #[cfg(feature = "grpc")]
1407 ConnectionKind::Grpc { client } => client.query_with_params(gql, params).await,
1408 }
1409 }
1410
1411 async fn query_with_params_quic(
1413 send: &mut quinn::SendStream,
1414 recv: &mut quinn::RecvStream,
1415 gql: &str,
1416 params: &HashMap<String, Value>,
1417 session_id: &str,
1418 ) -> Result<(Page, Option<String>)> {
1419 let (page, cursor) =
1420 Self::query_with_params_quic_inner(send, recv, gql, params, session_id).await?;
1421
1422 if !page.final_page {
1424 let mut all_rows = page.rows;
1425 let columns = page.columns;
1426 let mut ordered = page.ordered;
1427 let mut order_keys = page.order_keys;
1428 let mut request_id: u64 = 0;
1429
1430 loop {
1431 request_id += 1;
1432 let pull_req = proto::QuicClientMessage {
1433 msg: Some(proto::quic_client_message::Msg::Pull(proto::PullRequest {
1434 request_id,
1435 page_size: 1000,
1436 session_id: String::new(),
1437 })),
1438 };
1439 Self::send_proto_quic(send, &pull_req).await?;
1440
1441 let resp = Self::read_proto_quic(recv, 30).await?;
1442
1443 let exec_resp = match &resp.msg {
1445 Some(proto::quic_server_message::Msg::Pull(pull)) => pull.response.as_ref(),
1446 Some(proto::quic_server_message::Msg::Execute(e)) => Some(e),
1447 _ => None,
1448 };
1449
1450 let exec_resp = match exec_resp {
1451 Some(e) => e,
1452 None => break,
1453 };
1454
1455 if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload
1456 {
1457 return Err(Error::Query {
1458 code: err.code.clone(),
1459 message: err.message.clone(),
1460 });
1461 }
1462
1463 if let Some(proto::execution_response::Payload::Page(ref page_data)) =
1464 exec_resp.payload
1465 {
1466 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1467 all_rows.extend(rows);
1468 ordered = page_data.ordered;
1469 order_keys = page_data.order_keys.clone();
1470 if page_data.r#final {
1471 break;
1472 }
1473 } else {
1474 break;
1475 }
1476 }
1477
1478 let final_page = Page {
1479 columns,
1480 rows: all_rows,
1481 ordered,
1482 order_keys,
1483 final_page: true,
1484 };
1485 return Ok((final_page, cursor));
1486 }
1487
1488 Ok((page, cursor))
1489 }
1490
1491 async fn query_with_params_quic_inner(
1493 send: &mut quinn::SendStream,
1494 recv: &mut quinn::RecvStream,
1495 gql: &str,
1496 params: &HashMap<String, Value>,
1497 session_id: &str,
1498 ) -> Result<(Page, Option<String>)> {
1499 let params_proto: Vec<proto::Param> = params
1501 .iter()
1502 .map(|(k, v)| proto::Param {
1503 name: k.clone(),
1504 value: Some(v.to_proto_value()),
1505 })
1506 .collect();
1507
1508 let exec_req = proto::ExecuteRequest {
1510 session_id: session_id.to_string(),
1511 query: gql.to_string(),
1512 params: params_proto,
1513 };
1514 let msg = proto::QuicClientMessage {
1515 msg: Some(proto::quic_client_message::Msg::Execute(exec_req)),
1516 };
1517 Self::send_proto_quic(send, &msg)
1518 .await
1519 .map_err(|e| Error::query(format!("{}", e)))?;
1520
1521 let resp = Self::read_proto_quic(recv, 10).await?;
1523
1524 let exec_resp = match resp.msg {
1525 Some(proto::quic_server_message::Msg::Execute(e)) => e,
1526 _ => return Err(Error::protocol("Expected Execute response")),
1527 };
1528
1529 if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload {
1531 let _ = Self::try_read_proto_quic(recv).await;
1533 return Err(Error::Query {
1534 code: err.code.clone(),
1535 message: err.message.clone(),
1536 });
1537 }
1538
1539 let columns: Vec<Column> = match exec_resp.payload {
1541 Some(proto::execution_response::Payload::Schema(ref s)) => s
1542 .columns
1543 .iter()
1544 .map(|c| Column {
1545 name: c.name.clone(),
1546 col_type: c.r#type.clone(),
1547 })
1548 .collect(),
1549 _ => Vec::new(),
1550 };
1551
1552 trace!("Schema columns: {:?}", columns);
1553
1554 if let Some(inline_resp) = Self::try_read_proto_quic(recv).await? {
1556 if let Some(proto::quic_server_message::Msg::Execute(inline_exec)) = inline_resp.msg {
1557 if let Some(proto::execution_response::Payload::Error(ref err)) =
1558 inline_exec.payload
1559 {
1560 return Err(Error::Query {
1561 code: err.code.clone(),
1562 message: err.message.clone(),
1563 });
1564 }
1565
1566 if let Some(proto::execution_response::Payload::Page(ref page_data)) =
1567 inline_exec.payload
1568 {
1569 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1570 let page = Page {
1571 columns,
1572 rows,
1573 ordered: page_data.ordered,
1574 order_keys: page_data.order_keys.clone(),
1575 final_page: page_data.r#final,
1576 };
1577 return Ok((page, None));
1578 }
1579
1580 let page = Page {
1582 columns,
1583 rows: Vec::new(),
1584 ordered: false,
1585 order_keys: Vec::new(),
1586 final_page: true,
1587 };
1588 return Ok((page, None));
1589 }
1590 }
1591
1592 if let Some(proto::execution_response::Payload::Page(ref page_data)) = exec_resp.payload {
1594 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1595 let page = Page {
1596 columns,
1597 rows,
1598 ordered: page_data.ordered,
1599 order_keys: page_data.order_keys.clone(),
1600 final_page: page_data.r#final,
1601 };
1602 return Ok((page, None));
1603 }
1604
1605 let resp = Self::read_proto_quic(recv, 30).await?;
1607 if let Some(proto::quic_server_message::Msg::Execute(exec_resp)) = resp.msg {
1608 if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload {
1609 return Err(Error::Query {
1610 code: err.code.clone(),
1611 message: err.message.clone(),
1612 });
1613 }
1614
1615 if let Some(proto::execution_response::Payload::Page(ref page_data)) = exec_resp.payload
1616 {
1617 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1618 let page = Page {
1619 columns,
1620 rows,
1621 ordered: page_data.ordered,
1622 order_keys: page_data.order_keys.clone(),
1623 final_page: page_data.r#final,
1624 };
1625 return Ok((page, None));
1626 }
1627 }
1628
1629 let page = Page {
1631 columns,
1632 rows: Vec::new(),
1633 ordered: false,
1634 order_keys: Vec::new(),
1635 final_page: true,
1636 };
1637
1638 Ok((page, None))
1639 }
1640
1641 pub fn query_sync(
1643 &mut self,
1644 gql: &str,
1645 params: Option<HashMap<String, serde_json::Value>>,
1646 ) -> Result<Page> {
1647 let params_map = params.unwrap_or_default();
1648 let params_typed: HashMap<String, Value> = params_map
1649 .into_iter()
1650 .map(|(k, v)| {
1651 let typed_val = crate::types::Value::from_json(v);
1652 (k, typed_val)
1653 })
1654 .collect();
1655
1656 match tokio::runtime::Handle::try_current() {
1657 Ok(handle) => {
1658 let (page, _cursor) =
1659 handle.block_on(self.query_with_params(gql, ¶ms_typed))?;
1660 Ok(page)
1661 }
1662 Err(_) => {
1663 let rt = tokio::runtime::Runtime::new()
1664 .map_err(|e| Error::query(format!("Failed to create runtime: {}", e)))?;
1665 let (page, _cursor) = rt.block_on(self.query_with_params(gql, ¶ms_typed))?;
1666 Ok(page)
1667 }
1668 }
1669 }
1670
1671 pub async fn begin(&mut self) -> Result<()> {
1696 match &mut self.kind {
1697 ConnectionKind::Quic {
1698 send,
1699 recv,
1700 session_id,
1701 ..
1702 } => Self::send_begin_quic(send, recv, session_id).await,
1703 #[cfg(feature = "grpc")]
1704 ConnectionKind::Grpc { client } => client.begin().await,
1705 }
1706 }
1707
1708 pub async fn commit(&mut self) -> Result<()> {
1731 match &mut self.kind {
1732 ConnectionKind::Quic {
1733 send,
1734 recv,
1735 session_id,
1736 ..
1737 } => Self::send_commit_quic(send, recv, session_id).await,
1738 #[cfg(feature = "grpc")]
1739 ConnectionKind::Grpc { client } => client.commit().await,
1740 }
1741 }
1742
1743 pub async fn rollback(&mut self) -> Result<()> {
1767 match &mut self.kind {
1768 ConnectionKind::Quic {
1769 send,
1770 recv,
1771 session_id,
1772 ..
1773 } => Self::send_rollback_quic(send, recv, session_id).await,
1774 #[cfg(feature = "grpc")]
1775 ConnectionKind::Grpc { client } => client.rollback().await,
1776 }
1777 }
1778
1779 pub fn prepare(&self, query: &str) -> Result<PreparedStatement> {
1813 Ok(PreparedStatement::new(query))
1814 }
1815
1816 pub async fn explain(&mut self, gql: &str) -> Result<QueryPlan> {
1849 let explain_query = format!("EXPLAIN {}", gql);
1851 let (_page, _) = self.query(&explain_query).await?;
1852
1853 Ok(QueryPlan {
1856 operations: Vec::new(),
1857 estimated_rows: 0,
1858 raw: serde_json::json!({}),
1859 })
1860 }
1861
1862 pub async fn profile(&mut self, gql: &str) -> Result<QueryProfile> {
1893 let profile_query = format!("PROFILE {}", gql);
1895 let (page, _) = self.query(&profile_query).await?;
1896
1897 let plan = QueryPlan {
1899 operations: Vec::new(),
1900 estimated_rows: 0,
1901 raw: serde_json::json!({}),
1902 };
1903
1904 Ok(QueryProfile {
1905 plan,
1906 actual_rows: page.rows.len() as u64,
1907 execution_time_ms: 0.0,
1908 raw: serde_json::json!({}),
1909 })
1910 }
1911
1912 pub async fn batch(
1950 &mut self,
1951 queries: &[(&str, Option<&HashMap<String, Value>>)],
1952 ) -> Result<Vec<Page>> {
1953 let mut results = Vec::with_capacity(queries.len());
1954
1955 for (query, params) in queries {
1956 let (page, _) = match params {
1957 Some(p) => self.query_with_params(query, p).await?,
1958 None => self.query(query).await?,
1959 };
1960 results.push(page);
1961 }
1962
1963 Ok(results)
1964 }
1965
1966 #[allow(dead_code)]
1969 fn parse_plan_operations(result: &serde_json::Value) -> Vec<PlanOperation> {
1970 let mut operations = Vec::new();
1971
1972 if let Some(ops) = result.get("operations").and_then(|o| o.as_array()) {
1973 for op in ops {
1974 operations.push(Self::parse_single_operation(op));
1975 }
1976 } else if let Some(plan) = result.get("plan") {
1977 operations.push(Self::parse_single_operation(plan));
1979 }
1980
1981 operations
1982 }
1983
1984 #[allow(dead_code)]
1986 fn parse_single_operation(op: &serde_json::Value) -> PlanOperation {
1987 let op_type = op
1988 .get("type")
1989 .or_else(|| op.get("op_type"))
1990 .and_then(|t| t.as_str())
1991 .unwrap_or("Unknown")
1992 .to_string();
1993
1994 let description = op
1995 .get("description")
1996 .or_else(|| op.get("desc"))
1997 .and_then(|d| d.as_str())
1998 .unwrap_or("")
1999 .to_string();
2000
2001 let estimated_rows = op
2002 .get("estimated_rows")
2003 .or_else(|| op.get("rows"))
2004 .and_then(|r| r.as_u64());
2005
2006 let children = op
2007 .get("children")
2008 .and_then(|c| c.as_array())
2009 .map(|arr| arr.iter().map(Self::parse_single_operation).collect())
2010 .unwrap_or_default();
2011
2012 PlanOperation {
2013 op_type,
2014 description,
2015 estimated_rows,
2016 children,
2017 }
2018 }
2019
2020 pub fn close(&mut self) -> Result<()> {
2043 match &mut self.kind {
2044 ConnectionKind::Quic { conn, .. } => {
2045 conn.close(0u32.into(), b"client closing");
2049 Ok(())
2050 }
2051 #[cfg(feature = "grpc")]
2052 ConnectionKind::Grpc { client } => client.close(),
2053 }
2054 }
2055
2056 pub fn is_healthy(&self) -> bool {
2071 match &self.kind {
2072 ConnectionKind::Quic { conn, .. } => {
2073 conn.close_reason().is_none()
2075 }
2076 #[cfg(feature = "grpc")]
2077 ConnectionKind::Grpc { .. } => {
2078 true
2080 }
2081 }
2082 }
2083}
2084
2085#[derive(Debug)]
2087struct SkipServerVerification;
2088
2089impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
2090 fn verify_server_cert(
2091 &self,
2092 _end_entity: &CertificateDer,
2093 _intermediates: &[CertificateDer],
2094 _server_name: &RustlsServerName,
2095 _ocsp_response: &[u8],
2096 _now: rustls::pki_types::UnixTime,
2097 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
2098 Ok(rustls::client::danger::ServerCertVerified::assertion())
2099 }
2100
2101 fn verify_tls12_signature(
2102 &self,
2103 _message: &[u8],
2104 _cert: &CertificateDer,
2105 _dss: &rustls::DigitallySignedStruct,
2106 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
2107 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
2108 }
2109
2110 fn verify_tls13_signature(
2111 &self,
2112 _message: &[u8],
2113 _cert: &CertificateDer,
2114 _dss: &rustls::DigitallySignedStruct,
2115 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
2116 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
2117 }
2118
2119 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
2120 vec![
2121 rustls::SignatureScheme::RSA_PKCS1_SHA256,
2122 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
2123 rustls::SignatureScheme::ED25519,
2124 ]
2125 }
2126}
2127
2128#[cfg(test)]
2129mod tests {
2130 use super::*;
2131
2132 #[test]
2135 fn test_prepared_statement_new() {
2136 let stmt = PreparedStatement::new("MATCH (n:Person {id: $id}) RETURN n");
2137 assert_eq!(stmt.query(), "MATCH (n:Person {id: $id}) RETURN n");
2138 assert_eq!(stmt.param_names(), &["id"]);
2139 }
2140
2141 #[test]
2142 fn test_prepared_statement_multiple_params() {
2143 let stmt = PreparedStatement::new(
2144 "MATCH (p:Person {name: $name}) WHERE p.age > $min_age AND p.city = $city RETURN p",
2145 );
2146 assert!(stmt.query().contains("$name"));
2147 let names = stmt.param_names();
2148 assert_eq!(names.len(), 3);
2149 assert!(names.contains(&"name".to_string()));
2150 assert!(names.contains(&"min_age".to_string()));
2151 assert!(names.contains(&"city".to_string()));
2152 }
2153
2154 #[test]
2155 fn test_prepared_statement_no_params() {
2156 let stmt = PreparedStatement::new("MATCH (n) RETURN n LIMIT 10");
2157 assert!(stmt.param_names().is_empty());
2158 }
2159
2160 #[test]
2161 fn test_prepared_statement_duplicate_params() {
2162 let stmt =
2163 PreparedStatement::new("MATCH (a {id: $id})-[:KNOWS]->(b {id: $id}) RETURN a, b");
2164 assert_eq!(stmt.param_names(), &["id"]);
2166 }
2167
2168 #[test]
2169 fn test_prepared_statement_underscore_params() {
2170 let stmt = PreparedStatement::new("MATCH (n {user_id: $user_id}) RETURN n");
2171 assert_eq!(stmt.param_names(), &["user_id"]);
2172 }
2173
2174 #[test]
2175 fn test_prepared_statement_numeric_params() {
2176 let stmt = PreparedStatement::new("RETURN $param1, $param2, $param123");
2177 let names = stmt.param_names();
2178 assert_eq!(names.len(), 3);
2179 assert!(names.contains(&"param1".to_string()));
2180 assert!(names.contains(&"param2".to_string()));
2181 assert!(names.contains(&"param123".to_string()));
2182 }
2183
2184 #[test]
2187 fn test_plan_operation_struct() {
2188 let op = PlanOperation {
2189 op_type: "NodeScan".to_string(),
2190 description: "Scan Person nodes".to_string(),
2191 estimated_rows: Some(100),
2192 children: vec![],
2193 };
2194 assert_eq!(op.op_type, "NodeScan");
2195 assert_eq!(op.description, "Scan Person nodes");
2196 assert_eq!(op.estimated_rows, Some(100));
2197 assert!(op.children.is_empty());
2198 }
2199
2200 #[test]
2201 fn test_plan_operation_with_children() {
2202 let child = PlanOperation {
2203 op_type: "Filter".to_string(),
2204 description: "Filter by age".to_string(),
2205 estimated_rows: Some(50),
2206 children: vec![],
2207 };
2208 let parent = PlanOperation {
2209 op_type: "Projection".to_string(),
2210 description: "Project name, age".to_string(),
2211 estimated_rows: Some(50),
2212 children: vec![child],
2213 };
2214 assert_eq!(parent.children.len(), 1);
2215 assert_eq!(parent.children[0].op_type, "Filter");
2216 }
2217
2218 #[test]
2221 fn test_query_plan_struct() {
2222 let plan = QueryPlan {
2223 operations: vec![PlanOperation {
2224 op_type: "NodeScan".to_string(),
2225 description: "Full scan".to_string(),
2226 estimated_rows: Some(1000),
2227 children: vec![],
2228 }],
2229 estimated_rows: 1000,
2230 raw: serde_json::json!({"type": "plan"}),
2231 };
2232 assert_eq!(plan.operations.len(), 1);
2233 assert_eq!(plan.estimated_rows, 1000);
2234 }
2235
2236 #[test]
2239 fn test_query_profile_struct() {
2240 let plan = QueryPlan {
2241 operations: vec![],
2242 estimated_rows: 100,
2243 raw: serde_json::json!({}),
2244 };
2245 let profile = QueryProfile {
2246 plan,
2247 actual_rows: 95,
2248 execution_time_ms: 12.5,
2249 raw: serde_json::json!({"type": "profile"}),
2250 };
2251 assert_eq!(profile.actual_rows, 95);
2252 assert!((profile.execution_time_ms - 12.5).abs() < 0.001);
2253 }
2254
2255 #[test]
2258 fn test_page_struct() {
2259 let page = Page {
2260 columns: vec![Column {
2261 name: "x".to_string(),
2262 col_type: "INT".to_string(),
2263 }],
2264 rows: vec![],
2265 ordered: false,
2266 order_keys: vec![],
2267 final_page: true,
2268 };
2269 assert_eq!(page.columns.len(), 1);
2270 assert!(page.rows.is_empty());
2271 assert!(page.final_page);
2272 }
2273
2274 #[test]
2277 fn test_column_struct() {
2278 let col = Column {
2279 name: "age".to_string(),
2280 col_type: "INT".to_string(),
2281 };
2282 assert_eq!(col.name, "age");
2283 assert_eq!(col.col_type, "INT");
2284 }
2285
2286 #[test]
2289 fn test_savepoint_struct() {
2290 let sp = Savepoint {
2291 name: "before_update".to_string(),
2292 };
2293 assert_eq!(sp.name, "before_update");
2294 }
2295
2296 #[test]
2299 fn test_client_builder_defaults() {
2300 let _client = Client::new("localhost", 3141);
2301 }
2303
2304 #[test]
2305 fn test_client_builder_chain() {
2306 let _client = Client::new("example.com", 8443)
2307 .skip_verify(true)
2308 .page_size(500)
2309 .client_name("test-app")
2310 .client_version("2.0.0")
2311 .conformance("full");
2312 }
2314
2315 #[test]
2316 fn test_client_clone() {
2317 let client = Client::new("localhost", 3141).skip_verify(true);
2318 let _cloned = client.clone();
2319 }
2321
2322 #[test]
2325 fn test_parse_plan_operations_empty() {
2326 let result = serde_json::json!({});
2327 let ops = Connection::parse_plan_operations(&result);
2328 assert!(ops.is_empty());
2329 }
2330
2331 #[test]
2332 fn test_parse_plan_operations_array() {
2333 let result = serde_json::json!({
2334 "operations": [
2335 {"type": "NodeScan", "description": "Scan nodes", "estimated_rows": 100},
2336 {"type": "Filter", "description": "Apply filter", "estimated_rows": 50}
2337 ]
2338 });
2339 let ops = Connection::parse_plan_operations(&result);
2340 assert_eq!(ops.len(), 2);
2341 assert_eq!(ops[0].op_type, "NodeScan");
2342 assert_eq!(ops[1].op_type, "Filter");
2343 }
2344
2345 #[test]
2346 fn test_parse_plan_operations_single_plan() {
2347 let result = serde_json::json!({
2348 "plan": {"op_type": "FullScan", "desc": "Full table scan"}
2349 });
2350 let ops = Connection::parse_plan_operations(&result);
2351 assert_eq!(ops.len(), 1);
2352 assert_eq!(ops[0].op_type, "FullScan");
2353 assert_eq!(ops[0].description, "Full table scan");
2354 }
2355
2356 #[test]
2357 fn test_parse_single_operation() {
2358 let op_json = serde_json::json!({
2359 "type": "IndexScan",
2360 "description": "Use index on Person(name)",
2361 "estimated_rows": 25,
2362 "children": [
2363 {"type": "Filter", "description": "Filter results"}
2364 ]
2365 });
2366 let op = Connection::parse_single_operation(&op_json);
2367 assert_eq!(op.op_type, "IndexScan");
2368 assert_eq!(op.description, "Use index on Person(name)");
2369 assert_eq!(op.estimated_rows, Some(25));
2370 assert_eq!(op.children.len(), 1);
2371 assert_eq!(op.children[0].op_type, "Filter");
2372 }
2373
2374 #[test]
2375 fn test_parse_single_operation_minimal() {
2376 let op_json = serde_json::json!({});
2377 let op = Connection::parse_single_operation(&op_json);
2378 assert_eq!(op.op_type, "Unknown");
2379 assert_eq!(op.description, "");
2380 assert_eq!(op.estimated_rows, None);
2381 assert!(op.children.is_empty());
2382 }
2383
2384 #[test]
2385 fn test_parse_single_operation_alt_fields() {
2386 let op_json = serde_json::json!({
2387 "op_type": "Sort",
2388 "desc": "Sort by name ASC",
2389 "rows": 100
2390 });
2391 let op = Connection::parse_single_operation(&op_json);
2392 assert_eq!(op.op_type, "Sort");
2393 assert_eq!(op.description, "Sort by name ASC");
2394 assert_eq!(op.estimated_rows, Some(100));
2395 }
2396
2397 #[test]
2400 fn test_redact_dsn_url_with_password() {
2401 let dsn = "quic://admin:secret123@localhost:3141";
2402 let redacted = redact_dsn(dsn);
2403 assert!(redacted.contains("[REDACTED]"));
2404 assert!(!redacted.contains("secret123"));
2405 assert!(redacted.contains("admin"));
2406 assert!(redacted.contains("localhost"));
2407 }
2408
2409 #[test]
2410 fn test_redact_dsn_url_without_password() {
2411 let dsn = "quic://admin@localhost:3141";
2412 let redacted = redact_dsn(dsn);
2413 assert!(!redacted.contains("[REDACTED]"));
2414 assert!(redacted.contains("admin"));
2415 assert!(redacted.contains("localhost"));
2416 }
2417
2418 #[test]
2419 fn test_redact_dsn_url_no_auth() {
2420 let dsn = "quic://localhost:3141";
2421 let redacted = redact_dsn(dsn);
2422 assert_eq!(redacted, dsn);
2423 }
2424
2425 #[test]
2426 fn test_redact_dsn_query_param_password() {
2427 let dsn = "localhost:3141?username=admin&password=secret123";
2428 let redacted = redact_dsn(dsn);
2429 assert!(redacted.contains("[REDACTED]"));
2430 assert!(!redacted.contains("secret123"));
2431 assert!(redacted.contains("username=admin"));
2432 }
2433
2434 #[test]
2435 fn test_redact_dsn_query_param_pass() {
2436 let dsn = "localhost:3141?user=admin&pass=mysecret";
2437 let redacted = redact_dsn(dsn);
2438 assert!(redacted.contains("[REDACTED]"));
2439 assert!(!redacted.contains("mysecret"));
2440 }
2441
2442 #[test]
2443 fn test_redact_dsn_simple_no_password() {
2444 let dsn = "localhost:3141?insecure=true";
2445 let redacted = redact_dsn(dsn);
2446 assert_eq!(redacted, dsn);
2447 }
2448
2449 #[test]
2450 fn test_redact_dsn_url_with_query_and_password() {
2451 let dsn = "quic://user:pass@localhost:3141?insecure=true";
2452 let redacted = redact_dsn(dsn);
2453 assert!(redacted.contains("[REDACTED]"));
2454 assert!(!redacted.contains(":pass@"));
2455 assert!(redacted.contains("insecure=true"));
2456 }
2457
2458 #[test]
2461 fn test_client_validate_valid() {
2462 let client = Client::new("localhost", 3141);
2463 assert!(client.validate().is_ok());
2464 }
2465
2466 #[test]
2467 fn test_client_validate_valid_hostname() {
2468 let client = Client::new("geode.example.com", 3141);
2469 assert!(client.validate().is_ok());
2470 }
2471
2472 #[test]
2473 fn test_client_validate_valid_ipv4() {
2474 let client = Client::new("192.168.1.1", 8443);
2475 assert!(client.validate().is_ok());
2476 }
2477
2478 #[test]
2479 fn test_client_validate_invalid_hostname_hyphen_start() {
2480 let client = Client::new("-invalid", 3141);
2481 assert!(client.validate().is_err());
2482 }
2483
2484 #[test]
2485 fn test_client_validate_invalid_hostname_hyphen_end() {
2486 let client = Client::new("invalid-", 3141);
2487 assert!(client.validate().is_err());
2488 }
2489
2490 #[test]
2491 fn test_client_validate_invalid_port_zero() {
2492 let client = Client::new("localhost", 0);
2493 assert!(client.validate().is_err());
2494 }
2495
2496 #[test]
2497 fn test_client_validate_invalid_page_size_zero() {
2498 let client = Client::new("localhost", 3141).page_size(0);
2499 assert!(client.validate().is_err());
2500 }
2501
2502 #[test]
2503 fn test_client_validate_invalid_page_size_too_large() {
2504 let client = Client::new("localhost", 3141).page_size(200_000);
2505 assert!(client.validate().is_err());
2506 }
2507
2508 #[test]
2509 fn test_client_validate_with_all_options() {
2510 let client = Client::new("geode.example.com", 8443)
2511 .skip_verify(true)
2512 .page_size(500)
2513 .username("admin")
2514 .password("secret")
2515 .connect_timeout(15)
2516 .hello_timeout(10)
2517 .idle_timeout(60);
2518 assert!(client.validate().is_ok());
2519 }
2520
2521 #[test]
2523 fn test_client_extreme_timeout_values() {
2524 let _client = Client::new("localhost", 3141)
2526 .connect_timeout(u64::MAX)
2527 .hello_timeout(u64::MAX)
2528 .idle_timeout(u64::MAX);
2529 }
2531
2532 #[test]
2533 fn test_convert_edge_uses_type_field() {
2534 let edge = proto::EdgeValue {
2535 id: 100,
2536 from_id: 1,
2537 to_id: 2,
2538 label: "KNOWS".to_string(),
2539 properties: vec![],
2540 };
2541 let proto_val = proto::Value {
2542 kind: Some(proto::value::Kind::EdgeVal(edge)),
2543 };
2544 let val = Connection::convert_proto_value_static(&proto_val);
2545 let obj = val.as_object().unwrap();
2546 assert_eq!(obj.get("type").unwrap().as_string().unwrap(), "KNOWS");
2547 assert!(
2548 obj.get("label").is_none(),
2549 "edge should not have 'label' field"
2550 );
2551 }
2552
2553 #[test]
2554 fn test_convert_edge_uses_start_end_node() {
2555 let edge = proto::EdgeValue {
2556 id: 100,
2557 from_id: 42,
2558 to_id: 99,
2559 label: "LIKES".to_string(),
2560 properties: vec![],
2561 };
2562 let proto_val = proto::Value {
2563 kind: Some(proto::value::Kind::EdgeVal(edge)),
2564 };
2565 let val = Connection::convert_proto_value_static(&proto_val);
2566 let obj = val.as_object().unwrap();
2567 assert_eq!(obj.get("start_node").unwrap().as_int().unwrap(), 42);
2568 assert_eq!(obj.get("end_node").unwrap().as_int().unwrap(), 99);
2569 assert!(obj.get("from_id").is_none());
2570 assert!(obj.get("to_id").is_none());
2571 }
2572
2573 #[test]
2574 fn test_convert_edge_with_properties() {
2575 let edge = proto::EdgeValue {
2576 id: 100,
2577 from_id: 1,
2578 to_id: 2,
2579 label: "KNOWS".to_string(),
2580 properties: vec![proto::MapEntry {
2581 key: "since".to_string(),
2582 value: Some(proto::Value {
2583 kind: Some(proto::value::Kind::IntVal(proto::IntValue {
2584 value: 2020,
2585 kind: 1,
2586 })),
2587 }),
2588 }],
2589 };
2590 let proto_val = proto::Value {
2591 kind: Some(proto::value::Kind::EdgeVal(edge)),
2592 };
2593 let val = Connection::convert_proto_value_static(&proto_val);
2594 let obj = val.as_object().unwrap();
2595 let props = obj.get("properties").unwrap().as_object().unwrap();
2596 assert_eq!(props.get("since").unwrap().as_int().unwrap(), 2020);
2597 }
2598
2599 #[test]
2600 fn test_convert_node_fields() {
2601 let node = proto::NodeValue {
2602 id: 42,
2603 labels: vec!["Person".to_string()],
2604 properties: vec![proto::MapEntry {
2605 key: "name".to_string(),
2606 value: Some(proto::Value {
2607 kind: Some(proto::value::Kind::StringVal(proto::StringValue {
2608 value: "Alice".to_string(),
2609 kind: 1,
2610 })),
2611 }),
2612 }],
2613 };
2614 let proto_val = proto::Value {
2615 kind: Some(proto::value::Kind::NodeVal(node)),
2616 };
2617 let val = Connection::convert_proto_value_static(&proto_val);
2618 let obj = val.as_object().unwrap();
2619 assert_eq!(obj.get("id").unwrap().as_int().unwrap(), 42);
2620 let labels = obj.get("labels").unwrap().as_array().unwrap();
2621 assert_eq!(labels.len(), 1);
2622 let props = obj.get("properties").unwrap().as_object().unwrap();
2623 assert_eq!(props.get("name").unwrap().as_string().unwrap(), "Alice");
2624 }
2625}