oxirs_graphrag/sparql/
graph_functions.rs

1//! SPARQL extension functions for GraphRAG queries
2
3use crate::{GraphRAGResult, Triple};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// GraphRAG SPARQL function definitions
8///
9/// These functions extend SPARQL with GraphRAG capabilities:
10/// - graphrag:query(text) - Execute GraphRAG query
11/// - graphrag:similar(entity, threshold) - Find similar entities
12/// - graphrag:expand(entity, hops) - Expand from entity
13/// - graphrag:community(graph) - Detect communities
14#[derive(Debug, Clone)]
15pub struct GraphRAGFunctions {
16    /// Function registry
17    functions: HashMap<String, FunctionDef>,
18}
19
20/// Function definition
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionDef {
23    /// Function name
24    pub name: String,
25    /// Function URI
26    pub uri: String,
27    /// Parameter types
28    pub params: Vec<ParamDef>,
29    /// Return type
30    pub return_type: ReturnType,
31    /// Description
32    pub description: String,
33}
34
35/// Parameter definition
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ParamDef {
38    pub name: String,
39    pub param_type: ParamType,
40    pub required: bool,
41}
42
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum ParamType {
45    String,
46    Integer,
47    Float,
48    Uri,
49    Boolean,
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53pub enum ReturnType {
54    Binding,
55    Triple,
56    Graph,
57    Scalar,
58}
59
60impl Default for GraphRAGFunctions {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl GraphRAGFunctions {
67    /// Create new function registry with default GraphRAG functions
68    pub fn new() -> Self {
69        let mut functions = HashMap::new();
70
71        // graphrag:query - Main GraphRAG query function
72        functions.insert(
73            "query".to_string(),
74            FunctionDef {
75                name: "query".to_string(),
76                uri: "http://oxirs.io/graphrag#query".to_string(),
77                params: vec![
78                    ParamDef {
79                        name: "text".to_string(),
80                        param_type: ParamType::String,
81                        required: true,
82                    },
83                    ParamDef {
84                        name: "top_k".to_string(),
85                        param_type: ParamType::Integer,
86                        required: false,
87                    },
88                ],
89                return_type: ReturnType::Graph,
90                description: "Execute GraphRAG query and return relevant subgraph".to_string(),
91            },
92        );
93
94        // graphrag:similar - Vector similarity search
95        functions.insert(
96            "similar".to_string(),
97            FunctionDef {
98                name: "similar".to_string(),
99                uri: "http://oxirs.io/graphrag#similar".to_string(),
100                params: vec![
101                    ParamDef {
102                        name: "entity".to_string(),
103                        param_type: ParamType::Uri,
104                        required: true,
105                    },
106                    ParamDef {
107                        name: "threshold".to_string(),
108                        param_type: ParamType::Float,
109                        required: false,
110                    },
111                    ParamDef {
112                        name: "k".to_string(),
113                        param_type: ParamType::Integer,
114                        required: false,
115                    },
116                ],
117                return_type: ReturnType::Binding,
118                description: "Find entities similar to the given entity".to_string(),
119            },
120        );
121
122        // graphrag:expand - Graph expansion
123        functions.insert(
124            "expand".to_string(),
125            FunctionDef {
126                name: "expand".to_string(),
127                uri: "http://oxirs.io/graphrag#expand".to_string(),
128                params: vec![
129                    ParamDef {
130                        name: "entity".to_string(),
131                        param_type: ParamType::Uri,
132                        required: true,
133                    },
134                    ParamDef {
135                        name: "hops".to_string(),
136                        param_type: ParamType::Integer,
137                        required: false,
138                    },
139                    ParamDef {
140                        name: "max_triples".to_string(),
141                        param_type: ParamType::Integer,
142                        required: false,
143                    },
144                ],
145                return_type: ReturnType::Graph,
146                description: "Expand subgraph from entity".to_string(),
147            },
148        );
149
150        // graphrag:community - Community detection
151        functions.insert(
152            "community".to_string(),
153            FunctionDef {
154                name: "community".to_string(),
155                uri: "http://oxirs.io/graphrag#community".to_string(),
156                params: vec![
157                    ParamDef {
158                        name: "graph".to_string(),
159                        param_type: ParamType::Uri,
160                        required: true,
161                    },
162                    ParamDef {
163                        name: "algorithm".to_string(),
164                        param_type: ParamType::String,
165                        required: false,
166                    },
167                ],
168                return_type: ReturnType::Binding,
169                description: "Detect communities in graph".to_string(),
170            },
171        );
172
173        // graphrag:embed - Get entity embedding
174        functions.insert(
175            "embed".to_string(),
176            FunctionDef {
177                name: "embed".to_string(),
178                uri: "http://oxirs.io/graphrag#embed".to_string(),
179                params: vec![ParamDef {
180                    name: "entity".to_string(),
181                    param_type: ParamType::Uri,
182                    required: true,
183                }],
184                return_type: ReturnType::Scalar,
185                description: "Get embedding vector for entity".to_string(),
186            },
187        );
188
189        Self { functions }
190    }
191
192    /// Get function definition by name
193    pub fn get(&self, name: &str) -> Option<&FunctionDef> {
194        self.functions.get(name)
195    }
196
197    /// Get all function definitions
198    pub fn all(&self) -> impl Iterator<Item = &FunctionDef> {
199        self.functions.values()
200    }
201
202    /// Generate SPARQL SERVICE clause for GraphRAG
203    pub fn generate_service_clause(&self, function: &str, args: &[&str]) -> GraphRAGResult<String> {
204        let func_def = self.get(function).ok_or_else(|| {
205            crate::GraphRAGError::SparqlError(format!("Unknown function: {}", function))
206        })?;
207
208        let args_str = args.join(", ");
209        Ok(format!(
210            "SERVICE <{}> {{ ?result graphrag:{}({}) }}",
211            func_def.uri, function, args_str
212        ))
213    }
214
215    /// Parse SPARQL query for GraphRAG function calls
216    pub fn parse_query(&self, sparql: &str) -> Vec<FunctionCall> {
217        let mut calls = Vec::new();
218
219        // Simple regex-based parsing (full implementation would use SPARQL parser)
220        let re = regex::Regex::new(r"graphrag:(\w+)\(([^)]*)\)")
221            .expect("GraphRAG function regex pattern is valid");
222
223        for cap in re.captures_iter(sparql) {
224            if let (Some(func), Some(args)) = (cap.get(1), cap.get(2)) {
225                let func_name = func.as_str().to_string();
226                let args: Vec<String> = args
227                    .as_str()
228                    .split(',')
229                    .map(|s| s.trim().to_string())
230                    .filter(|s| !s.is_empty())
231                    .collect();
232
233                if self.functions.contains_key(&func_name) {
234                    calls.push(FunctionCall {
235                        function: func_name,
236                        arguments: args,
237                    });
238                }
239            }
240        }
241
242        calls
243    }
244}
245
246/// Parsed function call
247#[derive(Debug, Clone)]
248pub struct FunctionCall {
249    pub function: String,
250    pub arguments: Vec<String>,
251}
252
253/// GraphRAG SPARQL query builder
254pub struct QueryBuilder {
255    prefixes: Vec<(String, String)>,
256    select_vars: Vec<String>,
257    where_patterns: Vec<String>,
258    graphrag_calls: Vec<String>,
259    limit: Option<usize>,
260    offset: Option<usize>,
261}
262
263impl Default for QueryBuilder {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269impl QueryBuilder {
270    pub fn new() -> Self {
271        Self {
272            prefixes: vec![
273                (
274                    "graphrag".to_string(),
275                    "http://oxirs.io/graphrag#".to_string(),
276                ),
277                (
278                    "rdfs".to_string(),
279                    "http://www.w3.org/2000/01/rdf-schema#".to_string(),
280                ),
281            ],
282            select_vars: Vec::new(),
283            where_patterns: Vec::new(),
284            graphrag_calls: Vec::new(),
285            limit: None,
286            offset: None,
287        }
288    }
289
290    pub fn prefix(mut self, prefix: &str, uri: &str) -> Self {
291        self.prefixes.push((prefix.to_string(), uri.to_string()));
292        self
293    }
294
295    pub fn select(mut self, vars: &[&str]) -> Self {
296        self.select_vars = vars.iter().map(|s| s.to_string()).collect();
297        self
298    }
299
300    pub fn triple(mut self, subject: &str, predicate: &str, object: &str) -> Self {
301        self.where_patterns
302            .push(format!("{} {} {}", subject, predicate, object));
303        self
304    }
305
306    pub fn graphrag_query(mut self, text: &str, result_var: &str) -> Self {
307        self.graphrag_calls.push(format!(
308            "BIND(graphrag:query(\"{}\") AS {})",
309            text, result_var
310        ));
311        self
312    }
313
314    pub fn graphrag_similar(mut self, entity: &str, threshold: f32, result_var: &str) -> Self {
315        self.graphrag_calls.push(format!(
316            "{} graphrag:similar(\"{}\", {})",
317            result_var, entity, threshold
318        ));
319        self
320    }
321
322    pub fn limit(mut self, limit: usize) -> Self {
323        self.limit = Some(limit);
324        self
325    }
326
327    pub fn offset(mut self, offset: usize) -> Self {
328        self.offset = Some(offset);
329        self
330    }
331
332    pub fn build(self) -> String {
333        let mut query = String::new();
334
335        // Prefixes
336        for (prefix, uri) in &self.prefixes {
337            query.push_str(&format!("PREFIX {}: <{}>\n", prefix, uri));
338        }
339        query.push('\n');
340
341        // SELECT
342        if self.select_vars.is_empty() {
343            query.push_str("SELECT * ");
344        } else {
345            query.push_str("SELECT ");
346            query.push_str(&self.select_vars.join(" "));
347            query.push(' ');
348        }
349
350        // WHERE
351        query.push_str("WHERE {\n");
352
353        for pattern in &self.where_patterns {
354            query.push_str(&format!("  {} .\n", pattern));
355        }
356
357        for call in &self.graphrag_calls {
358            query.push_str(&format!("  {} .\n", call));
359        }
360
361        query.push_str("}\n");
362
363        // LIMIT/OFFSET
364        if let Some(limit) = self.limit {
365            query.push_str(&format!("LIMIT {}\n", limit));
366        }
367        if let Some(offset) = self.offset {
368            query.push_str(&format!("OFFSET {}\n", offset));
369        }
370
371        query
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_function_registry() {
381        let funcs = GraphRAGFunctions::new();
382
383        assert!(funcs.get("query").is_some());
384        assert!(funcs.get("similar").is_some());
385        assert!(funcs.get("expand").is_some());
386        assert!(funcs.get("unknown").is_none());
387    }
388
389    #[test]
390    fn test_query_parsing() {
391        let funcs = GraphRAGFunctions::new();
392
393        let sparql = r#"
394            SELECT ?entity WHERE {
395                ?entity graphrag:similar("battery", 0.8) .
396                BIND(graphrag:query("safety issues") AS ?result)
397            }
398        "#;
399
400        let calls = funcs.parse_query(sparql);
401
402        assert_eq!(calls.len(), 2);
403        assert!(calls.iter().any(|c| c.function == "similar"));
404        assert!(calls.iter().any(|c| c.function == "query"));
405    }
406
407    #[test]
408    fn test_query_builder() {
409        let query = QueryBuilder::new()
410            .select(&["?entity", "?score"])
411            .graphrag_similar("http://example.org/Battery", 0.8, "?entity")
412            .triple("?entity", "rdfs:label", "?label")
413            .limit(10)
414            .build();
415
416        assert!(query.contains("SELECT ?entity ?score"));
417        assert!(query.contains("graphrag:similar"));
418        assert!(query.contains("LIMIT 10"));
419    }
420}