Skip to main content

oxirs_core/federation/
executor.rs

1//! Federation executor for SERVICE clause execution
2
3use super::client::FederationClient;
4use super::results::SparqlResultsParser;
5use crate::model::Term;
6use crate::query::{
7    NamedNodePattern, SparqlGraphPattern as GraphPattern, SparqlTermPattern as TermPattern,
8};
9use crate::OxirsError;
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13/// Federation executor for executing SERVICE clauses
14pub struct FederationExecutor {
15    client: FederationClient,
16}
17
18impl FederationExecutor {
19    /// Create a new federation executor
20    pub fn new() -> Result<Self, OxirsError> {
21        let client = FederationClient::new()?;
22        Ok(Self { client })
23    }
24
25    /// Execute a SERVICE clause
26    ///
27    /// # Arguments
28    /// * `endpoint` - The SERVICE endpoint (IRI or variable)
29    /// * `pattern` - The graph pattern to execute at the endpoint
30    /// * `silent` - If true, suppress errors and return empty results
31    /// * `bindings` - Current variable bindings from local query
32    ///
33    /// # Returns
34    /// Vector of solution bindings from the remote endpoint
35    pub async fn execute_service(
36        &self,
37        endpoint: &NamedNodePattern,
38        pattern: &GraphPattern,
39        silent: bool,
40        bindings: &[HashMap<String, Term>],
41    ) -> Result<Vec<HashMap<String, Term>>, OxirsError> {
42        // Extract endpoint URL
43        let endpoint_url = match endpoint {
44            NamedNodePattern::NamedNode(node) => node.as_str().to_string(),
45            NamedNodePattern::Variable(_) => {
46                return Err(OxirsError::Federation(
47                    "Variable endpoints are not yet supported".to_string(),
48                ))
49            }
50        };
51
52        info!("Executing SERVICE clause at endpoint: {}", endpoint_url);
53        debug!("Pattern: {:?}", pattern);
54        debug!("Current bindings: {} solutions", bindings.len());
55
56        // Convert graph pattern to SPARQL query string
57        let sparql_query = self.pattern_to_sparql(pattern)?;
58        debug!("Generated SPARQL query: {}", sparql_query);
59
60        // Execute query at remote endpoint
61        let json_response = self
62            .client
63            .execute_query(&endpoint_url, &sparql_query, silent)
64            .await?;
65
66        // Parse results
67        let remote_bindings = SparqlResultsParser::parse(&json_response)?;
68
69        info!(
70            "Received {} solutions from remote endpoint",
71            remote_bindings.len()
72        );
73
74        Ok(remote_bindings)
75    }
76
77    /// Convert a graph pattern to a SPARQL SELECT query
78    fn pattern_to_sparql(&self, pattern: &GraphPattern) -> Result<String, OxirsError> {
79        // Extract variables from the pattern
80        let variables = self.extract_variables(pattern);
81
82        // Build SELECT clause
83        let select_clause = if variables.is_empty() {
84            "SELECT *".to_string()
85        } else {
86            format!("SELECT {}", variables.join(" "))
87        };
88
89        // Convert pattern to WHERE clause
90        let where_clause = Self::pattern_to_where_clause(pattern)?;
91
92        Ok(format!("{} WHERE {{ {} }}", select_clause, where_clause))
93    }
94
95    /// Extract variables from a graph pattern
96    fn extract_variables(&self, pattern: &GraphPattern) -> Vec<String> {
97        let mut vars = Vec::new();
98        Self::collect_variables(pattern, &mut vars);
99        vars.sort();
100        vars.dedup();
101        vars.into_iter().map(|v| format!("?{}", v)).collect()
102    }
103
104    /// Recursively collect variables from pattern
105    fn collect_variables(pattern: &GraphPattern, vars: &mut Vec<String>) {
106        match pattern {
107            GraphPattern::Bgp { patterns } => {
108                for tp in patterns {
109                    // Extract variables from triple pattern
110                    if let TermPattern::Variable(v) = &tp.subject {
111                        vars.push(v.name().to_string());
112                    }
113                    if let TermPattern::Variable(v) = &tp.predicate {
114                        vars.push(v.name().to_string());
115                    }
116                    if let TermPattern::Variable(v) = &tp.object {
117                        vars.push(v.name().to_string());
118                    }
119                }
120            }
121            GraphPattern::Join { left, right } | GraphPattern::Union { left, right } => {
122                Self::collect_variables(left, vars);
123                Self::collect_variables(right, vars);
124            }
125            GraphPattern::Filter { inner, .. }
126            | GraphPattern::Distinct { inner }
127            | GraphPattern::Reduced { inner }
128            | GraphPattern::Extend { inner, .. }
129            | GraphPattern::Group { inner, .. }
130            | GraphPattern::Project { inner, .. } => {
131                Self::collect_variables(inner, vars);
132            }
133            GraphPattern::LeftJoin { left, right, .. } => {
134                Self::collect_variables(left, vars);
135                Self::collect_variables(right, vars);
136            }
137            GraphPattern::Service { inner, .. } => {
138                Self::collect_variables(inner, vars);
139            }
140            _ => {}
141        }
142    }
143
144    /// Convert graph pattern to WHERE clause string
145    fn pattern_to_where_clause(pattern: &GraphPattern) -> Result<String, OxirsError> {
146        match pattern {
147            GraphPattern::Bgp { patterns } => {
148                let mut clauses = Vec::new();
149                for tp in patterns {
150                    let s = Self::term_pattern_to_string(&tp.subject);
151                    let p = Self::term_pattern_to_string(&tp.predicate);
152                    let o = Self::term_pattern_to_string(&tp.object);
153                    clauses.push(format!("{} {} {}", s, p, o));
154                }
155                Ok(clauses.join(" . "))
156            }
157            GraphPattern::Join { left, right } => {
158                let left_str = Self::pattern_to_where_clause(left)?;
159                let right_str = Self::pattern_to_where_clause(right)?;
160                Ok(format!("{} . {}", left_str, right_str))
161            }
162            GraphPattern::Union { left, right } => {
163                let left_str = Self::pattern_to_where_clause(left)?;
164                let right_str = Self::pattern_to_where_clause(right)?;
165                Ok(format!("{{ {} }} UNION {{ {} }}", left_str, right_str))
166            }
167            GraphPattern::Filter { expr: _, inner } => {
168                let inner_str = Self::pattern_to_where_clause(inner)?;
169                // Simplified filter expression (full implementation would handle all expression types)
170                Ok(format!("{} FILTER(?var)", inner_str))
171            }
172            _ => {
173                // For other patterns, use a simplified representation
174                Ok("?s ?p ?o".to_string())
175            }
176        }
177    }
178
179    /// Convert a term pattern to SPARQL string
180    fn term_pattern_to_string(term: &TermPattern) -> String {
181        match term {
182            TermPattern::Variable(v) => format!("?{}", v.name()),
183            TermPattern::NamedNode(n) => format!("<{}>", n.as_str()),
184            TermPattern::BlankNode(b) => format!("_:{}", b.as_str()),
185            TermPattern::Literal(l) => {
186                if let Some(lang) = l.language() {
187                    format!("\"{}\"@{}", l.value(), lang)
188                } else if l.datatype().as_str() != "http://www.w3.org/2001/XMLSchema#string" {
189                    format!("\"{}\"^^<{}>", l.value(), l.datatype().as_str())
190                } else {
191                    format!("\"{}\"", l.value())
192                }
193            }
194            #[cfg(feature = "sparql-12")]
195            TermPattern::Triple(triple) => {
196                format!(
197                    "<< {} {} {} >>",
198                    Self::term_pattern_to_string(&triple.subject),
199                    Self::term_pattern_to_string(&triple.predicate),
200                    Self::term_pattern_to_string(&triple.object)
201                )
202            }
203        }
204    }
205
206    /// Merge local and remote bindings
207    pub fn merge_bindings(
208        &self,
209        local: Vec<HashMap<String, Term>>,
210        remote: Vec<HashMap<String, Term>>,
211    ) -> Vec<HashMap<String, Term>> {
212        if local.is_empty() {
213            return remote;
214        }
215        if remote.is_empty() {
216            return local;
217        }
218
219        // Find common variables
220        let local_vars: std::collections::HashSet<_> = local[0].keys().cloned().collect();
221        let remote_vars: std::collections::HashSet<_> = remote[0].keys().cloned().collect();
222        let common_vars: Vec<_> = local_vars.intersection(&remote_vars).cloned().collect();
223
224        debug!(
225            "Merging bindings with {} common variables",
226            common_vars.len()
227        );
228
229        if common_vars.is_empty() {
230            // Cartesian product if no common variables
231            let mut result = Vec::new();
232            for l in &local {
233                for r in &remote {
234                    let mut merged = l.clone();
235                    merged.extend(r.clone());
236                    result.push(merged);
237                }
238            }
239            result
240        } else {
241            // Hash join on common variables
242            let mut result = Vec::new();
243            for l in &local {
244                for r in &remote {
245                    // Check if bindings are compatible
246                    let mut compatible = true;
247                    for var in &common_vars {
248                        if let (Some(l_val), Some(r_val)) = (l.get(var), r.get(var)) {
249                            if l_val != r_val {
250                                compatible = false;
251                                break;
252                            }
253                        }
254                    }
255
256                    if compatible {
257                        let mut merged = l.clone();
258                        merged.extend(r.clone());
259                        result.push(merged);
260                    }
261                }
262            }
263            result
264        }
265    }
266}
267
268impl Default for FederationExecutor {
269    fn default() -> Self {
270        Self::new().expect("Failed to create default federation executor")
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::model::NamedNode;
278
279    #[tokio::test]
280    async fn test_executor_creation() {
281        let executor = FederationExecutor::new();
282        assert!(executor.is_ok());
283    }
284
285    #[test]
286    fn test_merge_bindings_no_common_vars() {
287        let executor = FederationExecutor::new().expect("construction should succeed");
288
289        let local = vec![{
290            let mut m = HashMap::new();
291            m.insert(
292                "x".to_string(),
293                Term::NamedNode(NamedNode::new("http://example.org/a").expect("valid IRI")),
294            );
295            m
296        }];
297
298        let remote = vec![{
299            let mut m = HashMap::new();
300            m.insert(
301                "y".to_string(),
302                Term::NamedNode(NamedNode::new("http://example.org/b").expect("valid IRI")),
303            );
304            m
305        }];
306
307        let result = executor.merge_bindings(local, remote);
308        assert_eq!(result.len(), 1);
309        assert_eq!(result[0].len(), 2);
310    }
311
312    #[test]
313    fn test_merge_bindings_with_common_vars() {
314        let executor = FederationExecutor::new().expect("construction should succeed");
315
316        let node = Term::NamedNode(NamedNode::new("http://example.org/same").expect("valid IRI"));
317
318        let local = vec![{
319            let mut m = HashMap::new();
320            m.insert("x".to_string(), node.clone());
321            m.insert(
322                "y".to_string(),
323                Term::NamedNode(NamedNode::new("http://example.org/a").expect("valid IRI")),
324            );
325            m
326        }];
327
328        let remote = vec![{
329            let mut m = HashMap::new();
330            m.insert("x".to_string(), node.clone());
331            m.insert(
332                "z".to_string(),
333                Term::NamedNode(NamedNode::new("http://example.org/b").expect("valid IRI")),
334            );
335            m
336        }];
337
338        let result = executor.merge_bindings(local, remote);
339        assert_eq!(result.len(), 1);
340        assert_eq!(result[0].len(), 3); // x, y, z
341    }
342}