use std::collections::HashMap;
use std::sync::Arc;
use grafeo_common::types::Value;
use grafeo_core::graph::rdf::shacl::{ShaclError, SparqlExecutor, ValidationReport};
use grafeo_core::graph::rdf::{RdfStore, Term};
use crate::session::Session;
pub struct SessionSparqlExecutor<'a> {
session: &'a Session,
graph_name: Option<String>,
}
impl<'a> SessionSparqlExecutor<'a> {
pub fn new(session: &'a Session) -> Self {
Self {
session,
graph_name: None,
}
}
pub fn with_graph(session: &'a Session, graph_name: String) -> Self {
Self {
session,
graph_name: Some(graph_name),
}
}
}
impl SparqlExecutor for SessionSparqlExecutor<'_> {
fn execute(
&self,
query: &str,
this_binding: &Term,
) -> Result<Vec<HashMap<String, Term>>, ShaclError> {
let this_str = match this_binding {
Term::Iri(iri) => {
let iri_str = iri.as_str();
if !is_safe_iri(iri_str) {
return Err(ShaclError::SparqlError(format!(
"IRI contains characters unsafe for SPARQL embedding: {iri_str}"
)));
}
format!("<{iri_str}>")
}
Term::BlankNode(bnode) => format!("_:{}", bnode.id()),
Term::Literal(lit) => {
let escaped = escape_ntriples(lit.value());
if let Some(lang) = lit.language() {
format!("\"{escaped}\"@{}", lang)
} else if lit.datatype() != "http://www.w3.org/2001/XMLSchema#string" {
let dt = lit.datatype();
if !is_safe_iri(dt) {
return Err(ShaclError::SparqlError(format!(
"Datatype IRI contains unsafe characters: {dt}"
)));
}
format!("\"{escaped}\"^^<{dt}>")
} else {
format!("\"{escaped}\"")
}
}
_ => return Ok(Vec::new()),
};
let mut substituted = query.replace("$this", &this_str);
if let Some(ref graph) = self.graph_name {
if !is_safe_iri(graph) {
return Err(ShaclError::SparqlError(format!(
"Graph name contains characters unsafe for SPARQL embedding: {graph}"
)));
}
let upper = substituted.to_uppercase();
if let Some(pos) = upper.find("WHERE") {
substituted.insert_str(pos, &format!("FROM <{graph}> "));
}
}
let result = self
.session
.execute_sparql(&substituted)
.map_err(|e| ShaclError::SparqlError(e.to_string()))?;
let columns = &result.columns;
let mut rows = Vec::new();
for row in result.rows() {
let mut map = HashMap::new();
for (i, col) in columns.iter().enumerate() {
if let Some(value) = row.get(i)
&& let Some(term) = value_to_term(value)
{
map.insert(col.clone(), term);
}
}
rows.push(map);
}
Ok(rows)
}
}
fn is_safe_iri(iri: &str) -> bool {
!iri.chars().any(|c| {
matches!(c, '<' | '>' | '"' | '{' | '}' | '|' | '\\' | '^' | '`') || c.is_whitespace()
})
}
fn escape_ntriples(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'"' => out.push_str("\\\""),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
_ => out.push(ch),
}
}
out
}
fn value_to_term(value: &Value) -> Option<Term> {
match value {
Value::Null => None,
Value::String(s) => {
if s.starts_with("http://") || s.starts_with("https://") || s.starts_with("urn:") {
Some(Term::iri(s.as_str()))
} else {
Some(Term::literal(s.as_str()))
}
}
Value::Int64(n) => Some(Term::typed_literal(
n.to_string(),
"http://www.w3.org/2001/XMLSchema#integer",
)),
Value::Float64(f) => Some(Term::typed_literal(
f.to_string(),
"http://www.w3.org/2001/XMLSchema#double",
)),
Value::Bool(b) => Some(Term::typed_literal(
if *b { "true" } else { "false" },
"http://www.w3.org/2001/XMLSchema#boolean",
)),
_ => Some(Term::literal(value.to_string())),
}
}
pub fn validate_shacl(
session: &Session,
rdf_store: &Arc<RdfStore>,
shapes_graph_name: &str,
) -> grafeo_common::utils::error::Result<ValidationReport> {
let shapes_store = rdf_store.graph(shapes_graph_name).ok_or_else(|| {
grafeo_common::utils::error::Error::Internal(format!(
"Named graph '{shapes_graph_name}' not found"
))
})?;
let executor = SessionSparqlExecutor::new(session);
grafeo_core::graph::rdf::shacl::validate(rdf_store, &shapes_store, Some(&executor))
.map_err(|e| grafeo_common::utils::error::Error::Internal(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_ntriples_individual_chars() {
assert_eq!(escape_ntriples(r"\"), r"\\");
assert_eq!(escape_ntriples("\""), "\\\"");
assert_eq!(escape_ntriples("\n"), "\\n");
assert_eq!(escape_ntriples("\r"), "\\r");
assert_eq!(escape_ntriples("\t"), "\\t");
assert_eq!(escape_ntriples("\\\"\n\r\t"), "\\\\\\\"\\n\\r\\t");
}
#[test]
fn test_escape_ntriples_passthrough_and_mixed() {
assert_eq!(escape_ntriples(""), "");
assert_eq!(escape_ntriples("hello"), "hello");
assert_eq!(
escape_ntriples("Gus lives in Amsterdam"),
"Gus lives in Amsterdam"
);
assert_eq!(
escape_ntriples("Alix said \"hello\"\npath: C:\\data"),
"Alix said \\\"hello\\\"\\npath: C:\\\\data"
);
}
#[test]
fn test_value_to_term_null_and_primitives() {
assert_eq!(value_to_term(&Value::Null), None);
assert_eq!(
value_to_term(&Value::Int64(42)),
Some(Term::typed_literal(
"42",
"http://www.w3.org/2001/XMLSchema#integer"
))
);
assert_eq!(
value_to_term(&Value::Float64(2.72)),
Some(Term::typed_literal(
"2.72",
"http://www.w3.org/2001/XMLSchema#double"
))
);
assert_eq!(
value_to_term(&Value::Bool(true)),
Some(Term::typed_literal(
"true",
"http://www.w3.org/2001/XMLSchema#boolean"
))
);
}
#[test]
fn test_value_to_term_strings_and_uris() {
assert_eq!(
value_to_term(&Value::String("hello".into())),
Some(Term::literal("hello"))
);
assert_eq!(
value_to_term(&Value::String("Amsterdam".into())),
Some(Term::literal("Amsterdam"))
);
assert_eq!(
value_to_term(&Value::String("http://example.org/Alix".into())),
Some(Term::iri("http://example.org/Alix"))
);
assert_eq!(
value_to_term(&Value::String("https://schema.org/Person".into())),
Some(Term::iri("https://schema.org/Person"))
);
assert_eq!(
value_to_term(&Value::String("urn:isbn:0451450523".into())),
Some(Term::iri("urn:isbn:0451450523"))
);
}
#[test]
fn test_value_to_term_edge_cases() {
assert_eq!(
value_to_term(&Value::Int64(-7)),
Some(Term::typed_literal(
"-7",
"http://www.w3.org/2001/XMLSchema#integer"
))
);
assert_eq!(
value_to_term(&Value::Int64(0)),
Some(Term::typed_literal(
"0",
"http://www.w3.org/2001/XMLSchema#integer"
))
);
assert_eq!(
value_to_term(&Value::Bool(false)),
Some(Term::typed_literal(
"false",
"http://www.w3.org/2001/XMLSchema#boolean"
))
);
assert_eq!(
value_to_term(&Value::Float64(1.0)),
Some(Term::typed_literal(
"1",
"http://www.w3.org/2001/XMLSchema#double"
))
);
let result = value_to_term(&Value::Bytes(vec![0xCA, 0xFE].into()));
assert!(result.is_some());
assert!(result.unwrap().is_literal());
}
}