alopex-server 0.4.1

Server component for Alopex DB
Documentation
use std::sync::Arc;

use alopex_core::kv::KVTransaction;
use alopex_core::types::TxnMode;
use alopex_core::vector::hnsw::{HnswConfig, HnswIndex, HnswSearchResult, HnswStats};
use alopex_core::vector::Metric;
use alopex_core::KVStore;
use axum::extract::Extension;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};

use crate::error::{Result, ServerError};
use crate::http::{error_response, json_response, RequestContext};
use crate::server::ServerState;

const DEFAULT_M: usize = 16;
const DEFAULT_EF_CONSTRUCTION: usize = 200;

#[derive(Debug, Deserialize)]
pub struct HnswSearchRequest {
    pub index: String,
    pub query: Vec<f32>,
    #[serde(default = "default_k")]
    pub k: usize,
}

#[derive(Debug, Deserialize)]
pub struct HnswUpsertRequest {
    pub index: String,
    pub key: Vec<u8>,
    pub vector: Vec<f32>,
}

#[derive(Debug, Deserialize)]
pub struct HnswDeleteRequest {
    pub index: String,
    pub key: Vec<u8>,
}

#[derive(Debug, Deserialize)]
pub struct HnswCreateRequest {
    pub index: String,
    pub dim: usize,
    pub metric: String,
}

#[derive(Debug, Deserialize)]
pub struct HnswDropRequest {
    pub index: String,
}

#[derive(Debug, Deserialize)]
pub struct HnswStatsRequest {
    pub index: String,
}

#[derive(Debug, Serialize)]
pub struct HnswSearchResponse {
    pub results: Vec<HnswSearchResult>,
}

#[derive(Debug, Serialize)]
pub struct HnswStatsResponse {
    pub stats: HnswStats,
}

#[derive(Debug, Serialize)]
pub struct HnswStatusResponse {
    pub success: bool,
}

pub async fn search(
    Extension(state): Extension<Arc<ServerState>>,
    Extension(ctx): Extension<RequestContext>,
    Json(request): Json<HnswSearchRequest>,
) -> Response {
    match search_impl(state.clone(), request) {
        Ok(resp) => json_response(resp, state.config.max_response_size, &ctx),
        Err(err) => error_response(err, &ctx),
    }
}

pub async fn upsert(
    Extension(state): Extension<Arc<ServerState>>,
    Extension(ctx): Extension<RequestContext>,
    Json(request): Json<HnswUpsertRequest>,
) -> Response {
    match upsert_impl(state.clone(), request) {
        Ok(resp) => json_response(resp, state.config.max_response_size, &ctx),
        Err(err) => error_response(err, &ctx),
    }
}

pub async fn delete(
    Extension(state): Extension<Arc<ServerState>>,
    Extension(ctx): Extension<RequestContext>,
    Json(request): Json<HnswDeleteRequest>,
) -> Response {
    match delete_impl(state.clone(), request) {
        Ok(resp) => json_response(resp, state.config.max_response_size, &ctx),
        Err(err) => error_response(err, &ctx),
    }
}

pub async fn create(
    Extension(state): Extension<Arc<ServerState>>,
    Extension(ctx): Extension<RequestContext>,
    Json(request): Json<HnswCreateRequest>,
) -> Response {
    match create_impl(state.clone(), request) {
        Ok(resp) => json_response(resp, state.config.max_response_size, &ctx),
        Err(err) => error_response(err, &ctx),
    }
}

pub async fn drop(
    Extension(state): Extension<Arc<ServerState>>,
    Extension(ctx): Extension<RequestContext>,
    Json(request): Json<HnswDropRequest>,
) -> Response {
    match drop_impl(state.clone(), request) {
        Ok(resp) => json_response(resp, state.config.max_response_size, &ctx),
        Err(err) => error_response(err, &ctx),
    }
}

