use std::collections::HashMap;
use async_trait::async_trait;
use nodedb_types::document::Document;
use nodedb_types::error::{ErrorDetails, NodeDbError, NodeDbResult};
use nodedb_types::filter::{EdgeFilter, MetadataFilter};
use nodedb_types::id::{EdgeId, NodeId};
use nodedb_types::protocol::{OpCode, TextFields};
use nodedb_types::result::{QueryResult, SearchResult, SubGraph};
use nodedb_types::value::Value;
use nodedb_types::protocol::Limits;
use super::pool::{Pool, PoolConfig};
use super::response_parse::{json_to_value, parse_search_results, parse_subgraph_response};
use crate::traits::NodeDb;
pub struct NativeClient {
pool: Pool,
}
impl NativeClient {
pub fn new(config: PoolConfig) -> Self {
Self {
pool: Pool::new(config),
}
}
pub fn connect(addr: &str) -> Self {
Self::new(PoolConfig {
addr: addr.to_string(),
..Default::default()
})
}
pub async fn query(&self, sql: &str) -> NodeDbResult<QueryResult> {
let mut conn = self.pool.acquire().await?;
match conn.execute_sql(sql).await {
Ok(r) => Ok(r),
Err(e) if is_connection_error(&e) => {
drop(conn);
let mut conn = self.pool.acquire().await?;
conn.execute_sql(sql).await
}
Err(e) => Err(e),
}
}
pub async fn ddl(&self, sql: &str) -> NodeDbResult<QueryResult> {
let mut conn = self.pool.acquire().await?;
match conn.execute_ddl(sql).await {
Ok(r) => Ok(r),
Err(e) if is_connection_error(&e) => {
drop(conn);
let mut conn = self.pool.acquire().await?;
conn.execute_ddl(sql).await
}
Err(e) => Err(e),
}
}
pub async fn begin(&self) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.begin().await
}
pub async fn commit(&self) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.commit().await
}
pub async fn rollback(&self) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.rollback().await
}
pub async fn set_parameter(&self, key: &str, value: &str) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.set_parameter(key, value).await
}
pub async fn show_parameter(&self, key: &str) -> NodeDbResult<String> {
let mut conn = self.pool.acquire().await?;
conn.show_parameter(key).await
}
pub async fn ping(&self) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.ping().await
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl NodeDb for NativeClient {
fn proto_version(&self) -> u16 {
self.pool
.negotiated_meta()
.map(|m| m.proto_version)
.unwrap_or(0)
}
fn capabilities(&self) -> u64 {
self.pool
.negotiated_meta()
.map(|m| m.capabilities)
.unwrap_or(0)
}
fn server_version(&self) -> String {
self.pool
.negotiated_meta()
.map(|m| m.server_version)
.unwrap_or_default()
}
fn limits(&self) -> Limits {
self.pool
.negotiated_meta()
.map(|m| m.limits)
.unwrap_or_default()
}
async fn vector_search(
&self,
collection: &str,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> NodeDbResult<Vec<SearchResult>> {
let mut conn = self.pool.acquire().await?;
let resp = conn
.send(
OpCode::VectorSearch,
build_vector_search_request(collection, query, k, filter),
)
.await?;
parse_search_results(&resp)
}
async fn vector_insert(
&self,
collection: &str,
id: &str,
embedding: &[f32],
metadata: Option<Document>,
) -> NodeDbResult<()> {
let meta_json = match metadata {
Some(d) => {
let obj: HashMap<String, Value> = d.fields;
sonic_rs::to_string(&obj).map_err(|e| {
NodeDbError::serialization("json", format!("vector_insert metadata: {e}"))
})?
}
None => "{}".to_string(),
};
let sql = format!(
"INSERT INTO {} (id, embedding, metadata) VALUES ({}, {}, {})",
sql_quote_identifier(collection),
sql_quote_string_literal(id),
format_f32_array(embedding),
sql_quote_string_literal(&meta_json),
);
let mut conn = self.pool.acquire().await?;
conn.execute_sql(&sql).await?;
Ok(())
}
async fn vector_delete(&self, collection: &str, id: &str) -> NodeDbResult<()> {
let sql = format!(
"DELETE FROM {} WHERE id = {}",
sql_quote_identifier(collection),
sql_quote_string_literal(id),
);
let mut conn = self.pool.acquire().await?;
conn.execute_sql(&sql).await?;
Ok(())
}
async fn graph_traverse(
&self,
collection: &str,
start: &NodeId,
depth: u8,
edge_filter: Option<&EdgeFilter>,
) -> NodeDbResult<SubGraph> {
let mut conn = self.pool.acquire().await?;
let resp = conn
.send(
OpCode::GraphHop,
TextFields {
collection: Some(collection.to_string()),
start_node: Some(start.as_str().to_string()),
depth: Some(depth as u32),
edge_label: edge_filter.and_then(|f| f.labels.first().cloned()),
..Default::default()
},
)
.await?;
parse_subgraph_response(&resp)
}
async fn graph_insert_edge(
&self,
collection: &str,
from: &NodeId,
to: &NodeId,
edge_type: &str,
properties: Option<Document>,
) -> NodeDbResult<EdgeId> {
let props_json = properties.and_then(|d| serde_json::to_value(d.fields).ok());
let mut conn = self.pool.acquire().await?;
conn.send(
OpCode::EdgePut,
TextFields {
collection: Some(collection.to_string()),
from_node: Some(from.as_str().to_string()),
to_node: Some(to.as_str().to_string()),
edge_type: Some(edge_type.to_string()),
properties: props_json,
..Default::default()
},
)
.await?;
EdgeId::try_first(from.clone(), to.clone(), edge_type)
.map_err(|e| NodeDbError::storage(format!("invalid edge label: {e}")))
}
async fn graph_delete_edge(&self, collection: &str, edge_id: &EdgeId) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.send(
OpCode::EdgeDelete,
TextFields {
collection: Some(collection.to_string()),
from_node: Some(edge_id.src.as_str().to_string()),
to_node: Some(edge_id.dst.as_str().to_string()),
edge_type: Some(edge_id.label.clone()),
..Default::default()
},
)
.await?;
Ok(())
}
async fn document_get(&self, collection: &str, id: &str) -> NodeDbResult<Option<Document>> {
let mut conn = self.pool.acquire().await?;
let resp = conn
.send(
OpCode::PointGet,
TextFields {
collection: Some(collection.to_string()),
document_id: Some(id.to_string()),
..Default::default()
},
)
.await?;
let rows = resp.rows.unwrap_or_default();
if rows.is_empty() {
return Ok(None);
}
let json_text = rows[0].first().and_then(|v| v.as_str()).unwrap_or("{}");
let mut doc = Document::new(id);
if let Ok(obj) = sonic_rs::from_str::<HashMap<String, serde_json::Value>>(json_text) {
for (k, v) in obj {
doc.set(&k, json_to_value(v));
}
}
Ok(Some(doc))
}
async fn document_put(&self, collection: &str, doc: Document) -> NodeDbResult<()> {
let data = sonic_rs::to_vec(&doc.fields)
.map_err(|e| NodeDbError::serialization("json", format!("doc serialize: {e}")))?;
let mut conn = self.pool.acquire().await?;
conn.send(
OpCode::PointPut,
TextFields {
collection: Some(collection.to_string()),
document_id: Some(doc.id.clone()),
data: Some(data),
..Default::default()
},
)
.await?;
Ok(())
}
async fn document_delete(&self, collection: &str, id: &str) -> NodeDbResult<()> {
let mut conn = self.pool.acquire().await?;
conn.send(
OpCode::PointDelete,
TextFields {
collection: Some(collection.to_string()),
document_id: Some(id.to_string()),
..Default::default()
},
)
.await?;
Ok(())
}
async fn execute_sql(&self, query: &str, params: &[Value]) -> NodeDbResult<QueryResult> {
let mut conn = self.pool.acquire().await?;
match conn.execute_sql_with_params(query, params).await {
Ok(r) => Ok(r),
Err(e) if is_connection_error(&e) => {
drop(conn);
let mut conn = self.pool.acquire().await?;
conn.execute_sql_with_params(query, params).await
}
Err(e) => Err(e),
}
}
}
fn build_vector_search_request(
collection: &str,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> TextFields {
let filters_bytes = filter.and_then(|f| {
match sonic_rs::to_vec(f) {
Ok(b) => Some(b),
Err(e) => {
tracing::warn!(error = %e, "failed to serialize metadata filter for native request");
None
}
}
});
TextFields {
collection: Some(collection.to_string()),
query_vector: Some(query.to_vec()),
top_k: Some(k as u32),
filters: filters_bytes,
..Default::default()
}
}
fn format_f32_array(arr: &[f32]) -> String {
let inner: Vec<String> = arr.iter().map(|v| format!("{v}")).collect();
format!("ARRAY[{}]", inner.join(","))
}
fn sql_quote_identifier(name: &str) -> String {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
}
fn sql_quote_string_literal(s: &str) -> String {
let escaped = s.replace('\'', "''");
format!("'{escaped}'")
}
fn is_connection_error(e: &NodeDbError) -> bool {
matches!(
e.details(),
ErrorDetails::SyncConnectionFailed | ErrorDetails::Storage { .. }
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vector_search_request_without_filter_omits_filter_bytes() {
let req = build_vector_search_request("docs", &[0.1, 0.2], 5, None);
assert_eq!(req.collection.as_deref(), Some("docs"));
assert_eq!(req.query_vector.as_deref(), Some(&[0.1f32, 0.2][..]));
assert_eq!(req.top_k, Some(5));
assert!(
req.filters.is_none(),
"no-filter case must leave TextFields::filters empty"
);
}
#[test]
fn vector_search_request_serializes_metadata_filter() {
let filter = MetadataFilter::eq("category", Value::String("ai".into()));
let req = build_vector_search_request("docs", &[0.1], 3, Some(&filter));
assert!(
req.filters.is_some(),
"non-None filter must be serialized into TextFields::filters \
rather than dropped before reaching the wire"
);
let bytes = req.filters.expect("filters bytes recorded");
assert!(
!bytes.is_empty(),
"serialized filter bytes must not be empty"
);
}
#[test]
fn execute_sql_encodes_params_into_sql_params_field() {
let params = vec![
Value::Null,
Value::Bool(true),
Value::Integer(42),
Value::String("alice".into()),
];
let bytes = zerompk::to_msgpack_vec(¶ms).expect("encode params");
let decoded: Vec<Value> =
zerompk::from_msgpack(&bytes).expect("decode round-trips on same codec");
assert_eq!(decoded.len(), 4);
assert!(matches!(decoded[0], Value::Null));
assert!(matches!(decoded[1], Value::Bool(true)));
assert!(matches!(decoded[2], Value::Integer(42)));
match &decoded[3] {
Value::String(s) => assert_eq!(s, "alice"),
other => panic!("expected Value::String('alice'), got {other:?}"),
}
}
#[test]
fn format_f32_array_works() {
let arr = [0.1f32, 0.2, 0.3];
let s = format_f32_array(&arr);
assert!(s.starts_with("ARRAY["));
assert!(s.contains("0.1"));
assert!(s.ends_with(']'));
}
#[test]
fn sql_quote_identifier_wraps_and_escapes_double_quotes() {
assert_eq!(sql_quote_identifier("foo"), "\"foo\"");
assert_eq!(sql_quote_identifier("a\"b"), "\"a\"\"b\"");
}
#[test]
fn sql_quote_string_literal_escapes_single_quotes() {
assert_eq!(sql_quote_string_literal("plain"), "'plain'");
assert_eq!(sql_quote_string_literal("O'Reilly"), "'O''Reilly'");
assert_eq!(
sql_quote_string_literal("'; DROP TABLE x; --"),
"'''; DROP TABLE x; --'"
);
}
#[test]
fn sql_quote_string_literal_passes_through_json() {
let json = r#"{"name":"O'Reilly","ok":true}"#;
let quoted = sql_quote_string_literal(json);
assert_eq!(quoted, "'{\"name\":\"O''Reilly\",\"ok\":true}'");
}
}