use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::{info, Level};
use uuid::Uuid;
use reasonkit_mem::{
retrieval::HybridRetriever, Chunk, Document, DocumentType, EmbeddingIds, RetrievalConfig,
Source, SourceType,
};
struct AppState {
retriever: RwLock<HybridRetriever>,
}
#[derive(Debug, Deserialize)]
struct SearchRequest {
query: String,
#[serde(default = "default_top_k")]
top_k: usize,
#[serde(default = "default_alpha")]
alpha: f32,
#[serde(default = "default_mode")]
mode: String,
}
fn default_top_k() -> usize {
10
}
fn default_alpha() -> f32 {
0.7
}
fn default_mode() -> String {
"hybrid".to_string()
}
#[derive(Debug, Serialize)]
struct SearchResponse {
results: Vec<SearchResultItem>,
query: String,
mode: String,
#[serde(skip_serializing_if = "Option::is_none")]
stats: Option<SearchStats>,
}
#[derive(Debug, Serialize)]
struct SearchResultItem {
doc_id: String,
chunk_id: String,
text: String,
score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
dense_score: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
sparse_score: Option<f32>,
match_source: String,
}
#[derive(Debug, Serialize)]
struct SearchStats {
total_results: usize,
search_time_ms: u64,
}
#[derive(Debug, Deserialize)]
struct EmbedRequest {
texts: Vec<String>,
}
#[derive(Debug, Serialize)]
struct EmbedResponse {
embeddings: Vec<Vec<f32>>,
model: String,
dimension: usize,
}
#[derive(Debug, Deserialize)]
struct AddDocumentRequest {
content: String,
#[serde(default)]
title: Option<String>,
#[serde(default)]
#[allow(dead_code)]
metadata: HashMap<String, String>,
#[serde(default)]
source: Option<String>,
}
#[derive(Debug, Serialize)]
struct AddDocumentResponse {
id: String,
chunks: usize,
message: String,
}
#[derive(Debug, Serialize)]
struct StatsResponse {
document_count: usize,
chunk_count: usize,
indexed_chunks: usize,
embedding_count: usize,
storage_bytes: u64,
index_bytes: u64,
}
#[derive(Debug, Serialize)]
struct HealthResponse {
status: String,
version: String,
uptime_secs: u64,
}
#[derive(Debug, Serialize)]
struct ErrorResponse {
error: String,
code: String,
}
#[derive(Debug, Deserialize)]
struct OpenWebUISearchRequest {
query: String,
#[serde(default = "default_count")]
count: usize,
}
fn default_count() -> usize {
5
}
#[derive(Debug, Serialize)]
struct OpenWebUISearchResult {
link: String,
title: String,
snippet: String,
}
async fn health() -> Json<HealthResponse> {
Json(HealthResponse {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_secs: 0, })
}
async fn search(
State(state): State<Arc<AppState>>,
Json(req): Json<SearchRequest>,
) -> Result<Json<SearchResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let retriever = state.retriever.read().await;
let results = match req.mode.as_str() {
"sparse" | "bm25" => retriever
.search_sparse(&req.query, req.top_k)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "SEARCH_ERROR".to_string(),
}),
)
})?,
"dense" | "vector" => retriever
.search_dense(&req.query, req.top_k)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "SEARCH_ERROR".to_string(),
}),
)
})?,
_ => {
let config = RetrievalConfig {
top_k: req.top_k,
alpha: req.alpha,
..Default::default()
};
retriever
.search_hybrid(&req.query, None, &config)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "SEARCH_ERROR".to_string(),
}),
)
})?
}
};
let elapsed = start.elapsed();
let response_items: Vec<SearchResultItem> = results
.into_iter()
.map(|r| SearchResultItem {
doc_id: r.doc_id.to_string(),
chunk_id: r.chunk_id.to_string(),
text: r.text,
score: r.score,
dense_score: r.dense_score,
sparse_score: r.sparse_score,
match_source: format!("{:?}", r.match_source),
})
.collect();
let total = response_items.len();
Ok(Json(SearchResponse {
results: response_items,
query: req.query,
mode: req.mode,
stats: Some(SearchStats {
total_results: total,
search_time_ms: elapsed.as_millis() as u64,
}),
}))
}
async fn embed(
State(state): State<Arc<AppState>>,
Json(req): Json<EmbedRequest>,
) -> Result<Json<EmbedResponse>, (StatusCode, Json<ErrorResponse>)> {
let retriever = state.retriever.read().await;
if retriever.embedding_pipeline().is_none() {
return Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Embedding pipeline not configured. Start server with embeddings enabled."
.to_string(),
code: "EMBEDDING_NOT_CONFIGURED".to_string(),
}),
));
}
let pipeline = retriever.embedding_pipeline().unwrap();
let mut embeddings = Vec::with_capacity(req.texts.len());
for text in &req.texts {
let embedding = pipeline.embed_text(text).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "EMBEDDING_ERROR".to_string(),
}),
)
})?;
embeddings.push(embedding);
}
let dimension = embeddings.first().map(|e| e.len()).unwrap_or(0);
Ok(Json(EmbedResponse {
embeddings,
model: "bge-m3".to_string(), dimension,
}))
}
async fn add_document(
State(state): State<Arc<AppState>>,
Json(req): Json<AddDocumentRequest>,
) -> Result<Json<AddDocumentResponse>, (StatusCode, Json<ErrorResponse>)> {
use chrono::Utc;
let source = Source {
source_type: if req
.source
.as_ref()
.map(|s| s.starts_with("http"))
.unwrap_or(false)
{
SourceType::Website
} else {
SourceType::Local
},
url: req.source.clone().filter(|s| s.starts_with("http")),
path: req.source.clone().filter(|s| !s.starts_with("http")),
arxiv_id: None,
github_repo: None,
retrieved_at: Utc::now(),
version: None,
};
let mut doc = Document::new(DocumentType::Note, source).with_content(req.content.clone());
if let Some(title) = req.title {
doc.metadata.title = Some(title);
}
let chunks: Vec<Chunk> = req
.content
.split("\n\n")
.enumerate()
.filter(|(_, text)| !text.trim().is_empty())
.map(|(i, text)| Chunk {
id: Uuid::new_v4(),
text: text.to_string(),
index: i,
start_char: 0, end_char: text.len(),
token_count: None,
section: None,
page: None,
embedding_ids: EmbeddingIds::default(),
})
.collect();
let chunk_count = chunks.len();
doc.chunks = chunks;
let retriever = state.retriever.read().await;
retriever.add_document(&doc).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "ADD_DOCUMENT_ERROR".to_string(),
}),
)
})?;
Ok(Json(AddDocumentResponse {
id: doc.id.to_string(),
chunks: chunk_count,
message: format!("Document added with {} chunks", chunk_count),
}))
}
async fn openwebui_search(
State(state): State<Arc<AppState>>,
Json(req): Json<OpenWebUISearchRequest>,
) -> Result<Json<Vec<OpenWebUISearchResult>>, (StatusCode, Json<ErrorResponse>)> {
let retriever = state.retriever.read().await;
let results = retriever
.search_sparse(&req.query, req.count)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "SEARCH_ERROR".to_string(),
}),
)
})?;
let response: Vec<OpenWebUISearchResult> = results
.into_iter()
.map(|r| {
let snippet = if r.text.len() > 500 {
format!("{}...", &r.text[..497])
} else {
r.text.clone()
};
OpenWebUISearchResult {
link: format!("rkmem://doc/{}/chunk/{}", r.doc_id, r.chunk_id),
title: r
.text
.lines()
.next()
.map(|l| l.chars().take(100).collect::<String>())
.unwrap_or_else(|| format!("Chunk {}", r.chunk_id)),
snippet,
}
})
.collect();
Ok(Json(response))
}
async fn stats(
State(state): State<Arc<AppState>>,
) -> Result<Json<StatsResponse>, (StatusCode, Json<ErrorResponse>)> {
let retriever = state.retriever.read().await;
let stats = retriever.stats().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: "STATS_ERROR".to_string(),
}),
)
})?;
Ok(Json(StatsResponse {
document_count: stats.document_count,
chunk_count: stats.chunk_count,
indexed_chunks: stats.indexed_chunks,
embedding_count: stats.embedding_count,
storage_bytes: stats.storage_bytes,
index_bytes: stats.index_bytes,
}))
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().with_max_level(Level::INFO).init();
let host = std::env::var("RKMEM_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
let port: u16 = std::env::var("RKMEM_PORT")
.unwrap_or_else(|_| "8765".to_string())
.parse()?;
info!(
"Starting ReasonKit-mem HTTP Server v{}",
env!("CARGO_PKG_VERSION")
);
let retriever = HybridRetriever::in_memory()?;
let state = Arc::new(AppState {
retriever: RwLock::new(retriever),
});
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/health", get(health))
.route("/v1/search", post(search))
.route("/v1/embed", post(embed))
.route("/v1/documents", post(add_document))
.route("/v1/stats", get(stats))
.route("/v1/openwebui/search", post(openwebui_search))
.layer(cors)
.layer(TraceLayer::new_for_http())
.with_state(state);
let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
info!("Listening on http://{}", addr);
info!("Endpoints:");
info!(" POST /v1/search - Hybrid search");
info!(" POST /v1/embed - Generate embeddings");
info!(" POST /v1/documents - Add documents");
info!(" GET /v1/stats - Get statistics");
info!(" POST /v1/openwebui/search - OpenWebUI external RAG");
info!(" GET /health - Health check");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}