use std::collections::BTreeMap;
use std::fmt::Write;
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::connection::DatabaseClient;
use crate::error::Result;
use crate::query::builder::Query;
use crate::query::executor::flatten_rows;
use crate::query::expressions::Expression;
use crate::query::results::{record, RecordResult};
use crate::types::operators::{Operator, OperatorExpr};
use crate::types::record_id::RecordID;
pub async fn create_record(
client: &DatabaseClient,
table: &str,
data: Value,
) -> Result<RecordResult<Value>> {
let mut vars = BTreeMap::new();
vars.insert("data".to_owned(), data);
let surql = format!("CREATE {table} CONTENT $data");
let raw = client.query_with_vars(&surql, vars).await?;
let first = flatten_rows(&raw).into_iter().next();
let present = first.is_some();
Ok(record(first, present))
}
pub async fn create_records(
client: &DatabaseClient,
table: &str,
data: Vec<Value>,
) -> Result<Vec<Value>> {
let mut out = Vec::with_capacity(data.len());
for item in data {
let mut vars = BTreeMap::new();
vars.insert("data".to_owned(), item);
let surql = format!("CREATE {table} CONTENT $data");
let raw = client.query_with_vars(&surql, vars).await?;
if let Some(row) = flatten_rows(&raw).into_iter().next() {
out.push(row);
}
}
Ok(out)
}
pub async fn get_record<T>(
client: &DatabaseClient,
record_id: &RecordID<T>,
) -> Result<Option<Value>> {
let target = record_id.to_string();
let surql = format!("SELECT * FROM {target}");
let raw = client.query(&surql).await?;
Ok(flatten_rows(&raw).into_iter().next())
}
pub async fn update_record<T>(
client: &DatabaseClient,
record_id: &RecordID<T>,
data: Value,
) -> Result<Value> {
let target = record_id.to_string();
update_record_target(client, &target, data).await
}
pub async fn update_record_target(
client: &DatabaseClient,
target: &str,
data: Value,
) -> Result<Value> {
let mut vars = BTreeMap::new();
vars.insert("data".to_owned(), data);
let surql = format!("UPDATE {target} CONTENT $data");
let raw = client.query_with_vars(&surql, vars).await?;
Ok(flatten_rows(&raw).into_iter().next().unwrap_or(Value::Null))
}
pub async fn merge_record<T>(
client: &DatabaseClient,
record_id: &RecordID<T>,
patch: Value,
) -> Result<Value> {
let target = record_id.to_string();
let mut vars = BTreeMap::new();
vars.insert("patch".to_owned(), patch);
let surql = format!("UPDATE {target} MERGE $patch");
let raw = client.query_with_vars(&surql, vars).await?;
Ok(flatten_rows(&raw).into_iter().next().unwrap_or(Value::Null))
}
pub async fn upsert_record<T>(
client: &DatabaseClient,
record_id: &RecordID<T>,
data: Value,
) -> Result<Value> {
let target = record_id.to_string();
upsert_record_target(client, &target, data).await
}
pub async fn upsert_record_target(
client: &DatabaseClient,
target: &str,
data: Value,
) -> Result<Value> {
let mut vars = BTreeMap::new();
vars.insert("data".to_owned(), data);
let surql = format!("UPSERT {target} CONTENT $data");
let raw = client.query_with_vars(&surql, vars).await?;
Ok(flatten_rows(&raw).into_iter().next().unwrap_or(Value::Null))
}
pub async fn delete_record<T>(client: &DatabaseClient, record_id: &RecordID<T>) -> Result<()> {
let target = record_id.to_string();
let surql = format!("DELETE {target}");
client.query(&surql).await?;
Ok(())
}
pub async fn delete_records(
client: &DatabaseClient,
table: &str,
where_: Option<&Operator>,
) -> Result<u64> {
let surql = if let Some(op) = where_ {
format!("DELETE {table} WHERE ({}) RETURN BEFORE", op.to_surql())
} else {
format!("DELETE {table} RETURN BEFORE")
};
let raw = client.query(&surql).await?;
Ok(flatten_rows(&raw).len() as u64)
}
pub async fn query_records<T: DeserializeOwned>(
client: &DatabaseClient,
query: &Query,
) -> Result<Vec<T>> {
super::executor::fetch_all(client, query).await
}
pub async fn count_records(
client: &DatabaseClient,
table: &str,
where_: Option<&Operator>,
) -> Result<i64> {
let mut surql = format!("SELECT count() FROM {table}");
if let Some(op) = where_ {
write!(surql, " WHERE ({})", op.to_surql()).expect("write to String cannot fail");
}
surql.push_str(" GROUP ALL");
let raw = client.query(&surql).await?;
let row = flatten_rows(&raw).into_iter().next();
Ok(row
.as_ref()
.and_then(|r| r.get("count").and_then(Value::as_i64))
.unwrap_or(0))
}
pub async fn exists<T>(client: &DatabaseClient, record_id: &RecordID<T>) -> Result<bool> {
Ok(get_record(client, record_id).await?.is_some())
}
pub async fn first<T: DeserializeOwned>(
client: &DatabaseClient,
query: &Query,
) -> Result<Option<T>> {
let q_with_limit = if query.limit_value.is_some() {
query.clone()
} else {
query.clone().limit(1)?
};
super::executor::fetch_one(client, &q_with_limit).await
}
pub async fn last<T: DeserializeOwned>(
client: &DatabaseClient,
query: &Query,
) -> Result<Option<T>> {
let mut cloned = query.clone();
for entry in &mut cloned.order_fields {
entry.direction = if entry.direction.eq_ignore_ascii_case("ASC") {
"DESC".to_owned()
} else {
"ASC".to_owned()
};
}
if cloned.limit_value.is_none() {
cloned = cloned.limit(1)?;
}
super::executor::fetch_one(client, &cloned).await
}
#[derive(Debug, Clone, Default)]
pub struct AggregateOpts {
pub select: Vec<(String, Expression)>,
pub group_by: Vec<String>,
pub where_: Option<Operator>,
pub group_all: bool,
pub order_by: Vec<(String, String)>,
pub limit: Option<i64>,
}
pub fn build_aggregate_query(table: &str, opts: &AggregateOpts) -> Result<Query> {
if opts.select.is_empty() {
return Err(crate::error::SurqlError::Query {
reason: "aggregate_records requires at least one select entry".into(),
});
}
let fields: Vec<String> = opts
.select
.iter()
.map(|(alias, expr)| format!("{} AS {alias}", expr.to_surql()))
.collect();
let mut query = Query::new().select(Some(fields)).from_table(table)?;
if let Some(op) = opts.where_.as_ref() {
query = query.where_str(op.to_surql());
}
if opts.group_all {
query = query.group_all();
} else if !opts.group_by.is_empty() {
query = query.group_by(opts.group_by.iter().cloned());
}
for (field, direction) in &opts.order_by {
query = query.order_by(field.clone(), direction.clone())?;
}
if let Some(n) = opts.limit {
query = query.limit(n)?;
}
Ok(query)
}
pub async fn aggregate_records(
client: &DatabaseClient,
table: &str,
opts: AggregateOpts,
) -> Result<Vec<Value>> {
let query = build_aggregate_query(table, &opts)?;
let raw = super::executor::execute_query(client, &query).await?;
Ok(flatten_rows(&raw))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::operators::eq;
use serde_json::json;
#[test]
fn delete_records_renders_where_clause() {
let op = eq("status", "inactive");
let rendered = format!("DELETE user WHERE ({}) RETURN BEFORE", op.to_surql());
assert_eq!(
rendered,
"DELETE user WHERE (status = 'inactive') RETURN BEFORE"
);
}
#[test]
fn json_payload_serializes_stably() {
let v = json!({"name": "Alice", "age": 30});
let rendered = serde_json::to_string(&v).unwrap();
assert!(rendered.contains("\"name\":\"Alice\""));
assert!(rendered.contains("\"age\":30"));
}
#[test]
fn build_aggregate_query_rejects_empty_select() {
let err = build_aggregate_query("memory_entry", &AggregateOpts::default());
assert!(matches!(err, Err(crate::error::SurqlError::Query { .. })));
}
#[test]
fn build_aggregate_query_renders_select_group_by() {
use crate::query::expressions::{count_all, math_sum};
let opts = AggregateOpts {
select: vec![
("count".to_string(), count_all()),
("total".to_string(), math_sum("strength")),
],
group_by: vec!["network".into()],
..Default::default()
};
let q = build_aggregate_query("memory_entry", &opts).unwrap();
assert_eq!(
q.to_surql().unwrap(),
"SELECT count() AS count, math::sum(strength) AS total FROM memory_entry \
GROUP BY network",
);
}
#[test]
fn build_aggregate_query_renders_group_all() {
use crate::query::expressions::{count_all, math_mean};
let opts = AggregateOpts {
select: vec![
("total".to_string(), count_all()),
("mean".to_string(), math_mean("strength")),
],
group_all: true,
..Default::default()
};
let q = build_aggregate_query("memory_entry", &opts).unwrap();
assert_eq!(
q.to_surql().unwrap(),
"SELECT count() AS total, math::mean(strength) AS mean FROM memory_entry GROUP ALL",
);
}
#[test]
fn build_aggregate_query_renders_where_order_limit() {
use crate::query::expressions::count_all;
let opts = AggregateOpts {
select: vec![("count".to_string(), count_all())],
where_: Some(eq("status", "active")),
order_by: vec![("count".to_string(), "DESC".into())],
limit: Some(5),
..Default::default()
};
let q = build_aggregate_query("user", &opts).unwrap();
assert_eq!(
q.to_surql().unwrap(),
"SELECT count() AS count FROM user WHERE (status = 'active') \
ORDER BY count DESC LIMIT 5",
);
}
}