ggen_cli_lib/cmds/graph/
query.rs

1use clap::Args;
2use ggen_utils::error::Result;
3use oxigraph::sparql::QueryResults as OxigraphQueryResults;
4use serde::Serialize;
5use std::collections::HashMap;
6
7#[derive(Args, Debug)]
8pub struct QueryArgs {
9    /// SPARQL query to execute
10    pub query: String,
11
12    /// Output format (json, csv, table)
13    #[arg(long, default_value = "table")]
14    pub format: String,
15
16    /// RDF graph file to query
17    #[arg(long)]
18    pub graph: Option<String>,
19}
20
21#[cfg_attr(test, mockall::automock)]
22pub trait SparqlExecutor {
23    fn execute(&self, query: String, graph: Option<String>) -> Result<QueryResults>;
24}
25
26#[derive(Debug, Clone, Serialize)]
27pub struct QueryResults {
28    pub bindings: Vec<HashMap<String, String>>,
29    pub variables: Vec<String>,
30}
31
32/// Validate and sanitize SPARQL query input
33fn validate_sparql_query(query: &str) -> Result<()> {
34    // Validate query is not empty
35    if query.trim().is_empty() {
36        return Err(ggen_utils::error::Error::new(
37            "SPARQL query cannot be empty",
38        ));
39    }
40
41    // Validate query length
42    if query.len() > 10000 {
43        return Err(ggen_utils::error::Error::new(
44            "SPARQL query too long (max 10000 characters)",
45        ));
46    }
47
48    // Basic SPARQL syntax validation
49    let query_upper = query.to_uppercase();
50    if !query_upper.contains("SELECT")
51        && !query_upper.contains("ASK")
52        && !query_upper.contains("CONSTRUCT")
53        && !query_upper.contains("DESCRIBE")
54    {
55        return Err(ggen_utils::error::Error::new(
56            "Invalid SPARQL query: must contain SELECT, ASK, CONSTRUCT, or DESCRIBE",
57        ));
58    }
59
60    // Check for potentially dangerous patterns
61    if query.contains("DROP")
62        || query.contains("DELETE")
63        || query.contains("INSERT")
64        || query.contains("CREATE")
65    {
66        return Err(ggen_utils::error::Error::new(
67            "Write operations not allowed: only SELECT, ASK, CONSTRUCT, and DESCRIBE queries are permitted",
68        ));
69    }
70
71    Ok(())
72}
73
74/// Validate and sanitize output format input
75fn validate_output_format(format: &str) -> Result<()> {
76    // Validate format is not empty
77    if format.trim().is_empty() {
78        return Err(ggen_utils::error::Error::new(
79            "Output format cannot be empty",
80        ));
81    }
82
83    // Validate format length
84    if format.len() > 20 {
85        return Err(ggen_utils::error::Error::new(
86            "Output format too long (max 20 characters)",
87        ));
88    }
89
90    // Validate against known formats
91    let valid_formats = ["json", "csv", "table"];
92    if !valid_formats.contains(&format.to_lowercase().as_str()) {
93        return Err(ggen_utils::error::Error::new(
94            "Unsupported output format: supported formats are json, csv, table",
95        ));
96    }
97
98    Ok(())
99}
100
101/// Validate and sanitize graph file path input (if provided)
102fn validate_graph_path(graph: &Option<String>) -> Result<()> {
103    if let Some(graph) = graph {
104        // Validate graph path is not empty
105        if graph.trim().is_empty() {
106            return Err(ggen_utils::error::Error::new(
107                "Graph file path cannot be empty",
108            ));
109        }
110
111        // Validate graph path length
112        if graph.len() > 1000 {
113            return Err(ggen_utils::error::Error::new(
114                "Graph file path too long (max 1000 characters)",
115            ));
116        }
117
118        // Basic path traversal protection
119        if graph.contains("..") {
120            return Err(ggen_utils::error::Error::new(
121                "Path traversal detected: graph file path cannot contain '..'",
122            ));
123        }
124
125        // Validate graph path format (basic pattern check)
126        if !graph.chars().all(|c| {
127            c.is_alphanumeric() || c == '.' || c == '/' || c == '-' || c == '_' || c == '\\'
128        }) {
129            return Err(ggen_utils::error::Error::new(
130                "Invalid graph file path format: only alphanumeric characters, dots, slashes, dashes, underscores, and backslashes allowed",
131            ));
132        }
133    }
134
135    Ok(())
136}
137
138/// Convert SPARQL results to our QueryResults format
139fn convert_sparql_results(results: OxigraphQueryResults) -> Result<QueryResults> {
140    match results {
141        OxigraphQueryResults::Solutions(solutions) => {
142            let variables: Vec<String> = solutions
143                .variables()
144                .iter()
145                .map(|v| v.to_string())
146                .collect();
147            let mut bindings = Vec::new();
148
149            for solution in solutions {
150                let solution = solution.map_err(|e| {
151                    ggen_utils::error::Error::new(&format!("SPARQL solution error: {}", e))
152                })?;
153                let mut binding = HashMap::new();
154                for variable in &variables {
155                    if let Some(value) = solution.get(variable.as_str()) {
156                        binding.insert(variable.clone(), value.to_string());
157                    }
158                }
159                bindings.push(binding);
160            }
161
162            Ok(QueryResults {
163                bindings,
164                variables,
165            })
166        }
167        OxigraphQueryResults::Boolean(_) => {
168            // For ASK queries, return a simple boolean result
169            Ok(QueryResults {
170                bindings: vec![HashMap::new()],
171                variables: vec!["result".to_string()],
172            })
173        }
174        OxigraphQueryResults::Graph(_) => {
175            // For CONSTRUCT/DESCRIBE queries, return basic info
176            Ok(QueryResults {
177                bindings: vec![HashMap::new()],
178                variables: vec!["triple".to_string()],
179            })
180        }
181    }
182}
183
184pub async fn run(args: &QueryArgs) -> Result<()> {
185    // Validate inputs
186    validate_sparql_query(&args.query)?;
187    validate_output_format(&args.format)?;
188    validate_graph_path(&args.graph)?;
189
190    println!("šŸ” Executing SPARQL query...");
191
192    // Load graph if provided
193    let graph = if let Some(graph_path) = &args.graph {
194        println!("šŸ“Š Loading graph from: {}", graph_path);
195        ggen_core::Graph::load_from_file(graph_path)
196            .map_err(|e| ggen_utils::error::Error::new(&format!("Failed to load graph: {}", e)))?
197    } else {
198        println!("šŸ“Š Using empty graph for query");
199        ggen_core::Graph::new().map_err(|e| {
200            ggen_utils::error::Error::new(&format!("Failed to create empty graph: {}", e))
201        })?
202    };
203
204    // Execute SPARQL query
205    let results = graph
206        .query(&args.query)
207        .map_err(|e| ggen_utils::error::Error::new(&format!("SPARQL query failed: {}", e)))?;
208
209    // Convert results to our format
210    let query_results = convert_sparql_results(results)?;
211
212    // Output results in requested format
213    match args.format.as_str() {
214        "json" => {
215            let json = serde_json::to_string_pretty(&query_results.bindings).map_err(|e| {
216                ggen_utils::error::Error::new(&format!("JSON serialization failed: {}", e))
217            })?;
218            println!("{}", json);
219        }
220        "csv" => {
221            // Print CSV header
222            println!("{}", query_results.variables.join(","));
223            // Print rows
224            for binding in &query_results.bindings {
225                let row: Vec<String> = query_results
226                    .variables
227                    .iter()
228                    .map(|var| binding.get(var).cloned().unwrap_or_default())
229                    .collect();
230                println!("{}", row.join(","));
231            }
232        }
233        "table" => {
234            // Print table header
235            println!("{}", query_results.variables.join(" | "));
236            println!("{}", "-".repeat(query_results.variables.len() * 20));
237            // Print rows
238            for binding in &query_results.bindings {
239                let row: Vec<String> = query_results
240                    .variables
241                    .iter()
242                    .map(|var| binding.get(var).cloned().unwrap_or_default())
243                    .collect();
244                println!("{}", row.join(" | "));
245            }
246        }
247        _ => {
248            return Err(ggen_utils::error::Error::new(&format!(
249                "Unsupported output format: {}. Supported formats: json, csv, table",
250                args.format
251            )));
252        }
253    }
254
255    println!("\nšŸ“Š {} results", query_results.bindings.len());
256    Ok(())
257}
258
259pub async fn run_with_deps(args: &QueryArgs, executor: &dyn SparqlExecutor) -> Result<()> {
260    // Validate inputs
261    validate_sparql_query(&args.query)?;
262    validate_output_format(&args.format)?;
263    validate_graph_path(&args.graph)?;
264
265    // Show progress for query execution
266    println!("šŸ” Executing SPARQL query...");
267
268    let results = executor.execute(args.query.clone(), args.graph.clone())?;
269
270    // Show progress for large result sets
271    if results.bindings.len() > 100 {
272        println!("šŸ“Š Processing {} results...", results.bindings.len());
273    }
274
275    match args.format.as_str() {
276        "json" => {
277            let json = serde_json::to_string_pretty(&results.bindings)
278                .map_err(ggen_utils::error::Error::from)?;
279            println!("{}", json);
280        }
281        "csv" => {
282            // Print CSV header
283            println!("{}", results.variables.join(","));
284            // Print rows
285            for binding in &results.bindings {
286                let row: Vec<String> = results
287                    .variables
288                    .iter()
289                    .map(|var| binding.get(var).cloned().unwrap_or_default())
290                    .collect();
291                println!("{}", row.join(","));
292            }
293        }
294        "table" => {
295            // Print table header
296            println!("{}", results.variables.join(" | "));
297            println!("{}", "-".repeat(results.variables.len() * 20));
298            // Print rows
299            for binding in &results.bindings {
300                let row: Vec<String> = results
301                    .variables
302                    .iter()
303                    .map(|var| binding.get(var).cloned().unwrap_or_default())
304                    .collect();
305                println!("{}", row.join(" | "));
306            }
307        }
308        _ => {
309            return Err(ggen_utils::error::Error::new(&format!(
310                "Unsupported output format: {}. Supported formats: json, csv, table",
311                args.format
312            )));
313        }
314    }
315
316    println!("\nšŸ“Š {} results", results.bindings.len());
317    Ok(())
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use mockall::predicate::*;
324
325    #[tokio::test]
326    async fn test_query_executes_sparql() {
327        let mut mock_executor = MockSparqlExecutor::new();
328        mock_executor
329            .expect_execute()
330            .with(
331                eq(String::from("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")),
332                eq(Some(String::from("data.ttl"))),
333            )
334            .times(1)
335            .returning(|_, _| {
336                let mut binding = HashMap::new();
337                binding.insert("s".to_string(), "ex:Subject".to_string());
338                binding.insert("p".to_string(), "ex:predicate".to_string());
339                binding.insert("o".to_string(), "ex:Object".to_string());
340
341                Ok(QueryResults {
342                    bindings: vec![binding],
343                    variables: vec!["s".to_string(), "p".to_string(), "o".to_string()],
344                })
345            });
346
347        let args = QueryArgs {
348            query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o }".to_string(),
349            format: "table".to_string(),
350            graph: Some("data.ttl".to_string()),
351        };
352
353        let result = run_with_deps(&args, &mock_executor).await;
354        assert!(result.is_ok());
355    }
356
357    #[tokio::test]
358    async fn test_query_json_format() {
359        let mut mock_executor = MockSparqlExecutor::new();
360        mock_executor.expect_execute().times(1).returning(|_, _| {
361            Ok(QueryResults {
362                bindings: vec![],
363                variables: vec![],
364            })
365        });
366
367        let args = QueryArgs {
368            query: "SELECT * WHERE { ?s ?p ?o }".to_string(),
369            format: "json".to_string(),
370            graph: None,
371        };
372
373        let result = run_with_deps(&args, &mock_executor).await;
374        assert!(result.is_ok());
375    }
376
377    #[tokio::test]
378    async fn test_query_csv_format() {
379        let mut mock_executor = MockSparqlExecutor::new();
380        mock_executor.expect_execute().times(1).returning(|_, _| {
381            let mut binding = HashMap::new();
382            binding.insert("name".to_string(), "Alice".to_string());
383            binding.insert("age".to_string(), "30".to_string());
384
385            Ok(QueryResults {
386                bindings: vec![binding],
387                variables: vec!["name".to_string(), "age".to_string()],
388            })
389        });
390
391        let args = QueryArgs {
392            query: "SELECT ?name ?age WHERE { ?person :name ?name ; :age ?age }".to_string(),
393            format: "csv".to_string(),
394            graph: None,
395        };
396
397        let result = run_with_deps(&args, &mock_executor).await;
398        assert!(result.is_ok());
399    }
400}