use axum::{
extract::{Path, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use crate::types::VELESQL_CONTRACT_VERSION;
use crate::AppState;
#[derive(Debug, Deserialize)]
pub struct MatchQueryRequest {
pub query: String,
#[serde(default)]
pub params: HashMap<String, serde_json::Value>,
#[serde(default)]
pub vector: Option<Vec<f32>>,
#[serde(default)]
pub threshold: Option<f32>,
}
#[derive(Debug, Serialize)]
pub struct MatchQueryResultItem {
pub bindings: HashMap<String, u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub score: Option<f32>,
pub depth: u32,
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub projected: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct MatchQueryResponse {
pub results: Vec<MatchQueryResultItem>,
pub took_ms: u64,
pub count: usize,
pub meta: MatchQueryMeta,
}
#[derive(Debug, Serialize)]
pub struct MatchQueryMeta {
pub velesql_contract_version: String,
}
#[derive(Debug, Serialize)]
pub struct MatchQueryError {
pub error: String,
pub code: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub hint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
pub async fn match_query(
Path(collection_name): Path<String>,
State(state): State<Arc<AppState>>,
Json(request): Json<MatchQueryRequest>,
) -> Result<Json<MatchQueryResponse>, (StatusCode, Json<MatchQueryError>)> {
let start = std::time::Instant::now();
let mk_error = |error: String, code: &str, hint: &str, details: Option<serde_json::Value>| {
Json(MatchQueryError {
error,
code: code.to_string(),
hint: Some(hint.to_string()),
details,
})
};
let collection = state.db.get_collection(&collection_name).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
mk_error(
format!("Collection '{}' not found", collection_name),
"COLLECTION_NOT_FOUND",
"Create the collection first or correct the collection name in the route",
Some(serde_json::json!({ "collection": collection_name })),
),
)
})?;
let query = velesdb_core::velesql::Parser::parse(&request.query).map_err(|e| {
(
StatusCode::BAD_REQUEST,
mk_error(
format!("Parse error: {}", e),
"PARSE_ERROR",
"Check MATCH syntax and bound parameters",
None,
),
)
})?;
let match_clause = query.match_clause.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
mk_error(
"Query is not a MATCH query".to_string(),
"NOT_MATCH_QUERY",
"Use MATCH (...) RETURN ... or call /query for SELECT statements",
None,
),
)
})?;
if let Some(threshold) = request.threshold {
if !(0.0..=1.0).contains(&threshold) {
return Err((
StatusCode::BAD_REQUEST,
mk_error(
format!(
"Invalid threshold: {}. Must be between 0.0 and 1.0",
threshold
),
"INVALID_THRESHOLD",
"Provide threshold in inclusive range [0.0, 1.0]",
Some(serde_json::json!({ "threshold": threshold })),
),
));
}
}
let results = if let Some(ref vector) = request.vector {
let threshold = request.threshold.unwrap_or(0.0);
collection
.execute_match_with_similarity(&match_clause, vector, threshold, &request.params)
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
mk_error(
format!("Execution error: {}", e),
"EXECUTION_ERROR",
"Validate graph labels/properties and parameter types for this collection",
None,
),
)
})?
} else {
collection
.execute_match(&match_clause, &request.params)
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
mk_error(
format!("Execution error: {}", e),
"EXECUTION_ERROR",
"Validate graph labels/properties and parameter types for this collection",
None,
),
)
})?
};
let result_items: Vec<MatchQueryResultItem> = results
.into_iter()
.map(|r| MatchQueryResultItem {
bindings: r.bindings,
score: r.score,
depth: r.depth,
projected: r.projected,
})
.collect();
let count = result_items.len();
let took_ms = start.elapsed().as_millis() as u64;
Ok(Json(MatchQueryResponse {
results: result_items,
took_ms,
count,
meta: MatchQueryMeta {
velesql_contract_version: VELESQL_CONTRACT_VERSION.to_string(),
},
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_match_query_request_deserialize() {
let json = r#"{
"query": "MATCH (a:Person)-[:KNOWS]->(b) RETURN a.name",
"params": {}
}"#;
let request: MatchQueryRequest = serde_json::from_str(json).unwrap();
assert!(request.query.contains("MATCH"));
assert!(request.params.is_empty());
}
#[test]
fn test_match_query_response_serialize() {
let response = MatchQueryResponse {
results: vec![MatchQueryResultItem {
bindings: HashMap::from([("a".to_string(), 123)]),
score: Some(0.95),
depth: 1,
projected: HashMap::new(),
}],
took_ms: 15,
count: 1,
meta: MatchQueryMeta {
velesql_contract_version: VELESQL_CONTRACT_VERSION.to_string(),
},
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("bindings"));
assert!(json.contains("0.95"));
}
#[test]
fn test_match_query_response_with_projected_properties() {
let mut projected = HashMap::new();
projected.insert("author.name".to_string(), serde_json::json!("John Doe"));
let response = MatchQueryResponse {
results: vec![MatchQueryResultItem {
bindings: HashMap::from([("author".to_string(), 42)]),
score: Some(0.92),
depth: 1,
projected,
}],
took_ms: 10,
count: 1,
meta: MatchQueryMeta {
velesql_contract_version: VELESQL_CONTRACT_VERSION.to_string(),
},
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("John Doe"));
assert!(json.contains("author.name"));
}
}