use std::time::Instant;
use axum::Json;
use axum::extract::{Multipart, State};
use mnem_ingest::{
ChunkerAuto, ChunkerKind, IngestConfig, Ingester, NerConfig, SourceKind, auto_chunker,
};
use serde::Deserialize;
use serde_json::{Value, json};
use crate::error::Error;
use crate::state::AppState;
pub(crate) const MAX_INGEST_TOKENS: u32 = 8192;
pub(crate) const DEFAULT_MAX_INGEST_BYTES: u64 = 32 * 1024 * 1024;
fn max_ingest_bytes() -> u64 {
std::env::var("MNEM_HTTP_INGEST_MAX_BYTES")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_MAX_INGEST_BYTES)
}
#[derive(Deserialize, Debug)]
pub(crate) struct IngestJsonBody {
pub text: String,
#[serde(default)]
pub kind: Option<String>,
#[serde(default)]
pub ntype: Option<String>,
#[serde(default)]
pub chunker: Option<String>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub overlap: Option<u32>,
pub author: String,
#[serde(default)]
pub message: Option<String>,
#[serde(default)]
pub extractor: Option<String>,
#[serde(default)]
pub ner_provider: Option<String>,
}
pub(crate) async fn ingest(
State(state): State<AppState>,
multipart_or_json: axum::http::Request<axum::body::Body>,
) -> Result<Json<Value>, Error> {
let content_type = multipart_or_json
.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_ascii_lowercase();
if content_type.starts_with("application/json") {
let bytes = axum::body::to_bytes(multipart_or_json.into_body(), usize::MAX)
.await
.map_err(|e| Error::bad_request(format!("reading body: {e}")))?;
let body: IngestJsonBody = serde_json::from_slice(&bytes)
.map_err(|e| Error::bad_request(format!("malformed JSON body: {e}")))?;
ingest_json(state, body).await
} else {
let multipart = Multipart::from_request(multipart_or_json, &state)
.await
.map_err(|e| Error::bad_request(format!("multipart decode: {e}")))?;
ingest_multipart(state, multipart).await
}
}
async fn ingest_multipart(state: AppState, mut multipart: Multipart) -> Result<Json<Value>, Error> {
let mut file_bytes: Option<Vec<u8>> = None;
let mut file_name: Option<String> = None;
let mut ntype: Option<String> = None;
let mut chunker_str: Option<String> = None;
let mut max_tokens: Option<u32> = None;
let mut overlap: Option<u32> = None;
let mut author: Option<String> = None;
let mut message: Option<String> = None;
let mut extractor: Option<String> = None;
let mut ner_provider: Option<String> = None;
let max_bytes = max_ingest_bytes();
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| Error::bad_request(format!("multipart field: {e}")))?
{
let name = field.name().unwrap_or("").to_string();
match name.as_str() {
"file" => {
file_name = field.file_name().map(ToString::to_string);
let data = field
.bytes()
.await
.map_err(|e| Error::bad_request(format!("reading file field: {e}")))?;
if data.len() as u64 > max_bytes {
return Err(Error::bad_request(format!(
"file field is {} bytes; exceeds the {max_bytes}-byte cap \
(raise MNEM_HTTP_INGEST_MAX_BYTES if legitimate)",
data.len()
)));
}
file_bytes = Some(data.to_vec());
}
"ntype" => ntype = Some(field_text(field).await?),
"chunker" => chunker_str = Some(field_text(field).await?),
"max_tokens" => {
let s = field_text(field).await?;
max_tokens = Some(
s.parse::<u32>()
.map_err(|e| Error::bad_request(format!("max_tokens: {e}")))?,
);
}
"overlap" => {
let s = field_text(field).await?;
overlap = Some(
s.parse::<u32>()
.map_err(|e| Error::bad_request(format!("overlap: {e}")))?,
);
}
"author" => author = Some(field_text(field).await?),
"message" => message = Some(field_text(field).await?),
"extractor" => extractor = Some(field_text(field).await?),
"ner_provider" => ner_provider = Some(field_text(field).await?),
other => {
tracing::debug!(field = %other, "ignoring unknown multipart field on /v1/ingest");
let _ = field.bytes().await;
}
}
}
let bytes =
file_bytes.ok_or_else(|| Error::bad_request("missing `file` field in multipart body"))?;
let kind = file_name.as_ref().map_or(SourceKind::Text, |n| {
Ingester::source_kind_for_path(std::path::Path::new(n))
});
let author =
author.ok_or_else(|| Error::bad_request("missing `author` field in multipart body"))?;
run_ingest(
&state,
&bytes,
kind,
IngestParams {
ntype: ntype.unwrap_or_else(|| "Doc".into()),
chunker: chunker_str.unwrap_or_else(|| "auto".into()),
max_tokens: max_tokens.unwrap_or(512),
overlap: overlap.unwrap_or(32),
author,
message: message.unwrap_or_else(|| "mnem http ingest".into()),
extractor,
ner_provider,
},
)
}
#[allow(clippy::unused_async)]
async fn ingest_json(state: AppState, body: IngestJsonBody) -> Result<Json<Value>, Error> {
let max_bytes = max_ingest_bytes();
if body.text.len() as u64 > max_bytes {
return Err(Error::bad_request(format!(
"text body is {} bytes; exceeds the {max_bytes}-byte cap \
(raise MNEM_HTTP_INGEST_MAX_BYTES if legitimate)",
body.text.len()
)));
}
let kind = match body.kind.as_deref() {
Some("markdown" | "md") => SourceKind::Markdown,
Some("pdf") => SourceKind::Pdf,
Some("conversation" | "json" | "jsonl") => SourceKind::Conversation,
Some("text" | "txt") | None => SourceKind::Text,
Some(other) => {
return Err(Error::bad_request(format!(
"unknown `kind`: {other}; want one of markdown|text|pdf|conversation"
)));
}
};
let bytes = body.text.into_bytes();
run_ingest(
&state,
&bytes,
kind,
IngestParams {
ntype: body.ntype.unwrap_or_else(|| "Doc".into()),
chunker: body.chunker.unwrap_or_else(|| "auto".into()),
max_tokens: body.max_tokens.unwrap_or(512),
overlap: body.overlap.unwrap_or(32),
author: body.author,
message: body.message.unwrap_or_else(|| "mnem http ingest".into()),
extractor: body.extractor,
ner_provider: body.ner_provider,
},
)
}
struct IngestParams {
ntype: String,
chunker: String,
max_tokens: u32,
overlap: u32,
author: String,
message: String,
extractor: Option<String>,
ner_provider: Option<String>,
}
fn run_ingest(
state: &AppState,
bytes: &[u8],
kind: SourceKind,
mut params: IngestParams,
) -> Result<Json<Value>, Error> {
if params.max_tokens > MAX_INGEST_TOKENS {
return Err(Error::bad_request(format!(
"max_tokens {} exceeds the {MAX_INGEST_TOKENS} cap",
params.max_tokens
)));
}
if params.author.trim().is_empty() {
return Err(Error::bad_request("author is required"));
}
if params.message.trim().is_empty() {
params.message = "mnem http ingest".into();
}
let ner = match params.ner_provider.as_deref() {
Some("none") => NerConfig::None,
Some("rule") | None => state.ner_cfg.clone().unwrap_or(NerConfig::Rule),
Some(other) => {
return Err(Error::bad_request(format!(
"unknown `ner_provider`: {other}; want one of rule|none"
)));
}
};
let chunker = resolve_chunker(¶ms.chunker, kind, params.max_tokens, params.overlap)?;
let config = IngestConfig {
chunker,
ntype: params.ntype,
max_tokens: params.max_tokens,
overlap: params.overlap,
ner,
};
let mut ing = Ingester::new(config);
match params.extractor.as_deref() {
None | Some("" | "none") => {}
Some("keybert") => {
let pc = state.embed_cfg.as_ref().ok_or_else(|| {
Error::bad_request(
"extractor=keybert requires an [embed] provider configured on the server \
(MNEM_EMBED_PROVIDER / config.toml); none resolved",
)
})?;
let boxed = mnem_embed_providers::open(pc).map_err(|e| {
Error::bad_request(format!("opening embed provider for keybert: {e}"))
})?;
let arc: std::sync::Arc<dyn mnem_embed_providers::Embedder> =
std::sync::Arc::from(boxed);
ing = ing.with_extractor(Box::new(mnem_ingest::KeyBertAdapter::new(arc, "Keyword")));
}
Some(other) => {
return Err(Error::bad_request(format!(
"unknown `extractor`: {other}; want one of none|keybert"
)));
}
}
let started = Instant::now();
let mut guard = state.repo.lock().map_err(|_| Error::locked())?;
let mut tx = guard.start_transaction();
let result = ing
.ingest(&mut tx, bytes, kind)
.map_err(|e| Error::bad_request(format!("ingest failed: {e}")))?;
let commit_start = Instant::now();
let new_repo = tx.commit(¶ms.author, ¶ms.message)?;
state
.metrics
.commit_duration
.observe(commit_start.elapsed().as_secs_f64());
let op_id = new_repo.op_id().to_string();
let commit_cid = new_repo
.view()
.heads
.first()
.map_or_else(|| "<none>".to_string(), ToString::to_string);
*guard = new_repo;
let elapsed = started.elapsed().as_secs_f64();
state.metrics.ingest_duration.observe(elapsed);
state.metrics.ingest_chunks.inc_by(result.chunk_count);
Ok(Json(json!({
"schema": "mnem.v1.ingest",
"op_id": op_id,
"commit_cid": commit_cid,
"node_count": result.node_count,
"chunk_count": result.chunk_count,
"entity_count": result.entity_count,
"relation_count": result.relation_count,
"elapsed_ms": result.elapsed_ms,
})))
}
fn resolve_chunker(
choice: &str,
kind: SourceKind,
max_tokens: u32,
overlap: u32,
) -> Result<ChunkerKind, Error> {
Ok(match choice.to_ascii_lowercase().as_str() {
"auto" => auto_chunker(
kind,
ChunkerAuto {
max_tokens: Some(max_tokens),
overlap: Some(overlap),
max_messages: None,
},
),
"paragraph" => ChunkerKind::Paragraph,
"recursive" => ChunkerKind::Recursive {
max_tokens,
overlap,
},
"session" => ChunkerKind::Session { max_messages: 10 },
other => {
return Err(Error::bad_request(format!(
"chunker must be one of auto|paragraph|recursive|session; got `{other}`"
)));
}
})
}
async fn field_text(field: axum::extract::multipart::Field<'_>) -> Result<String, Error> {
field
.text()
.await
.map_err(|e| Error::bad_request(format!("decoding text field: {e}")))
}
use axum::extract::FromRequest;