use log::{debug, trace, warn};
use quinn::{ClientConfig, Endpoint};
use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use tokio::time::{Duration, timeout};
use crate::dsn::{Dsn, Transport};
use crate::error::{Error, Result};
use crate::proto;
use crate::types::Value;
use crate::validate;
const GEODE_ALPN: &[u8] = b"geode/1";
#[allow(dead_code)] fn redact_dsn(dsn: &str) -> String {
let mut result = dsn.to_string();
if let Some(scheme_end) = result.find("://") {
let after_scheme = scheme_end + 3;
if let Some(at_pos) = result[after_scheme..].find('@') {
let auth_section = &result[after_scheme..after_scheme + at_pos];
if let Some(colon_pos) = auth_section.find(':') {
let user = &auth_section[..colon_pos];
let rest_start = after_scheme + at_pos;
result = format!(
"{}{}:{}{}",
&result[..after_scheme],
user,
"[REDACTED]",
&result[rest_start..]
);
}
}
}
let patterns = ["password=", "pass="];
for pattern in patterns {
let lower = result.to_lowercase();
if let Some(start) = lower.find(pattern) {
let value_start = start + pattern.len();
let value_end = result[value_start..]
.find('&')
.map(|i| value_start + i)
.unwrap_or(result.len());
result = format!(
"{}[REDACTED]{}",
&result[..value_start],
&result[value_end..]
);
}
}
result
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Column {
pub name: String,
#[serde(rename = "type")]
pub col_type: String,
}
#[derive(Debug, Clone)]
pub struct Page {
pub columns: Vec<Column>,
pub rows: Vec<HashMap<String, Value>>,
pub ordered: bool,
pub order_keys: Vec<String>,
pub final_page: bool,
}
#[derive(Debug, Clone)]
pub struct Savepoint {
pub name: String,
}
#[derive(Debug, Clone)]
pub struct PreparedStatement {
query: String,
param_names: Vec<String>,
}
impl PreparedStatement {
pub fn new(query: impl Into<String>) -> Self {
let query = query.into();
let param_names = Self::extract_param_names(&query);
Self { query, param_names }
}
fn extract_param_names(query: &str) -> Vec<String> {
let mut names = Vec::new();
let mut chars = query.chars().peekable();
while let Some(c) = chars.next() {
if c == '$' {
let mut name = String::new();
while let Some(&next) = chars.peek() {
if next.is_ascii_alphanumeric() || next == '_' {
name.push(chars.next().unwrap());
} else {
break;
}
}
if !name.is_empty() && !names.contains(&name) {
names.push(name);
}
}
}
names
}
pub fn query(&self) -> &str {
&self.query
}
pub fn param_names(&self) -> &[String] {
&self.param_names
}
pub async fn execute(
&self,
conn: &mut Connection,
params: &HashMap<String, crate::types::Value>,
) -> crate::error::Result<(Page, Option<String>)> {
for name in &self.param_names {
if !params.contains_key(name) {
return Err(crate::error::Error::validation(format!(
"Missing required parameter: {}",
name
)));
}
}
conn.query_with_params(&self.query, params).await
}
}
#[derive(Debug, Clone)]
pub struct PlanOperation {
pub op_type: String,
pub description: String,
pub estimated_rows: Option<u64>,
pub children: Vec<PlanOperation>,
}
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub operations: Vec<PlanOperation>,
pub estimated_rows: u64,
pub raw: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct QueryProfile {
pub plan: QueryPlan,
pub actual_rows: u64,
pub execution_time_ms: f64,
pub raw: serde_json::Value,
}
#[derive(Clone)]
pub struct Client {
transport: Transport,
host: String,
port: u16,
tls_enabled: bool,
skip_verify: bool,
page_size: usize,
hello_name: String,
hello_ver: String,
conformance: String,
username: Option<String>,
password: Option<SecretString>,
connect_timeout_secs: u64,
hello_timeout_secs: u64,
idle_timeout_secs: u64,
}
impl Client {
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
transport: Transport::Quic,
host: host.into(),
port,
tls_enabled: true,
skip_verify: false,
page_size: 1000,
hello_name: "geode-rust".to_string(),
hello_ver: env!("CARGO_PKG_VERSION").to_string(),
conformance: "min".to_string(),
username: None,
password: None,
connect_timeout_secs: 10,
hello_timeout_secs: 5,
idle_timeout_secs: 30,
}
}
pub fn from_dsn(dsn_str: &str) -> Result<Self> {
let dsn = Dsn::parse(dsn_str)?;
Ok(Self {
transport: dsn.transport(),
host: dsn.host().to_string(),
port: dsn.port(),
tls_enabled: dsn.tls_enabled(),
skip_verify: dsn.skip_verify(),
page_size: dsn.page_size(),
hello_name: dsn.client_name().to_string(),
hello_ver: dsn.client_version().to_string(),
conformance: dsn.conformance().to_string(),
username: dsn.username().map(String::from),
password: dsn.password().map(|p| SecretString::from(p.to_string())),
connect_timeout_secs: 10,
hello_timeout_secs: 5,
idle_timeout_secs: 30,
})
}
pub fn transport(&self) -> Transport {
self.transport
}
pub fn skip_verify(mut self, skip: bool) -> Self {
self.skip_verify = skip;
self
}
pub fn page_size(mut self, size: usize) -> Self {
self.page_size = size;
self
}
pub fn client_name(mut self, name: impl Into<String>) -> Self {
self.hello_name = name.into();
self
}
pub fn client_version(mut self, version: impl Into<String>) -> Self {
self.hello_ver = version.into();
self
}
pub fn conformance(mut self, level: impl Into<String>) -> Self {
self.conformance = level.into();
self
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password = Some(SecretString::from(password.into()));
self
}
pub fn connect_timeout(mut self, seconds: u64) -> Self {
self.connect_timeout_secs = seconds.max(1);
self
}
pub fn hello_timeout(mut self, seconds: u64) -> Self {
self.hello_timeout_secs = seconds.max(1);
self
}
pub fn idle_timeout(mut self, seconds: u64) -> Self {
self.idle_timeout_secs = seconds.max(1);
self
}
pub fn validate(&self) -> Result<()> {
validate::hostname(&self.host)?;
validate::port(self.port)?;
validate::page_size(self.page_size)?;
Ok(())
}
pub async fn connect(&self) -> Result<Connection> {
self.validate()?;
let password_ref = self.password.as_ref().map(|s| s.expose_secret());
match self.transport {
Transport::Quic => {
Connection::new_quic(
&self.host,
self.port,
self.skip_verify,
self.page_size,
&self.hello_name,
&self.hello_ver,
&self.conformance,
self.username.as_deref(),
password_ref,
self.connect_timeout_secs,
self.hello_timeout_secs,
self.idle_timeout_secs,
)
.await
}
Transport::Grpc => {
#[cfg(feature = "grpc")]
{
Connection::new_grpc(
&self.host,
self.port,
self.tls_enabled,
self.skip_verify,
self.page_size,
self.username.as_deref(),
password_ref,
)
.await
}
#[cfg(not(feature = "grpc"))]
{
Err(Error::connection(
"gRPC transport requires the 'grpc' feature to be enabled",
))
}
}
}
}
}
#[allow(dead_code)]
enum ConnectionKind {
Quic {
conn: quinn::Connection,
send: quinn::SendStream,
recv: quinn::RecvStream,
buffer: Vec<u8>,
next_request_id: u64,
session_id: String,
},
#[cfg(feature = "grpc")]
Grpc { client: crate::grpc::GrpcClient },
}
pub struct Connection {
kind: ConnectionKind,
#[allow(dead_code)]
page_size: usize,
}
impl Connection {
#[allow(clippy::too_many_arguments)]
async fn new_quic(
host: &str,
port: u16,
skip_verify: bool,
page_size: usize,
hello_name: &str,
hello_ver: &str,
conformance: &str,
username: Option<&str>,
password: Option<&str>,
connect_timeout_secs: u64,
hello_timeout_secs: u64,
idle_timeout_secs: u64,
) -> Result<Self> {
let mut last_err: Option<Error> = None;
for attempt in 1..=3 {
match Self::connect_quic_once(
host,
port,
skip_verify,
page_size,
hello_name,
hello_ver,
conformance,
username,
password,
connect_timeout_secs,
hello_timeout_secs,
idle_timeout_secs,
)
.await
{
Ok(conn) => return Ok(conn),
Err(e) => {
last_err = Some(e);
if attempt < 3 {
debug!("Connection attempt {} failed, retrying...", attempt);
tokio::time::sleep(Duration::from_millis(150)).await;
}
}
}
}
Err(last_err.unwrap_or_else(|| Error::connection("Failed to connect")))
}
#[cfg(feature = "grpc")]
#[allow(clippy::too_many_arguments)]
async fn new_grpc(
host: &str,
port: u16,
tls_enabled: bool,
skip_verify: bool,
page_size: usize,
username: Option<&str>,
password: Option<&str>,
) -> Result<Self> {
use crate::dsn::Dsn;
let tls_val = if tls_enabled { "1" } else { "0" };
let dsn_str = if let (Some(user), Some(pass)) = (username, password) {
format!(
"grpc://{}:{}@{}:{}?tls={}&insecure={}",
user, pass, host, port, tls_val, skip_verify
)
} else {
format!(
"grpc://{}:{}?tls={}&insecure={}",
host, port, tls_val, skip_verify
)
};
let dsn = Dsn::parse(&dsn_str)?;
let client = crate::grpc::GrpcClient::connect(&dsn).await?;
Ok(Self {
kind: ConnectionKind::Grpc { client },
page_size,
})
}
#[allow(clippy::too_many_arguments)]
async fn connect_quic_once(
host: &str,
port: u16,
skip_verify: bool,
page_size: usize,
_hello_name: &str,
_hello_ver: &str,
_conformance: &str,
username: Option<&str>,
password: Option<&str>,
connect_timeout_secs: u64,
_hello_timeout_secs: u64,
idle_timeout_secs: u64,
) -> Result<Self> {
debug!("Creating connection to {}:{}", host, port);
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let mut client_crypto = if skip_verify {
warn!(
"TLS certificate verification DISABLED - connection to {}:{} is vulnerable to MITM attacks. \
Do NOT use skip_verify in production!",
host, port
);
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth()
} else {
let mut root_store = rustls::RootCertStore::empty();
let cert_result = rustls_native_certs::load_native_certs();
for err in &cert_result.errors {
warn!("Error loading native certificate: {:?}", err);
}
let mut certs_loaded = 0;
let mut certs_failed = 0;
for cert in cert_result.certs {
match root_store.add(cert) {
Ok(()) => certs_loaded += 1,
Err(_) => certs_failed += 1,
}
}
if certs_loaded == 0 {
return Err(Error::tls(
"No system root certificates found. TLS verification cannot proceed. \
Either install system CA certificates or use skip_verify(true) for development only.",
));
}
debug!(
"Loaded {} system root certificates ({} failed to parse)",
certs_loaded, certs_failed
);
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_root_certificates(root_store)
.with_no_client_auth()
};
client_crypto.alpn_protocols = vec![GEODE_ALPN.to_vec()];
let mut client_config = ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)
.map_err(|e| Error::connection(format!("Failed to create QUIC config: {}", e)))?,
));
let mut transport = quinn::TransportConfig::default();
let idle_timeout = Duration::from_secs(idle_timeout_secs.min(146_000 * 365 * 24 * 3600));
transport.max_idle_timeout(Some(idle_timeout.try_into().map_err(|_| {
Error::connection("Idle timeout value too large for QUIC protocol")
})?));
transport.keep_alive_interval(Some(Duration::from_secs(5)));
client_config.transport_config(Arc::new(transport));
let mut endpoint = Endpoint::client(
"0.0.0.0:0"
.parse()
.expect("0.0.0.0:0 is a valid socket address"),
)
.map_err(|e| Error::connection(format!("Failed to create endpoint: {}", e)))?;
endpoint.set_default_client_config(client_config);
let mut resolved_addrs = format!("{}:{}", host, port)
.to_socket_addrs()
.map_err(|e| {
Error::connection(format!(
"Failed to resolve address {}:{} - {}",
host, port, e
))
})?;
let server_addr: SocketAddr = resolved_addrs
.find(|addr| matches!(addr, SocketAddr::V4(_) | SocketAddr::V6(_)))
.ok_or_else(|| Error::connection("Invalid address: could not resolve host"))?;
debug!("Connecting to {}", server_addr);
let server_name = if skip_verify {
"localhost" } else {
host
};
trace!("Using server name for SNI: {}", server_name);
let conn = timeout(
Duration::from_secs(connect_timeout_secs),
endpoint
.connect(server_addr, server_name)
.map_err(|e| Error::connection(format!("Connection failed: {}", e)))?,
)
.await
.map_err(|_| Error::connection("Connection timeout"))?
.map_err(|e| Error::connection(format!("Failed to establish connection: {}", e)))?;
debug!("Connection established to {}:{}", host, port);
let (mut send, mut recv) = conn
.open_bi()
.await
.map_err(|e| Error::connection(format!("Failed to open stream: {}", e)))?;
let hello_req = proto::HelloRequest {
username: username.unwrap_or("").to_string(),
password: password.unwrap_or("").to_string(),
tenant_id: None,
client_name: String::new(),
client_version: String::new(),
wanted_conformance: String::new(),
};
let msg = proto::QuicClientMessage {
msg: Some(proto::quic_client_message::Msg::Hello(hello_req)),
};
let data = proto::encode_with_length_prefix(&msg);
send.write_all(&data)
.await
.map_err(|e| Error::connection(format!("Failed to send HELLO: {}", e)))?;
let mut length_buf = [0u8; 4];
timeout(Duration::from_secs(5), recv.read_exact(&mut length_buf))
.await
.map_err(|_| Error::connection("HELLO response timeout"))?
.map_err(|e| {
Error::connection(format!("Failed to read HELLO response length: {}", e))
})?;
let msg_len = u32::from_be_bytes(length_buf) as usize;
let mut msg_buf = vec![0u8; msg_len];
recv.read_exact(&mut msg_buf)
.await
.map_err(|e| Error::connection(format!("Failed to read HELLO response body: {}", e)))?;
let hello_response = proto::decode_quic_server_message(&msg_buf)?;
let session_id = match hello_response.msg {
Some(proto::quic_server_message::Msg::Hello(ref hello_resp)) => {
if !hello_resp.success {
return Err(Error::connection(format!(
"Authentication failed: {}",
hello_resp.error_message
)));
}
hello_resp.session_id.clone()
}
_ => {
return Err(Error::connection("Expected HELLO response"));
}
};
debug!("HELLO handshake complete, session_id={}", session_id);
Ok(Self {
kind: ConnectionKind::Quic {
conn,
send,
recv,
buffer: Vec::new(),
next_request_id: 1,
session_id,
},
page_size,
})
}
async fn send_proto_quic(
send: &mut quinn::SendStream,
msg: &proto::QuicClientMessage,
) -> Result<()> {
let data = proto::encode_with_length_prefix(msg);
send.write_all(&data)
.await
.map_err(|e| Error::connection(format!("Failed to send message: {}", e)))?;
Ok(())
}
async fn read_proto_quic(
recv: &mut quinn::RecvStream,
timeout_secs: u64,
) -> Result<proto::QuicServerMessage> {
timeout(Duration::from_secs(timeout_secs), async {
let mut length_buf = [0u8; 4];
recv.read_exact(&mut length_buf)
.await
.map_err(|e| Error::connection(format!("Failed to read response length: {}", e)))?;
let msg_len = u32::from_be_bytes(length_buf) as usize;
let mut msg_buf = vec![0u8; msg_len];
recv.read_exact(&mut msg_buf)
.await
.map_err(|e| Error::connection(format!("Failed to read response body: {}", e)))?;
proto::decode_quic_server_message(&msg_buf)
})
.await
.map_err(|_| Error::timeout())?
}
async fn try_read_proto_quic(
recv: &mut quinn::RecvStream,
) -> Result<Option<proto::QuicServerMessage>> {
let read_result = timeout(Duration::from_millis(5000), async {
let mut length_buf = [0u8; 4];
recv.read_exact(&mut length_buf)
.await
.map_err(|e| Error::connection(format!("Failed to read response: {}", e)))?;
let msg_len = u32::from_be_bytes(length_buf) as usize;
let mut msg_buf = vec![0u8; msg_len];
recv.read_exact(&mut msg_buf)
.await
.map_err(|e| Error::connection(format!("Failed to read response body: {}", e)))?;
proto::decode_quic_server_message(&msg_buf)
})
.await;
match read_result {
Ok(Ok(msg)) => Ok(Some(msg)),
Ok(Err(e)) => Err(e),
Err(_) => Ok(None), }
}
fn parse_proto_rows_static(
proto_rows: &[proto::Row],
columns: &[Column],
) -> Result<Vec<HashMap<String, Value>>> {
let mut rows = Vec::new();
for proto_row in proto_rows {
let mut row = HashMap::new();
for (i, col) in columns.iter().enumerate() {
let value = if i < proto_row.values.len() {
Self::convert_proto_value_static(&proto_row.values[i])
} else {
Value::null()
};
row.insert(col.name.clone(), value);
}
rows.push(row);
}
Ok(rows)
}
fn convert_proto_value_static(proto_val: &proto::Value) -> Value {
match &proto_val.kind {
Some(proto::value::Kind::NullVal(_)) => Value::null(),
Some(proto::value::Kind::StringVal(s)) => Value::string(s.value.clone()),
Some(proto::value::Kind::IntVal(i)) => Value::int(i.value),
Some(proto::value::Kind::DoubleVal(d)) => {
Value::decimal(rust_decimal::Decimal::from_f64_retain(d.value).unwrap_or_default())
}
Some(proto::value::Kind::BoolVal(b)) => Value::bool(*b),
Some(proto::value::Kind::ListVal(list)) => {
let values: Vec<Value> = list
.values
.iter()
.map(Self::convert_proto_value_static)
.collect();
Value::array(values)
}
Some(proto::value::Kind::MapVal(map)) => {
let mut obj = std::collections::HashMap::new();
for entry in &map.entries {
let val = entry
.value
.as_ref()
.map(Self::convert_proto_value_static)
.unwrap_or_else(Value::null);
obj.insert(entry.key.clone(), val);
}
Value::object(obj)
}
Some(proto::value::Kind::NodeVal(node)) => {
let mut obj = std::collections::HashMap::new();
obj.insert("id".to_string(), Value::int(node.id as i64));
let labels: Vec<Value> = node
.labels
.iter()
.map(|l| Value::string(l.clone()))
.collect();
obj.insert("labels".to_string(), Value::array(labels));
let mut props = std::collections::HashMap::new();
for entry in &node.properties {
let val = entry
.value
.as_ref()
.map(Self::convert_proto_value_static)
.unwrap_or_else(Value::null);
props.insert(entry.key.clone(), val);
}
obj.insert("properties".to_string(), Value::object(props));
Value::object(obj)
}
Some(proto::value::Kind::EdgeVal(edge)) => {
let mut obj = std::collections::HashMap::new();
obj.insert("id".to_string(), Value::int(edge.id as i64));
obj.insert("start_node".to_string(), Value::int(edge.from_id as i64));
obj.insert("end_node".to_string(), Value::int(edge.to_id as i64));
obj.insert("type".to_string(), Value::string(edge.label.clone()));
let mut props = std::collections::HashMap::new();
for entry in &edge.properties {
let val = entry
.value
.as_ref()
.map(Self::convert_proto_value_static)
.unwrap_or_else(Value::null);
props.insert(entry.key.clone(), val);
}
obj.insert("properties".to_string(), Value::object(props));
Value::object(obj)
}
Some(proto::value::Kind::DecimalVal(d)) => {
if let Ok(dec) = d.coeff.parse::<rust_decimal::Decimal>() {
Value::decimal(dec)
} else {
Value::string(d.orig_repr.clone())
}
}
Some(proto::value::Kind::BytesVal(b)) => {
Value::string(format!("\\x{}", hex::encode(&b.value)))
}
_ => Value::null(),
}
}
async fn send_begin_quic(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
session_id: &str,
) -> Result<()> {
let msg = proto::QuicClientMessage {
msg: Some(proto::quic_client_message::Msg::Begin(
proto::BeginRequest {
session_id: session_id.to_string(),
..Default::default()
},
)),
};
Self::send_proto_quic(send, &msg).await?;
let resp = Self::read_proto_quic(recv, 5).await?;
if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Begin(_))) {
return Err(Error::protocol("Expected BEGIN response"));
}
Ok(())
}
async fn send_commit_quic(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
session_id: &str,
) -> Result<()> {
let msg = proto::QuicClientMessage {
msg: Some(proto::quic_client_message::Msg::Commit(
proto::CommitRequest {
session_id: session_id.to_string(),
},
)),
};
Self::send_proto_quic(send, &msg).await?;
let resp = Self::read_proto_quic(recv, 5).await?;
if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Commit(_))) {
return Err(Error::protocol("Expected COMMIT response"));
}
Ok(())
}
async fn send_rollback_quic(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
session_id: &str,
) -> Result<()> {
let msg = proto::QuicClientMessage {
msg: Some(proto::quic_client_message::Msg::Rollback(
proto::RollbackRequest {
session_id: session_id.to_string(),
},
)),
};
Self::send_proto_quic(send, &msg).await?;
let resp = Self::read_proto_quic(recv, 5).await?;
if !matches!(resp.msg, Some(proto::quic_server_message::Msg::Rollback(_))) {
return Err(Error::protocol("Expected ROLLBACK response"));
}
Ok(())
}
pub async fn query(&mut self, gql: &str) -> Result<(Page, Option<String>)> {
self.query_with_params(gql, &HashMap::new()).await
}
pub async fn query_with_params(
&mut self,
gql: &str,
params: &HashMap<String, Value>,
) -> Result<(Page, Option<String>)> {
match &mut self.kind {
ConnectionKind::Quic {
send,
recv,
session_id,
..
} => Self::query_with_params_quic(send, recv, gql, params, session_id).await,
#[cfg(feature = "grpc")]
ConnectionKind::Grpc { client } => client.query_with_params(gql, params).await,
}
}
async fn query_with_params_quic(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
gql: &str,
params: &HashMap<String, Value>,
session_id: &str,
) -> Result<(Page, Option<String>)> {
let (page, cursor) =
Self::query_with_params_quic_inner(send, recv, gql, params, session_id).await?;
if !page.final_page {
let mut all_rows = page.rows;
let columns = page.columns;
let mut ordered = page.ordered;
let mut order_keys = page.order_keys;
let mut request_id: u64 = 0;
loop {
request_id += 1;
let pull_req = proto::QuicClientMessage {
msg: Some(proto::quic_client_message::Msg::Pull(proto::PullRequest {
request_id,
page_size: 1000,
session_id: String::new(),
})),
};
Self::send_proto_quic(send, &pull_req).await?;
let resp = Self::read_proto_quic(recv, 30).await?;
let exec_resp = match &resp.msg {
Some(proto::quic_server_message::Msg::Pull(pull)) => pull.response.as_ref(),
Some(proto::quic_server_message::Msg::Execute(e)) => Some(e),
_ => None,
};
let exec_resp = match exec_resp {
Some(e) => e,
None => break,
};
if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload
{
return Err(Error::Query {
code: err.code.clone(),
message: err.message.clone(),
});
}
if let Some(proto::execution_response::Payload::Page(ref page_data)) =
exec_resp.payload
{
let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
all_rows.extend(rows);
ordered = page_data.ordered;
order_keys = page_data.order_keys.clone();
if page_data.r#final {
break;
}
} else {
break;
}
}
let final_page = Page {
columns,
rows: all_rows,
ordered,
order_keys,
final_page: true,
};
return Ok((final_page, cursor));
}
Ok((page, cursor))
}
async fn query_with_params_quic_inner(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
gql: &str,
params: &HashMap<String, Value>,
session_id: &str,
) -> Result<(Page, Option<String>)> {
let params_proto: Vec<proto::Param> = params
.iter()
.map(|(k, v)| proto::Param {
name: k.clone(),
value: Some(v.to_proto_value()),
})
.collect();
let exec_req = proto::ExecuteRequest {
session_id: session_id.to_string(),
query: gql.to_string(),
params: params_proto,
};
let msg = proto::QuicClientMessage {
msg: Some(proto::quic_client_message::Msg::Execute(exec_req)),
};
Self::send_proto_quic(send, &msg)
.await
.map_err(|e| Error::query(format!("{}", e)))?;
let resp = Self::read_proto_quic(recv, 10).await?;
let exec_resp = match resp.msg {
Some(proto::quic_server_message::Msg::Execute(e)) => e,
_ => return Err(Error::protocol("Expected Execute response")),
};
if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload {
let _ = Self::try_read_proto_quic(recv).await;
return Err(Error::Query {
code: err.code.clone(),
message: err.message.clone(),
});
}
let columns: Vec<Column> = match exec_resp.payload {
Some(proto::execution_response::Payload::Schema(ref s)) => s
.columns
.iter()
.map(|c| Column {
name: c.name.clone(),
col_type: c.r#type.clone(),
})
.collect(),
_ => Vec::new(),
};
trace!("Schema columns: {:?}", columns);
if let Some(inline_resp) = Self::try_read_proto_quic(recv).await? {
if let Some(proto::quic_server_message::Msg::Execute(inline_exec)) = inline_resp.msg {
if let Some(proto::execution_response::Payload::Error(ref err)) =
inline_exec.payload
{
return Err(Error::Query {
code: err.code.clone(),
message: err.message.clone(),
});
}
if let Some(proto::execution_response::Payload::Page(ref page_data)) =
inline_exec.payload
{
let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
let page = Page {
columns,
rows,
ordered: page_data.ordered,
order_keys: page_data.order_keys.clone(),
final_page: page_data.r#final,
};
return Ok((page, None));
}
let page = Page {
columns,
rows: Vec::new(),
ordered: false,
order_keys: Vec::new(),
final_page: true,
};
return Ok((page, None));
}
}
if let Some(proto::execution_response::Payload::Page(ref page_data)) = exec_resp.payload {
let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
let page = Page {
columns,
rows,
ordered: page_data.ordered,
order_keys: page_data.order_keys.clone(),
final_page: page_data.r#final,
};
return Ok((page, None));
}
let resp = Self::read_proto_quic(recv, 30).await?;
if let Some(proto::quic_server_message::Msg::Execute(exec_resp)) = resp.msg {
if let Some(proto::execution_response::Payload::Error(ref err)) = exec_resp.payload {
return Err(Error::Query {
code: err.code.clone(),
message: err.message.clone(),
});
}
if let Some(proto::execution_response::Payload::Page(ref page_data)) = exec_resp.payload
{
let rows = Self::parse_proto_rows_static(&page_data.rows, &columns)?;
let page = Page {
columns,
rows,
ordered: page_data.ordered,
order_keys: page_data.order_keys.clone(),
final_page: page_data.r#final,
};
return Ok((page, None));
}
}
let page = Page {
columns,
rows: Vec::new(),
ordered: false,
order_keys: Vec::new(),
final_page: true,
};
Ok((page, None))
}
pub fn query_sync(
&mut self,
gql: &str,
params: Option<HashMap<String, serde_json::Value>>,
) -> Result<Page> {
let params_map = params.unwrap_or_default();
let params_typed: HashMap<String, Value> = params_map
.into_iter()
.map(|(k, v)| {
let typed_val = crate::types::Value::from_json(v);
(k, typed_val)
})
.collect();
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
let (page, _cursor) =
handle.block_on(self.query_with_params(gql, ¶ms_typed))?;
Ok(page)
}
Err(_) => {
let rt = tokio::runtime::Runtime::new()
.map_err(|e| Error::query(format!("Failed to create runtime: {}", e)))?;
let (page, _cursor) = rt.block_on(self.query_with_params(gql, ¶ms_typed))?;
Ok(page)
}
}
}
pub async fn begin(&mut self) -> Result<()> {
match &mut self.kind {
ConnectionKind::Quic {
send,
recv,
session_id,
..
} => Self::send_begin_quic(send, recv, session_id).await,
#[cfg(feature = "grpc")]
ConnectionKind::Grpc { client } => client.begin().await,
}
}
pub async fn commit(&mut self) -> Result<()> {
match &mut self.kind {
ConnectionKind::Quic {
send,
recv,
session_id,
..
} => Self::send_commit_quic(send, recv, session_id).await,
#[cfg(feature = "grpc")]
ConnectionKind::Grpc { client } => client.commit().await,
}
}
pub async fn rollback(&mut self) -> Result<()> {
match &mut self.kind {
ConnectionKind::Quic {
send,
recv,
session_id,
..
} => Self::send_rollback_quic(send, recv, session_id).await,
#[cfg(feature = "grpc")]
ConnectionKind::Grpc { client } => client.rollback().await,
}
}
pub fn prepare(&self, query: &str) -> Result<PreparedStatement> {
Ok(PreparedStatement::new(query))
}
pub async fn explain(&mut self, gql: &str) -> Result<QueryPlan> {
let explain_query = format!("EXPLAIN {}", gql);
let (_page, _) = self.query(&explain_query).await?;
Ok(QueryPlan {
operations: Vec::new(),
estimated_rows: 0,
raw: serde_json::json!({}),
})
}
pub async fn profile(&mut self, gql: &str) -> Result<QueryProfile> {
let profile_query = format!("PROFILE {}", gql);
let (page, _) = self.query(&profile_query).await?;
let plan = QueryPlan {
operations: Vec::new(),
estimated_rows: 0,
raw: serde_json::json!({}),
};
Ok(QueryProfile {
plan,
actual_rows: page.rows.len() as u64,
execution_time_ms: 0.0,
raw: serde_json::json!({}),
})
}
pub async fn batch(
&mut self,
queries: &[(&str, Option<&HashMap<String, Value>>)],
) -> Result<Vec<Page>> {
let mut results = Vec::with_capacity(queries.len());
for (query, params) in queries {
let (page, _) = match params {
Some(p) => self.query_with_params(query, p).await?,
None => self.query(query).await?,
};
results.push(page);
}
Ok(results)
}
#[allow(dead_code)]
fn parse_plan_operations(result: &serde_json::Value) -> Vec<PlanOperation> {
let mut operations = Vec::new();
if let Some(ops) = result.get("operations").and_then(|o| o.as_array()) {
for op in ops {
operations.push(Self::parse_single_operation(op));
}
} else if let Some(plan) = result.get("plan") {
operations.push(Self::parse_single_operation(plan));
}
operations
}
#[allow(dead_code)]
fn parse_single_operation(op: &serde_json::Value) -> PlanOperation {
let op_type = op
.get("type")
.or_else(|| op.get("op_type"))
.and_then(|t| t.as_str())
.unwrap_or("Unknown")
.to_string();
let description = op
.get("description")
.or_else(|| op.get("desc"))
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let estimated_rows = op
.get("estimated_rows")
.or_else(|| op.get("rows"))
.and_then(|r| r.as_u64());
let children = op
.get("children")
.and_then(|c| c.as_array())
.map(|arr| arr.iter().map(Self::parse_single_operation).collect())
.unwrap_or_default();
PlanOperation {
op_type,
description,
estimated_rows,
children,
}
}
pub fn close(&mut self) -> Result<()> {
match &mut self.kind {
ConnectionKind::Quic { conn, .. } => {
conn.close(0u32.into(), b"client closing");
Ok(())
}
#[cfg(feature = "grpc")]
ConnectionKind::Grpc { client } => client.close(),
}
}
pub fn is_healthy(&self) -> bool {
match &self.kind {
ConnectionKind::Quic { conn, .. } => {
conn.close_reason().is_none()
}
#[cfg(feature = "grpc")]
ConnectionKind::Grpc { .. } => {
true
}
}
}
}
#[derive(Debug)]
struct SkipServerVerification;
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer,
_intermediates: &[CertificateDer],
_server_name: &RustlsServerName,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ED25519,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepared_statement_new() {
let stmt = PreparedStatement::new("MATCH (n:Person {id: $id}) RETURN n");
assert_eq!(stmt.query(), "MATCH (n:Person {id: $id}) RETURN n");
assert_eq!(stmt.param_names(), &["id"]);
}
#[test]
fn test_prepared_statement_multiple_params() {
let stmt = PreparedStatement::new(
"MATCH (p:Person {name: $name}) WHERE p.age > $min_age AND p.city = $city RETURN p",
);
assert!(stmt.query().contains("$name"));
let names = stmt.param_names();
assert_eq!(names.len(), 3);
assert!(names.contains(&"name".to_string()));
assert!(names.contains(&"min_age".to_string()));
assert!(names.contains(&"city".to_string()));
}
#[test]
fn test_prepared_statement_no_params() {
let stmt = PreparedStatement::new("MATCH (n) RETURN n LIMIT 10");
assert!(stmt.param_names().is_empty());
}
#[test]
fn test_prepared_statement_duplicate_params() {
let stmt =
PreparedStatement::new("MATCH (a {id: $id})-[:KNOWS]->(b {id: $id}) RETURN a, b");
assert_eq!(stmt.param_names(), &["id"]);
}
#[test]
fn test_prepared_statement_underscore_params() {
let stmt = PreparedStatement::new("MATCH (n {user_id: $user_id}) RETURN n");
assert_eq!(stmt.param_names(), &["user_id"]);
}
#[test]
fn test_prepared_statement_numeric_params() {
let stmt = PreparedStatement::new("RETURN $param1, $param2, $param123");
let names = stmt.param_names();
assert_eq!(names.len(), 3);
assert!(names.contains(&"param1".to_string()));
assert!(names.contains(&"param2".to_string()));
assert!(names.contains(&"param123".to_string()));
}
#[test]
fn test_plan_operation_struct() {
let op = PlanOperation {
op_type: "NodeScan".to_string(),
description: "Scan Person nodes".to_string(),
estimated_rows: Some(100),
children: vec![],
};
assert_eq!(op.op_type, "NodeScan");
assert_eq!(op.description, "Scan Person nodes");
assert_eq!(op.estimated_rows, Some(100));
assert!(op.children.is_empty());
}
#[test]
fn test_plan_operation_with_children() {
let child = PlanOperation {
op_type: "Filter".to_string(),
description: "Filter by age".to_string(),
estimated_rows: Some(50),
children: vec![],
};
let parent = PlanOperation {
op_type: "Projection".to_string(),
description: "Project name, age".to_string(),
estimated_rows: Some(50),
children: vec![child],
};
assert_eq!(parent.children.len(), 1);
assert_eq!(parent.children[0].op_type, "Filter");
}
#[test]
fn test_query_plan_struct() {
let plan = QueryPlan {
operations: vec![PlanOperation {
op_type: "NodeScan".to_string(),
description: "Full scan".to_string(),
estimated_rows: Some(1000),
children: vec![],
}],
estimated_rows: 1000,
raw: serde_json::json!({"type": "plan"}),
};
assert_eq!(plan.operations.len(), 1);
assert_eq!(plan.estimated_rows, 1000);
}
#[test]
fn test_query_profile_struct() {
let plan = QueryPlan {
operations: vec![],
estimated_rows: 100,
raw: serde_json::json!({}),
};
let profile = QueryProfile {
plan,
actual_rows: 95,
execution_time_ms: 12.5,
raw: serde_json::json!({"type": "profile"}),
};
assert_eq!(profile.actual_rows, 95);
assert!((profile.execution_time_ms - 12.5).abs() < 0.001);
}
#[test]
fn test_page_struct() {
let page = Page {
columns: vec![Column {
name: "x".to_string(),
col_type: "INT".to_string(),
}],
rows: vec![],
ordered: false,
order_keys: vec![],
final_page: true,
};
assert_eq!(page.columns.len(), 1);
assert!(page.rows.is_empty());
assert!(page.final_page);
}
#[test]
fn test_column_struct() {
let col = Column {
name: "age".to_string(),
col_type: "INT".to_string(),
};
assert_eq!(col.name, "age");
assert_eq!(col.col_type, "INT");
}
#[test]
fn test_savepoint_struct() {
let sp = Savepoint {
name: "before_update".to_string(),
};
assert_eq!(sp.name, "before_update");
}
#[test]
fn test_client_builder_defaults() {
let _client = Client::new("localhost", 3141);
}
#[test]
fn test_client_builder_chain() {
let _client = Client::new("example.com", 8443)
.skip_verify(true)
.page_size(500)
.client_name("test-app")
.client_version("2.0.0")
.conformance("full");
}
#[test]
fn test_client_clone() {
let client = Client::new("localhost", 3141).skip_verify(true);
let _cloned = client.clone();
}
#[test]
fn test_parse_plan_operations_empty() {
let result = serde_json::json!({});
let ops = Connection::parse_plan_operations(&result);
assert!(ops.is_empty());
}
#[test]
fn test_parse_plan_operations_array() {
let result = serde_json::json!({
"operations": [
{"type": "NodeScan", "description": "Scan nodes", "estimated_rows": 100},
{"type": "Filter", "description": "Apply filter", "estimated_rows": 50}
]
});
let ops = Connection::parse_plan_operations(&result);
assert_eq!(ops.len(), 2);
assert_eq!(ops[0].op_type, "NodeScan");
assert_eq!(ops[1].op_type, "Filter");
}
#[test]
fn test_parse_plan_operations_single_plan() {
let result = serde_json::json!({
"plan": {"op_type": "FullScan", "desc": "Full table scan"}
});
let ops = Connection::parse_plan_operations(&result);
assert_eq!(ops.len(), 1);
assert_eq!(ops[0].op_type, "FullScan");
assert_eq!(ops[0].description, "Full table scan");
}
#[test]
fn test_parse_single_operation() {
let op_json = serde_json::json!({
"type": "IndexScan",
"description": "Use index on Person(name)",
"estimated_rows": 25,
"children": [
{"type": "Filter", "description": "Filter results"}
]
});
let op = Connection::parse_single_operation(&op_json);
assert_eq!(op.op_type, "IndexScan");
assert_eq!(op.description, "Use index on Person(name)");
assert_eq!(op.estimated_rows, Some(25));
assert_eq!(op.children.len(), 1);
assert_eq!(op.children[0].op_type, "Filter");
}
#[test]
fn test_parse_single_operation_minimal() {
let op_json = serde_json::json!({});
let op = Connection::parse_single_operation(&op_json);
assert_eq!(op.op_type, "Unknown");
assert_eq!(op.description, "");
assert_eq!(op.estimated_rows, None);
assert!(op.children.is_empty());
}
#[test]
fn test_parse_single_operation_alt_fields() {
let op_json = serde_json::json!({
"op_type": "Sort",
"desc": "Sort by name ASC",
"rows": 100
});
let op = Connection::parse_single_operation(&op_json);
assert_eq!(op.op_type, "Sort");
assert_eq!(op.description, "Sort by name ASC");
assert_eq!(op.estimated_rows, Some(100));
}
#[test]
fn test_redact_dsn_url_with_password() {
let dsn = "quic://admin:secret123@localhost:3141";
let redacted = redact_dsn(dsn);
assert!(redacted.contains("[REDACTED]"));
assert!(!redacted.contains("secret123"));
assert!(redacted.contains("admin"));
assert!(redacted.contains("localhost"));
}
#[test]
fn test_redact_dsn_url_without_password() {
let dsn = "quic://admin@localhost:3141";
let redacted = redact_dsn(dsn);
assert!(!redacted.contains("[REDACTED]"));
assert!(redacted.contains("admin"));
assert!(redacted.contains("localhost"));
}
#[test]
fn test_redact_dsn_url_no_auth() {
let dsn = "quic://localhost:3141";
let redacted = redact_dsn(dsn);
assert_eq!(redacted, dsn);
}
#[test]
fn test_redact_dsn_query_param_password() {
let dsn = "localhost:3141?username=admin&password=secret123";
let redacted = redact_dsn(dsn);
assert!(redacted.contains("[REDACTED]"));
assert!(!redacted.contains("secret123"));
assert!(redacted.contains("username=admin"));
}
#[test]
fn test_redact_dsn_query_param_pass() {
let dsn = "localhost:3141?user=admin&pass=mysecret";
let redacted = redact_dsn(dsn);
assert!(redacted.contains("[REDACTED]"));
assert!(!redacted.contains("mysecret"));
}
#[test]
fn test_redact_dsn_simple_no_password() {
let dsn = "localhost:3141?insecure=true";
let redacted = redact_dsn(dsn);
assert_eq!(redacted, dsn);
}
#[test]
fn test_redact_dsn_url_with_query_and_password() {
let dsn = "quic://user:pass@localhost:3141?insecure=true";
let redacted = redact_dsn(dsn);
assert!(redacted.contains("[REDACTED]"));
assert!(!redacted.contains(":pass@"));
assert!(redacted.contains("insecure=true"));
}
#[test]
fn test_client_validate_valid() {
let client = Client::new("localhost", 3141);
assert!(client.validate().is_ok());
}
#[test]
fn test_client_validate_valid_hostname() {
let client = Client::new("geode.example.com", 3141);
assert!(client.validate().is_ok());
}
#[test]
fn test_client_validate_valid_ipv4() {
let client = Client::new("192.168.1.1", 8443);
assert!(client.validate().is_ok());
}
#[test]
fn test_client_validate_invalid_hostname_hyphen_start() {
let client = Client::new("-invalid", 3141);
assert!(client.validate().is_err());
}
#[test]
fn test_client_validate_invalid_hostname_hyphen_end() {
let client = Client::new("invalid-", 3141);
assert!(client.validate().is_err());
}
#[test]
fn test_client_validate_invalid_port_zero() {
let client = Client::new("localhost", 0);
assert!(client.validate().is_err());
}
#[test]
fn test_client_validate_invalid_page_size_zero() {
let client = Client::new("localhost", 3141).page_size(0);
assert!(client.validate().is_err());
}
#[test]
fn test_client_validate_invalid_page_size_too_large() {
let client = Client::new("localhost", 3141).page_size(200_000);
assert!(client.validate().is_err());
}
#[test]
fn test_client_validate_with_all_options() {
let client = Client::new("geode.example.com", 8443)
.skip_verify(true)
.page_size(500)
.username("admin")
.password("secret")
.connect_timeout(15)
.hello_timeout(10)
.idle_timeout(60);
assert!(client.validate().is_ok());
}
#[test]
fn test_client_extreme_timeout_values() {
let _client = Client::new("localhost", 3141)
.connect_timeout(u64::MAX)
.hello_timeout(u64::MAX)
.idle_timeout(u64::MAX);
}
#[test]
fn test_convert_edge_uses_type_field() {
let edge = proto::EdgeValue {
id: 100,
from_id: 1,
to_id: 2,
label: "KNOWS".to_string(),
properties: vec![],
};
let proto_val = proto::Value {
kind: Some(proto::value::Kind::EdgeVal(edge)),
};
let val = Connection::convert_proto_value_static(&proto_val);
let obj = val.as_object().unwrap();
assert_eq!(obj.get("type").unwrap().as_string().unwrap(), "KNOWS");
assert!(
obj.get("label").is_none(),
"edge should not have 'label' field"
);
}
#[test]
fn test_convert_edge_uses_start_end_node() {
let edge = proto::EdgeValue {
id: 100,
from_id: 42,
to_id: 99,
label: "LIKES".to_string(),
properties: vec![],
};
let proto_val = proto::Value {
kind: Some(proto::value::Kind::EdgeVal(edge)),
};
let val = Connection::convert_proto_value_static(&proto_val);
let obj = val.as_object().unwrap();
assert_eq!(obj.get("start_node").unwrap().as_int().unwrap(), 42);
assert_eq!(obj.get("end_node").unwrap().as_int().unwrap(), 99);
assert!(obj.get("from_id").is_none());
assert!(obj.get("to_id").is_none());
}
#[test]
fn test_convert_edge_with_properties() {
let edge = proto::EdgeValue {
id: 100,
from_id: 1,
to_id: 2,
label: "KNOWS".to_string(),
properties: vec![proto::MapEntry {
key: "since".to_string(),
value: Some(proto::Value {
kind: Some(proto::value::Kind::IntVal(proto::IntValue {
value: 2020,
kind: 1,
})),
}),
}],
};
let proto_val = proto::Value {
kind: Some(proto::value::Kind::EdgeVal(edge)),
};
let val = Connection::convert_proto_value_static(&proto_val);
let obj = val.as_object().unwrap();
let props = obj.get("properties").unwrap().as_object().unwrap();
assert_eq!(props.get("since").unwrap().as_int().unwrap(), 2020);
}
#[test]
fn test_convert_node_fields() {
let node = proto::NodeValue {
id: 42,
labels: vec!["Person".to_string()],
properties: vec![proto::MapEntry {
key: "name".to_string(),
value: Some(proto::Value {
kind: Some(proto::value::Kind::StringVal(proto::StringValue {
value: "Alice".to_string(),
kind: 1,
})),
}),
}],
};
let proto_val = proto::Value {
kind: Some(proto::value::Kind::NodeVal(node)),
};
let val = Connection::convert_proto_value_static(&proto_val);
let obj = val.as_object().unwrap();
assert_eq!(obj.get("id").unwrap().as_int().unwrap(), 42);
let labels = obj.get("labels").unwrap().as_array().unwrap();
assert_eq!(labels.len(), 1);
let props = obj.get("properties").unwrap().as_object().unwrap();
assert_eq!(props.get("name").unwrap().as_string().unwrap(), "Alice");
}
}