Skip to main content

graphmind/nlq/
mod.rs

1//! Natural Language Querying (NLQ)
2//!
3//! Implements Text-to-Cypher translation using LLMs.
4
5pub mod client;
6
7use crate::persistence::tenant::NLQConfig;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum NLQError {
12    #[error("LLM API error: {0}")]
13    ApiError(String),
14    #[error("Configuration error: {0}")]
15    ConfigError(String),
16    #[error("Network error: {0}")]
17    NetworkError(String),
18    #[error("Serialization error: {0}")]
19    SerializationError(String),
20    #[error("Validation error: {0}")]
21    ValidationError(String),
22}
23
24pub type NLQResult<T> = Result<T, NLQError>;
25
26pub struct NLQPipeline {
27    client: client::NLQClient,
28}
29
30impl NLQPipeline {
31    pub fn new(config: NLQConfig) -> NLQResult<Self> {
32        let client = client::NLQClient::new(&config)?;
33        Ok(Self { client })
34    }
35
36    pub async fn text_to_cypher(&self, question: &str, schema_summary: &str) -> NLQResult<String> {
37        let prompt = format!(
38            "You are a Cypher query expert for a graph database. Given this schema:\n\n{}\n\n\
39            Rules:\n\
40            - Follow the Relationship Patterns EXACTLY — do not invent edges between labels that aren't listed\n\
41            - When a question involves two unrelated labels (e.g. Country + DiseaseCategory), join them through a shared node (e.g. Trial)\n\
42            - Use property names from the Key Properties section\n\
43            - Use count(x) not COUNT(DISTINCT x) — DISTINCT inside aggregation is not supported\n\
44            - Return ONLY the Cypher query, no markdown, no explanations\n\n\
45            Question: \"{}\"",
46            schema_summary,
47            question
48        );
49
50        let cypher = self.client.generate_cypher(&prompt).await?;
51
52        // Extract Cypher from LLM response — handle markdown fences and explanations
53        let cleaned_cypher = Self::extract_cypher(&cypher);
54
55        if self.is_safe_query(&cleaned_cypher) {
56            Ok(cleaned_cypher)
57        } else {
58            Err(NLQError::ValidationError(
59                "Generated query contains write operations or unsafe keywords".to_string(),
60            ))
61        }
62    }
63
64    /// Extract a Cypher query from an LLM response that may contain markdown
65    /// fences, explanations, or multiple code blocks.
66    fn extract_cypher(response: &str) -> String {
67        let trimmed = response.trim();
68
69        // If response contains a fenced code block, extract the first one
70        if let Some(start) = trimmed.find("```") {
71            let after_fence = &trimmed[start + 3..];
72            // Skip language tag (e.g. "cypher\n")
73            let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
74            if let Some(end) = after_fence[code_start..].find("```") {
75                return after_fence[code_start..code_start + end].trim().to_string();
76            }
77        }
78
79        // No fences — take lines that look like Cypher (start with MATCH/RETURN/WITH/etc.)
80        let cypher_keywords = ["MATCH", "RETURN", "WITH", "UNWIND", "CALL", "OPTIONAL"];
81        let lines: Vec<&str> = trimmed
82            .lines()
83            .filter(|line| {
84                let upper = line.trim().to_uppercase();
85                cypher_keywords.iter().any(|kw| upper.starts_with(kw))
86                    || upper.starts_with("WHERE")
87                    || upper.starts_with("ORDER")
88                    || upper.starts_with("LIMIT")
89            })
90            .collect();
91
92        if !lines.is_empty() {
93            return lines.join(" ");
94        }
95
96        // Fallback: strip outer fences and return as-is
97        trimmed
98            .trim_start_matches("```cypher")
99            .trim_start_matches("```")
100            .trim_end_matches("```")
101            .trim()
102            .to_string()
103    }
104
105    pub fn is_safe_query(&self, query: &str) -> bool {
106        let trimmed = query.trim().to_uppercase();
107        trimmed.starts_with("MATCH")
108            || trimmed.starts_with("RETURN")
109            || trimmed.starts_with("UNWIND")
110            || trimmed.starts_with("CALL")
111            || trimmed.starts_with("WITH")
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::persistence::tenant::{LLMProvider, NLQConfig};
119
120    fn make_pipeline() -> NLQPipeline {
121        NLQPipeline::new(NLQConfig {
122            enabled: true,
123            provider: LLMProvider::Mock,
124            model: "mock".to_string(),
125            api_key: None,
126            api_base_url: None,
127            system_prompt: None,
128        })
129        .unwrap()
130    }
131
132    // --- is_safe_query tests (via pipeline) ---
133
134    #[test]
135    fn test_safe_read_queries() {
136        let pipeline = make_pipeline();
137        assert!(pipeline.is_safe_query("MATCH (n:Person) RETURN n.name"));
138        assert!(pipeline.is_safe_query("MATCH (a)-[:KNOWS]->(b) RETURN a, b"));
139        assert!(pipeline.is_safe_query("MATCH (n) WHERE n.age > 30 RETURN count(n)"));
140        assert!(pipeline.is_safe_query("RETURN 1"));
141        assert!(pipeline.is_safe_query("UNWIND [1,2,3] AS x RETURN x"));
142        assert!(pipeline.is_safe_query("WITH 1 AS x RETURN x"));
143        assert!(pipeline.is_safe_query("CALL db.labels()"));
144        // Regression: property value containing write keyword must be safe
145        assert!(pipeline.is_safe_query("MATCH (n:Person) WHERE n.name = 'SET' RETURN n"));
146        assert!(pipeline.is_safe_query("MATCH (n) WHERE n.status = 'CREATED' RETURN n"));
147        assert!(pipeline.is_safe_query("match (n) return n")); // lowercase
148    }
149
150    #[test]
151    fn test_unsafe_write_queries() {
152        let pipeline = make_pipeline();
153        assert!(!pipeline.is_safe_query("CREATE (n:Person {name: 'Alice'})"));
154        assert!(!pipeline.is_safe_query("DELETE n"));
155        assert!(!pipeline.is_safe_query("SET n.name = 'Bob'"));
156        assert!(!pipeline.is_safe_query("MERGE (n:Person {name: 'Alice'})"));
157        assert!(!pipeline.is_safe_query("DROP INDEX my_index"));
158        assert!(!pipeline.is_safe_query("REMOVE n.age"));
159    }
160
161    // --- extract_cypher tests ---
162
163    #[test]
164    fn test_extract_cypher_plain_query() {
165        let input = "MATCH (n:Person) RETURN n.name";
166        let result = NLQPipeline::extract_cypher(input);
167        assert_eq!(result, "MATCH (n:Person) RETURN n.name");
168    }
169
170    #[test]
171    fn test_extract_cypher_markdown_fenced() {
172        let input =
173            "Here is the query:\n```cypher\nMATCH (n:Person) RETURN n.name\n```\nHope this helps!";
174        let result = NLQPipeline::extract_cypher(input);
175        assert_eq!(result, "MATCH (n:Person) RETURN n.name");
176    }
177
178    #[test]
179    fn test_extract_cypher_markdown_no_language_tag() {
180        let input = "```\nMATCH (n) RETURN n\n```";
181        let result = NLQPipeline::extract_cypher(input);
182        assert_eq!(result, "MATCH (n) RETURN n");
183    }
184
185    #[test]
186    fn test_extract_cypher_mixed_with_explanation() {
187        let input = "To find all people, use this:\nMATCH (n:Person)\nWHERE n.age > 30\nRETURN n.name\nThis returns names of people over 30.";
188        let result = NLQPipeline::extract_cypher(input);
189        assert!(result.contains("MATCH (n:Person)"));
190        assert!(result.contains("WHERE n.age > 30"));
191        assert!(result.contains("RETURN n.name"));
192        assert!(!result.contains("To find all people"));
193    }
194
195    #[test]
196    fn test_extract_cypher_with_optional_match() {
197        let input = "OPTIONAL MATCH (n:Person)-[:KNOWS]->(m)\nRETURN n, m";
198        let result = NLQPipeline::extract_cypher(input);
199        assert!(result.contains("OPTIONAL MATCH"));
200        assert!(result.contains("RETURN"));
201    }
202
203    #[test]
204    fn test_extract_cypher_with_order_and_limit() {
205        let input = "MATCH (n:Person)\nRETURN n.name\nORDER BY n.name\nLIMIT 10";
206        let result = NLQPipeline::extract_cypher(input);
207        assert!(result.contains("MATCH"));
208        assert!(result.contains("ORDER BY"));
209        assert!(result.contains("LIMIT 10"));
210    }
211
212    #[test]
213    fn test_extract_cypher_whitespace_trimming() {
214        let input = "  \n  MATCH (n) RETURN n  \n  ";
215        let result = NLQPipeline::extract_cypher(input);
216        assert_eq!(result, "MATCH (n) RETURN n");
217    }
218
219    // ========== Coverage batch: additional NLQ pipeline tests ==========
220
221    #[test]
222    fn test_pipeline_creation_with_mock() {
223        let pipeline = make_pipeline();
224        // Just verify it was created successfully (constructor exercises NLQClient::new)
225        assert!(pipeline.is_safe_query("MATCH (n) RETURN n"));
226    }
227
228    #[test]
229    fn test_is_safe_query_call_prefix() {
230        let pipeline = make_pipeline();
231        assert!(pipeline.is_safe_query("CALL algo.pageRank({}) YIELD node"));
232        assert!(pipeline.is_safe_query("CALL db.labels()"));
233    }
234
235    #[test]
236    fn test_is_safe_query_with_prefix() {
237        let pipeline = make_pipeline();
238        assert!(pipeline.is_safe_query("WITH 1 AS x MATCH (n) RETURN n"));
239    }
240
241    #[test]
242    fn test_is_safe_query_return_only() {
243        let pipeline = make_pipeline();
244        assert!(pipeline.is_safe_query("RETURN 42"));
245        assert!(pipeline.is_safe_query("RETURN datetime()"));
246    }
247
248    #[test]
249    fn test_is_safe_query_rejects_create() {
250        let pipeline = make_pipeline();
251        assert!(!pipeline.is_safe_query("CREATE (:Person {name: 'Eve'})"));
252    }
253
254    #[test]
255    fn test_is_safe_query_rejects_drop() {
256        let pipeline = make_pipeline();
257        assert!(!pipeline.is_safe_query("DROP INDEX myIdx"));
258    }
259
260    #[test]
261    fn test_is_safe_query_rejects_set_at_start() {
262        let pipeline = make_pipeline();
263        assert!(!pipeline.is_safe_query("SET n.name = 'test'"));
264    }
265
266    #[test]
267    fn test_is_safe_query_rejects_remove_at_start() {
268        let pipeline = make_pipeline();
269        assert!(!pipeline.is_safe_query("REMOVE n.age"));
270    }
271
272    #[test]
273    fn test_is_safe_query_whitespace_handling() {
274        let pipeline = make_pipeline();
275        assert!(pipeline.is_safe_query("  MATCH (n) RETURN n  "));
276        assert!(pipeline.is_safe_query("  RETURN 1  "));
277    }
278
279    #[test]
280    fn test_is_safe_query_empty_string() {
281        let pipeline = make_pipeline();
282        assert!(!pipeline.is_safe_query(""));
283    }
284
285    #[test]
286    fn test_extract_cypher_multiple_fenced_blocks() {
287        // extract_cypher should return the first fenced code block
288        let input = "First block:\n```cypher\nMATCH (a) RETURN a\n```\nSecond:\n```cypher\nMATCH (b) RETURN b\n```";
289        let result = NLQPipeline::extract_cypher(input);
290        assert_eq!(result, "MATCH (a) RETURN a");
291    }
292
293    #[test]
294    fn test_extract_cypher_fenced_without_closing() {
295        // If there's no closing ```, the fallback should handle it
296        let input = "Here:\n```cypher\nMATCH (n) RETURN n";
297        let result = NLQPipeline::extract_cypher(input);
298        // With no closing ```, fallback to line-based extraction
299        assert!(result.contains("MATCH"));
300        assert!(result.contains("RETURN"));
301    }
302
303    #[test]
304    fn test_extract_cypher_only_non_cypher_text() {
305        // No cypher keywords at line starts => fallback to trimmed input
306        let input = "I think you should look at the data.";
307        let result = NLQPipeline::extract_cypher(input);
308        assert_eq!(result, "I think you should look at the data.");
309    }
310
311    #[test]
312    fn test_extract_cypher_unwind_at_start() {
313        let input = "UNWIND range(1, 10) AS i\nRETURN i";
314        let result = NLQPipeline::extract_cypher(input);
315        assert!(result.contains("UNWIND"));
316        assert!(result.contains("RETURN"));
317    }
318
319    #[test]
320    fn test_extract_cypher_call_at_start() {
321        let input = "CALL db.labels()";
322        let result = NLQPipeline::extract_cypher(input);
323        assert!(result.contains("CALL"));
324    }
325
326    #[test]
327    fn test_extract_cypher_with_clause_lines() {
328        let input = "MATCH (n:Person)\nWITH n.city AS city\nRETURN city";
329        let result = NLQPipeline::extract_cypher(input);
330        assert!(result.contains("MATCH"));
331        assert!(result.contains("WITH"));
332        assert!(result.contains("RETURN"));
333    }
334
335    #[tokio::test]
336    async fn test_text_to_cypher_with_mock() {
337        let pipeline = make_pipeline();
338        let schema = "Labels: Person, Company\nRelationships: WORKS_AT";
339        let result = pipeline.text_to_cypher("Find all people", schema).await;
340        // Mock returns "MATCH (n) RETURN n LIMIT 10", which starts with MATCH => safe
341        assert!(result.is_ok());
342        let cypher = result.unwrap();
343        assert!(cypher.contains("MATCH"));
344    }
345
346    #[tokio::test]
347    async fn test_text_to_cypher_validates_safety() {
348        // The mock always returns "MATCH (n) RETURN n LIMIT 10" which is safe
349        // We can only test that the pipeline works end-to-end with mock
350        let pipeline = make_pipeline();
351        let result = pipeline.text_to_cypher("test question", "schema").await;
352        assert!(result.is_ok());
353    }
354
355    #[test]
356    fn test_extract_cypher_plain_fence_no_lang_tag() {
357        let input = "```\nRETURN 42\n```";
358        let result = NLQPipeline::extract_cypher(input);
359        assert_eq!(result, "RETURN 42");
360    }
361
362    #[test]
363    fn test_extract_cypher_mixed_case_keywords() {
364        let input = "match (n:Person)\nwhere n.age > 30\nreturn n.name";
365        let result = NLQPipeline::extract_cypher(input);
366        // Keywords are checked case-insensitively (via to_uppercase)
367        assert!(result.contains("match") || result.contains("MATCH"));
368    }
369
370    #[test]
371    fn test_nlq_pipeline_new_with_different_providers() {
372        // Test with OpenAI provider
373        let config = NLQConfig {
374            enabled: true,
375            provider: LLMProvider::OpenAI,
376            model: "gpt-4".to_string(),
377            api_key: Some("sk-test".to_string()),
378            api_base_url: None,
379            system_prompt: None,
380        };
381        let pipeline = NLQPipeline::new(config);
382        assert!(pipeline.is_ok());
383    }
384
385    #[test]
386    fn test_is_safe_query_unwind_prefix() {
387        let pipeline = make_pipeline();
388        assert!(pipeline.is_safe_query("UNWIND [1,2,3] AS x RETURN x"));
389    }
390}