rust-memex 0.6.5

Operator CLI + MCP server: canonical corpus second: semantic index second to aicx
Documentation
use std::convert::Infallible;
use std::io;
use std::path::PathBuf;
use std::time::Duration;

use axum::{
    Json, Router,
    body::{Body, Bytes},
    extract::{Multipart, State},
    http::{HeaderValue, StatusCode, header},
    response::{
        IntoResponse,
        sse::{Event, Sse},
    },
    routing::post,
};
use futures::StreamExt;
use memex_contracts::progress::SseEvent;
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::sync::mpsc;

use crate::{
    ChunkerKind, ReindexJob, ReprocessJob, SliceMode, default_reindexed_namespace,
    export_namespace_jsonl_stream, import_jsonl_bytes_stream, migrate_namespace_atomic,
    reindex_namespace, reprocess_jsonl_file,
};

use super::HttpState;

#[derive(Debug, Deserialize)]
struct ReprocessRequest {
    input_path: String,
    target_namespace: String,
    slice_mode: String,
    #[serde(default)]
    chunker: Option<ChunkerKind>,
    #[serde(default)]
    preprocess: bool,
    #[serde(default)]
    skip_existing: bool,
    #[serde(default)]
    allow_duplicates: bool,
}

#[derive(Debug, Deserialize)]
struct ReindexRequest {
    source_namespace: String,
    target_namespace: Option<String>,
    slice_mode: String,
    #[serde(default)]
    chunker: Option<ChunkerKind>,
    #[serde(default)]
    preprocess: bool,
    #[serde(default)]
    skip_existing: bool,
    #[serde(default)]
    allow_duplicates: bool,
}

#[derive(Debug, Deserialize)]
struct ExportRequest {
    namespace: String,
    #[serde(default)]
    include_embeddings: bool,
}

#[derive(Debug, Deserialize)]
struct MigrateNamespaceRequest {
    from: String,
    to: String,
}

pub(super) fn routes() -> Router<HttpState> {
    Router::new()
        .route("/sse/reprocess", post(sse_reprocess_handler))
        .route("/sse/reindex", post(sse_reindex_handler))
        .route("/api/export", post(export_handler))
        .route("/api/import", post(import_handler))
        .route("/api/migrate-namespace", post(migrate_namespace_handler))
}

async fn sse_reprocess_handler(
    State(state): State<HttpState>,
    Json(request): Json<ReprocessRequest>,
) -> Result<Sse<impl futures::Stream<Item = Result<Event, Infallible>>>, (StatusCode, String)> {
    let slice_mode: SliceMode = request.slice_mode.parse().map_err(|err| {
        (
            StatusCode::BAD_REQUEST,
            format!("invalid slice_mode '{}': {}", request.slice_mode, err),
        )
    })?;

    let input_path = PathBuf::from(request.input_path.clone());
    let target_namespace = request.target_namespace.clone();
    let slice_mode_name = request.slice_mode.clone();
    let chunker = request.chunker;
    let preprocess = request.preprocess;
    let skip_existing = request.skip_existing;
    let allow_duplicates = request.allow_duplicates;
    let (tx, mut rx) = mpsc::unbounded_channel();
    let rag = state.rag.clone();

    tokio::spawn(async move {
        let _ = tx.send(sse_event(
            "start",
            json!({
                "input_path": input_path.display().to_string(),
                "target_namespace": target_namespace.clone(),
                "slice_mode": slice_mode_name.clone(),
                "chunker": chunker.map(|kind| kind.name()),
                "preprocess": preprocess,
                "skip_existing": skip_existing,
                "allow_duplicates": allow_duplicates,
            }),
        ));

        let result = reprocess_jsonl_file(
            rag,
            ReprocessJob {
                input_path: input_path.clone(),
                target_namespace: target_namespace.clone(),
                slice_mode,
                chunker,
                preprocess,
                skip_existing,
                allow_duplicates,
                dry_run: false,
            },
            |progress| {
                let _ = tx.send(sse_event(
                    "progress",
                    serde_json::to_value(progress).unwrap_or(Value::Null),
                ));
            },
        )
        .await;

        match result {
            Ok(summary) => {
                let _ = tx.send(sse_event(
                    "result",
                    serde_json::to_value(summary).unwrap_or(Value::Null),
                ));
            }
            Err(err) => {
                let _ = tx.send(sse_event("error", json!({ "error": err.to_string() })));
            }
        }
    });

    let stream = async_stream::stream! {
        while let Some(event) = rx.recv().await {
            yield Ok(to_axum_event(event));
        }
    };

    Ok(Sse::new(stream).keep_alive(
        axum::response::sse::KeepAlive::new()
            .interval(Duration::from_secs(15))
            .text("ping"),
    ))
}

