#![cfg(any(feature = "client", feature = "client-rustls"))]
use std::collections::BTreeMap;
use std::fmt::Write as _;
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::connection::DatabaseClient;
use crate::error::{Result, SurqlError};
use super::executor::{extract_rows, flatten_rows};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Direction {
Out,
In,
Both,
}
impl Direction {
fn arrow(self) -> &'static str {
match self {
Self::Out => "->",
Self::In => "<-",
Self::Both => "<->",
}
}
}
pub async fn traverse<T: DeserializeOwned>(
client: &DatabaseClient,
start: &str,
path: &str,
) -> Result<Vec<T>> {
let surql = format!("SELECT * FROM {start}{path}");
let raw = client.query(&surql).await?;
extract_rows::<T>(&raw)
}
pub async fn traverse_with_depth<T: DeserializeOwned>(
client: &DatabaseClient,
start: &str,
edge_table: &str,
target_table: &str,
direction: Direction,
depth: Option<u32>,
) -> Result<Vec<T>> {
let arrow = direction.arrow();
let depth_str = depth.map_or(String::new(), |d| d.to_string());
let path = format!("{arrow}{edge_table}{depth_str}{arrow}{target_table}");
traverse(client, start, &path).await
}
pub async fn traverse_raw(client: &DatabaseClient, start: &str, path: &str) -> Result<Vec<Value>> {
let surql = format!("SELECT * FROM {start}{path}");
let raw = client.query(&surql).await?;
Ok(flatten_rows(&raw))
}
pub async fn create_relation(
client: &DatabaseClient,
edge_table: &str,
from_record: &str,
to_record: &str,
data: Option<Value>,
) -> Result<Value> {
let surql = if data.is_some() {
format!("RELATE {from_record}->{edge_table}->{to_record} CONTENT $data")
} else {
format!("RELATE {from_record}->{edge_table}->{to_record}")
};
let raw = if let Some(payload) = data {
let mut vars = BTreeMap::new();
vars.insert("data".to_owned(), payload);
client.query_with_vars(&surql, vars).await?
} else {
client.query(&surql).await?
};
Ok(flatten_rows(&raw).into_iter().next().unwrap_or(Value::Null))
}
pub async fn remove_relation(
client: &DatabaseClient,
edge_table: &str,
from_record: &str,
to_record: &str,
) -> Result<()> {
let surql = format!("DELETE {from_record}->{edge_table}->{to_record}");
client.query(&surql).await?;
Ok(())
}
pub async fn get_outgoing_edges(
client: &DatabaseClient,
record: &str,
edge_table: &str,
) -> Result<Vec<Value>> {
let surql = format!("SELECT * FROM {record}->{edge_table}");
let raw = client.query(&surql).await?;
Ok(flatten_rows(&raw))
}
pub async fn get_incoming_edges(
client: &DatabaseClient,
record: &str,
edge_table: &str,
) -> Result<Vec<Value>> {
let surql = format!("SELECT * FROM {record}<-{edge_table}");
let raw = client.query(&surql).await?;
Ok(flatten_rows(&raw))
}
pub async fn get_related_records(
client: &DatabaseClient,
record: &str,
edge_table: &str,
target_table: &str,
direction: Direction,
) -> Result<Vec<Value>> {
let path = match direction {
Direction::Out => format!("->{edge_table}->{target_table}"),
Direction::In => format!("<-{edge_table}<-{target_table}"),
Direction::Both => {
return Err(SurqlError::Validation {
reason: "get_related_records direction must be Out or In".to_string(),
});
}
};
let surql = format!("SELECT * FROM {record}{path}");
let raw = client.query(&surql).await?;
Ok(flatten_rows(&raw))
}
pub async fn count_related(
client: &DatabaseClient,
record: &str,
edge_table: &str,
direction: Direction,
) -> Result<i64> {
let mut surql = match direction {
Direction::Out => format!("SELECT count() FROM {record}->{edge_table}"),
Direction::In => format!("SELECT count() FROM {record}<-{edge_table}"),
Direction::Both => {
return Err(SurqlError::Validation {
reason: "count_related direction must be Out or In".to_string(),
});
}
};
surql.push_str(" GROUP ALL");
let raw = client.query(&surql).await?;
let first = flatten_rows(&raw).into_iter().next();
Ok(first
.as_ref()
.and_then(|r| r.get("count").and_then(Value::as_i64))
.unwrap_or(0))
}
pub async fn shortest_path(
client: &DatabaseClient,
from_record: &str,
to_record: &str,
edge_table: &str,
max_depth: u32,
) -> Result<Vec<Value>> {
for depth in 1..=max_depth {
let mut path = String::new();
for _ in 0..depth {
write!(path, "->{edge_table}->?").expect("write to String cannot fail");
}
let surql = format!("SELECT * FROM {from_record}{path} WHERE id = {to_record} LIMIT 1");
let raw = client.query(&surql).await?;
let rows = flatten_rows(&raw);
if !rows.is_empty() {
return Ok(rows);
}
}
Ok(Vec::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn direction_arrow_matches_py_semantics() {
assert_eq!(Direction::Out.arrow(), "->");
assert_eq!(Direction::In.arrow(), "<-");
assert_eq!(Direction::Both.arrow(), "<->");
}
#[test]
fn traverse_path_is_plain_append() {
let start = "user:alice";
let path = "->likes->post";
assert_eq!(
format!("SELECT * FROM {start}{path}"),
"SELECT * FROM user:alice->likes->post"
);
}
#[test]
fn traverse_with_depth_renders_depth_suffix() {
let arrow = Direction::Out.arrow();
let edge = "follows";
let target = "user";
let depth = Some(2u32);
let depth_str = depth.map_or(String::new(), |d| d.to_string());
let path = format!("{arrow}{edge}{depth_str}{arrow}{target}");
assert_eq!(path, "->follows2->user");
}
#[test]
fn count_related_rejects_both_direction() {
let rendered = match Direction::Both {
Direction::Out | Direction::In => "ok",
Direction::Both => "err",
};
assert_eq!(rendered, "err");
}
#[test]
fn shortest_path_renders_chained_wildcard_edges() {
let edge_table = "follows";
let mut path = String::new();
for _ in 0..3 {
write!(path, "->{edge_table}->?").unwrap();
}
assert_eq!(path, "->follows->?->follows->?->follows->?");
}
}