murr 0.2.0

Columnar in-memory cache for AI/ML inference workloads
Documentation
use std::io::Cursor;
use std::sync::{Arc, LazyLock};


use arrow::ipc::reader::StreamReader;
use arrow::ipc::writer::StreamWriter;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use axum::body::Bytes;
use axum::extract::{Path, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::Deserialize;

use crate::core::{MurrError, TableSchema};
use crate::service::MurrService;

use super::convert::{FetchResponse, WriteRequest};
use super::error::ApiError;

const ARROW_IPC_MIME: &str = "application/vnd.apache.arrow.stream";
const PARQUET_MIME: &str = "application/vnd.apache.parquet";

static OPENAPI_JSON: LazyLock<serde_json::Value> = LazyLock::new(|| {
    let yaml = include_str!("../../../openapi.yaml");
    serde_yaml_ng::from_str(yaml).expect("openapi.yaml must be valid YAML")
});

pub async fn openapi() -> Json<serde_json::Value> {
    Json(OPENAPI_JSON.clone())
}

pub async fn health() -> &'static str {
    "OK"
}

pub async fn list_tables(
    State(service): State<Arc<MurrService>>,
) -> Result<Json<std::collections::HashMap<String, TableSchema>>, ApiError> {
    let svc = service.clone();
    let tables = tokio::task::spawn_blocking(move || svc.list_tables())
        .await
        .map_err(join_to_api_error)?;
    Ok(Json(tables))
}

pub async fn get_schema(
    State(service): State<Arc<MurrService>>,
    Path(name): Path<String>,
) -> Result<Json<TableSchema>, ApiError> {
    let svc = service.clone();
    let schema = tokio::task::spawn_blocking(move || svc.get_schema(&name))
        .await
        .map_err(join_to_api_error)??;
    Ok(Json(schema))
}

pub async fn create_table(
    State(service): State<Arc<MurrService>>,
    Path(name): Path<String>,
    Json(schema): Json<TableSchema>,
) -> Result<StatusCode, ApiError> {
    let svc = service.clone();
    tokio::task::spawn_blocking(move || svc.create(&name, schema))
        .await
        .map_err(join_to_api_error)??;
    Ok(StatusCode::CREATED)
}

#[derive(Deserialize)]
pub struct FetchRequest {
    pub keys: Vec<String>,
    pub columns: Vec<String>,
}

pub async fn fetch(
    State(service): State<Arc<MurrService>>,
    Path(name): Path<String>,
    headers: HeaderMap,
    Json(req): Json<FetchRequest>,
) -> Result<Response, ApiError> {
    let wants_arrow = headers
        .get("accept")
        .and_then(|v| v.to_str().ok())
        .is_some_and(|v| v.contains(ARROW_IPC_MIME));

    let svc = service.clone();
    tokio::task::spawn_blocking(move || -> Result<Response, ApiError> {
        let keys: Vec<&str> = req.keys.iter().map(String::as_str).collect();
        let columns: Vec<&str> = req.columns.iter().map(String::as_str).collect();
        let batch = svc.read(&name, &keys, &columns)?;

        if wants_arrow {
            let mut buf = Vec::new();
            {
                let mut writer = StreamWriter::try_new(&mut buf, &batch.schema())
                    .map_err(|e| ApiError(e.into()))?;
                writer.write(&batch).map_err(|e| ApiError(e.into()))?;
                writer.finish().map_err(|e| ApiError(e.into()))?;
            }
            Ok(([(axum::http::header::CONTENT_TYPE, ARROW_IPC_MIME)], buf).into_response())
        } else {
            let FetchResponse(json) = FetchResponse::try_from(&batch).map_err(ApiError)?;
            Ok(Json(json).into_response())
        }
    })
    .await
    .map_err(join_to_api_error)?
}

pub async fn write_table(
    State(service): State<Arc<MurrService>>,
    Path(name): Path<String>,
    headers: HeaderMap,
    body: Bytes,
) -> Result<StatusCode, ApiError> {
    let content_type = headers
        .get("content-type")
        .and_then(|v| v.to_str().ok())
        .unwrap_or("")
        .to_string();

    let svc = service.clone();
    tokio::task::spawn_blocking(move || -> Result<StatusCode, ApiError> {
        let batch = if content_type.contains(ARROW_IPC_MIME) {
            let cursor = Cursor::new(&body);
            let mut reader = StreamReader::try_new(cursor, None)
                .map_err(|e| ApiError(e.into()))?;
            reader
                .next()
                .ok_or_else(|| ApiError(MurrError::TableError("empty Arrow IPC stream".into())))?
                .map_err(|e| ApiError(e.into()))?
        } else if content_type.contains(PARQUET_MIME) {
            let reader = ParquetRecordBatchReaderBuilder::try_new(body)
                .map_err(|e| ApiError(MurrError::TableError(format!("invalid Parquet: {e}"))))?
                .build()
                .map_err(|e| ApiError(MurrError::TableError(format!("invalid Parquet: {e}"))))?;
            let batches: Vec<_> = reader
                .collect::<Result<Vec<_>, _>>()
                .map_err(|e| ApiError(e.into()))?;
            arrow::compute::concat_batches(
                &batches[0].schema(),
                &batches,
            )
            .map_err(|e| ApiError(e.into()))?
        } else {
            let write: WriteRequest = serde_json::from_slice(&body)
                .map_err(|e| ApiError(MurrError::TableError(format!("invalid JSON: {e}"))))?;
            let schema = svc.get_schema(&name)?;
            write.into_record_batch(&schema).map_err(ApiError)?
        };

        svc.write(&name, &batch)?;
        Ok(StatusCode::OK)
    })
    .await
    .map_err(join_to_api_error)?
}

fn join_to_api_error(e: tokio::task::JoinError) -> ApiError {
    ApiError(MurrError::IoError(format!("blocking task failed: {e}")))
}