use std::sync::Arc;
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use crate::common::{AppState, check_auth, redacted_error};
use super::{RagRequest, RagResponse, RagResult};
pub(crate) async fn rag_post(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
body: Body,
) -> Response {
if let Err(r) = check_auth(&state, &headers) {
return r;
}
let body_bytes = match axum::body::to_bytes(body, 1024 * 1024).await {
Ok(b) => b,
Err(_) => {
return (StatusCode::PAYLOAD_TOO_LARGE, "request body too large").into_response();
}
};
let req: RagRequest = match serde_json::from_slice(&body_bytes) {
Ok(r) => r,
Err(e) => {
return redacted_error(
"invalid_request",
&format!("invalid JSON body: {e}"),
StatusCode::BAD_REQUEST,
);
}
};
let client = match state.pool.get().await {
Ok(c) => c,
Err(e) => {
return redacted_error(
"pool_error",
&format!("connection pool error: {e}"),
StatusCode::SERVICE_UNAVAILABLE,
);
}
};
let sparql_filter_val: Option<String> = req.sparql_filter.clone();
let model_val: Option<String> = req.model.clone();
let output_format = if req.output_format == "jsonld" {
"jsonld"
} else {
"jsonb"
};
let rows = match client
.query(
"SELECT entity_iri, label, context_json, distance \
FROM pg_ripple.rag_retrieve($1, sparql_filter := $2::text, \
k := $3, model := $4::text, output_format := $5)",
&[
&req.question,
&sparql_filter_val,
&req.k,
&model_val,
&output_format,
],
)
.await
{
Ok(r) => r,
Err(e) => {
return redacted_error(
"rag_error",
&format!("rag_retrieve failed: {e}"),
StatusCode::INTERNAL_SERVER_ERROR,
);
}
};
let mut results = Vec::with_capacity(rows.len());
let mut context_parts = Vec::with_capacity(rows.len());
for row in &rows {
let entity_iri: String = row.get(0);
let label: String = row.get(1);
let context_json: serde_json::Value = row
.try_get::<_, serde_json::Value>(2)
.unwrap_or(serde_json::Value::Null);
let distance: f64 = row.get(3);
let context_text = context_json
.get("contextText")
.and_then(|v| v.as_str())
.unwrap_or(&label)
.to_owned();
context_parts.push(context_text);
results.push(RagResult {
entity_iri,
label,
context_json,
distance,
});
}
let response = RagResponse {
results,
context: context_parts.join("\n\n"),
};
axum::Json(response).into_response()
}