use std::sync::atomic::{AtomicU64, Ordering};
use nodedb_types::error::{NodeDbError, NodeDbResult};
use nodedb_types::protocol::{
AuthMethod, FRAME_HEADER_LEN, MAX_FRAME_SIZE, NativeRequest, NativeResponse, OpCode,
RequestFields, ResponseStatus, TextFields,
};
use nodedb_types::result::QueryResult;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
enum ConnStream {
Plain(TcpStream),
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
}
impl AsyncRead for ConnStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
ConnStream::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
ConnStream::Tls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for ConnStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
match self.get_mut() {
ConnStream::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
ConnStream::Tls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
ConnStream::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
ConnStream::Tls(s) => std::pin::Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
ConnStream::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
ConnStream::Tls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TlsConfig {
pub enabled: bool,
pub ca_cert_path: Option<std::path::PathBuf>,
pub server_name: Option<String>,
pub danger_accept_invalid_certs: bool,
}
pub struct NativeConnection {
stream: ConnStream,
seq: AtomicU64,
authenticated: bool,
}
impl NativeConnection {
pub async fn connect(addr: &str) -> NodeDbResult<Self> {
let stream = TcpStream::connect(addr)
.await
.map_err(|e| NodeDbError::sync_connection_failed(format!("connect to {addr}: {e}")))?;
Ok(Self {
stream: ConnStream::Plain(stream),
seq: AtomicU64::new(1),
authenticated: false,
})
}
pub async fn connect_tls(addr: &str, tls: &TlsConfig) -> NodeDbResult<Self> {
let tcp = TcpStream::connect(addr)
.await
.map_err(|e| NodeDbError::sync_connection_failed(format!("connect to {addr}: {e}")))?;
let config = build_tls_client_config(tls)?;
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let server_name = tls
.server_name
.as_deref()
.unwrap_or_else(|| addr.split(':').next().unwrap_or("localhost"));
let sni = tokio_rustls::rustls::pki_types::ServerName::try_from(server_name.to_string())
.map_err(|e| {
NodeDbError::sync_connection_failed(format!(
"invalid server name '{server_name}': {e}"
))
})?;
let tls_stream = connector.connect(sni, tcp).await.map_err(|e| {
NodeDbError::sync_connection_failed(format!("TLS handshake failed: {e}"))
})?;
Ok(Self {
stream: ConnStream::Tls(Box::new(tls_stream)),
seq: AtomicU64::new(1),
authenticated: false,
})
}
pub async fn authenticate(&mut self, method: AuthMethod) -> NodeDbResult<()> {
let resp = self
.send(
OpCode::Auth,
TextFields {
auth: Some(method),
..Default::default()
},
)
.await?;
if resp.status == ResponseStatus::Error {
let msg = resp
.error
.map(|e| e.message)
.unwrap_or_else(|| "auth failed".into());
return Err(NodeDbError::authorization_denied(msg));
}
self.authenticated = true;
Ok(())
}
pub async fn ping(&mut self) -> NodeDbResult<()> {
let resp = self.send(OpCode::Ping, TextFields::default()).await?;
if resp.status == ResponseStatus::Error {
return Err(NodeDbError::internal("ping failed"));
}
Ok(())
}
pub fn is_authenticated(&self) -> bool {
self.authenticated
}
pub async fn execute_sql(&mut self, sql: &str) -> NodeDbResult<QueryResult> {
let resp = self
.send(
OpCode::Sql,
TextFields {
sql: Some(sql.to_string()),
..Default::default()
},
)
.await?;
response_to_query_result(resp)
}
pub async fn execute_ddl(&mut self, sql: &str) -> NodeDbResult<QueryResult> {
let resp = self
.send(
OpCode::Ddl,
TextFields {
sql: Some(sql.to_string()),
..Default::default()
},
)
.await?;
response_to_query_result(resp)
}
pub async fn begin(&mut self) -> NodeDbResult<()> {
let resp = self.send(OpCode::Begin, TextFields::default()).await?;
check_error(resp)
}
pub async fn commit(&mut self) -> NodeDbResult<()> {
let resp = self.send(OpCode::Commit, TextFields::default()).await?;
check_error(resp)
}
pub async fn rollback(&mut self) -> NodeDbResult<()> {
let resp = self.send(OpCode::Rollback, TextFields::default()).await?;
check_error(resp)
}
pub async fn set_parameter(&mut self, key: &str, value: &str) -> NodeDbResult<()> {
let resp = self
.send(
OpCode::Set,
TextFields {
key: Some(key.to_string()),
value: Some(value.to_string()),
..Default::default()
},
)
.await?;
check_error(resp)
}
pub async fn show_parameter(&mut self, key: &str) -> NodeDbResult<String> {
let resp = self
.send(
OpCode::Show,
TextFields {
key: Some(key.to_string()),
..Default::default()
},
)
.await?;
if resp.status == ResponseStatus::Error {
let msg = resp
.error
.map(|e| e.message)
.unwrap_or_else(|| "show failed".into());
return Err(NodeDbError::internal(msg));
}
let value = resp
.rows
.and_then(|rows| rows.into_iter().next())
.and_then(|row| row.into_iter().next())
.and_then(|v| v.as_str().map(String::from))
.unwrap_or_default();
Ok(value)
}
fn next_seq(&self) -> u64 {
self.seq.fetch_add(1, Ordering::Relaxed)
}
pub(crate) async fn send(
&mut self,
op: OpCode,
fields: TextFields,
) -> NodeDbResult<NativeResponse> {
let req = NativeRequest {
op,
seq: self.next_seq(),
fields: RequestFields::Text(fields),
};
let payload = zerompk::to_msgpack_vec(&req)
.map_err(|e| NodeDbError::serialization("msgpack", format!("request encode: {e}")))?;
let len = payload.len() as u32;
self.stream
.write_all(&len.to_be_bytes())
.await
.map_err(io_err)?;
self.stream.write_all(&payload).await.map_err(io_err)?;
self.stream.flush().await.map_err(io_err)?;
let mut combined_rows: Vec<Vec<nodedb_types::Value>> = Vec::new();
let mut final_resp: Option<NativeResponse> = None;
loop {
let mut len_buf = [0u8; FRAME_HEADER_LEN];
self.stream.read_exact(&mut len_buf).await.map_err(io_err)?;
let resp_len = u32::from_be_bytes(len_buf);
if resp_len > MAX_FRAME_SIZE {
return Err(NodeDbError::internal(format!(
"response frame too large: {resp_len}"
)));
}
let mut resp_buf = vec![0u8; resp_len as usize];
self.stream
.read_exact(&mut resp_buf)
.await
.map_err(io_err)?;
let resp: NativeResponse = zerompk::from_msgpack(&resp_buf).map_err(|e| {
NodeDbError::serialization("msgpack", format!("response decode: {e}"))
})?;
if resp.status == ResponseStatus::Partial {
if let Some(rows) = resp.rows {
combined_rows.extend(rows);
}
if final_resp.is_none() {
final_resp = Some(NativeResponse { rows: None, ..resp });
}
} else {
if combined_rows.is_empty() {
final_resp = Some(resp);
} else {
if let Some(ref rows) = resp.rows {
combined_rows.extend(rows.iter().cloned());
}
let mut merged = final_resp.unwrap_or(resp);
merged.rows = Some(combined_rows);
merged.status = ResponseStatus::Ok;
final_resp = Some(merged);
}
break;
}
}
final_resp.ok_or_else(|| NodeDbError::internal("no final response received"))
}
}
fn build_tls_client_config(tls: &TlsConfig) -> NodeDbResult<tokio_rustls::rustls::ClientConfig> {
use tokio_rustls::rustls;
let builder = rustls::ClientConfig::builder();
if tls.danger_accept_invalid_certs {
let config = builder
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoCertVerifier))
.with_no_client_auth();
return Ok(config);
}
if let Some(ref ca_path) = tls.ca_cert_path {
let mut root_store = rustls::RootCertStore::empty();
let cert_file = std::fs::File::open(ca_path).map_err(|e| {
NodeDbError::sync_connection_failed(format!("open CA cert {}: {e}", ca_path.display()))
})?;
let mut reader = std::io::BufReader::new(cert_file);
for cert in rustls_pemfile::certs(&mut reader) {
match cert {
Ok(c) => {
root_store.add(c).map_err(|e| {
NodeDbError::sync_connection_failed(format!("add CA cert: {e}"))
})?;
}
Err(e) => {
return Err(NodeDbError::sync_connection_failed(format!(
"parse CA cert: {e}"
)));
}
}
}
let config = builder
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
} else {
let root_store = rustls::RootCertStore::empty();
let config = builder
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
}
#[derive(Debug)]
struct NoCertVerifier;
impl tokio_rustls::rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>],
_server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: tokio_rustls::rustls::pki_types::UnixTime,
) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
{
Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
fn io_err(e: std::io::Error) -> NodeDbError {
NodeDbError::sync_connection_failed(format!("I/O: {e}"))
}
fn check_error(resp: NativeResponse) -> NodeDbResult<()> {
if resp.status == ResponseStatus::Error {
let msg = resp
.error
.map(|e| e.message)
.unwrap_or_else(|| "unknown error".into());
return Err(NodeDbError::internal(msg));
}
Ok(())
}
fn response_to_query_result(resp: NativeResponse) -> NodeDbResult<QueryResult> {
if resp.status == ResponseStatus::Error {
let msg = resp
.error
.map(|e| e.message)
.unwrap_or_else(|| "query failed".into());
return Err(NodeDbError::internal(msg));
}
Ok(QueryResult {
columns: resp.columns.unwrap_or_default(),
rows: resp.rows.unwrap_or_default(),
rows_affected: resp.rows_affected.unwrap_or(0),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn response_to_query_result_ok() {
let resp = NativeResponse::from_query_result(
1,
QueryResult {
columns: vec!["x".into()],
rows: vec![vec![nodedb_types::Value::Integer(42)]],
rows_affected: 0,
},
0,
);
let qr = response_to_query_result(resp).unwrap();
assert_eq!(qr.columns, vec!["x"]);
assert_eq!(qr.rows[0][0].as_i64(), Some(42));
}
#[test]
fn response_to_query_result_error() {
let resp = NativeResponse::error(1, "42P01", "not found");
let err = response_to_query_result(resp).unwrap_err();
assert!(format!("{err}").contains("not found"));
}
#[test]
fn check_error_ok() {
let resp = NativeResponse::ok(1);
assert!(check_error(resp).is_ok());
}
#[test]
fn check_error_fail() {
let resp = NativeResponse::error(1, "XX000", "boom");
assert!(check_error(resp).is_err());
}
}