1use log::{debug, trace, warn};
5use quinn::{ClientConfig, Endpoint};
6use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName};
7use secrecy::{ExposeSecret, SecretString};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::net::{SocketAddr, ToSocketAddrs};
11use std::sync::Arc;
12use tokio::time::{Duration, timeout};
13
14use crate::dsn::{Dsn, Transport};
15use crate::error::{Error, Result};
16use crate::proto;
17use crate::types::Value;
18use crate::validate;
19
20const GEODE_ALPN: &[u8] = b"geode/1";
21
22#[allow(dead_code)] fn redact_dsn(dsn: &str) -> String {
26 let mut result = dsn.to_string();
27
28 if let Some(scheme_end) = result.find("://") {
31 let after_scheme = scheme_end + 3;
32 if let Some(at_pos) = result[after_scheme..].find('@') {
33 let auth_section = &result[after_scheme..after_scheme + at_pos];
34 if let Some(colon_pos) = auth_section.find(':') {
35 let user = &auth_section[..colon_pos];
37 let rest_start = after_scheme + at_pos;
38 result = format!(
39 "{}{}:{}{}",
40 &result[..after_scheme],
41 user,
42 "[REDACTED]",
43 &result[rest_start..]
44 );
45 }
46 }
47 }
48
49 let patterns = ["password=", "pass="];
52 for pattern in patterns {
53 let lower = result.to_lowercase();
54 if let Some(start) = lower.find(pattern) {
55 let value_start = start + pattern.len();
56 let value_end = result[value_start..]
58 .find('&')
59 .map(|i| value_start + i)
60 .unwrap_or(result.len());
61
62 result = format!(
63 "{}[REDACTED]{}",
64 &result[..value_start],
65 &result[value_end..]
66 );
67 }
68 }
69
70 result
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct Column {
78 pub name: String,
80 #[serde(rename = "type")]
82 pub col_type: String,
83}
84
85#[derive(Debug, Clone)]
106pub struct Page {
107 pub columns: Vec<Column>,
109 pub rows: Vec<HashMap<String, Value>>,
111 pub ordered: bool,
113 pub order_keys: Vec<String>,
115 pub final_page: bool,
117}
118
119#[derive(Debug, Clone)]
139pub struct Savepoint {
140 pub name: String,
142}
143
144#[derive(Debug, Clone)]
163pub struct PreparedStatement {
164 query: String,
166 param_names: Vec<String>,
168}
169
170impl PreparedStatement {
171 pub fn new(query: impl Into<String>) -> Self {
175 let query = query.into();
176 let param_names = Self::extract_param_names(&query);
177 Self { query, param_names }
178 }
179
180 fn extract_param_names(query: &str) -> Vec<String> {
182 let mut names = Vec::new();
183 let mut chars = query.chars().peekable();
184
185 while let Some(c) = chars.next() {
186 if c == '$' {
187 let mut name = String::new();
188 while let Some(&next) = chars.peek() {
189 if next.is_ascii_alphanumeric() || next == '_' {
190 name.push(chars.next().unwrap());
191 } else {
192 break;
193 }
194 }
195 if !name.is_empty() && !names.contains(&name) {
196 names.push(name);
197 }
198 }
199 }
200
201 names
202 }
203
204 pub fn query(&self) -> &str {
206 &self.query
207 }
208
209 pub fn param_names(&self) -> &[String] {
211 &self.param_names
212 }
213
214 pub async fn execute(
229 &self,
230 conn: &mut Connection,
231 params: &HashMap<String, crate::types::Value>,
232 ) -> crate::error::Result<(Page, Option<String>)> {
233 for name in &self.param_names {
235 if !params.contains_key(name) {
236 return Err(crate::error::Error::validation(format!(
237 "Missing required parameter: {}",
238 name
239 )));
240 }
241 }
242
243 conn.query_with_params(&self.query, params).await
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct PlanOperation {
250 pub op_type: String,
252 pub description: String,
254 pub estimated_rows: Option<u64>,
256 pub children: Vec<PlanOperation>,
258}
259
260#[derive(Debug, Clone)]
265pub struct QueryPlan {
266 pub operations: Vec<PlanOperation>,
268 pub estimated_rows: u64,
270 pub raw: serde_json::Value,
272}
273
274#[derive(Debug, Clone)]
278pub struct QueryProfile {
279 pub plan: QueryPlan,
281 pub actual_rows: u64,
283 pub execution_time_ms: f64,
285 pub raw: serde_json::Value,
287}
288
289#[derive(Clone)]
328pub struct Client {
329 transport: Transport,
330 host: String,
331 port: u16,
332 tls_enabled: bool,
333 skip_verify: bool,
334 page_size: usize,
335 hello_name: String,
336 hello_ver: String,
337 conformance: String,
338 username: Option<String>,
339 password: Option<SecretString>,
342 connect_timeout_secs: u64,
344 hello_timeout_secs: u64,
346 idle_timeout_secs: u64,
348}
349
350impl Client {
351 pub fn new(host: impl Into<String>, port: u16) -> Self {
371 Self {
372 transport: Transport::Quic,
373 host: host.into(),
374 port,
375 tls_enabled: true,
376 skip_verify: false,
377 page_size: 1000,
378 hello_name: "geode-rust".to_string(),
379 hello_ver: env!("CARGO_PKG_VERSION").to_string(),
380 conformance: "min".to_string(),
381 username: None,
382 password: None,
383 connect_timeout_secs: 10,
384 hello_timeout_secs: 5,
385 idle_timeout_secs: 30,
386 }
387 }
388
389 pub fn from_dsn(dsn_str: &str) -> Result<Self> {
437 let dsn = Dsn::parse(dsn_str)?;
438
439 Ok(Self {
440 transport: dsn.transport(),
441 host: dsn.host().to_string(),
442 port: dsn.port(),
443 tls_enabled: dsn.tls_enabled(),
444 skip_verify: dsn.skip_verify(),
445 page_size: dsn.page_size(),
446 hello_name: dsn.client_name().to_string(),
447 hello_ver: dsn.client_version().to_string(),
448 conformance: dsn.conformance().to_string(),
449 username: dsn.username().map(String::from),
450 password: dsn.password().map(|p| SecretString::from(p.to_string())),
451 connect_timeout_secs: 10,
452 hello_timeout_secs: 5,
453 idle_timeout_secs: 30,
454 })
455 }
456
457 pub fn transport(&self) -> Transport {
459 self.transport
460 }
461
462 pub fn skip_verify(mut self, skip: bool) -> Self {
474 self.skip_verify = skip;
475 self
476 }
477
478 pub fn page_size(mut self, size: usize) -> Self {
487 self.page_size = size;
488 self
489 }
490
491 pub fn client_name(mut self, name: impl Into<String>) -> Self {
499 self.hello_name = name.into();
500 self
501 }
502
503 pub fn client_version(mut self, version: impl Into<String>) -> Self {
509 self.hello_ver = version.into();
510 self
511 }
512
513 pub fn conformance(mut self, level: impl Into<String>) -> Self {
519 self.conformance = level.into();
520 self
521 }
522
523 pub fn username(mut self, username: impl Into<String>) -> Self {
539 self.username = Some(username.into());
540 self
541 }
542
543 pub fn password(mut self, password: impl Into<String>) -> Self {
553 self.password = Some(SecretString::from(password.into()));
554 self
555 }
556
557 pub fn connect_timeout(mut self, seconds: u64) -> Self {
566 self.connect_timeout_secs = seconds.max(1);
567 self
568 }
569
570 pub fn hello_timeout(mut self, seconds: u64) -> Self {
579 self.hello_timeout_secs = seconds.max(1);
580 self
581 }
582
583 pub fn idle_timeout(mut self, seconds: u64) -> Self {
592 self.idle_timeout_secs = seconds.max(1);
593 self
594 }
595
596 pub fn validate(&self) -> Result<()> {
624 validate::hostname(&self.host)?;
626
627 validate::port(self.port)?;
629
630 validate::page_size(self.page_size)?;
632
633 Ok(())
634 }
635
636 pub async fn connect(&self) -> Result<Connection> {
667 self.validate()?;
669
670 let password_ref = self.password.as_ref().map(|s| s.expose_secret());
672
673 match self.transport {
674 Transport::Quic => {
675 Connection::new_quic(
676 &self.host,
677 self.port,
678 self.skip_verify,
679 self.page_size,
680 &self.hello_name,
681 &self.hello_ver,
682 &self.conformance,
683 self.username.as_deref(),
684 password_ref,
685 self.connect_timeout_secs,
686 self.hello_timeout_secs,
687 self.idle_timeout_secs,
688 )
689 .await
690 }
691 Transport::Grpc => {
692 #[cfg(feature = "grpc")]
693 {
694 Connection::new_grpc(
695 &self.host,
696 self.port,
697 self.tls_enabled,
698 self.skip_verify,
699 self.page_size,
700 self.username.as_deref(),
701 password_ref,
702 )
703 .await
704 }
705 #[cfg(not(feature = "grpc"))]
706 {
707 Err(Error::connection(
708 "gRPC transport requires the 'grpc' feature to be enabled",
709 ))
710 }
711 }
712 }
713 }
714}
715
716#[allow(dead_code)]
718enum ConnectionKind {
719 Quic {
721 conn: quinn::Connection,
722 send: quinn::SendStream,
723 recv: quinn::RecvStream,
724 buffer: Vec<u8>,
726 next_request_id: u64,
728 session_id: String,
730 },
731 #[cfg(feature = "grpc")]
733 Grpc { client: crate::grpc::GrpcClient },
734}
735
736pub struct Connection {
781 kind: ConnectionKind,
782 #[allow(dead_code)]
784 page_size: usize,
785}
786
787impl Connection {
788 #[allow(clippy::too_many_arguments)]
790 async fn new_quic(
791 host: &str,
792 port: u16,
793 skip_verify: bool,
794 page_size: usize,
795 hello_name: &str,
796 hello_ver: &str,
797 conformance: &str,
798 username: Option<&str>,
799 password: Option<&str>,
800 connect_timeout_secs: u64,
801 hello_timeout_secs: u64,
802 idle_timeout_secs: u64,
803 ) -> Result<Self> {
804 let mut last_err: Option<Error> = None;
805
806 for attempt in 1..=3 {
807 match Self::connect_quic_once(
808 host,
809 port,
810 skip_verify,
811 page_size,
812 hello_name,
813 hello_ver,
814 conformance,
815 username,
816 password,
817 connect_timeout_secs,
818 hello_timeout_secs,
819 idle_timeout_secs,
820 )
821 .await
822 {
823 Ok(conn) => return Ok(conn),
824 Err(e) => {
825 last_err = Some(e);
826 if attempt < 3 {
827 debug!("Connection attempt {} failed, retrying...", attempt);
828 tokio::time::sleep(Duration::from_millis(150)).await;
829 }
830 }
831 }
832 }
833
834 Err(last_err.unwrap_or_else(|| Error::connection("Failed to connect")))
835 }
836
837 #[cfg(feature = "grpc")]
839 #[allow(clippy::too_many_arguments)]
840 async fn new_grpc(
841 host: &str,
842 port: u16,
843 tls_enabled: bool,
844 skip_verify: bool,
845 page_size: usize,
846 username: Option<&str>,
847 password: Option<&str>,
848 ) -> Result<Self> {
849 use crate::dsn::Dsn;
850
851 let tls_val = if tls_enabled { "1" } else { "0" };
853 let dsn_str = if let (Some(user), Some(pass)) = (username, password) {
854 format!(
855 "grpc://{}:{}@{}:{}?tls={}&insecure={}",
856 user, pass, host, port, tls_val, skip_verify
857 )
858 } else {
859 format!(
860 "grpc://{}:{}?tls={}&insecure={}",
861 host, port, tls_val, skip_verify
862 )
863 };
864
865 let dsn = Dsn::parse(&dsn_str)?;
866 let client = crate::grpc::GrpcClient::connect(&dsn).await?;
867
868 Ok(Self {
869 kind: ConnectionKind::Grpc { client },
870 page_size,
871 })
872 }
873
874 #[allow(clippy::too_many_arguments)]
875 async fn connect_quic_once(
876 host: &str,
877 port: u16,
878 skip_verify: bool,
879 page_size: usize,
880 _hello_name: &str,
881 _hello_ver: &str,
882 _conformance: &str,
883 username: Option<&str>,
884 password: Option<&str>,
885 connect_timeout_secs: u64,
886 _hello_timeout_secs: u64,
887 idle_timeout_secs: u64,
888 ) -> Result<Self> {
889 debug!("Creating connection to {}:{}", host, port);
890
891 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
893
894 let mut client_crypto = if skip_verify {
896 warn!(
899 "TLS certificate verification DISABLED - connection to {}:{} is vulnerable to MITM attacks. \
900 Do NOT use skip_verify in production!",
901 host, port
902 );
903 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
904 .dangerous()
905 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
906 .with_no_client_auth()
907 } else {
908 let mut root_store = rustls::RootCertStore::empty();
910
911 let cert_result = rustls_native_certs::load_native_certs();
912
913 for err in &cert_result.errors {
915 warn!("Error loading native certificate: {:?}", err);
916 }
917
918 let mut certs_loaded = 0;
919 let mut certs_failed = 0;
920
921 for cert in cert_result.certs {
922 match root_store.add(cert) {
923 Ok(()) => certs_loaded += 1,
924 Err(_) => certs_failed += 1,
925 }
926 }
927
928 if certs_loaded == 0 {
929 return Err(Error::tls(
930 "No system root certificates found. TLS verification cannot proceed. \
931 Either install system CA certificates or use skip_verify(true) for development only.",
932 ));
933 }
934
935 debug!(
936 "Loaded {} system root certificates ({} failed to parse)",
937 certs_loaded, certs_failed
938 );
939
940 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
941 .with_root_certificates(root_store)
942 .with_no_client_auth()
943 };
944
945 client_crypto.alpn_protocols = vec![GEODE_ALPN.to_vec()];
947
948 let mut client_config = ClientConfig::new(Arc::new(
949 quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)
950 .map_err(|e| Error::connection(format!("Failed to create QUIC config: {}", e)))?,
951 ));
952
953 let mut transport = quinn::TransportConfig::default();
955 let idle_timeout = Duration::from_secs(idle_timeout_secs.min(146_000 * 365 * 24 * 3600));
958 transport.max_idle_timeout(Some(idle_timeout.try_into().map_err(|_| {
959 Error::connection("Idle timeout value too large for QUIC protocol")
960 })?));
961 transport.keep_alive_interval(Some(Duration::from_secs(5)));
962 client_config.transport_config(Arc::new(transport));
963
964 let mut endpoint = Endpoint::client(
967 "0.0.0.0:0"
968 .parse()
969 .expect("0.0.0.0:0 is a valid socket address"),
970 )
971 .map_err(|e| Error::connection(format!("Failed to create endpoint: {}", e)))?;
972 endpoint.set_default_client_config(client_config);
973
974 let mut resolved_addrs = format!("{}:{}", host, port)
976 .to_socket_addrs()
977 .map_err(|e| {
978 Error::connection(format!(
979 "Failed to resolve address {}:{} - {}",
980 host, port, e
981 ))
982 })?;
983
984 let server_addr: SocketAddr = resolved_addrs
985 .find(|addr| matches!(addr, SocketAddr::V4(_) | SocketAddr::V6(_)))
986 .ok_or_else(|| Error::connection("Invalid address: could not resolve host"))?;
987
988 debug!("Connecting to {}", server_addr);
989
990 let server_name = if skip_verify {
993 "localhost" } else {
995 host
996 };
997
998 trace!("Using server name for SNI: {}", server_name);
999
1000 let conn = timeout(
1001 Duration::from_secs(connect_timeout_secs),
1002 endpoint
1003 .connect(server_addr, server_name)
1004 .map_err(|e| Error::connection(format!("Connection failed: {}", e)))?,
1005 )
1006 .await
1007 .map_err(|_| Error::connection("Connection timeout"))?
1008 .map_err(|e| Error::connection(format!("Failed to establish connection: {}", e)))?;
1009
1010 debug!("Connection established to {}:{}", host, port);
1011
1012 let (mut send, mut recv) = conn
1014 .open_bi()
1015 .await
1016 .map_err(|e| Error::connection(format!("Failed to open stream: {}", e)))?;
1017
1018 let hello_req = proto::HelloRequest {
1020 username: username.unwrap_or("").to_string(),
1021 password: password.unwrap_or("").to_string(),
1022 tenant_id: None,
1023 client_name: String::new(),
1024 client_version: String::new(),
1025 wanted_conformance: String::new(),
1026 };
1027 let msg = proto::QuicClientMessage {
1028 msg: Some(proto::quic_client_message::Msg::Hello(hello_req)),
1029 };
1030 let data = proto::encode_with_length_prefix(&msg);
1031
1032 send.write_all(&data)
1033 .await
1034 .map_err(|e| Error::connection(format!("Failed to send HELLO: {}", e)))?;
1035
1036 let mut length_buf = [0u8; 4];
1038 timeout(Duration::from_secs(5), recv.read_exact(&mut length_buf))
1039 .await
1040 .map_err(|_| Error::connection("HELLO response timeout"))?
1041 .map_err(|e| {
1042 Error::connection(format!("Failed to read HELLO response length: {}", e))
1043 })?;
1044
1045 let msg_len = u32::from_be_bytes(length_buf) as usize;
1046 let mut msg_buf = vec![0u8; msg_len];
1047 recv.read_exact(&mut msg_buf)
1048 .await
1049 .map_err(|e| Error::connection(format!("Failed to read HELLO response body: {}", e)))?;
1050
1051 let hello_response = proto::decode_quic_server_message(&msg_buf)?;
1052
1053 let session_id = match hello_response.msg {
1054 Some(proto::quic_server_message::Msg::Hello(ref hello_resp)) => {
1055 if !hello_resp.success {
1056 return Err(Error::connection(format!(
1057 "Authentication failed: {}",
1058 hello_resp.error_message
1059 )));
1060 }
1061 hello_resp.session_id.clone()
1062 }
1063 _ => {
1064 return Err(Error::connection("Expected HELLO response"));
1065 }
1066 };
1067
1068 debug!("HELLO handshake complete, session_id={}", session_id);
1069
1070 Ok(Self {
1071 kind: ConnectionKind::Quic {
1072 conn,
1073 send,
1074 recv,
1075 buffer: Vec::new(),
1076 next_request_id: 1,
1077 session_id,
1078 },
1079 page_size,
1080 })
1081 }
1082
1083 async fn send_proto_quic(
1085 send: &mut quinn::SendStream,
1086 msg: &proto::QuicClientMessage,
1087 ) -> Result<()> {
1088 let data = proto::encode_with_length_prefix(msg);
1089 send.write_all(&data)
1090 .await
1091 .map_err(|e| Error::connection(format!("Failed to send message: {}", e)))?;
1092 Ok(())
1093 }
1094
1095 async fn read_proto_quic(
1097 recv: &mut quinn::RecvStream,
1098 timeout_secs: u64,
1099 ) -> Result<proto::QuicServerMessage> {
1100 let mut length_buf = [0u8; 4];
1102 timeout(
1103 Duration::from_secs(timeout_secs),
1104 recv.read_exact(&mut length_buf),
1105 )
1106 .await
1107 .map_err(|_| Error::timeout())?
1108 .map_err(|e| Error::connection(format!("Failed to read response length: {}", e)))?;
1109
1110 let msg_len = u32::from_be_bytes(length_buf) as usize;
1111 let mut msg_buf = vec![0u8; msg_len];
1112 recv.read_exact(&mut msg_buf)
1113 .await
1114 .map_err(|e| Error::connection(format!("Failed to read response body: {}", e)))?;
1115
1116 proto::decode_quic_server_message(&msg_buf)
1117 }
1118
1119 async fn try_read_proto_quic(
1122 recv: &mut quinn::RecvStream,
1123 ) -> Result<Option<proto::QuicServerMessage>> {
1124 let mut length_buf = [0u8; 4];
1126 let read_result = timeout(
1127 Duration::from_millis(5000),
1128 recv.read_exact(&mut length_buf),
1129 )
1130 .await;
1131
1132 match read_result {
1133 Ok(Ok(())) => {
1134 let msg_len = u32::from_be_bytes(length_buf) as usize;
1135 let mut msg_buf = vec![0u8; msg_len];
1136 recv.read_exact(&mut msg_buf).await.map_err(|e| {
1137 Error::connection(format!("Failed to read response body: {}", e))
1138 })?;
1139 let msg = proto::decode_quic_server_message(&msg_buf)?;
1140 Ok(Some(msg))
1141 }
1142 Ok(Err(e)) => Err(Error::connection(format!("Failed to read response: {}", e))),
1143 Err(_) => Ok(None), }
1145 }
1146
1147 fn parse_proto_rows_static(
1149 proto_rows: &[proto::Row],
1150 columns: &[Column],
1151 ) -> Result<Vec<HashMap<String, Value>>> {
1152 let mut rows = Vec::new();
1153 for proto_row in proto_rows {
1154 let mut row = HashMap::new();
1155 for (i, col) in columns.iter().enumerate() {
1156 let value = if i < proto_row.values.len() {
1157 Self::convert_proto_value_static(&proto_row.values[i])
1158 } else {
1159 Value::null()
1160 };
1161 row.insert(col.name.clone(), value);
1162 }
1163 rows.push(row);
1164 }
1165 Ok(rows)
1166 }
1167
1168 fn convert_proto_value_static(proto_val: &proto::Value) -> Value {
1170 match &proto_val.kind {
1171 Some(proto::value::Kind::NullVal(_)) => Value::null(),
1172 Some(proto::value::Kind::StringVal(s)) => Value::string(s.value.clone()),
1173 Some(proto::value::Kind::IntVal(i)) => Value::int(i.value),
1174 Some(proto::value::Kind::DoubleVal(d)) => {
1175 Value::decimal(rust_decimal::Decimal::from_f64_retain(d.value).unwrap_or_default())
1176 }
1177 Some(proto::value::Kind::BoolVal(b)) => Value::bool(*b),
1178 Some(proto::value::Kind::ListVal(list)) => {
1179 let values: Vec<Value> = list
1180 .values
1181 .iter()
1182 .map(Self::convert_proto_value_static)
1183 .collect();
1184 Value::array(values)
1185 }
1186 Some(proto::value::Kind::MapVal(map)) => {
1187 let mut obj = std::collections::HashMap::new();
1188 for entry in &map.entries {
1189 let val = entry
1190 .value
1191 .as_ref()
1192 .map(Self::convert_proto_value_static)
1193 .unwrap_or_else(Value::null);
1194 obj.insert(entry.key.clone(), val);
1195 }
1196 Value::object(obj)
1197 }
1198 Some(proto::value::Kind::NodeVal(node)) => {
1199 let mut obj = std::collections::HashMap::new();
1200 obj.insert("id".to_string(), Value::int(node.id as i64));
1201 let labels: Vec<Value> = node
1202 .labels
1203 .iter()
1204 .map(|l| Value::string(l.clone()))
1205 .collect();
1206 obj.insert("labels".to_string(), Value::array(labels));
1207 let mut props = std::collections::HashMap::new();
1208 for entry in &node.properties {
1209 let val = entry
1210 .value
1211 .as_ref()
1212 .map(Self::convert_proto_value_static)
1213 .unwrap_or_else(Value::null);
1214 props.insert(entry.key.clone(), val);
1215 }
1216 obj.insert("properties".to_string(), Value::object(props));
1217 Value::object(obj)
1218 }
1219 Some(proto::value::Kind::EdgeVal(edge)) => {
1220 let mut obj = std::collections::HashMap::new();
1221 obj.insert("id".to_string(), Value::int(edge.id as i64));
1222 obj.insert("start_node".to_string(), Value::int(edge.from_id as i64));
1223 obj.insert("end_node".to_string(), Value::int(edge.to_id as i64));
1224 obj.insert("type".to_string(), Value::string(edge.label.clone()));
1225 let mut props = std::collections::HashMap::new();
1226 for entry in &edge.properties {
1227 let val = entry
1228 .value
1229 .as_ref()
1230 .map(Self::convert_proto_value_static)
1231 .unwrap_or_else(Value::null);
1232 props.insert(entry.key.clone(), val);
1233 }
1234 obj.insert("properties".to_string(), Value::object(props));
1235 Value::object(obj)
1236 }
1237 Some(proto::value::Kind::DecimalVal(d)) => {
1238 if let Ok(dec) = d.coeff.parse::<rust_decimal::Decimal>() {
1240 Value::decimal(dec)
1241 } else {
1242 Value::string(d.orig_repr.clone())
1243 }
1244 }
1245 Some(proto::value::Kind::BytesVal(b)) => {
1246 Value::string(format!("\\x{}", hex::encode(&b.value)))
1247 }
1248 _ => Value::null(),
1249 }
1250 }
1251
1252 async fn send_begin_quic(
1254 send: &mut quinn::SendStream,
1255 recv: &mut quinn::RecvStream,
1256 session_id: &str,
1257 ) -> Result<()> {
1258 let msg = proto::QuicClientMessage {
1259 msg: Some(proto::quic_client_message::Msg::Begin(
1260 proto::BeginRequest {
1261 session_id: session_id.to_string(),
1262 ..Default::default()
1263 },
1264 )),
1265 };
1266 Self::send_proto_quic(send, &msg).await?;
1267
1268 let resp = Self::read_proto_quic(recv, 5).await?;
1269 if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Begin(_))) {
1270 return Err(Error::protocol("Expected BEGIN response"));
1271 }
1272 Ok(())
1273 }
1274
1275 async fn send_commit_quic(
1277 send: &mut quinn::SendStream,
1278 recv: &mut quinn::RecvStream,
1279 session_id: &str,
1280 ) -> Result<()> {
1281 let msg = proto::QuicClientMessage {
1282 msg: Some(proto::quic_client_message::Msg::Commit(
1283 proto::CommitRequest {
1284 session_id: session_id.to_string(),
1285 },
1286 )),
1287 };
1288 Self::send_proto_quic(send, &msg).await?;
1289
1290 let resp = Self::read_proto_quic(recv, 5).await?;
1291 if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Commit(_))) {
1292 return Err(Error::protocol("Expected COMMIT response"));
1293 }
1294 Ok(())
1295 }
1296
1297 async fn send_rollback_quic(
1299 send: &mut quinn::SendStream,
1300 recv: &mut quinn::RecvStream,
1301 session_id: &str,
1302 ) -> Result<()> {
1303 let msg = proto::QuicClientMessage {
1304 msg: Some(proto::quic_client_message::Msg::Rollback(
1305 proto::RollbackRequest {
1306 session_id: session_id.to_string(),
1307 },
1308 )),
1309 };
1310 Self::send_proto_quic(send, &msg).await?;
1311
1312 let resp = Self::read_proto_quic(recv, 5).await?;
1313 if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Rollback(_))) {
1314 return Err(Error::protocol("Expected ROLLBACK response"));
1315 }
1316 Ok(())
1317 }
1318
1319 pub async fn query(&mut self, gql: &str) -> Result<(Page, Option<String>)> {
1349 self.query_with_params(gql, &HashMap::new()).await
1350 }
1351
1352 pub async fn query_with_params(
1392 &mut self,
1393 gql: &str,
1394 params: &HashMap<String, Value>,
1395 ) -> Result<(Page, Option<String>)> {
1396 match &mut self.kind {
1397 ConnectionKind::Quic {
1398 send,
1399 recv,
1400 session_id,
1401 ..
1402 } => Self::query_with_params_quic(send, recv, gql, params, session_id).await,
1403 #[cfg(feature = "grpc")]
1404 ConnectionKind::Grpc { client } => client.query_with_params(gql, params).await,
1405 }
1406 }
1407
1408 async fn query_with_params_quic(
1410 send: &mut quinn::SendStream,
1411 recv: &mut quinn::RecvStream,
1412 gql: &str,
1413 params: &HashMap<String, Value>,
1414 session_id: &str,
1415 ) -> Result<(Page, Option<String>)> {
1416 let (page, cursor) =
1417 Self::query_with_params_quic_inner(send, recv, gql, params, session_id).await?;
1418
1419 if !page.final_page {
1421 let mut all_rows = page.rows;
1422 let columns = page.columns;
1423 let mut ordered = page.ordered;
1424 let mut order_keys = page.order_keys;
1425 let mut request_id: u64 = 0;
1426
1427 loop {
1428 request_id += 1;
1429 let pull_req = proto::QuicClientMessage {
1430 msg: Some(proto::quic_client_message::Msg::Pull(proto::PullRequest {
1431 request_id,
1432 page_size: 1000,
1433 session_id: String::new(),
1434 })),
1435 };
1436 Self::send_proto_quic(send, &pull_req).await?;
1437
1438 let resp = Self::read_proto_quic(recv, 30).await?;
1439
1440 let exec_resp = match &resp.msg {
1442 Some(proto::quic_server_message::Msg::Pull(pull)) => pull.response.as_ref(),
1443 Some(proto::quic_server_message::Msg::Execute(e)) => Some(e),
1444 _ => None,
1445 };
1446
1447 let exec_resp = match exec_resp {
1448 Some(e) => e,
1449 None => break,
1450 };
1451
1452 if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload
1453 {
1454 return Err(Error::Query {
1455 code: err.code.clone(),
1456 message: err.message.clone(),
1457 });
1458 }
1459
1460 if let Some(proto::execution_response::Payload::Page(ref page_data)) =
1461 exec_resp.payload
1462 {
1463 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1464 all_rows.extend(rows);
1465 ordered = page_data.ordered;
1466 order_keys = page_data.order_keys.clone();
1467 if page_data.r#final {
1468 break;
1469 }
1470 } else {
1471 break;
1472 }
1473 }
1474
1475 let final_page = Page {
1476 columns,
1477 rows: all_rows,
1478 ordered,
1479 order_keys,
1480 final_page: true,
1481 };
1482 return Ok((final_page, cursor));
1483 }
1484
1485 Ok((page, cursor))
1486 }
1487
1488 async fn query_with_params_quic_inner(
1490 send: &mut quinn::SendStream,
1491 recv: &mut quinn::RecvStream,
1492 gql: &str,
1493 params: &HashMap<String, Value>,
1494 session_id: &str,
1495 ) -> Result<(Page, Option<String>)> {
1496 let params_proto: Vec<proto::Param> = params
1498 .iter()
1499 .map(|(k, v)| proto::Param {
1500 name: k.clone(),
1501 value: Some(v.to_proto_value()),
1502 })
1503 .collect();
1504
1505 let exec_req = proto::ExecuteRequest {
1507 session_id: session_id.to_string(),
1508 query: gql.to_string(),
1509 params: params_proto,
1510 };
1511 let msg = proto::QuicClientMessage {
1512 msg: Some(proto::quic_client_message::Msg::Execute(exec_req)),
1513 };
1514 Self::send_proto_quic(send, &msg)
1515 .await
1516 .map_err(|e| Error::query(format!("{}", e)))?;
1517
1518 let resp = Self::read_proto_quic(recv, 10).await?;
1520
1521 let exec_resp = match resp.msg {
1522 Some(proto::quic_server_message::Msg::Execute(e)) => e,
1523 _ => return Err(Error::protocol("Expected Execute response")),
1524 };
1525
1526 if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload {
1528 let _ = Self::try_read_proto_quic(recv).await;
1530 return Err(Error::Query {
1531 code: err.code.clone(),
1532 message: err.message.clone(),
1533 });
1534 }
1535
1536 let columns: Vec<Column> = match exec_resp.payload {
1538 Some(proto::execution_response::Payload::Schema(ref s)) => s
1539 .columns
1540 .iter()
1541 .map(|c| Column {
1542 name: c.name.clone(),
1543 col_type: c.r#type.clone(),
1544 })
1545 .collect(),
1546 _ => Vec::new(),
1547 };
1548
1549 trace!("Schema columns: {:?}", columns);
1550
1551 if let Some(inline_resp) = Self::try_read_proto_quic(recv).await? {
1553 if let Some(proto::quic_server_message::Msg::Execute(inline_exec)) = inline_resp.msg {
1554 if let Some(proto::execution_response::Payload::Error(ref err)) =
1555 inline_exec.payload
1556 {
1557 return Err(Error::Query {
1558 code: err.code.clone(),
1559 message: err.message.clone(),
1560 });
1561 }
1562
1563 if let Some(proto::execution_response::Payload::Page(ref page_data)) =
1564 inline_exec.payload
1565 {
1566 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1567 let page = Page {
1568 columns,
1569 rows,
1570 ordered: page_data.ordered,
1571 order_keys: page_data.order_keys.clone(),
1572 final_page: page_data.r#final,
1573 };
1574 return Ok((page, None));
1575 }
1576
1577 let page = Page {
1579 columns,
1580 rows: Vec::new(),
1581 ordered: false,
1582 order_keys: Vec::new(),
1583 final_page: true,
1584 };
1585 return Ok((page, None));
1586 }
1587 }
1588
1589 if let Some(proto::execution_response::Payload::Page(ref page_data)) = exec_resp.payload {
1591 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1592 let page = Page {
1593 columns,
1594 rows,
1595 ordered: page_data.ordered,
1596 order_keys: page_data.order_keys.clone(),
1597 final_page: page_data.r#final,
1598 };
1599 return Ok((page, None));
1600 }
1601
1602 let resp = Self::read_proto_quic(recv, 30).await?;
1604 if let Some(proto::quic_server_message::Msg::Execute(exec_resp)) = resp.msg {
1605 if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload {
1606 return Err(Error::Query {
1607 code: err.code.clone(),
1608 message: err.message.clone(),
1609 });
1610 }
1611
1612 if let Some(proto::execution_response::Payload::Page(ref page_data)) = exec_resp.payload
1613 {
1614 let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
1615 let page = Page {
1616 columns,
1617 rows,
1618 ordered: page_data.ordered,
1619 order_keys: page_data.order_keys.clone(),
1620 final_page: page_data.r#final,
1621 };
1622 return Ok((page, None));
1623 }
1624 }
1625
1626 let page = Page {
1628 columns,
1629 rows: Vec::new(),
1630 ordered: false,
1631 order_keys: Vec::new(),
1632 final_page: true,
1633 };
1634
1635 Ok((page, None))
1636 }
1637
1638 pub fn query_sync(
1640 &mut self,
1641 gql: &str,
1642 params: Option<HashMap<String, serde_json::Value>>,
1643 ) -> Result<Page> {
1644 let params_map = params.unwrap_or_default();
1645 let params_typed: HashMap<String, Value> = params_map
1646 .into_iter()
1647 .map(|(k, v)| {
1648 let typed_val = crate::types::Value::from_json(v);
1649 (k, typed_val)
1650 })
1651 .collect();
1652
1653 match tokio::runtime::Handle::try_current() {
1654 Ok(handle) => {
1655 let (page, _cursor) =
1656 handle.block_on(self.query_with_params(gql, ¶ms_typed))?;
1657 Ok(page)
1658 }
1659 Err(_) => {
1660 let rt = tokio::runtime::Runtime::new()
1661 .map_err(|e| Error::query(format!("Failed to create runtime: {}", e)))?;
1662 let (page, _cursor) = rt.block_on(self.query_with_params(gql, ¶ms_typed))?;
1663 Ok(page)
1664 }
1665 }
1666 }
1667
1668 pub async fn begin(&mut self) -> Result<()> {
1693 match &mut self.kind {
1694 ConnectionKind::Quic {
1695 send,
1696 recv,
1697 session_id,
1698 ..
1699 } => Self::send_begin_quic(send, recv, session_id).await,
1700 #[cfg(feature = "grpc")]
1701 ConnectionKind::Grpc { client } => client.begin().await,
1702 }
1703 }
1704
1705 pub async fn commit(&mut self) -> Result<()> {
1728 match &mut self.kind {
1729 ConnectionKind::Quic {
1730 send,
1731 recv,
1732 session_id,
1733 ..
1734 } => Self::send_commit_quic(send, recv, session_id).await,
1735 #[cfg(feature = "grpc")]
1736 ConnectionKind::Grpc { client } => client.commit().await,
1737 }
1738 }
1739
1740 pub async fn rollback(&mut self) -> Result<()> {
1764 match &mut self.kind {
1765 ConnectionKind::Quic {
1766 send,
1767 recv,
1768 session_id,
1769 ..
1770 } => Self::send_rollback_quic(send, recv, session_id).await,
1771 #[cfg(feature = "grpc")]
1772 ConnectionKind::Grpc { client } => client.rollback().await,
1773 }
1774 }
1775
1776 pub fn prepare(&self, query: &str) -> Result<PreparedStatement> {
1810 Ok(PreparedStatement::new(query))
1811 }
1812
1813 pub async fn explain(&mut self, gql: &str) -> Result<QueryPlan> {
1846 let explain_query = format!("EXPLAIN {}", gql);
1848 let (_page, _) = self.query(&explain_query).await?;
1849
1850 Ok(QueryPlan {
1853 operations: Vec::new(),
1854 estimated_rows: 0,
1855 raw: serde_json::json!({}),
1856 })
1857 }
1858
1859 pub async fn profile(&mut self, gql: &str) -> Result<QueryProfile> {
1890 let profile_query = format!("PROFILE {}", gql);
1892 let (page, _) = self.query(&profile_query).await?;
1893
1894 let plan = QueryPlan {
1896 operations: Vec::new(),
1897 estimated_rows: 0,
1898 raw: serde_json::json!({}),
1899 };
1900
1901 Ok(QueryProfile {
1902 plan,
1903 actual_rows: page.rows.len() as u64,
1904 execution_time_ms: 0.0,
1905 raw: serde_json::json!({}),
1906 })
1907 }
1908
1909 pub async fn batch(
1947 &mut self,
1948 queries: &[(&str, Option<&HashMap<String, Value>>)],
1949 ) -> Result<Vec<Page>> {
1950 let mut results = Vec::with_capacity(queries.len());
1951
1952 for (query, params) in queries {
1953 let (page, _) = match params {
1954 Some(p) => self.query_with_params(query, p).await?,
1955 None => self.query(query).await?,
1956 };
1957 results.push(page);
1958 }
1959
1960 Ok(results)
1961 }
1962
1963 #[allow(dead_code)]
1966 fn parse_plan_operations(result: &serde_json::Value) -> Vec<PlanOperation> {
1967 let mut operations = Vec::new();
1968
1969 if let Some(ops) = result.get("operations").and_then(|o| o.as_array()) {
1970 for op in ops {
1971 operations.push(Self::parse_single_operation(op));
1972 }
1973 } else if let Some(plan) = result.get("plan") {
1974 operations.push(Self::parse_single_operation(plan));
1976 }
1977
1978 operations
1979 }
1980
1981 #[allow(dead_code)]
1983 fn parse_single_operation(op: &serde_json::Value) -> PlanOperation {
1984 let op_type = op
1985 .get("type")
1986 .or_else(|| op.get("op_type"))
1987 .and_then(|t| t.as_str())
1988 .unwrap_or("Unknown")
1989 .to_string();
1990
1991 let description = op
1992 .get("description")
1993 .or_else(|| op.get("desc"))
1994 .and_then(|d| d.as_str())
1995 .unwrap_or("")
1996 .to_string();
1997
1998 let estimated_rows = op
1999 .get("estimated_rows")
2000 .or_else(|| op.get("rows"))
2001 .and_then(|r| r.as_u64());
2002
2003 let children = op
2004 .get("children")
2005 .and_then(|c| c.as_array())
2006 .map(|arr| arr.iter().map(Self::parse_single_operation).collect())
2007 .unwrap_or_default();
2008
2009 PlanOperation {
2010 op_type,
2011 description,
2012 estimated_rows,
2013 children,
2014 }
2015 }
2016
2017 pub fn close(&mut self) -> Result<()> {
2040 match &mut self.kind {
2041 ConnectionKind::Quic { conn, .. } => {
2042 conn.close(0u32.into(), b"client closing");
2046 Ok(())
2047 }
2048 #[cfg(feature = "grpc")]
2049 ConnectionKind::Grpc { client } => client.close(),
2050 }
2051 }
2052
2053 pub fn is_healthy(&self) -> bool {
2068 match &self.kind {
2069 ConnectionKind::Quic { conn, .. } => {
2070 conn.close_reason().is_none()
2072 }
2073 #[cfg(feature = "grpc")]
2074 ConnectionKind::Grpc { .. } => {
2075 true
2077 }
2078 }
2079 }
2080}
2081
2082#[derive(Debug)]
2084struct SkipServerVerification;
2085
2086impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
2087 fn verify_server_cert(
2088 &self,
2089 _end_entity: &CertificateDer,
2090 _intermediates: &[CertificateDer],
2091 _server_name: &RustlsServerName,
2092 _ocsp_response: &[u8],
2093 _now: rustls::pki_types::UnixTime,
2094 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
2095 Ok(rustls::client::danger::ServerCertVerified::assertion())
2096 }
2097
2098 fn verify_tls12_signature(
2099 &self,
2100 _message: &[u8],
2101 _cert: &CertificateDer,
2102 _dss: &rustls::DigitallySignedStruct,
2103 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
2104 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
2105 }
2106
2107 fn verify_tls13_signature(
2108 &self,
2109 _message: &[u8],
2110 _cert: &CertificateDer,
2111 _dss: &rustls::DigitallySignedStruct,
2112 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
2113 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
2114 }
2115
2116 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
2117 vec![
2118 rustls::SignatureScheme::RSA_PKCS1_SHA256,
2119 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
2120 rustls::SignatureScheme::ED25519,
2121 ]
2122 }
2123}
2124
2125#[cfg(test)]
2126mod tests {
2127 use super::*;
2128
2129 #[test]
2132 fn test_prepared_statement_new() {
2133 let stmt = PreparedStatement::new("MATCH (n:Person {id: $id}) RETURN n");
2134 assert_eq!(stmt.query(), "MATCH (n:Person {id: $id}) RETURN n");
2135 assert_eq!(stmt.param_names(), &["id"]);
2136 }
2137
2138 #[test]
2139 fn test_prepared_statement_multiple_params() {
2140 let stmt = PreparedStatement::new(
2141 "MATCH (p:Person {name: $name}) WHERE p.age > $min_age AND p.city = $city RETURN p",
2142 );
2143 assert!(stmt.query().contains("$name"));
2144 let names = stmt.param_names();
2145 assert_eq!(names.len(), 3);
2146 assert!(names.contains(&"name".to_string()));
2147 assert!(names.contains(&"min_age".to_string()));
2148 assert!(names.contains(&"city".to_string()));
2149 }
2150
2151 #[test]
2152 fn test_prepared_statement_no_params() {
2153 let stmt = PreparedStatement::new("MATCH (n) RETURN n LIMIT 10");
2154 assert!(stmt.param_names().is_empty());
2155 }
2156
2157 #[test]
2158 fn test_prepared_statement_duplicate_params() {
2159 let stmt =
2160 PreparedStatement::new("MATCH (a {id: $id})-[:KNOWS]->(b {id: $id}) RETURN a, b");
2161 assert_eq!(stmt.param_names(), &["id"]);
2163 }
2164
2165 #[test]
2166 fn test_prepared_statement_underscore_params() {
2167 let stmt = PreparedStatement::new("MATCH (n {user_id: $user_id}) RETURN n");
2168 assert_eq!(stmt.param_names(), &["user_id"]);
2169 }
2170
2171 #[test]
2172 fn test_prepared_statement_numeric_params() {
2173 let stmt = PreparedStatement::new("RETURN $param1, $param2, $param123");
2174 let names = stmt.param_names();
2175 assert_eq!(names.len(), 3);
2176 assert!(names.contains(&"param1".to_string()));
2177 assert!(names.contains(&"param2".to_string()));
2178 assert!(names.contains(&"param123".to_string()));
2179 }
2180
2181 #[test]
2184 fn test_plan_operation_struct() {
2185 let op = PlanOperation {
2186 op_type: "NodeScan".to_string(),
2187 description: "Scan Person nodes".to_string(),
2188 estimated_rows: Some(100),
2189 children: vec![],
2190 };
2191 assert_eq!(op.op_type, "NodeScan");
2192 assert_eq!(op.description, "Scan Person nodes");
2193 assert_eq!(op.estimated_rows, Some(100));
2194 assert!(op.children.is_empty());
2195 }
2196
2197 #[test]
2198 fn test_plan_operation_with_children() {
2199 let child = PlanOperation {
2200 op_type: "Filter".to_string(),
2201 description: "Filter by age".to_string(),
2202 estimated_rows: Some(50),
2203 children: vec![],
2204 };
2205 let parent = PlanOperation {
2206 op_type: "Projection".to_string(),
2207 description: "Project name, age".to_string(),
2208 estimated_rows: Some(50),
2209 children: vec![child],
2210 };
2211 assert_eq!(parent.children.len(), 1);
2212 assert_eq!(parent.children[0].op_type, "Filter");
2213 }
2214
2215 #[test]
2218 fn test_query_plan_struct() {
2219 let plan = QueryPlan {
2220 operations: vec![PlanOperation {
2221 op_type: "NodeScan".to_string(),
2222 description: "Full scan".to_string(),
2223 estimated_rows: Some(1000),
2224 children: vec![],
2225 }],
2226 estimated_rows: 1000,
2227 raw: serde_json::json!({"type": "plan"}),
2228 };
2229 assert_eq!(plan.operations.len(), 1);
2230 assert_eq!(plan.estimated_rows, 1000);
2231 }
2232
2233 #[test]
2236 fn test_query_profile_struct() {
2237 let plan = QueryPlan {
2238 operations: vec![],
2239 estimated_rows: 100,
2240 raw: serde_json::json!({}),
2241 };
2242 let profile = QueryProfile {
2243 plan,
2244 actual_rows: 95,
2245 execution_time_ms: 12.5,
2246 raw: serde_json::json!({"type": "profile"}),
2247 };
2248 assert_eq!(profile.actual_rows, 95);
2249 assert!((profile.execution_time_ms - 12.5).abs() < 0.001);
2250 }
2251
2252 #[test]
2255 fn test_page_struct() {
2256 let page = Page {
2257 columns: vec![Column {
2258 name: "x".to_string(),
2259 col_type: "INT".to_string(),
2260 }],
2261 rows: vec![],
2262 ordered: false,
2263 order_keys: vec![],
2264 final_page: true,
2265 };
2266 assert_eq!(page.columns.len(), 1);
2267 assert!(page.rows.is_empty());
2268 assert!(page.final_page);
2269 }
2270
2271 #[test]
2274 fn test_column_struct() {
2275 let col = Column {
2276 name: "age".to_string(),
2277 col_type: "INT".to_string(),
2278 };
2279 assert_eq!(col.name, "age");
2280 assert_eq!(col.col_type, "INT");
2281 }
2282
2283 #[test]
2286 fn test_savepoint_struct() {
2287 let sp = Savepoint {
2288 name: "before_update".to_string(),
2289 };
2290 assert_eq!(sp.name, "before_update");
2291 }
2292
2293 #[test]
2296 fn test_client_builder_defaults() {
2297 let _client = Client::new("localhost", 3141);
2298 }
2300
2301 #[test]
2302 fn test_client_builder_chain() {
2303 let _client = Client::new("example.com", 8443)
2304 .skip_verify(true)
2305 .page_size(500)
2306 .client_name("test-app")
2307 .client_version("2.0.0")
2308 .conformance("full");
2309 }
2311
2312 #[test]
2313 fn test_client_clone() {
2314 let client = Client::new("localhost", 3141).skip_verify(true);
2315 let _cloned = client.clone();
2316 }
2318
2319 #[test]
2322 fn test_parse_plan_operations_empty() {
2323 let result = serde_json::json!({});
2324 let ops = Connection::parse_plan_operations(&result);
2325 assert!(ops.is_empty());
2326 }
2327
2328 #[test]
2329 fn test_parse_plan_operations_array() {
2330 let result = serde_json::json!({
2331 "operations": [
2332 {"type": "NodeScan", "description": "Scan nodes", "estimated_rows": 100},
2333 {"type": "Filter", "description": "Apply filter", "estimated_rows": 50}
2334 ]
2335 });
2336 let ops = Connection::parse_plan_operations(&result);
2337 assert_eq!(ops.len(), 2);
2338 assert_eq!(ops[0].op_type, "NodeScan");
2339 assert_eq!(ops[1].op_type, "Filter");
2340 }
2341
2342 #[test]
2343 fn test_parse_plan_operations_single_plan() {
2344 let result = serde_json::json!({
2345 "plan": {"op_type": "FullScan", "desc": "Full table scan"}
2346 });
2347 let ops = Connection::parse_plan_operations(&result);
2348 assert_eq!(ops.len(), 1);
2349 assert_eq!(ops[0].op_type, "FullScan");
2350 assert_eq!(ops[0].description, "Full table scan");
2351 }
2352
2353 #[test]
2354 fn test_parse_single_operation() {
2355 let op_json = serde_json::json!({
2356 "type": "IndexScan",
2357 "description": "Use index on Person(name)",
2358 "estimated_rows": 25,
2359 "children": [
2360 {"type": "Filter", "description": "Filter results"}
2361 ]
2362 });
2363 let op = Connection::parse_single_operation(&op_json);
2364 assert_eq!(op.op_type, "IndexScan");
2365 assert_eq!(op.description, "Use index on Person(name)");
2366 assert_eq!(op.estimated_rows, Some(25));
2367 assert_eq!(op.children.len(), 1);
2368 assert_eq!(op.children[0].op_type, "Filter");
2369 }
2370
2371 #[test]
2372 fn test_parse_single_operation_minimal() {
2373 let op_json = serde_json::json!({});
2374 let op = Connection::parse_single_operation(&op_json);
2375 assert_eq!(op.op_type, "Unknown");
2376 assert_eq!(op.description, "");
2377 assert_eq!(op.estimated_rows, None);
2378 assert!(op.children.is_empty());
2379 }
2380
2381 #[test]
2382 fn test_parse_single_operation_alt_fields() {
2383 let op_json = serde_json::json!({
2384 "op_type": "Sort",
2385 "desc": "Sort by name ASC",
2386 "rows": 100
2387 });
2388 let op = Connection::parse_single_operation(&op_json);
2389 assert_eq!(op.op_type, "Sort");
2390 assert_eq!(op.description, "Sort by name ASC");
2391 assert_eq!(op.estimated_rows, Some(100));
2392 }
2393
2394 #[test]
2397 fn test_redact_dsn_url_with_password() {
2398 let dsn = "quic://admin:secret123@localhost:3141";
2399 let redacted = redact_dsn(dsn);
2400 assert!(redacted.contains("[REDACTED]"));
2401 assert!(!redacted.contains("secret123"));
2402 assert!(redacted.contains("admin"));
2403 assert!(redacted.contains("localhost"));
2404 }
2405
2406 #[test]
2407 fn test_redact_dsn_url_without_password() {
2408 let dsn = "quic://admin@localhost:3141";
2409 let redacted = redact_dsn(dsn);
2410 assert!(!redacted.contains("[REDACTED]"));
2411 assert!(redacted.contains("admin"));
2412 assert!(redacted.contains("localhost"));
2413 }
2414
2415 #[test]
2416 fn test_redact_dsn_url_no_auth() {
2417 let dsn = "quic://localhost:3141";
2418 let redacted = redact_dsn(dsn);
2419 assert_eq!(redacted, dsn);
2420 }
2421
2422 #[test]
2423 fn test_redact_dsn_query_param_password() {
2424 let dsn = "localhost:3141?username=admin&password=secret123";
2425 let redacted = redact_dsn(dsn);
2426 assert!(redacted.contains("[REDACTED]"));
2427 assert!(!redacted.contains("secret123"));
2428 assert!(redacted.contains("username=admin"));
2429 }
2430
2431 #[test]
2432 fn test_redact_dsn_query_param_pass() {
2433 let dsn = "localhost:3141?user=admin&pass=mysecret";
2434 let redacted = redact_dsn(dsn);
2435 assert!(redacted.contains("[REDACTED]"));
2436 assert!(!redacted.contains("mysecret"));
2437 }
2438
2439 #[test]
2440 fn test_redact_dsn_simple_no_password() {
2441 let dsn = "localhost:3141?insecure=true";
2442 let redacted = redact_dsn(dsn);
2443 assert_eq!(redacted, dsn);
2444 }
2445
2446 #[test]
2447 fn test_redact_dsn_url_with_query_and_password() {
2448 let dsn = "quic://user:pass@localhost:3141?insecure=true";
2449 let redacted = redact_dsn(dsn);
2450 assert!(redacted.contains("[REDACTED]"));
2451 assert!(!redacted.contains(":pass@"));
2452 assert!(redacted.contains("insecure=true"));
2453 }
2454
2455 #[test]
2458 fn test_client_validate_valid() {
2459 let client = Client::new("localhost", 3141);
2460 assert!(client.validate().is_ok());
2461 }
2462
2463 #[test]
2464 fn test_client_validate_valid_hostname() {
2465 let client = Client::new("geode.example.com", 3141);
2466 assert!(client.validate().is_ok());
2467 }
2468
2469 #[test]
2470 fn test_client_validate_valid_ipv4() {
2471 let client = Client::new("192.168.1.1", 8443);
2472 assert!(client.validate().is_ok());
2473 }
2474
2475 #[test]
2476 fn test_client_validate_invalid_hostname_hyphen_start() {
2477 let client = Client::new("-invalid", 3141);
2478 assert!(client.validate().is_err());
2479 }
2480
2481 #[test]
2482 fn test_client_validate_invalid_hostname_hyphen_end() {
2483 let client = Client::new("invalid-", 3141);
2484 assert!(client.validate().is_err());
2485 }
2486
2487 #[test]
2488 fn test_client_validate_invalid_port_zero() {
2489 let client = Client::new("localhost", 0);
2490 assert!(client.validate().is_err());
2491 }
2492
2493 #[test]
2494 fn test_client_validate_invalid_page_size_zero() {
2495 let client = Client::new("localhost", 3141).page_size(0);
2496 assert!(client.validate().is_err());
2497 }
2498
2499 #[test]
2500 fn test_client_validate_invalid_page_size_too_large() {
2501 let client = Client::new("localhost", 3141).page_size(200_000);
2502 assert!(client.validate().is_err());
2503 }
2504
2505 #[test]
2506 fn test_client_validate_with_all_options() {
2507 let client = Client::new("geode.example.com", 8443)
2508 .skip_verify(true)
2509 .page_size(500)
2510 .username("admin")
2511 .password("secret")
2512 .connect_timeout(15)
2513 .hello_timeout(10)
2514 .idle_timeout(60);
2515 assert!(client.validate().is_ok());
2516 }
2517
2518 #[test]
2520 fn test_client_extreme_timeout_values() {
2521 let _client = Client::new("localhost", 3141)
2523 .connect_timeout(u64::MAX)
2524 .hello_timeout(u64::MAX)
2525 .idle_timeout(u64::MAX);
2526 }
2528
2529 #[test]
2530 fn test_convert_edge_uses_type_field() {
2531 let edge = proto::EdgeValue {
2532 id: 100,
2533 from_id: 1,
2534 to_id: 2,
2535 label: "KNOWS".to_string(),
2536 properties: vec![],
2537 };
2538 let proto_val = proto::Value {
2539 kind: Some(proto::value::Kind::EdgeVal(edge)),
2540 };
2541 let val = Connection::convert_proto_value_static(&proto_val);
2542 let obj = val.as_object().unwrap();
2543 assert_eq!(obj.get("type").unwrap().as_string().unwrap(), "KNOWS");
2544 assert!(
2545 obj.get("label").is_none(),
2546 "edge should not have 'label' field"
2547 );
2548 }
2549
2550 #[test]
2551 fn test_convert_edge_uses_start_end_node() {
2552 let edge = proto::EdgeValue {
2553 id: 100,
2554 from_id: 42,
2555 to_id: 99,
2556 label: "LIKES".to_string(),
2557 properties: vec![],
2558 };
2559 let proto_val = proto::Value {
2560 kind: Some(proto::value::Kind::EdgeVal(edge)),
2561 };
2562 let val = Connection::convert_proto_value_static(&proto_val);
2563 let obj = val.as_object().unwrap();
2564 assert_eq!(obj.get("start_node").unwrap().as_int().unwrap(), 42);
2565 assert_eq!(obj.get("end_node").unwrap().as_int().unwrap(), 99);
2566 assert!(obj.get("from_id").is_none());
2567 assert!(obj.get("to_id").is_none());
2568 }
2569
2570 #[test]
2571 fn test_convert_edge_with_properties() {
2572 let edge = proto::EdgeValue {
2573 id: 100,
2574 from_id: 1,
2575 to_id: 2,
2576 label: "KNOWS".to_string(),
2577 properties: vec![proto::MapEntry {
2578 key: "since".to_string(),
2579 value: Some(proto::Value {
2580 kind: Some(proto::value::Kind::IntVal(proto::IntValue {
2581 value: 2020,
2582 kind: 1,
2583 })),
2584 }),
2585 }],
2586 };
2587 let proto_val = proto::Value {
2588 kind: Some(proto::value::Kind::EdgeVal(edge)),
2589 };
2590 let val = Connection::convert_proto_value_static(&proto_val);
2591 let obj = val.as_object().unwrap();
2592 let props = obj.get("properties").unwrap().as_object().unwrap();
2593 assert_eq!(props.get("since").unwrap().as_int().unwrap(), 2020);
2594 }
2595
2596 #[test]
2597 fn test_convert_node_fields() {
2598 let node = proto::NodeValue {
2599 id: 42,
2600 labels: vec!["Person".to_string()],
2601 properties: vec![proto::MapEntry {
2602 key: "name".to_string(),
2603 value: Some(proto::Value {
2604 kind: Some(proto::value::Kind::StringVal(proto::StringValue {
2605 value: "Alice".to_string(),
2606 kind: 1,
2607 })),
2608 }),
2609 }],
2610 };
2611 let proto_val = proto::Value {
2612 kind: Some(proto::value::Kind::NodeVal(node)),
2613 };
2614 let val = Connection::convert_proto_value_static(&proto_val);
2615 let obj = val.as_object().unwrap();
2616 assert_eq!(obj.get("id").unwrap().as_int().unwrap(), 42);
2617 let labels = obj.get("labels").unwrap().as_array().unwrap();
2618 assert_eq!(labels.len(), 1);
2619 let props = obj.get("properties").unwrap().as_object().unwrap();
2620 assert_eq!(props.get("name").unwrap().as_string().unwrap(), "Alice");
2621 }
2622}