async fn sse_reindex_handler(
    State(state): State<HttpState>,
    Json(request): Json<ReindexRequest>,
) -> Result<Sse<impl futures::Stream<Item = Result<Event, Infallible>>>, (StatusCode, String)> {
    let slice_mode: SliceMode = request.slice_mode.parse().map_err(|err| {
        (
            StatusCode::BAD_REQUEST,
            format!("invalid slice_mode '{}': {}", request.slice_mode, err),
        )
    })?;

    let source_namespace = request.source_namespace.clone();
    let target_namespace = request
        .target_namespace
        .clone()
        .unwrap_or_else(|| default_reindexed_namespace(&source_namespace));
    let slice_mode_name = request.slice_mode.clone();
    let chunker = request.chunker;
    let preprocess = request.preprocess;
    let skip_existing = request.skip_existing;
    let allow_duplicates = request.allow_duplicates;
    let (tx, mut rx) = mpsc::unbounded_channel();
    let rag = state.rag.clone();

    tokio::spawn(async move {
        let _ = tx.send(sse_event(
            "start",
            json!({
                "source_namespace": source_namespace.clone(),
                "target_namespace": target_namespace.clone(),
                "slice_mode": slice_mode_name.clone(),
                "chunker": chunker.map(|kind| kind.name()),
                "preprocess": preprocess,
                "skip_existing": skip_existing,
                "allow_duplicates": allow_duplicates,
            }),
        ));

        let result = reindex_namespace(
            rag,
            ReindexJob {
                source_namespace: source_namespace.clone(),
                target_namespace: target_namespace.clone(),
                slice_mode,
                chunker,
                preprocess,
                skip_existing,
                allow_duplicates,
                dry_run: false,
            },
            |progress| {
                let _ = tx.send(sse_event(
                    "progress",
                    serde_json::to_value(progress).unwrap_or(Value::Null),
                ));
            },
        )
        .await;

        match result {
            Ok(summary) => {
                let _ = tx.send(sse_event(
                    "result",
                    serde_json::to_value(summary).unwrap_or(Value::Null),
                ));
            }
            Err(err) => {
                let _ = tx.send(sse_event("error", json!({ "error": err.to_string() })));
            }
        }
    });

    let stream = async_stream::stream! {
        while let Some(event) = rx.recv().await {
            yield Ok(to_axum_event(event));
        }
    };

    Ok(Sse::new(stream).keep_alive(
        axum::response::sse::KeepAlive::new()
            .interval(Duration::from_secs(15))
            .text("ping"),
    ))
}

async fn export_handler(
    State(state): State<HttpState>,
    Json(request): Json<ExportRequest>,
) -> impl IntoResponse {
    let namespace = request.namespace.clone();
    let stream = export_namespace_jsonl_stream(
        state.rag.storage_manager(),
        request.namespace.clone(),
        request.include_embeddings,
    )
    .map(|item| item.map(Bytes::from).map_err(io::Error::other));

    let mut headers = axum::http::HeaderMap::new();
    headers.insert(
        header::CONTENT_TYPE,
        HeaderValue::from_static("application/x-ndjson"),
    );
    if let Ok(value) = HeaderValue::from_str(&format!(
        "attachment; filename=\"{}.jsonl\"",
        sanitize_filename(&namespace)
    )) {
        headers.insert(header::CONTENT_DISPOSITION, value);
    }

    (StatusCode::OK, headers, Body::from_stream(stream))
}

