1use crate::{GraphRAGResult, Triple};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
15pub struct GraphRAGFunctions {
16 functions: HashMap<String, FunctionDef>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionDef {
23 pub name: String,
25 pub uri: String,
27 pub params: Vec<ParamDef>,
29 pub return_type: ReturnType,
31 pub description: String,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ParamDef {
38 pub name: String,
39 pub param_type: ParamType,
40 pub required: bool,
41}
42
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum ParamType {
45 String,
46 Integer,
47 Float,
48 Uri,
49 Boolean,
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53pub enum ReturnType {
54 Binding,
55 Triple,
56 Graph,
57 Scalar,
58}
59
60impl Default for GraphRAGFunctions {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl GraphRAGFunctions {
67 pub fn new() -> Self {
69 let mut functions = HashMap::new();
70
71 functions.insert(
73 "query".to_string(),
74 FunctionDef {
75 name: "query".to_string(),
76 uri: "http://oxirs.io/graphrag#query".to_string(),
77 params: vec![
78 ParamDef {
79 name: "text".to_string(),
80 param_type: ParamType::String,
81 required: true,
82 },
83 ParamDef {
84 name: "top_k".to_string(),
85 param_type: ParamType::Integer,
86 required: false,
87 },
88 ],
89 return_type: ReturnType::Graph,
90 description: "Execute GraphRAG query and return relevant subgraph".to_string(),
91 },
92 );
93
94 functions.insert(
96 "similar".to_string(),
97 FunctionDef {
98 name: "similar".to_string(),
99 uri: "http://oxirs.io/graphrag#similar".to_string(),
100 params: vec![
101 ParamDef {
102 name: "entity".to_string(),
103 param_type: ParamType::Uri,
104 required: true,
105 },
106 ParamDef {
107 name: "threshold".to_string(),
108 param_type: ParamType::Float,
109 required: false,
110 },
111 ParamDef {
112 name: "k".to_string(),
113 param_type: ParamType::Integer,
114 required: false,
115 },
116 ],
117 return_type: ReturnType::Binding,
118 description: "Find entities similar to the given entity".to_string(),
119 },
120 );
121
122 functions.insert(
124 "expand".to_string(),
125 FunctionDef {
126 name: "expand".to_string(),
127 uri: "http://oxirs.io/graphrag#expand".to_string(),
128 params: vec![
129 ParamDef {
130 name: "entity".to_string(),
131 param_type: ParamType::Uri,
132 required: true,
133 },
134 ParamDef {
135 name: "hops".to_string(),
136 param_type: ParamType::Integer,
137 required: false,
138 },
139 ParamDef {
140 name: "max_triples".to_string(),
141 param_type: ParamType::Integer,
142 required: false,
143 },
144 ],
145 return_type: ReturnType::Graph,
146 description: "Expand subgraph from entity".to_string(),
147 },
148 );
149
150 functions.insert(
152 "community".to_string(),
153 FunctionDef {
154 name: "community".to_string(),
155 uri: "http://oxirs.io/graphrag#community".to_string(),
156 params: vec![
157 ParamDef {
158 name: "graph".to_string(),
159 param_type: ParamType::Uri,
160 required: true,
161 },
162 ParamDef {
163 name: "algorithm".to_string(),
164 param_type: ParamType::String,
165 required: false,
166 },
167 ],
168 return_type: ReturnType::Binding,
169 description: "Detect communities in graph".to_string(),
170 },
171 );
172
173 functions.insert(
175 "embed".to_string(),
176 FunctionDef {
177 name: "embed".to_string(),
178 uri: "http://oxirs.io/graphrag#embed".to_string(),
179 params: vec![ParamDef {
180 name: "entity".to_string(),
181 param_type: ParamType::Uri,
182 required: true,
183 }],
184 return_type: ReturnType::Scalar,
185 description: "Get embedding vector for entity".to_string(),
186 },
187 );
188
189 Self { functions }
190 }
191
192 pub fn get(&self, name: &str) -> Option<&FunctionDef> {
194 self.functions.get(name)
195 }
196
197 pub fn all(&self) -> impl Iterator<Item = &FunctionDef> {
199 self.functions.values()
200 }
201
202 pub fn generate_service_clause(&self, function: &str, args: &[&str]) -> GraphRAGResult<String> {
204 let func_def = self.get(function).ok_or_else(|| {
205 crate::GraphRAGError::SparqlError(format!("Unknown function: {}", function))
206 })?;
207
208 let args_str = args.join(", ");
209 Ok(format!(
210 "SERVICE <{}> {{ ?result graphrag:{}({}) }}",
211 func_def.uri, function, args_str
212 ))
213 }
214
215 pub fn parse_query(&self, sparql: &str) -> Vec<FunctionCall> {
217 let mut calls = Vec::new();
218
219 let re = regex::Regex::new(r"graphrag:(\w+)\(([^)]*)\)")
221 .expect("GraphRAG function regex pattern is valid");
222
223 for cap in re.captures_iter(sparql) {
224 if let (Some(func), Some(args)) = (cap.get(1), cap.get(2)) {
225 let func_name = func.as_str().to_string();
226 let args: Vec<String> = args
227 .as_str()
228 .split(',')
229 .map(|s| s.trim().to_string())
230 .filter(|s| !s.is_empty())
231 .collect();
232
233 if self.functions.contains_key(&func_name) {
234 calls.push(FunctionCall {
235 function: func_name,
236 arguments: args,
237 });
238 }
239 }
240 }
241
242 calls
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct FunctionCall {
249 pub function: String,
250 pub arguments: Vec<String>,
251}
252
253pub struct QueryBuilder {
255 prefixes: Vec<(String, String)>,
256 select_vars: Vec<String>,
257 where_patterns: Vec<String>,
258 graphrag_calls: Vec<String>,
259 limit: Option<usize>,
260 offset: Option<usize>,
261}
262
263impl Default for QueryBuilder {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl QueryBuilder {
270 pub fn new() -> Self {
271 Self {
272 prefixes: vec![
273 (
274 "graphrag".to_string(),
275 "http://oxirs.io/graphrag#".to_string(),
276 ),
277 (
278 "rdfs".to_string(),
279 "http://www.w3.org/2000/01/rdf-schema#".to_string(),
280 ),
281 ],
282 select_vars: Vec::new(),
283 where_patterns: Vec::new(),
284 graphrag_calls: Vec::new(),
285 limit: None,
286 offset: None,
287 }
288 }
289
290 pub fn prefix(mut self, prefix: &str, uri: &str) -> Self {
291 self.prefixes.push((prefix.to_string(), uri.to_string()));
292 self
293 }
294
295 pub fn select(mut self, vars: &[&str]) -> Self {
296 self.select_vars = vars.iter().map(|s| s.to_string()).collect();
297 self
298 }
299
300 pub fn triple(mut self, subject: &str, predicate: &str, object: &str) -> Self {
301 self.where_patterns
302 .push(format!("{} {} {}", subject, predicate, object));
303 self
304 }
305
306 pub fn graphrag_query(mut self, text: &str, result_var: &str) -> Self {
307 self.graphrag_calls.push(format!(
308 "BIND(graphrag:query(\"{}\") AS {})",
309 text, result_var
310 ));
311 self
312 }
313
314 pub fn graphrag_similar(mut self, entity: &str, threshold: f32, result_var: &str) -> Self {
315 self.graphrag_calls.push(format!(
316 "{} graphrag:similar(\"{}\", {})",
317 result_var, entity, threshold
318 ));
319 self
320 }
321
322 pub fn limit(mut self, limit: usize) -> Self {
323 self.limit = Some(limit);
324 self
325 }
326
327 pub fn offset(mut self, offset: usize) -> Self {
328 self.offset = Some(offset);
329 self
330 }
331
332 pub fn build(self) -> String {
333 let mut query = String::new();
334
335 for (prefix, uri) in &self.prefixes {
337 query.push_str(&format!("PREFIX {}: <{}>\n", prefix, uri));
338 }
339 query.push('\n');
340
341 if self.select_vars.is_empty() {
343 query.push_str("SELECT * ");
344 } else {
345 query.push_str("SELECT ");
346 query.push_str(&self.select_vars.join(" "));
347 query.push(' ');
348 }
349
350 query.push_str("WHERE {\n");
352
353 for pattern in &self.where_patterns {
354 query.push_str(&format!(" {} .\n", pattern));
355 }
356
357 for call in &self.graphrag_calls {
358 query.push_str(&format!(" {} .\n", call));
359 }
360
361 query.push_str("}\n");
362
363 if let Some(limit) = self.limit {
365 query.push_str(&format!("LIMIT {}\n", limit));
366 }
367 if let Some(offset) = self.offset {
368 query.push_str(&format!("OFFSET {}\n", offset));
369 }
370
371 query
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_function_registry() {
381 let funcs = GraphRAGFunctions::new();
382
383 assert!(funcs.get("query").is_some());
384 assert!(funcs.get("similar").is_some());
385 assert!(funcs.get("expand").is_some());
386 assert!(funcs.get("unknown").is_none());
387 }
388
389 #[test]
390 fn test_query_parsing() {
391 let funcs = GraphRAGFunctions::new();
392
393 let sparql = r#"
394 SELECT ?entity WHERE {
395 ?entity graphrag:similar("battery", 0.8) .
396 BIND(graphrag:query("safety issues") AS ?result)
397 }
398 "#;
399
400 let calls = funcs.parse_query(sparql);
401
402 assert_eq!(calls.len(), 2);
403 assert!(calls.iter().any(|c| c.function == "similar"));
404 assert!(calls.iter().any(|c| c.function == "query"));
405 }
406
407 #[test]
408 fn test_query_builder() {
409 let query = QueryBuilder::new()
410 .select(&["?entity", "?score"])
411 .graphrag_similar("http://example.org/Battery", 0.8, "?entity")
412 .triple("?entity", "rdfs:label", "?label")
413 .limit(10)
414 .build();
415
416 assert!(query.contains("SELECT ?entity ?score"));
417 assert!(query.contains("graphrag:similar"));
418 assert!(query.contains("LIMIT 10"));
419 }
420}