use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SparqlResultFormat {
Json,
Xml,
Csv,
Tsv,
NTriples,
Turtle,
NQuads,
TriG,
JsonLd,
}
impl SparqlResultFormat {
pub fn mime_type(&self) -> &'static str {
match self {
SparqlResultFormat::Json => "application/sparql-results+json",
SparqlResultFormat::Xml => "application/sparql-results+xml",
SparqlResultFormat::Csv => "text/csv",
SparqlResultFormat::Tsv => "text/tab-separated-values",
SparqlResultFormat::NTriples => "text/plain",
SparqlResultFormat::Turtle => "text/turtle",
SparqlResultFormat::NQuads => "application/n-quads",
SparqlResultFormat::TriG => "application/trig",
SparqlResultFormat::JsonLd => "application/ld+json",
}
}
pub fn from_mime(mime: &str) -> Option<Self> {
let base = mime.split(';').next().unwrap_or(mime).trim().to_lowercase();
match base.as_str() {
"application/sparql-results+json" | "application/json" => {
Some(SparqlResultFormat::Json)
}
"application/sparql-results+xml" | "application/xml" => Some(SparqlResultFormat::Xml),
"text/csv" => Some(SparqlResultFormat::Csv),
"text/tab-separated-values" | "text/tsv" => Some(SparqlResultFormat::Tsv),
"text/plain" | "application/n-triples" => Some(SparqlResultFormat::NTriples),
"text/turtle" | "application/turtle" | "application/x-turtle" => {
Some(SparqlResultFormat::Turtle)
}
"application/n-quads" => Some(SparqlResultFormat::NQuads),
"application/trig" => Some(SparqlResultFormat::TriG),
"application/ld+json" => Some(SparqlResultFormat::JsonLd),
_ => None,
}
}
pub fn is_select_format(&self) -> bool {
matches!(
self,
SparqlResultFormat::Json
| SparqlResultFormat::Xml
| SparqlResultFormat::Csv
| SparqlResultFormat::Tsv
)
}
pub fn is_graph_format(&self) -> bool {
matches!(
self,
SparqlResultFormat::NTriples
| SparqlResultFormat::Turtle
| SparqlResultFormat::NQuads
| SparqlResultFormat::TriG
| SparqlResultFormat::JsonLd
)
}
}
impl fmt::Display for SparqlResultFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.mime_type())
}
}
#[derive(Debug, Clone)]
pub struct AcceptEntry {
pub mime: String,
pub q: f32,
}
pub fn parse_accept_header(header: &str) -> Vec<AcceptEntry> {
let mut entries: Vec<AcceptEntry> = header
.split(',')
.filter_map(|part| {
let part = part.trim();
if part.is_empty() {
return None;
}
let mut segments = part.splitn(2, ';');
let mime = segments.next()?.trim().to_string();
let q = segments
.next()
.and_then(|params| {
params
.split(';')
.find(|p| p.trim().starts_with("q="))
.and_then(|p| p.trim().strip_prefix("q="))
.and_then(|v| v.parse::<f32>().ok())
})
.unwrap_or(1.0);
Some(AcceptEntry { mime, q })
})
.collect();
entries.sort_by(|a, b| {
b.q.partial_cmp(&a.q)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| {
let spec_a = if a.mime == "*/*" {
0
} else if a.mime.ends_with("/*") {
1
} else {
2
};
let spec_b = if b.mime == "*/*" {
0
} else if b.mime.ends_with("/*") {
1
} else {
2
};
spec_b.cmp(&spec_a)
})
});
entries
}
pub fn negotiate_select_format(accept: &str) -> SparqlResultFormat {
let default_formats = [
SparqlResultFormat::Json,
SparqlResultFormat::Xml,
SparqlResultFormat::Csv,
SparqlResultFormat::Tsv,
];
negotiate_format(accept, &default_formats).unwrap_or(SparqlResultFormat::Json)
}
pub fn negotiate_graph_format(accept: &str) -> SparqlResultFormat {
let default_formats = [
SparqlResultFormat::Turtle,
SparqlResultFormat::NTriples,
SparqlResultFormat::JsonLd,
SparqlResultFormat::NQuads,
SparqlResultFormat::TriG,
];
negotiate_format(accept, &default_formats).unwrap_or(SparqlResultFormat::Turtle)
}
fn negotiate_format(accept: &str, supported: &[SparqlResultFormat]) -> Option<SparqlResultFormat> {
if accept.trim().is_empty() {
return supported.first().copied();
}
let entries = parse_accept_header(accept);
for entry in &entries {
if entry.mime == "*/*" || entry.mime == "application/*" {
if let Some(fmt) = supported.first() {
return Some(*fmt);
}
}
if let Some(fmt) = SparqlResultFormat::from_mime(&entry.mime) {
if supported.contains(&fmt) {
return Some(fmt);
}
}
}
None
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GraphStoreOperation {
Get,
Put,
Post,
Delete,
Head,
}
#[derive(Debug, Clone)]
pub struct GraphStoreContentType {
pub format: SparqlResultFormat,
pub recognised: bool,
}
impl GraphStoreContentType {
pub fn from_header(content_type: Option<&str>) -> Self {
match content_type {
None => GraphStoreContentType {
format: SparqlResultFormat::Turtle,
recognised: false,
},
Some(ct) => match SparqlResultFormat::from_mime(ct) {
Some(fmt) if fmt.is_graph_format() => GraphStoreContentType {
format: fmt,
recognised: true,
},
Some(_) | None => GraphStoreContentType {
format: SparqlResultFormat::Turtle,
recognised: false,
},
},
}
}
}
#[derive(Debug, Clone)]
pub struct MultipartUpdatePart {
pub name: String,
pub content_type: String,
pub body: Vec<u8>,
}
impl MultipartUpdatePart {
pub fn body_str(&self) -> Result<&str, std::str::Utf8Error> {
std::str::from_utf8(&self.body)
}
}
pub fn parse_multipart_update(
boundary: &str,
body: &[u8],
) -> Result<Vec<MultipartUpdatePart>, MultipartError> {
let delimiter = format!("--{}", boundary);
let end_delimiter = format!("--{}--", boundary);
let body_str =
std::str::from_utf8(body).map_err(|e| MultipartError::InvalidUtf8(e.to_string()))?;
let mut parts = Vec::new();
for section in body_str.split(&delimiter) {
let section = section.trim_start_matches("\r\n").trim_start_matches('\n');
if section.trim().is_empty() || section.starts_with("--") || section == end_delimiter {
continue;
}
let split_pos = section
.find("\r\n\r\n")
.or_else(|| section.find("\n\n"))
.ok_or(MultipartError::MalformedPart)?;
let headers_str = §ion[..split_pos];
let body_start = if section[split_pos..].starts_with("\r\n\r\n") {
split_pos + 4
} else {
split_pos + 2
};
let part_body = if body_start <= section.len() {
§ion[body_start..]
} else {
""
};
let part_body = part_body
.trim_end_matches("--")
.trim_end_matches('\n')
.trim_end_matches('\r');
let mut name = String::new();
let mut content_type = "application/sparql-update".to_string();
for line in headers_str.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let lower = line.to_lowercase();
if lower.starts_with("content-disposition:") {
if let Some(n) = extract_header_param(line, "name") {
name = n;
}
} else if lower.starts_with("content-type:") {
if let Some((_key, ct)) = line.split_once(':') {
content_type = ct.trim().to_string();
}
}
}
if name.is_empty() {
return Err(MultipartError::MissingName);
}
parts.push(MultipartUpdatePart {
name,
content_type,
body: part_body.as_bytes().to_vec(),
});
}
Ok(parts)
}
fn extract_header_param(header_line: &str, param_name: &str) -> Option<String> {
let search = format!("{}=\"", param_name);
let start = header_line.find(&search)? + search.len();
let rest = &header_line[start..];
let end = rest.find('"')?;
Some(rest[..end].to_string())
}
pub fn extract_multipart_boundary(content_type: &str) -> Option<String> {
content_type.split(';').skip(1).find_map(|param| {
let param = param.trim();
param
.strip_prefix("boundary=")
.map(|boundary| boundary.trim_matches('"').to_string())
})
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum MultipartError {
#[error("Invalid UTF-8 in multipart body: {0}")]
InvalidUtf8(String),
#[error("Malformed part: missing header/body separator")]
MalformedPart,
#[error("Part is missing 'name' in Content-Disposition")]
MissingName,
}
pub struct SparqlProtocolHandler {
pub default_select_format: SparqlResultFormat,
pub default_graph_format: SparqlResultFormat,
}
impl Default for SparqlProtocolHandler {
fn default() -> Self {
SparqlProtocolHandler {
default_select_format: SparqlResultFormat::Json,
default_graph_format: SparqlResultFormat::Turtle,
}
}
}
impl SparqlProtocolHandler {
pub fn new(
default_select_format: SparqlResultFormat,
default_graph_format: SparqlResultFormat,
) -> Self {
SparqlProtocolHandler {
default_select_format,
default_graph_format,
}
}
pub fn select_format_from_headers(&self, headers: &HeaderMap) -> SparqlResultFormat {
match headers.get("accept").and_then(|v| v.to_str().ok()) {
Some(accept) => negotiate_select_format(accept),
None => self.default_select_format,
}
}
pub fn graph_format_from_headers(&self, headers: &HeaderMap) -> SparqlResultFormat {
match headers.get("accept").and_then(|v| v.to_str().ok()) {
Some(accept) => negotiate_graph_format(accept),
None => self.default_graph_format,
}
}
pub fn build_query_response(&self, body: String, format: SparqlResultFormat) -> Response {
use axum::http::{header, HeaderValue};
use axum::response::IntoResponse;
let content_type = format.mime_type();
match HeaderValue::from_str(content_type) {
Ok(ct_value) => {
let mut resp = body.into_response();
resp.headers_mut().insert(header::CONTENT_TYPE, ct_value);
resp
}
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Content-Type error").into_response(),
}
}
pub fn parse_multipart_update_request(
&self,
content_type_header: &str,
body: &[u8],
) -> Result<(String, Vec<String>, Vec<String>), MultipartError> {
let boundary =
extract_multipart_boundary(content_type_header).ok_or(MultipartError::MalformedPart)?;
let parts = parse_multipart_update(&boundary, body)?;
let mut update = String::new();
let mut using_graphs = Vec::new();
let mut using_named = Vec::new();
for part in parts {
match part.name.as_str() {
"update" => {
update = part
.body_str()
.map_err(|e| MultipartError::InvalidUtf8(e.to_string()))?
.to_string();
}
"using-graph-uri" => {
let uri = part
.body_str()
.map_err(|e| MultipartError::InvalidUtf8(e.to_string()))?
.trim()
.to_string();
using_graphs.push(uri);
}
"using-named-graph-uri" => {
let uri = part
.body_str()
.map_err(|e| MultipartError::InvalidUtf8(e.to_string()))?
.trim()
.to_string();
using_named.push(uri);
}
_ => {} }
}
if update.is_empty() {
return Err(MultipartError::MissingName);
}
Ok((update, using_graphs, using_named))
}
pub fn build_graph_store_response(
&self,
body: String,
content_type: &GraphStoreContentType,
) -> Response {
self.build_query_response(body, content_type.format)
}
pub fn not_acceptable_response(&self) -> Response {
use axum::response::IntoResponse;
(
StatusCode::NOT_ACCEPTABLE,
"No supported media type in Accept header",
)
.into_response()
}
pub fn unsupported_media_type_response(&self, provided: &str) -> Response {
use axum::response::IntoResponse;
(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
format!("Unsupported Content-Type: {provided}"),
)
.into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mime_type_strings() {
assert_eq!(
SparqlResultFormat::Json.mime_type(),
"application/sparql-results+json"
);
assert_eq!(
SparqlResultFormat::Xml.mime_type(),
"application/sparql-results+xml"
);
assert_eq!(SparqlResultFormat::Csv.mime_type(), "text/csv");
assert_eq!(
SparqlResultFormat::Tsv.mime_type(),
"text/tab-separated-values"
);
assert_eq!(SparqlResultFormat::Turtle.mime_type(), "text/turtle");
assert_eq!(SparqlResultFormat::NTriples.mime_type(), "text/plain");
}
#[test]
fn test_from_mime_roundtrip() {
for fmt in [
SparqlResultFormat::Json,
SparqlResultFormat::Xml,
SparqlResultFormat::Csv,
SparqlResultFormat::Tsv,
SparqlResultFormat::Turtle,
SparqlResultFormat::NTriples,
SparqlResultFormat::NQuads,
SparqlResultFormat::TriG,
SparqlResultFormat::JsonLd,
] {
let mime = fmt.mime_type();
let decoded = SparqlResultFormat::from_mime(mime);
assert!(decoded.is_some(), "round-trip failed for {mime}");
}
}
#[test]
fn test_from_mime_aliases() {
assert_eq!(
SparqlResultFormat::from_mime("application/json"),
Some(SparqlResultFormat::Json)
);
assert_eq!(
SparqlResultFormat::from_mime("application/xml"),
Some(SparqlResultFormat::Xml)
);
assert_eq!(
SparqlResultFormat::from_mime("application/turtle"),
Some(SparqlResultFormat::Turtle)
);
}
#[test]
fn test_from_mime_ignores_params() {
let fmt = SparqlResultFormat::from_mime("text/turtle; charset=utf-8");
assert_eq!(fmt, Some(SparqlResultFormat::Turtle));
}
#[test]
fn test_select_vs_graph_format() {
assert!(SparqlResultFormat::Json.is_select_format());
assert!(!SparqlResultFormat::Json.is_graph_format());
assert!(SparqlResultFormat::Turtle.is_graph_format());
assert!(!SparqlResultFormat::Turtle.is_select_format());
}
#[test]
fn test_parse_accept_single() {
let entries = parse_accept_header("application/sparql-results+json");
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].mime, "application/sparql-results+json");
assert!((entries[0].q - 1.0).abs() < 1e-6);
}
#[test]
fn test_parse_accept_multiple_with_q() {
let entries =
parse_accept_header("application/sparql-results+json, text/csv; q=0.5, */*; q=0.1");
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].mime, "application/sparql-results+json");
assert!(entries[0].q >= entries[1].q);
assert!(entries[1].q >= entries[2].q);
}
#[test]
fn test_parse_accept_empty() {
assert!(parse_accept_header("").is_empty());
}
#[test]
fn test_negotiate_select_default() {
let fmt = negotiate_select_format("");
assert_eq!(fmt, SparqlResultFormat::Json);
}
#[test]
fn test_negotiate_select_prefers_quality() {
let fmt = negotiate_select_format("text/csv; q=0.9, application/sparql-results+xml; q=0.8");
assert_eq!(fmt, SparqlResultFormat::Csv);
}
#[test]
fn test_negotiate_select_wildcard() {
let fmt = negotiate_select_format("*/*");
assert_eq!(fmt, SparqlResultFormat::Json);
}
#[test]
fn test_negotiate_graph_default() {
let fmt = negotiate_graph_format("");
assert_eq!(fmt, SparqlResultFormat::Turtle);
}
#[test]
fn test_negotiate_graph_picks_jsonld() {
let fmt = negotiate_graph_format("application/ld+json");
assert_eq!(fmt, SparqlResultFormat::JsonLd);
}
#[test]
fn test_negotiate_select_tsv() {
let fmt = negotiate_select_format("text/tab-separated-values");
assert_eq!(fmt, SparqlResultFormat::Tsv);
}
#[test]
fn test_graph_store_ct_turtle() {
let ct = GraphStoreContentType::from_header(Some("text/turtle"));
assert_eq!(ct.format, SparqlResultFormat::Turtle);
assert!(ct.recognised);
}
#[test]
fn test_graph_store_ct_rejects_json_select() {
let ct = GraphStoreContentType::from_header(Some("application/sparql-results+json"));
assert!(!ct.recognised, "SELECT format is not valid for graph store");
}
#[test]
fn test_graph_store_ct_missing() {
let ct = GraphStoreContentType::from_header(None);
assert_eq!(ct.format, SparqlResultFormat::Turtle);
assert!(!ct.recognised);
}
#[test]
fn test_graph_store_ct_nquads() {
let ct = GraphStoreContentType::from_header(Some("application/n-quads"));
assert_eq!(ct.format, SparqlResultFormat::NQuads);
assert!(ct.recognised);
}
#[test]
fn test_extract_boundary() {
let ct = "multipart/form-data; boundary=----WebKitFormBoundary";
let b = extract_multipart_boundary(ct);
assert_eq!(b, Some("----WebKitFormBoundary".to_string()));
}
#[test]
fn test_extract_boundary_quoted() {
let ct = r#"multipart/form-data; boundary="my-boundary-123""#;
let b = extract_multipart_boundary(ct);
assert_eq!(b, Some("my-boundary-123".to_string()));
}
#[test]
fn test_parse_multipart_basic() {
let boundary = "testboundary";
let body = format!(
"--{boundary}\r\nContent-Disposition: form-data; name=\"update\"\r\nContent-Type: application/sparql-update\r\n\r\nINSERT DATA {{ <s> <p> <o> }}\r\n--{boundary}--",
);
let parts = parse_multipart_update(boundary, body.as_bytes()).unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0].name, "update");
assert!(parts[0].body_str().unwrap().contains("INSERT DATA"));
}
#[test]
fn test_parse_multipart_with_graph_uri() {
let boundary = "b1";
let body = format!(
"--{boundary}\r\nContent-Disposition: form-data; name=\"update\"\r\n\r\nINSERT DATA {{ <s> <p> <o> }}\r\n--{boundary}\r\nContent-Disposition: form-data; name=\"using-graph-uri\"\r\n\r\nhttp://example.org/graph\r\n--{boundary}--",
);
let parts = parse_multipart_update(boundary, body.as_bytes()).unwrap();
assert_eq!(parts.len(), 2);
let graph = parts.iter().find(|p| p.name == "using-graph-uri").unwrap();
assert!(graph
.body_str()
.unwrap()
.contains("http://example.org/graph"));
}
#[test]
fn test_handler_parse_multipart() {
let handler = SparqlProtocolHandler::default();
let boundary = "bound42";
let ct = format!("multipart/form-data; boundary={boundary}");
let body = format!(
"--{boundary}\r\nContent-Disposition: form-data; name=\"update\"\r\n\r\nDELETE WHERE {{ ?s ?p ?o }}\r\n--{boundary}--",
);
let (update, graphs, named) = handler
.parse_multipart_update_request(&ct, body.as_bytes())
.unwrap();
assert!(update.contains("DELETE WHERE"));
assert!(graphs.is_empty());
assert!(named.is_empty());
}
#[test]
fn test_handler_build_response_content_type() {
let handler = SparqlProtocolHandler::default();
let resp = handler.build_query_response("{}".to_string(), SparqlResultFormat::Json);
let ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok());
assert_eq!(ct, Some("application/sparql-results+json"));
}
#[test]
fn test_handler_negotiate_from_headers() {
use axum::http::HeaderMap;
use axum::http::HeaderValue;
let handler = SparqlProtocolHandler::default();
let mut headers = HeaderMap::new();
headers.insert("accept", HeaderValue::from_static("text/csv"));
let fmt = handler.select_format_from_headers(&headers);
assert_eq!(fmt, SparqlResultFormat::Csv);
}
#[test]
fn test_not_acceptable_response() {
let handler = SparqlProtocolHandler::default();
let resp = handler.not_acceptable_response();
assert_eq!(resp.status(), StatusCode::NOT_ACCEPTABLE);
}
#[test]
fn test_unsupported_media_type_response() {
let handler = SparqlProtocolHandler::default();
let resp = handler.unsupported_media_type_response("application/octet-stream");
assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[test]
fn test_graph_store_response_content_type() {
let handler = SparqlProtocolHandler::default();
let ct = GraphStoreContentType {
format: SparqlResultFormat::NTriples,
recognised: true,
};
let resp = handler.build_graph_store_response("<s> <p> <o> .".to_string(), &ct);
let header_ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok());
assert_eq!(header_ct, Some("text/plain"));
}
#[test]
fn test_display() {
let fmt = SparqlResultFormat::Turtle;
assert_eq!(fmt.to_string(), "text/turtle");
}
#[test]
fn test_multipart_missing_boundary_returns_error() {
let handler = SparqlProtocolHandler::default();
let result = handler.parse_multipart_update_request(
"multipart/form-data", b"irrelevant",
);
assert!(result.is_err());
}
}