use axum::{
body::Body,
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use futures::StreamExt;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::types::{ErrorResponse, UpsertPointsRequest};
use crate::AppState;
use velesdb_core::{Collection, Point};
const STREAM_BATCH_SIZE: usize = 100;
const STREAM_BATCH_MAX_WAIT: Duration = Duration::from_millis(100);
#[derive(Default)]
struct StreamUpsertStats {
inserted: usize,
malformed: usize,
failed_upserts: usize,
}
fn parse_ndjson_line(line: &str, batch: &mut Vec<Point>, stats: &mut StreamUpsertStats) {
if line.is_empty() {
return;
}
match serde_json::from_str::<Point>(line) {
Ok(point) => batch.push(point),
Err(error) => {
stats.malformed += 1;
tracing::warn!(error = %error, "Skipping malformed NDJSON point");
}
}
}
async fn flush_point_batch(
collection: Collection,
batch: &mut Vec<Point>,
stats: &mut StreamUpsertStats,
) {
if batch.is_empty() {
return;
}
let points = std::mem::take(batch);
let batch_size = points.len();
match tokio::task::spawn_blocking(move || collection.upsert_bulk(&points)).await {
Ok(Ok(inserted)) => stats.inserted += inserted,
Ok(Err(error)) => {
stats.failed_upserts += batch_size;
tracing::error!(
error = %error,
batch_size,
"Failed to upsert streamed batch"
);
}
Err(error) => {
stats.failed_upserts += batch_size;
tracing::error!(
error = %error,
batch_size,
"Stream upsert batch task panicked"
);
}
}
}
#[utoipa::path(
post,
path = "/collections/{name}/points",
tag = "points",
params(
("name" = String, Path, description = "Collection name")
),
request_body = UpsertPointsRequest,
responses(
(status = 200, description = "Points upserted", body = Object),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 400, description = "Invalid request", body = ErrorResponse)
)
)]
pub async fn upsert_points(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
Json(req): Json<UpsertPointsRequest>,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let points: Vec<Point> = req
.points
.into_iter()
.map(|p| Point::new(p.id, p.vector, p.payload))
.collect();
let result = tokio::task::spawn_blocking(move || collection.upsert_bulk(&points)).await;
match result {
Ok(Ok(inserted)) => Json(serde_json::json!({
"message": "Points upserted",
"count": inserted
}))
.into_response(),
Ok(Err(e)) => (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e.to_string(),
}),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Task panicked: {e}"),
}),
)
.into_response(),
}
}
#[utoipa::path(
post,
path = "/collections/{name}/points/stream",
tag = "points",
params(
("name" = String, Path, description = "Collection name")
),
request_body(content = String, content_type = "application/x-ndjson", description = "NDJSON stream with one point per line"),
responses(
(status = 200, description = "Stream processed", body = Object),
(status = 404, description = "Collection not found", body = ErrorResponse)
)
)]
pub async fn stream_upsert_points(
State(state): State<Arc<AppState>>,
Path(name): Path<String>,
body: Body,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let mut stream = body.into_data_stream();
let mut buffer = Vec::<u8>::new();
let mut batch = Vec::with_capacity(STREAM_BATCH_SIZE);
let mut stats = StreamUpsertStats::default();
let mut last_flush = Instant::now();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
buffer.extend_from_slice(&chunk);
while let Some(newline_pos) = buffer.iter().position(|byte| *byte == b'\n') {
let line_bytes: Vec<u8> = buffer.drain(..=newline_pos).collect();
let line = String::from_utf8_lossy(&line_bytes);
parse_ndjson_line(line.trim(), &mut batch, &mut stats);
if batch.len() >= STREAM_BATCH_SIZE
|| (!batch.is_empty() && last_flush.elapsed() >= STREAM_BATCH_MAX_WAIT)
{
flush_point_batch(collection.clone(), &mut batch, &mut stats).await;
last_flush = Instant::now();
}
}
}
Err(error) => {
tracing::warn!(error = %error, "Error while reading request body stream");
}
}
}
if !buffer.is_empty() {
let line = String::from_utf8_lossy(&buffer);
parse_ndjson_line(line.trim(), &mut batch, &mut stats);
}
flush_point_batch(collection, &mut batch, &mut stats).await;
Json(serde_json::json!({
"message": "Stream processed",
"inserted": stats.inserted,
"malformed": stats.malformed,
"failed_upserts": stats.failed_upserts
}))
.into_response()
}
#[utoipa::path(
get,
path = "/collections/{name}/points/{id}",
tag = "points",
params(
("name" = String, Path, description = "Collection name"),
("id" = u64, Path, description = "Point ID")
),
responses(
(status = 200, description = "Point found", body = Object),
(status = 404, description = "Point or collection not found", body = ErrorResponse)
)
)]
pub async fn get_point(
State(state): State<Arc<AppState>>,
Path((name, id)): Path<(String, u64)>,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
let points = collection.get(&[id]);
match points.into_iter().next().flatten() {
Some(point) => Json(serde_json::json!({
"id": point.id,
"vector": point.vector,
"payload": point.payload
}))
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Point {} not found", id),
}),
)
.into_response(),
}
}
#[utoipa::path(
delete,
path = "/collections/{name}/points/{id}",
tag = "points",
params(
("name" = String, Path, description = "Collection name"),
("id" = u64, Path, description = "Point ID")
),
responses(
(status = 200, description = "Point deleted", body = Object),
(status = 404, description = "Point or collection not found", body = ErrorResponse)
)
)]
#[allow(clippy::unused_async)]
pub async fn delete_point(
State(state): State<Arc<AppState>>,
Path((name, id)): Path<(String, u64)>,
) -> impl IntoResponse {
let collection = match state.db.get_collection(&name) {
Some(c) => c,
None => {
return (
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("Collection '{}' not found", name),
}),
)
.into_response()
}
};
match collection.delete(&[id]) {
Ok(()) => Json(serde_json::json!({
"message": "Point deleted",
"id": id
}))
.into_response(),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e.to_string(),
}),
)
.into_response(),
}
}