async fn import_handler(
    State(state): State<HttpState>,
    mut multipart: Multipart,
) -> Result<Json<Value>, (StatusCode, String)> {
    let mut namespace = None;
    let mut skip_existing = false;
    let mut imported_count = None;

    while let Some(field) = multipart.next_field().await.map_err(internal_error)? {
        let field_name = field.name().unwrap_or_default().to_string();
        match field_name.as_str() {
            "namespace" => {
                namespace = Some(field.text().await.map_err(internal_error)?);
            }
            "skip_existing" => {
                let value = field.text().await.map_err(internal_error)?;
                skip_existing = parse_bool_field(&value)?;
            }
            "file" => {
                let namespace = namespace.clone().ok_or_else(|| {
                    (
                        StatusCode::BAD_REQUEST,
                        "missing multipart field 'namespace' before file upload".to_string(),
                    )
                })?;
                let mut field = field;
                let stream = async_stream::stream! {
                    loop {
                        match field.chunk().await {
                            Ok(Some(chunk)) => yield Ok::<Bytes, String>(chunk),
                            Ok(None) => break,
                            Err(err) => {
                                yield Err(err.to_string());
                                break;
                            }
                        }
                    }
                };
                futures::pin_mut!(stream);
                let outcome =
                    import_jsonl_bytes_stream(state.rag.clone(), namespace, skip_existing, stream)
                        .await
                        .map_err(internal_error)?;
                imported_count = Some(outcome.imported_count);
                break;
            }
            _ => {}
        }
    }

    if namespace.is_none() {
        return Err((
            StatusCode::BAD_REQUEST,
            "missing multipart field 'namespace'".to_string(),
        ));
    }

    let imported_count = imported_count.ok_or_else(|| {
        (
            StatusCode::BAD_REQUEST,
            "missing multipart file field 'file'".to_string(),
        )
    })?;

    Ok(Json(json!({ "imported_count": imported_count })))
}

async fn migrate_namespace_handler(
    State(state): State<HttpState>,
    Json(request): Json<MigrateNamespaceRequest>,
) -> Result<Json<Value>, (StatusCode, String)> {
    let storage = state.rag.storage_manager();
    let outcome = migrate_namespace_atomic(storage.as_ref(), &request.from, &request.to)
        .await
        .map_err(internal_error)?;
    Ok(Json(json!({ "migrated_chunks": outcome.migrated_chunks })))
}

fn sse_event(event: &str, data: Value) -> SseEvent {
    SseEvent {
        event: event.to_string(),
        id: None,
        data,
    }
}

fn to_axum_event(event: SseEvent) -> Event {
    let mut axum_event = Event::default()
        .event(event.event)
        .data(event.data.to_string());
    if let Some(id) = event.id {
        axum_event = axum_event.id(id);
    }
    axum_event
}

fn sanitize_filename(namespace: &str) -> String {
    namespace
        .chars()
        .map(|ch| match ch {
            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
            _ => '_',
        })
        .collect()
}

fn parse_bool_field(value: &str) -> Result<bool, (StatusCode, String)> {
    match value.trim().to_ascii_lowercase().as_str() {
        "1" | "true" | "yes" | "on" => Ok(true),
        "0" | "false" | "no" | "off" | "" => Ok(false),
        other => Err((
            StatusCode::BAD_REQUEST,
            format!("invalid boolean field value '{}'", other),
        )),
    }
}

fn internal_error(err: impl std::fmt::Display) -> (StatusCode, String) {
    (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}