pub async fn stats(
    Extension(state): Extension<Arc<ServerState>>,
    Extension(ctx): Extension<RequestContext>,
    Json(request): Json<HnswStatsRequest>,
) -> Response {
    match stats_impl(state.clone(), request) {
        Ok(resp) => json_response(resp, state.config.max_response_size, &ctx),
        Err(err) => error_response(err, &ctx),
    }
}

fn search_impl(state: Arc<ServerState>, request: HnswSearchRequest) -> Result<HnswSearchResponse> {
    let mut txn = state.store.begin(TxnMode::ReadOnly)?;
    let index = HnswIndex::load(&request.index, &mut txn).map_err(map_core_error)?;
    let (results, _) = index
        .search(&request.query, request.k, None)
        .map_err(map_core_error)?;
    txn.commit_self()?;
    Ok(HnswSearchResponse { results })
}

fn upsert_impl(state: Arc<ServerState>, request: HnswUpsertRequest) -> Result<HnswStatusResponse> {
    let mut txn = state.store.begin(TxnMode::ReadWrite)?;
    let mut index = HnswIndex::load(&request.index, &mut txn).map_err(map_core_error)?;
    index
        .upsert(&request.key, &request.vector, &[])
        .map_err(map_core_error)?;
    index.save(&mut txn).map_err(map_core_error)?;
    txn.commit_self()?;
    Ok(HnswStatusResponse { success: true })
}

fn delete_impl(state: Arc<ServerState>, request: HnswDeleteRequest) -> Result<HnswStatusResponse> {
    let mut txn = state.store.begin(TxnMode::ReadWrite)?;
    let mut index = HnswIndex::load(&request.index, &mut txn).map_err(map_core_error)?;
    index.delete(&request.key).map_err(map_core_error)?;
    index.save(&mut txn).map_err(map_core_error)?;
    txn.commit_self()?;
    Ok(HnswStatusResponse { success: true })
}

fn create_impl(state: Arc<ServerState>, request: HnswCreateRequest) -> Result<HnswStatusResponse> {
    let metric = parse_metric(&request.metric)?;
    let config = HnswConfig {
        dimension: request.dim,
        metric,
        m: DEFAULT_M,
        ef_construction: DEFAULT_EF_CONSTRUCTION,
    };
    config.validate().map_err(map_core_error)?;

    let index = HnswIndex::create(&request.index, config).map_err(map_core_error)?;
    let mut txn = state.store.begin(TxnMode::ReadWrite)?;
    index.save(&mut txn).map_err(map_core_error)?;
    txn.commit_self()?;
    Ok(HnswStatusResponse { success: true })
}

fn drop_impl(state: Arc<ServerState>, request: HnswDropRequest) -> Result<HnswStatusResponse> {
    let mut txn = state.store.begin(TxnMode::ReadWrite)?;
    let index = HnswIndex::load(&request.index, &mut txn).map_err(map_core_error)?;
    index.drop(&mut txn).map_err(map_core_error)?;
    txn.commit_self()?;
    Ok(HnswStatusResponse { success: true })
}

fn stats_impl(state: Arc<ServerState>, request: HnswStatsRequest) -> Result<HnswStatsResponse> {
    let mut txn = state.store.begin(TxnMode::ReadOnly)?;
    let index = HnswIndex::load(&request.index, &mut txn).map_err(map_core_error)?;
    let stats = index.stats();
    txn.commit_self()?;
    Ok(HnswStatsResponse { stats })
}

fn parse_metric(raw: &str) -> Result<Metric> {
    match raw {
        "cosine" => Ok(Metric::Cosine),
        "l2" => Ok(Metric::L2),
        "ip" => Ok(Metric::InnerProduct),
        other => Err(ServerError::BadRequest(format!("unknown metric: {other}"))),
    }
}

fn default_k() -> usize {
    10
}

fn map_core_error(err: alopex_core::Error) -> ServerError {
    match err {
        alopex_core::Error::NotFound => ServerError::NotFound("index not found".into()),
        alopex_core::Error::InvalidParameter { param, reason } => {
            ServerError::BadRequest(format!("invalid parameter {param}: {reason}"))
        }
        other => ServerError::Core(other),
    }
}