use std::sync::Arc;
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use velesdb_core::api_types::serde_id;
use velesdb_core::collection::graph::GraphEdge;
use velesdb_core::point::Point;
use crate::types::ErrorResponse;
use crate::AppState;
use super::super::helpers::{error_response, get_collection_or_404};
const EXPIRES_AT_KEY: &str = "_veles_expires_at";
#[derive(Debug, Deserialize, ToSchema)]
pub struct RelateRequest {
#[serde(deserialize_with = "serde_id::deserialize_id_from_string_or_number")]
#[cfg_attr(feature = "openapi", schema(schema_with = serde_id::id_input_schema))]
pub source: u64,
#[serde(deserialize_with = "serde_id::deserialize_id_from_string_or_number")]
#[cfg_attr(feature = "openapi", schema(schema_with = serde_id::id_input_schema))]
pub target: u64,
pub rel_type: String,
#[serde(default)]
pub properties: serde_json::Value,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct RelateResponse {
#[serde(serialize_with = "serde_id::serialize_id_as_string")]
#[cfg_attr(feature = "openapi", schema(value_type = String))]
pub edge_id: u64,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct RelationEdge {
#[serde(serialize_with = "serde_id::serialize_id_as_string")]
#[cfg_attr(feature = "openapi", schema(value_type = String))]
pub id: u64,
#[serde(serialize_with = "serde_id::serialize_id_as_string")]
#[cfg_attr(feature = "openapi", schema(value_type = String))]
pub source: u64,
#[serde(serialize_with = "serde_id::serialize_id_as_string")]
#[cfg_attr(feature = "openapi", schema(value_type = String))]
pub target: u64,
pub rel_type: String,
pub properties: serde_json::Value,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct RelationsResponse {
pub edges: Vec<RelationEdge>,
pub count: usize,
}
#[derive(Debug, Deserialize, ToSchema)]
pub struct SetTtlRequest {
pub ttl_seconds: u64,
}
#[utoipa::path(
post,
path = "/collections/{name}/relations",
params(("name" = String, Path, description = "Collection name")),
request_body = RelateRequest,
responses(
(status = 201, description = "Relation created", body = RelateResponse),
(status = 400, description = "Invalid request", body = ErrorResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 500, description = "Internal server error", body = ErrorResponse)
),
tag = "graph"
)]
pub async fn relate_points(
Path(name): Path<String>,
State(state): State<Arc<AppState>>,
Json(req): Json<RelateRequest>,
) -> axum::response::Response {
let coll = match get_collection_or_404(&state, &name) {
Ok(c) => c,
Err(r) => return r,
};
let properties: std::collections::HashMap<String, serde_json::Value> = match req.properties {
serde_json::Value::Object(ref map) => {
map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}
serde_json::Value::Null => std::collections::HashMap::new(),
_ => {
return error_response(
StatusCode::BAD_REQUEST,
"properties must be an object or null".to_string(),
)
}
};
match insert_edge_with_retry(&coll, &req, properties) {
Ok(edge_id) => (StatusCode::CREATED, Json(RelateResponse { edge_id })).into_response(),
Err(r) => r,
}
}
#[allow(clippy::result_large_err)]
fn insert_edge_with_retry(
coll: &velesdb_core::collection::AnyCollection,
req: &RelateRequest,
properties: std::collections::HashMap<String, serde_json::Value>,
) -> Result<u64, axum::response::Response> {
let mut next_id = coll.max_edge_id().map_or(1, |m| m.saturating_add(1));
loop {
if coll.edge_exists(next_id) {
next_id = next_id.saturating_add(1);
continue;
}
let edge = match GraphEdge::new(next_id, req.source, req.target, &req.rel_type) {
Ok(e) => e.with_properties(properties.clone()),
Err(e) => {
return Err(error_response(
StatusCode::BAD_REQUEST,
format!("invalid edge: {e}"),
))
}
};
match coll.add_edge(edge) {
Ok(()) => return Ok(next_id),
Err(velesdb_core::Error::EdgeExists(_)) => {
next_id = next_id.saturating_add(1);
}
Err(e) => {
return Err(error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to create relation: {e}"),
))
}
}
}
}
#[utoipa::path(
delete,
path = "/collections/{name}/relations/{edge_id}",
params(
("name" = String, Path, description = "Collection name"),
("edge_id" = u64, Path, description = "Edge ID to remove")
),
responses(
(status = 204, description = "Relation removed"),
(status = 404, description = "Collection or edge not found", body = ErrorResponse),
(status = 500, description = "Internal server error", body = ErrorResponse)
),
tag = "graph"
)]
pub async fn unrelate_points(
Path((name, edge_id)): Path<(String, u64)>,
State(state): State<Arc<AppState>>,
) -> axum::response::Response {
let coll = match get_collection_or_404(&state, &name) {
Ok(c) => c,
Err(r) => return r,
};
if coll.remove_edge(edge_id) {
StatusCode::NO_CONTENT.into_response()
} else {
let err = velesdb_core::Error::EdgeNotFound(edge_id);
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("{err} in collection '{name}'"),
code: Some(err.code().to_string()),
}),
)
.into_response()
}
}
#[utoipa::path(
get,
path = "/collections/{name}/points/{id}/relations",
params(
("name" = String, Path, description = "Collection name"),
("id" = u64, Path, description = "Point ID")
),
responses(
(status = 200, description = "Outgoing relations", body = RelationsResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 500, description = "Internal server error", body = ErrorResponse)
),
tag = "graph"
)]
pub async fn get_point_relations(
Path((name, id)): Path<(String, u64)>,
State(state): State<Arc<AppState>>,
) -> axum::response::Response {
let coll = match get_collection_or_404(&state, &name) {
Ok(c) => c,
Err(r) => return r,
};
let raw_edges = coll.get_outgoing_edges(id);
let edges: Vec<RelationEdge> = raw_edges
.into_iter()
.map(|e| RelationEdge {
id: e.id(),
source: e.source(),
target: e.target(),
rel_type: e.label().to_string(),
properties: serde_json::to_value(e.properties()).unwrap_or_default(),
})
.collect();
let count = edges.len();
Json(RelationsResponse { edges, count }).into_response()
}
#[utoipa::path(
patch,
path = "/collections/{name}/points/{id}/ttl",
params(
("name" = String, Path, description = "Collection name"),
("id" = u64, Path, description = "Point ID")
),
request_body = SetTtlRequest,
responses(
(status = 204, description = "TTL set successfully"),
(status = 400, description = "Non-object payload", body = ErrorResponse),
(status = 404, description = "Collection or point not found", body = ErrorResponse),
(status = 500, description = "Internal server error", body = ErrorResponse)
),
tag = "points"
)]
pub async fn set_point_ttl(
Path((name, id)): Path<(String, u64)>,
State(state): State<Arc<AppState>>,
Json(req): Json<SetTtlRequest>,
) -> axum::response::Response {
let coll = match get_collection_or_404(&state, &name) {
Ok(c) => c,
Err(r) => return r,
};
let point = match coll.get(&[id]).into_iter().flatten().next() {
Some(p) => p,
None => {
let err = velesdb_core::Error::PointNotFound(id);
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("{err} in collection '{name}'"),
code: Some(err.code().to_string()),
}),
)
.into_response();
}
};
let expires_at = now_secs().saturating_add(req.ttl_seconds);
let updated = match stamp_ttl(point, id, expires_at, &name) {
Ok(p) => p,
Err(r) => return r,
};
match coll.upsert(vec![updated]) {
Ok(()) => StatusCode::NO_CONTENT.into_response(),
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to set TTL: {e}"),
),
}
}
#[allow(clippy::result_large_err)]
fn stamp_ttl(
point: Point,
id: u64,
expires_at: u64,
collection: &str,
) -> Result<Point, axum::response::Response> {
let mut payload = point
.payload
.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
let Some(obj) = payload.as_object_mut() else {
return Err(error_response(
StatusCode::BAD_REQUEST,
format!("point {id} in '{collection}' has a non-object payload"),
));
};
obj.insert(
EXPIRES_AT_KEY.to_string(),
serde_json::Value::from(expires_at),
);
Ok(Point {
id,
vector: point.vector,
payload: Some(payload),
sparse_vectors: point.sparse_vectors,
})
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}