use std::collections::HashMap;
use std::sync::Arc;
use grafeo_common::types::Value;
use grafeo_core::graph::rdf::shacl::{ShaclError, SparqlExecutor, ValidationReport};
use grafeo_core::graph::rdf::{RdfStore, Term};
use crate::session::Session;
pub struct SessionSparqlExecutor<'a> {
session: &'a Session,
graph_name: Option<String>,
}
impl<'a> SessionSparqlExecutor<'a> {
pub fn new(session: &'a Session) -> Self {
Self {
session,
graph_name: None,
}
}
pub fn with_graph(session: &'a Session, graph_name: String) -> Self {
Self {
session,
graph_name: Some(graph_name),
}
}
}
impl SparqlExecutor for SessionSparqlExecutor<'_> {
fn execute(
&self,
query: &str,
this_binding: &Term,
) -> Result<Vec<HashMap<String, Term>>, ShaclError> {
let this_str = match this_binding {
Term::Iri(iri) => format!("<{}>", iri.as_str()),
Term::BlankNode(bnode) => format!("_:{}", bnode.id()),
Term::Literal(lit) => {
let escaped = escape_ntriples(lit.value());
if let Some(lang) = lit.language() {
format!("\"{escaped}\"@{}", lang)
} else if lit.datatype() != "http://www.w3.org/2001/XMLSchema#string" {
format!("\"{escaped}\"^^<{}>", lit.datatype())
} else {
format!("\"{escaped}\"")
}
}
_ => return Ok(Vec::new()),
};
let mut substituted = query.replace("$this", &this_str);
if let Some(ref graph) = self.graph_name {
if !graph.contains('>') && !graph.contains('<') && !graph.contains('"') {
let upper = substituted.to_uppercase();
if let Some(pos) = upper.find("WHERE") {
substituted.insert_str(pos, &format!("FROM <{graph}> "));
}
}
}
let result = self
.session
.execute_sparql(&substituted)
.map_err(|e| ShaclError::SparqlError(e.to_string()))?;
let columns = &result.columns;
let mut rows = Vec::new();
for row in result.rows() {
let mut map = HashMap::new();
for (i, col) in columns.iter().enumerate() {
if let Some(value) = row.get(i)
&& let Some(term) = value_to_term(value)
{
map.insert(col.clone(), term);
}
}
rows.push(map);
}
Ok(rows)
}
}
fn escape_ntriples(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'"' => out.push_str("\\\""),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
_ => out.push(ch),
}
}
out
}
fn value_to_term(value: &Value) -> Option<Term> {
match value {
Value::Null => None,
Value::String(s) => {
if s.starts_with("http://") || s.starts_with("https://") || s.starts_with("urn:") {
Some(Term::iri(s.as_str()))
} else {
Some(Term::literal(s.as_str()))
}
}
Value::Int64(n) => Some(Term::typed_literal(
n.to_string(),
"http://www.w3.org/2001/XMLSchema#integer",
)),
Value::Float64(f) => Some(Term::typed_literal(
f.to_string(),
"http://www.w3.org/2001/XMLSchema#double",
)),
Value::Bool(b) => Some(Term::typed_literal(
if *b { "true" } else { "false" },
"http://www.w3.org/2001/XMLSchema#boolean",
)),
_ => Some(Term::literal(value.to_string())),
}
}
pub fn validate_shacl(
session: &Session,
rdf_store: &Arc<RdfStore>,
shapes_graph_name: &str,
) -> grafeo_common::utils::error::Result<ValidationReport> {
let shapes_store = rdf_store.graph(shapes_graph_name).ok_or_else(|| {
grafeo_common::utils::error::Error::Internal(format!(
"Named graph '{shapes_graph_name}' not found"
))
})?;
let executor = SessionSparqlExecutor::new(session);
grafeo_core::graph::rdf::shacl::validate(rdf_store, &shapes_store, Some(&executor))
.map_err(|e| grafeo_common::utils::error::Error::Internal(e.to_string()))
}