use pgwire::api::results::{Response, Tag};
use pgwire::error::PgWireResult;
use crate::bridge::physical_plan::{DocumentOp, VectorOp};
use crate::control::security::identity::AuthenticatedIdentity;
use crate::control::state::SharedState;
use super::super::types::sqlstate_error;
use super::sql_parse::{parse_array_literal, parse_sql_value, split_values};
struct ParsedInsert {
coll_name: String,
doc_id: String,
fields: serde_json::Map<String, serde_json::Value>,
vector_fields: Vec<(String, Vec<f32>)>,
value_bytes: Vec<u8>,
has_returning: bool,
}
fn parse_write_statement(
state: &SharedState,
identity: &AuthenticatedIdentity,
sql: &str,
keyword: &str,
) -> Option<PgWireResult<ParsedInsert>> {
let upper = sql.to_uppercase();
let kw_pos = upper.find(keyword)?;
let after_into = sql[kw_pos + keyword.len()..].trim_start();
let coll_name_str = after_into.split_whitespace().next()?;
let coll_name = coll_name_str.to_lowercase();
let tenant_id = identity.tenant_id;
if let Some(catalog) = state.credentials.catalog()
&& let Ok(Some(coll)) = catalog.get_collection(tenant_id.as_u32(), &coll_name)
{
if !coll.fields.is_empty() {
return None;
}
if coll.collection_type.is_strict() || coll.collection_type.is_columnar() {
return None;
}
}
let first_open = match sql.find('(') {
Some(p) => p,
None => {
return Some(Err(sqlstate_error(
"42601",
&format!("missing column list in {}", keyword.trim()),
)));
}
};
let values_kw = match upper.find("VALUES") {
Some(p) => p,
None => return Some(Err(sqlstate_error("42601", "missing VALUES clause"))),
};
let first_close = match sql[first_open..values_kw].rfind(')') {
Some(p) => first_open + p,
None => {
return Some(Err(sqlstate_error(
"42601",
"missing closing ) for column list",
)));
}
};
let cols_str = &sql[first_open + 1..first_close];
let columns: Vec<&str> = cols_str.split(',').map(|c| c.trim()).collect();
let after_values = sql[values_kw + 6..].trim_start();
let vals_open = match after_values.find('(') {
Some(p) => p,
None => return Some(Err(sqlstate_error("42601", "missing VALUES (...)"))),
};
let vals_close = match after_values.rfind(')') {
Some(p) => p,
None => return Some(Err(sqlstate_error("42601", "missing closing ) for VALUES"))),
};
let vals_str = &after_values[vals_open + 1..vals_close];
let values: Vec<&str> = split_values(vals_str);
if columns.len() != values.len() {
return Some(Err(sqlstate_error(
"42601",
&format!(
"column count ({}) doesn't match value count ({})",
columns.len(),
values.len()
),
)));
}
let mut doc_id = String::new();
let mut fields = serde_json::Map::new();
for (col, val) in columns.iter().zip(values.iter()) {
let col = col.trim().trim_matches('"');
let val = val.trim();
if col.eq_ignore_ascii_case("id") {
doc_id = val.trim_matches('\'').to_string();
} else {
fields.insert(col.to_string(), parse_sql_value(val));
}
}
if doc_id.is_empty() {
doc_id = nodedb_types::id_gen::uuid_v7();
}
let mut vector_fields: Vec<(String, Vec<f32>)> = Vec::new();
for (col, val) in columns.iter().zip(values.iter()) {
let col = col.trim().trim_matches('"');
let val = val.trim();
if let Some(vec_data) = parse_array_literal(val) {
vector_fields.push((col.to_string(), vec_data));
}
}
let value_bytes = serde_json::to_vec(&fields).unwrap_or_default();
let has_returning = upper.contains("RETURNING");
Some(Ok(ParsedInsert {
coll_name,
doc_id,
fields,
vector_fields,
value_bytes,
has_returning,
}))
}
fn returning_response(
doc_id: &str,
fields: &serde_json::Map<String, serde_json::Value>,
) -> PgWireResult<Vec<Response>> {
use futures::stream;
use pgwire::api::results::{DataRowEncoder, QueryResponse};
let mut result_doc = fields.clone();
result_doc.insert(
"id".to_string(),
serde_json::Value::String(doc_id.to_string()),
);
let json_str =
serde_json::to_string(&serde_json::Value::Object(result_doc)).unwrap_or_default();
let schema = std::sync::Arc::new(vec![super::super::types::text_field("result")]);
let mut encoder = DataRowEncoder::new(schema.clone());
let _ = encoder.encode_field(&json_str);
let row = encoder.take_row();
Ok(vec![Response::Query(QueryResponse::new(
schema,
stream::iter(vec![Ok(row)]),
))])
}
pub async fn insert_document(
state: &SharedState,
identity: &AuthenticatedIdentity,
sql: &str,
) -> Option<PgWireResult<Vec<Response>>> {
let parsed = match parse_write_statement(state, identity, sql, "INSERT INTO ")? {
Ok(p) => p,
Err(e) => return Some(Err(e)),
};
let tenant_id = identity.tenant_id;
let vshard_id = crate::types::VShardId::from_key(parsed.doc_id.as_bytes());
let plan = crate::bridge::envelope::PhysicalPlan::Document(DocumentOp::PointPut {
collection: parsed.coll_name.clone(),
document_id: parsed.doc_id.clone(),
value: parsed.value_bytes,
});
if let Err(e) = crate::control::server::dispatch_utils::wal_append_if_write(
&state.wal, tenant_id, vshard_id, &plan,
) {
return Some(Err(sqlstate_error("XX000", &e.to_string())));
}
if let Err(e) = crate::control::server::dispatch_utils::dispatch_to_data_plane(
state, tenant_id, vshard_id, plan, 0,
)
.await
{
return Some(Err(sqlstate_error("XX000", &e.to_string())));
}
let vec_vshard = crate::types::VShardId::from_collection(&parsed.coll_name);
for (_field_name, vector) in &parsed.vector_fields {
let dim = vector.len();
let vec_plan = crate::bridge::envelope::PhysicalPlan::Vector(VectorOp::Insert {
collection: parsed.coll_name.clone(),
vector: vector.clone(),
dim,
field_name: String::new(),
doc_id: Some(parsed.doc_id.clone()),
});
if let Err(e) = crate::control::server::dispatch_utils::wal_append_if_write(
&state.wal, tenant_id, vec_vshard, &vec_plan,
) {
return Some(Err(sqlstate_error("XX000", &e.to_string())));
}
if let Err(e) = crate::control::server::dispatch_utils::dispatch_to_data_plane(
state, tenant_id, vec_vshard, vec_plan, 0,
)
.await
{
return Some(Err(sqlstate_error("XX000", &e.to_string())));
}
}
if parsed.has_returning {
return Some(returning_response(&parsed.doc_id, &parsed.fields));
}
Some(Ok(vec![Response::Execution(Tag::new("INSERT"))]))
}
pub async fn upsert_document(
state: &SharedState,
identity: &AuthenticatedIdentity,
sql: &str,
) -> Option<PgWireResult<Vec<Response>>> {
let parsed = match parse_write_statement(state, identity, sql, "UPSERT INTO ")? {
Ok(p) => p,
Err(e) => return Some(Err(e)),
};
let tenant_id = identity.tenant_id;
let vshard_id = crate::types::VShardId::from_key(parsed.doc_id.as_bytes());
let plan = crate::bridge::envelope::PhysicalPlan::Document(DocumentOp::Upsert {
collection: parsed.coll_name.clone(),
document_id: parsed.doc_id.clone(),
value: parsed.value_bytes,
});
if let Err(e) = crate::control::server::dispatch_utils::wal_append_if_write(
&state.wal, tenant_id, vshard_id, &plan,
) {
return Some(Err(sqlstate_error("XX000", &e.to_string())));
}
if let Err(e) = crate::control::server::dispatch_utils::dispatch_to_data_plane(
state, tenant_id, vshard_id, plan, 0,
)
.await
{
return Some(Err(sqlstate_error("XX000", &e.to_string())));
}
Some(Ok(vec![Response::Execution(Tag::new("UPSERT"))]))
}