use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{delete, get, post, put};
use serde::{Deserialize, Serialize};
use tower_http::cors::CorsLayer;
use tower_http::limit::RequestBodyLimitLayer;
use utoipa::OpenApi;
pub mod middleware;
use crate::arrow_encode::flight_do_get;
use crate::common::{AppState, check_auth};
pub(crate) mod datalog_admin;
pub(crate) mod datalog_handlers;
pub(crate) mod datalog_inference;
use self::datalog_handlers as datalog;
#[derive(OpenApi)]
#[openapi(
info(
title = "pg_ripple_http",
version = "0.16.0",
description = "SPARQL 1.1 Protocol HTTP endpoint and Datalog REST API for pg_ripple",
license(name = "Apache-2.0")
),
paths(
admin_handlers::openapi_spec,
),
tags(
(name = "sparql", description = "SPARQL 1.1 Query and Update Protocol"),
(name = "datalog", description = "Datalog inference and rule management"),
(name = "health", description = "Health and observability"),
(name = "metadata", description = "Dataset and service metadata"),
)
)]
pub struct ApiDoc;
pub(crate) const CT_SPARQL_JSON: &str = "application/sparql-results+json";
pub(crate) const CT_SPARQL_XML: &str = "application/sparql-results+xml";
pub(crate) const CT_CSV: &str = "text/csv";
pub(crate) const CT_TSV: &str = "text/tab-separated-values";
pub(crate) const CT_TURTLE: &str = "text/turtle";
pub(crate) const CT_NTRIPLES: &str = "application/n-triples";
pub(crate) const CT_JSONLD: &str = "application/ld+json";
pub(crate) const CT_SPARQL_QUERY: &str = "application/sparql-query";
pub(crate) const CT_SPARQL_UPDATE: &str = "application/sparql-update";
pub(crate) const CT_FORM: &str = "application/x-www-form-urlencoded";
#[derive(Deserialize, Default)]
pub(crate) struct SparqlParams {
query: Option<String>,
update: Option<String>,
}
#[derive(Deserialize)]
pub(crate) struct RagRequest {
question: String,
sparql_filter: Option<String>,
#[serde(default = "default_k")]
k: i32,
model: Option<String>,
#[serde(default = "default_output_format")]
output_format: String,
}
fn default_k() -> i32 {
5
}
fn default_output_format() -> String {
"jsonb".to_owned()
}
#[derive(Serialize)]
pub(crate) struct RagResult {
entity_iri: String,
label: String,
context_json: serde_json::Value,
distance: f64,
}
#[derive(Serialize)]
pub(crate) struct RagResponse {
results: Vec<RagResult>,
context: String,
}
pub(crate) mod admin_handlers;
pub(crate) mod confidence_handlers;
pub(crate) mod conflict_handler;
pub(crate) mod explain_handler;
pub(crate) mod hypothetical_handler;
pub(crate) mod pagerank_handlers;
pub(crate) mod rag_handler;
pub(crate) mod sparql_handlers;
pub(crate) use sparql_handlers::{
format_ask_result, format_graph_results, format_select_results, json_response_http,
};
pub(crate) fn build_router(state: Arc<AppState>, max_body_bytes: usize, cors: CorsLayer) -> Router {
Router::new()
.route(
"/sparql",
get(sparql_handlers::sparql_get).post(sparql_handlers::sparql_post),
)
.route("/sparql/stream", post(sparql_handlers::sparql_stream_post))
.route("/rag", post(rag_handler::rag_post))
.route("/health", get(admin_handlers::health))
.route("/ready", get(admin_handlers::ready))
.route("/health/ready", get(admin_handlers::health_ready))
.route("/metrics", get(admin_handlers::metrics_endpoint))
.route(
"/metrics/extension",
get(admin_handlers::extension_metrics_endpoint),
)
.route("/void", get(admin_handlers::void_endpoint))
.route("/service", get(admin_handlers::service_description))
.route("/openapi.yaml", get(admin_handlers::openapi_spec))
.route("/datalog/rules", get(datalog::list_rules))
.route(
"/datalog/rules/{rule_set}",
post(datalog::load_rules).delete(datalog::drop_rules),
)
.route(
"/datalog/rules/{rule_set}/builtin",
post(datalog::load_builtin),
)
.route("/datalog/rules/{rule_set}/add", post(datalog::add_rule))
.route(
"/datalog/rules/{rule_set}/{rule_id}",
delete(datalog::remove_rule),
)
.route(
"/datalog/rules/{rule_set}/enable",
put(datalog::enable_rule_set),
)
.route(
"/datalog/rules/{rule_set}/disable",
put(datalog::disable_rule_set),
)
.route("/datalog/infer/{rule_set}", post(datalog::infer))
.route(
"/datalog/infer/{rule_set}/stats",
post(datalog::infer_with_stats),
)
.route("/datalog/infer/{rule_set}/agg", post(datalog::infer_agg))
.route("/datalog/infer/{rule_set}/wfs", post(datalog::infer_wfs))
.route(
"/datalog/infer/{rule_set}/demand",
post(datalog::infer_demand),
)
.route(
"/datalog/infer/{rule_set}/lattice",
post(datalog::infer_lattice),
)
.route("/datalog/query/{rule_set}", post(datalog::query_goal))
.route("/datalog/constraints", get(datalog::check_constraints_all))
.route(
"/datalog/constraints/{rule_set}",
get(datalog::check_constraints),
)
.route("/datalog/stats/cache", get(datalog::cache_stats))
.route("/datalog/stats/tabling", get(datalog::tabling_stats))
.route(
"/datalog/lattices",
get(datalog::list_lattices).post(datalog::create_lattice),
)
.route(
"/datalog/views",
get(datalog::list_views).post(datalog::create_view),
)
.route("/datalog/views/{name}", delete(datalog::drop_view))
.route("/explorer", get(admin_handlers::explorer_page))
.route("/flight/do_get", post(flight_do_get))
.route("/subscribe/{subscription_id}", get(sparql_subscription_sse))
.route(
"/confidence/load",
post(confidence_handlers::load_with_confidence),
)
.route(
"/confidence/shacl-score",
get(confidence_handlers::shacl_score),
)
.route(
"/confidence/shacl-report",
get(confidence_handlers::shacl_report_scored),
)
.route(
"/confidence/vacuum",
post(confidence_handlers::vacuum_confidence),
)
.route("/pagerank/run", post(pagerank_handlers::pagerank_run))
.route(
"/pagerank/results",
get(pagerank_handlers::pagerank_results),
)
.route("/pagerank/status", get(pagerank_handlers::pagerank_status))
.route(
"/pagerank/vacuum-dirty",
post(pagerank_handlers::vacuum_dirty),
)
.route("/pagerank/export", get(pagerank_handlers::pagerank_export))
.route(
"/pagerank/explain/{node_iri}",
get(pagerank_handlers::pagerank_explain),
)
.route(
"/pagerank/queue-stats",
get(pagerank_handlers::pagerank_queue_stats),
)
.route("/centrality/run", post(pagerank_handlers::centrality_run))
.route(
"/centrality/results",
get(pagerank_handlers::centrality_results),
)
.route(
"/pagerank/find-duplicates",
post(pagerank_handlers::find_duplicates),
)
.route(
"/explain",
post(explain_handler::explain_post).get(explain_handler::explain_get),
)
.route(
"/hypothetical",
post(hypothetical_handler::hypothetical_post),
)
.route(
"/rule-conflicts/{ruleset}",
get(conflict_handler::rule_conflicts_get),
)
.layer(RequestBodyLimitLayer::new(max_body_bytes))
.layer(cors)
.with_state(state)
}
async fn sparql_subscription_sse(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
axum::extract::Path(subscription_id): axum::extract::Path<String>,
) -> Response {
if let Err(resp) = check_auth(&state, &headers) {
return resp;
}
if !subscription_id
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
{
return (
StatusCode::BAD_REQUEST,
"invalid subscription_id: only alphanumeric, hyphen and underscore allowed",
)
.into_response();
}
if subscription_id.len() > 63 {
return (
StatusCode::BAD_REQUEST,
"invalid subscription_id: maximum length is 63 characters",
)
.into_response();
}
let (tx, rx) = tokio::sync::mpsc::channel::<String>(32);
let pool = state.pool.clone();
let sub_id = subscription_id.clone();
tokio::spawn(async move {
let channel = format!("pg_ripple_subscription_{sub_id}");
let client = match pool.get().await {
Ok(c) => c,
Err(e) => {
let _ = tx
.send(format!("event: error\ndata: {{\"error\":\"{e}\"}}\n\n"))
.await;
return;
}
};
if let Err(e) = client.execute(&format!("LISTEN \"{channel}\""), &[]).await {
let _ = tx
.send(format!("event: error\ndata: {{\"error\":\"{e}\"}}\n\n"))
.await;
return;
}
if tx
.send(format!(
"event: subscribed\ndata: {{\"subscription_id\":\"{sub_id}\"}}\n\n"
))
.await
.is_err()
{
return;
}
let mut keepalive_tick: u64 = 0;
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
if tx.is_closed() {
break;
}
keepalive_tick += 1;
if keepalive_tick.is_multiple_of(3) {
if tx.send(": keepalive\n\n".to_string()).await.is_err() {
break;
}
}
}
});
use tokio_stream::StreamExt as _;
use tokio_stream::wrappers::ReceiverStream;
let body_stream = ReceiverStream::new(rx)
.map(|chunk: String| Ok::<_, std::convert::Infallible>(axum::body::Bytes::from(chunk)));
(
StatusCode::OK,
[
("content-type", "text/event-stream"),
("cache-control", "no-cache"),
("x-accel-buffering", "no"),
],
Body::from_stream(body_stream),
)
.into_response()
}