use crate::{error::Error, state::AppState, Result};
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post, put},
Json, Router,
};
use ruvector_core::{SearchQuery, SearchResult, VectorEntry};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Deserialize)]
pub struct UpsertPointsRequest {
pub points: Vec<VectorEntry>,
}
#[derive(Debug, Deserialize)]
pub struct SearchRequest {
pub vector: Vec<f32>,
#[serde(default = "default_limit")]
pub k: usize,
pub score_threshold: Option<f32>,
pub filter: Option<HashMap<String, serde_json::Value>>,
}
fn default_limit() -> usize {
10
}
#[derive(Debug, Serialize)]
pub struct SearchResponse {
pub results: Vec<SearchResult>,
}
#[derive(Debug, Serialize)]
pub struct UpsertResponse {
pub ids: Vec<String>,
}
pub fn routes() -> Router<AppState> {
Router::new()
.route("/collections/:name/points", put(upsert_points))
.route("/collections/:name/points/search", post(search_points))
.route("/collections/:name/points/:id", get(get_point))
}
async fn upsert_points(
State(state): State<AppState>,
Path(name): Path<String>,
Json(req): Json<UpsertPointsRequest>,
) -> Result<impl IntoResponse> {
let db = state
.get_collection(&name)
.ok_or_else(|| Error::CollectionNotFound(name.clone()))?;
let ids = db.insert_batch(req.points).map_err(Error::Core)?;
Ok((StatusCode::OK, Json(UpsertResponse { ids })))
}
async fn search_points(
State(state): State<AppState>,
Path(name): Path<String>,
Json(req): Json<SearchRequest>,
) -> Result<impl IntoResponse> {
let db = state
.get_collection(&name)
.ok_or_else(|| Error::CollectionNotFound(name))?;
let query = SearchQuery {
vector: req.vector,
k: req.k,
filter: req.filter,
ef_search: None,
};
let mut results = db.search(query).map_err(Error::Core)?;
if let Some(threshold) = req.score_threshold {
results.retain(|r| r.score >= threshold);
}
Ok(Json(SearchResponse { results }))
}
async fn get_point(
State(state): State<AppState>,
Path((name, id)): Path<(String, String)>,
) -> Result<impl IntoResponse> {
let db = state
.get_collection(&name)
.ok_or_else(|| Error::CollectionNotFound(name))?;
let entry = db.get(&id).map_err(Error::Core)?;
Ok(Json(entry))
}