use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use std::sync::Arc;
use crate::types::{
mode_to_ef_search, BatchSearchRequest, BatchSearchResponse, ErrorResponse, HybridSearchRequest,
MultiQuerySearchRequest, SearchRequest, SearchResponse, SearchResultResponse,
TextSearchRequest,
};
use crate::AppState;
#[utoipa::path(
post,
path = "/collections/{name}/search",
tag = "search",
params(
("name" = String, Path, description = "Collection name")
),
request_body = SearchRequest,
responses(
(status = 200, description = "Search results", body = SearchResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 400, description = "Invalid request", body = ErrorResponse)
)
)]
#[allow(clippy::unused_async)]
pub async fn search(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
Json(req): Json<SearchRequest>,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let effective_ef = req
.ef_search
.or_else(|| req.mode.as_ref().and_then(|m| mode_to_ef_search(m)));
let search_result = if let Some(ref filter_json) = req.filter {
let filter: velesdb_core::Filter = match serde_json::from_value(filter_json.clone()) {
Ok(f) => f,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Invalid filter: {}", e),
}),
)
.into_response()
}
};
collection.search_with_filter(&req.vector, req.top_k, &filter)
} else if let Some(ef) = effective_ef {
collection.search_with_ef(&req.vector, req.top_k, ef)
} else {
collection.search(&req.vector, req.top_k)
};
match search_result {
Ok(results) => {
let response = SearchResponse {
results: results
.into_iter()
.map(|r| SearchResultResponse {
id: r.point.id,
score: r.score,
payload: r.point.payload,
})
.collect(),
};
Json(response).into_response()
}
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e.to_string(),
}),
)
.into_response(),
}
}
#[utoipa::path(
post,
path = "/collections/{name}/search/batch",
tag = "search",
params(
("name" = String, Path, description = "Collection name")
),
request_body = BatchSearchRequest,
responses(
(status = 200, description = "Batch search results", body = BatchSearchResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 400, description = "Invalid request", body = ErrorResponse)
)
)]
#[allow(clippy::unused_async)]
pub async fn batch_search(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
Json(req): Json<BatchSearchRequest>,
) -> impl IntoResponse {
let start = std::time::Instant::now();
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let queries: Vec<&[f32]> = req.searches.iter().map(|s| s.vector.as_slice()).collect();
let filters: Vec<Option<velesdb_core::Filter>> = req
.searches
.iter()
.map(|s| {
s.filter
.as_ref()
.and_then(|f_json| serde_json::from_value(f_json.clone()).ok())
})
.collect();
let top_k = req.searches.first().map_or(10, |s| s.top_k);
let all_results = match collection.search_batch_with_filters(&queries, top_k, &filters) {
Ok(batch_results) => batch_results
.into_iter()
.map(|results| SearchResponse {
results: results
.into_iter()
.map(|r| SearchResultResponse {
id: r.point.id,
score: r.score,
payload: r.point.payload,
})
.collect(),
})
.collect(),
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e.to_string(),
}),
)
.into_response()
}
};
let timing_ms = start.elapsed().as_secs_f64() * 1000.0;
Json(BatchSearchResponse {
results: all_results,
timing_ms,
})
.into_response()
}
#[allow(clippy::unused_async)]
pub async fn multi_query_search(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
Json(req): Json<MultiQuerySearchRequest>,
) -> impl IntoResponse {
use velesdb_core::FusionStrategy;
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let strategy = match req.strategy.to_lowercase().as_str() {
"average" | "avg" => FusionStrategy::Average,
"maximum" | "max" => FusionStrategy::Maximum,
"rrf" => FusionStrategy::RRF { k: req.rrf_k },
"weighted" => FusionStrategy::Weighted {
avg_weight: req.avg_weight,
max_weight: req.max_weight,
hit_weight: req.hit_weight,
},
_ => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!(
"Invalid strategy: {}. Valid: average, maximum, rrf, weighted",
req.strategy
),
}),
)
.into_response()
}
};
let query_refs: Vec<&[f32]> = req.vectors.iter().map(Vec::as_slice).collect();
let results = match collection.multi_query_search(&query_refs, req.top_k, strategy, None) {
Ok(r) => r,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e.to_string(),
}),
)
.into_response()
}
};
let response = SearchResponse {
results: results
.into_iter()
.map(|r| SearchResultResponse {
id: r.point.id,
score: r.score,
payload: r.point.payload,
})
.collect(),
};
Json(response).into_response()
}
#[utoipa::path(
post,
path = "/collections/{name}/search/text",
tag = "search",
params(
("name" = String, Path, description = "Collection name")
),
request_body = TextSearchRequest,
responses(
(status = 200, description = "Text search results", body = SearchResponse),
(status = 404, description = "Collection not found", body = ErrorResponse)
)
)]
#[allow(clippy::unused_async)]
pub async fn text_search(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
Json(req): Json<TextSearchRequest>,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let results = if let Some(ref filter_json) = req.filter {
let filter: velesdb_core::Filter = match serde_json::from_value(filter_json.clone()) {
Ok(f) => f,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Invalid filter: {}", e),
}),
)
.into_response()
}
};
collection.text_search_with_filter(&req.query, req.top_k, &filter)
} else {
collection.text_search(&req.query, req.top_k)
};
let response = SearchResponse {
results: results
.into_iter()
.map(|r| SearchResultResponse {
id: r.point.id,
score: r.score,
payload: r.point.payload,
})
.collect(),
};
Json(response).into_response()
}
#[utoipa::path(
post,
path = "/collections/{name}/search/hybrid",
tag = "search",
params(
("name" = String, Path, description = "Collection name")
),
request_body = HybridSearchRequest,
responses(
(status = 200, description = "Hybrid search results", body = SearchResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 400, description = "Invalid request", body = ErrorResponse)
)
)]
#[allow(clippy::unused_async)]
pub async fn hybrid_search(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
Json(req): Json<HybridSearchRequest>,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let search_result = if let Some(ref filter_json) = req.filter {
let filter: velesdb_core::Filter = match serde_json::from_value(filter_json.clone()) {
Ok(f) => f,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Invalid filter: {}", e),
}),
)
.into_response()
}
};
collection.hybrid_search_with_filter(
&req.vector,
&req.query,
req.top_k,
Some(req.vector_weight),
&filter,
)
} else {
collection.hybrid_search(&req.vector, &req.query, req.top_k, Some(req.vector_weight))
};
match search_result {
Ok(results) => {
let response = SearchResponse {
results: results
.into_iter()
.map(|r| SearchResultResponse {
id: r.point.id,
score: r.score,
payload: r.point.payload,
})
.collect(),
};
Json(response).into_response()
}
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e.to_string(),
}),
)
.into_response(),
}
}