spark-bert 0.1.0

Hybrid vector search using an inverted index and BERT embeddings
Documentation
use anyhow::Result;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use spark_bert::{
    api::{Config, SparkBert},
    util::device,
};
use std::sync::Arc;
use tokio::{net::TcpListener, sync::Mutex};

type SharedSparkBert = Arc<Mutex<SparkBert>>;

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();
    let config = Config {
        use_ram_index: false,
        device: device(false)?,
        index_n_neighbors: 8,
    };
    let spark_bert = Arc::new(Mutex::new(SparkBert::new(config)?));
    let app = Router::new()
        .route("/search", post(search))
        .route("/index", post(index_docs))
        .with_state(spark_bert);

    let listener = TcpListener::bind("0.0.0.0:8000").await?;
    axum::serve(listener, app).await?;
    Ok(())
}

async fn search(
    State(spark_bert): State<SharedSparkBert>,
    Json(payload): Json<SearchRequest>,
) -> Result<Json<SearchResponse>, (StatusCode, String)> {
    const DEFAULT_SEARCH_N_NEIGHBORS: usize = 3;
    const DEFAULT_TOP_K: usize = 10;

    let mut spark_bert = spark_bert.lock().await;
    let results = spark_bert
        .search(
            &payload.query,
            payload
                .search_n_neighbors
                .unwrap_or(DEFAULT_SEARCH_N_NEIGHBORS),
            payload.top_k.unwrap_or(DEFAULT_TOP_K),
        )
        .map_err(internal_error)?;

    let hits = results
        .into_iter()
        .map(|(doc_id, score)| SearchResult { doc_id, score })
        .collect();

    Ok(Json(SearchResponse { results: hits }))
}

fn internal_error(err: anyhow::Error) -> (StatusCode, String) {
    (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}

#[derive(Deserialize)]
struct SearchRequest {
    query: String,
    search_n_neighbors: Option<usize>,
    top_k: Option<usize>,
}

#[derive(Serialize)]
struct SearchResponse {
    results: Vec<SearchResult>,
}

#[derive(Serialize)]
struct SearchResult {
    doc_id: u64,
    score: f64,
}

async fn index_docs(
    State(spark_bert): State<SharedSparkBert>,
    Json(payload): Json<IndexRequest>,
) -> Result<Json<IndexResponse>, (StatusCode, String)> {
    let mut spark_bert = spark_bert.lock().await;
    let indexed = payload.docs.len();
    let docs = payload.docs.into_iter().map(|doc| (doc.doc_id, doc.text));
    spark_bert.index(docs, false).map_err(internal_error)?;
    Ok(Json(IndexResponse { indexed }))
}

#[derive(Deserialize)]
struct IndexRequest {
    docs: Vec<Doc>,
}

#[derive(Deserialize)]
struct Doc {
    doc_id: u64,
    text: String,
}

#[derive(Serialize)]
struct IndexResponse {
    indexed: usize,
}