use std::collections::BTreeMap;
use std::fmt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::Result;
use super::builder::Query;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum ReturnFormat {
None,
Diff,
Full,
Before,
After,
}
impl ReturnFormat {
pub fn to_surql(self) -> &'static str {
match self {
Self::None => "NONE",
Self::Diff => "DIFF",
Self::Full => "FULL",
Self::Before => "BEFORE",
Self::After => "AFTER",
}
}
}
impl fmt::Display for ReturnFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.to_surql())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum VectorDistanceType {
Cosine,
Euclidean,
Manhattan,
Minkowski,
Chebyshev,
Hamming,
Jaccard,
Pearson,
Mahalanobis,
}
impl VectorDistanceType {
pub fn to_surql(self) -> &'static str {
match self {
Self::Cosine => "COSINE",
Self::Euclidean => "EUCLIDEAN",
Self::Manhattan => "MANHATTAN",
Self::Minkowski => "MINKOWSKI",
Self::Chebyshev => "CHEBYSHEV",
Self::Hamming => "HAMMING",
Self::Jaccard => "JACCARD",
Self::Pearson => "PEARSON",
Self::Mahalanobis => "MAHALANOBIS",
}
}
pub fn as_func_suffix(self) -> &'static str {
match self {
Self::Cosine => "cosine",
Self::Euclidean => "euclidean",
Self::Manhattan => "manhattan",
Self::Minkowski => "minkowski",
Self::Chebyshev => "chebyshev",
Self::Hamming => "hamming",
Self::Jaccard => "jaccard",
Self::Pearson => "pearson",
Self::Mahalanobis => "mahalanobis",
}
}
}
impl fmt::Display for VectorDistanceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.to_surql())
}
}
pub type DataMap = BTreeMap<String, Value>;
pub fn select(fields: Option<Vec<String>>) -> Query {
Query::new().select(fields)
}
pub fn from_table(query: Query, table: impl Into<String>) -> Result<Query> {
query.from_table(table)
}
pub fn where_(query: Query, condition: impl Into<String>) -> Query {
query.where_str(condition)
}
pub fn order_by(
query: Query,
field: impl Into<String>,
direction: impl Into<String>,
) -> Result<Query> {
query.order_by(field, direction)
}
pub fn limit(query: Query, n: i64) -> Result<Query> {
query.limit(n)
}
pub fn offset(query: Query, n: i64) -> Result<Query> {
query.offset(n)
}
pub fn insert(table: impl Into<String>, data: DataMap) -> Result<Query> {
Query::new().insert(table, data)
}
pub fn update(target: impl Into<String>, data: DataMap) -> Result<Query> {
Query::new().update(target, data)
}
pub fn upsert(target: impl Into<String>, data: DataMap) -> Result<Query> {
Query::new().upsert(target, data)
}
pub fn delete(target: impl Into<String>) -> Result<Query> {
Query::new().delete(target)
}
pub fn relate(
edge_table: impl Into<String>,
from_record: impl Into<String>,
to_record: impl Into<String>,
data: Option<DataMap>,
) -> Result<Query> {
Query::new().relate(edge_table, from_record, to_record, data)
}
pub fn vector_search_query(
table: impl Into<String>,
field: impl Into<String>,
vector: Vec<f64>,
k: i64,
distance: VectorDistanceType,
fields: Option<Vec<String>>,
threshold: Option<f64>,
) -> Result<Query> {
Query::new()
.select(fields)
.from_table(table)?
.vector_search(field, vector, k, distance, threshold)
}
#[allow(clippy::too_many_arguments)]
pub fn similarity_search_query(
table: impl Into<String>,
field: impl Into<String>,
vector: Vec<f64>,
k: i64,
distance: VectorDistanceType,
threshold: Option<f64>,
fields: Option<Vec<String>>,
alias: impl Into<String>,
) -> Result<Query> {
let target_field: String = field.into();
Query::new()
.select(fields)
.from_table(table)?
.similarity_score(&target_field, &vector, distance, alias)
.vector_search(target_field, vector, k, distance, threshold)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn return_format_to_surql() {
assert_eq!(ReturnFormat::None.to_surql(), "NONE");
assert_eq!(ReturnFormat::Diff.to_surql(), "DIFF");
assert_eq!(ReturnFormat::Full.to_surql(), "FULL");
assert_eq!(ReturnFormat::Before.to_surql(), "BEFORE");
assert_eq!(ReturnFormat::After.to_surql(), "AFTER");
}
#[test]
fn return_format_display_matches_surql() {
assert_eq!(ReturnFormat::Diff.to_string(), "DIFF");
}
#[test]
fn vector_distance_uppercase() {
assert_eq!(VectorDistanceType::Cosine.to_surql(), "COSINE");
assert_eq!(VectorDistanceType::Euclidean.to_surql(), "EUCLIDEAN");
assert_eq!(VectorDistanceType::Manhattan.to_surql(), "MANHATTAN");
assert_eq!(VectorDistanceType::Minkowski.to_surql(), "MINKOWSKI");
assert_eq!(VectorDistanceType::Chebyshev.to_surql(), "CHEBYSHEV");
assert_eq!(VectorDistanceType::Hamming.to_surql(), "HAMMING");
assert_eq!(VectorDistanceType::Jaccard.to_surql(), "JACCARD");
assert_eq!(VectorDistanceType::Pearson.to_surql(), "PEARSON");
assert_eq!(VectorDistanceType::Mahalanobis.to_surql(), "MAHALANOBIS");
}
#[test]
fn vector_distance_func_suffix_is_lowercase() {
assert_eq!(VectorDistanceType::Cosine.as_func_suffix(), "cosine");
assert_eq!(VectorDistanceType::Euclidean.as_func_suffix(), "euclidean");
}
#[test]
fn select_helper_is_star_by_default() {
let q = select(None).from_table("user").unwrap();
assert_eq!(q.to_surql().unwrap(), "SELECT * FROM user");
}
#[test]
fn select_helper_projects_fields() {
let q = select(Some(vec!["name".into(), "email".into()]))
.from_table("user")
.unwrap();
assert_eq!(q.to_surql().unwrap(), "SELECT name, email FROM user");
}
#[test]
fn from_table_helper_sets_table() {
let q = from_table(select(None), "user").unwrap();
assert_eq!(q.to_surql().unwrap(), "SELECT * FROM user");
}
#[test]
fn where_helper_adds_condition() {
let q = where_(select(None).from_table("user").unwrap(), "age > 18");
assert_eq!(q.to_surql().unwrap(), "SELECT * FROM user WHERE (age > 18)");
}
#[test]
fn order_by_helper() {
let q = order_by(select(None).from_table("user").unwrap(), "name", "ASC").unwrap();
assert_eq!(
q.to_surql().unwrap(),
"SELECT * FROM user ORDER BY name ASC"
);
}
#[test]
fn limit_helper() {
let q = limit(select(None).from_table("user").unwrap(), 10).unwrap();
assert_eq!(q.to_surql().unwrap(), "SELECT * FROM user LIMIT 10");
}
#[test]
fn offset_helper_renders_start() {
let q = offset(select(None).from_table("user").unwrap(), 20).unwrap();
assert_eq!(q.to_surql().unwrap(), "SELECT * FROM user START 20");
}
#[test]
fn insert_helper_constructs_query() {
let mut data = DataMap::new();
data.insert("name".into(), Value::String("Alice".into()));
let q = insert("user", data).unwrap();
let sql = q.to_surql().unwrap();
assert!(sql.starts_with("CREATE user CONTENT"));
assert!(sql.contains("name: 'Alice'"));
}
#[test]
fn update_helper_constructs_query() {
let mut data = DataMap::new();
data.insert("status".into(), Value::String("active".into()));
let q = update("user:alice", data).unwrap();
assert_eq!(
q.to_surql().unwrap(),
"UPDATE user:alice SET status = 'active'"
);
}
#[test]
fn upsert_helper_constructs_query() {
let mut data = DataMap::new();
data.insert("status".into(), Value::String("active".into()));
let q = upsert("user:alice", data).unwrap();
assert_eq!(
q.to_surql().unwrap(),
"UPSERT user:alice CONTENT {status: 'active'}"
);
}
#[test]
fn delete_helper_constructs_query() {
let q = delete("user:alice").unwrap();
assert_eq!(q.to_surql().unwrap(), "DELETE user:alice");
}
#[test]
fn relate_helper_constructs_query() {
let q = relate("likes", "user:alice", "post:123", None).unwrap();
assert_eq!(q.to_surql().unwrap(), "RELATE user:alice->likes->post:123");
}
#[test]
fn vector_search_query_helper() {
let q = vector_search_query(
"documents",
"embedding",
vec![0.1, 0.2, 0.3],
10,
VectorDistanceType::Cosine,
None,
Some(0.7),
)
.unwrap();
let sql = q.to_surql().unwrap();
assert!(sql.starts_with("SELECT * FROM documents"));
assert!(sql.contains("embedding <|10,COSINE,0.7|>"));
}
}