1pub mod client;
6
7use crate::persistence::tenant::NLQConfig;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum NLQError {
12 #[error("LLM API error: {0}")]
13 ApiError(String),
14 #[error("Configuration error: {0}")]
15 ConfigError(String),
16 #[error("Network error: {0}")]
17 NetworkError(String),
18 #[error("Serialization error: {0}")]
19 SerializationError(String),
20 #[error("Validation error: {0}")]
21 ValidationError(String),
22}
23
24pub type NLQResult<T> = Result<T, NLQError>;
25
26pub struct NLQPipeline {
27 client: client::NLQClient,
28}
29
30impl NLQPipeline {
31 pub fn new(config: NLQConfig) -> NLQResult<Self> {
32 let client = client::NLQClient::new(&config)?;
33 Ok(Self { client })
34 }
35
36 pub async fn text_to_cypher(&self, question: &str, schema_summary: &str) -> NLQResult<String> {
37 let prompt = format!(
38 "You are a Cypher query expert for a graph database. Given this schema:\n\n{}\n\n\
39 Rules:\n\
40 - Follow the Relationship Patterns EXACTLY — do not invent edges between labels that aren't listed\n\
41 - When a question involves two unrelated labels (e.g. Country + DiseaseCategory), join them through a shared node (e.g. Trial)\n\
42 - Use property names from the Key Properties section\n\
43 - Use count(x) not COUNT(DISTINCT x) — DISTINCT inside aggregation is not supported\n\
44 - Return ONLY the Cypher query, no markdown, no explanations\n\n\
45 Question: \"{}\"",
46 schema_summary,
47 question
48 );
49
50 let cypher = self.client.generate_cypher(&prompt).await?;
51
52 let cleaned_cypher = Self::extract_cypher(&cypher);
54
55 if self.is_safe_query(&cleaned_cypher) {
56 Ok(cleaned_cypher)
57 } else {
58 Err(NLQError::ValidationError(
59 "Generated query contains write operations or unsafe keywords".to_string(),
60 ))
61 }
62 }
63
64 fn extract_cypher(response: &str) -> String {
67 let trimmed = response.trim();
68
69 if let Some(start) = trimmed.find("```") {
71 let after_fence = &trimmed[start + 3..];
72 let code_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
74 if let Some(end) = after_fence[code_start..].find("```") {
75 return after_fence[code_start..code_start + end].trim().to_string();
76 }
77 }
78
79 let cypher_keywords = ["MATCH", "RETURN", "WITH", "UNWIND", "CALL", "OPTIONAL"];
81 let lines: Vec<&str> = trimmed
82 .lines()
83 .filter(|line| {
84 let upper = line.trim().to_uppercase();
85 cypher_keywords.iter().any(|kw| upper.starts_with(kw))
86 || upper.starts_with("WHERE")
87 || upper.starts_with("ORDER")
88 || upper.starts_with("LIMIT")
89 })
90 .collect();
91
92 if !lines.is_empty() {
93 return lines.join(" ");
94 }
95
96 trimmed
98 .trim_start_matches("```cypher")
99 .trim_start_matches("```")
100 .trim_end_matches("```")
101 .trim()
102 .to_string()
103 }
104
105 pub fn is_safe_query(&self, query: &str) -> bool {
106 let trimmed = query.trim().to_uppercase();
107 trimmed.starts_with("MATCH")
108 || trimmed.starts_with("RETURN")
109 || trimmed.starts_with("UNWIND")
110 || trimmed.starts_with("CALL")
111 || trimmed.starts_with("WITH")
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::persistence::tenant::{LLMProvider, NLQConfig};
119
120 fn make_pipeline() -> NLQPipeline {
121 NLQPipeline::new(NLQConfig {
122 enabled: true,
123 provider: LLMProvider::Mock,
124 model: "mock".to_string(),
125 api_key: None,
126 api_base_url: None,
127 system_prompt: None,
128 })
129 .unwrap()
130 }
131
132 #[test]
135 fn test_safe_read_queries() {
136 let pipeline = make_pipeline();
137 assert!(pipeline.is_safe_query("MATCH (n:Person) RETURN n.name"));
138 assert!(pipeline.is_safe_query("MATCH (a)-[:KNOWS]->(b) RETURN a, b"));
139 assert!(pipeline.is_safe_query("MATCH (n) WHERE n.age > 30 RETURN count(n)"));
140 assert!(pipeline.is_safe_query("RETURN 1"));
141 assert!(pipeline.is_safe_query("UNWIND [1,2,3] AS x RETURN x"));
142 assert!(pipeline.is_safe_query("WITH 1 AS x RETURN x"));
143 assert!(pipeline.is_safe_query("CALL db.labels()"));
144 assert!(pipeline.is_safe_query("MATCH (n:Person) WHERE n.name = 'SET' RETURN n"));
146 assert!(pipeline.is_safe_query("MATCH (n) WHERE n.status = 'CREATED' RETURN n"));
147 assert!(pipeline.is_safe_query("match (n) return n")); }
149
150 #[test]
151 fn test_unsafe_write_queries() {
152 let pipeline = make_pipeline();
153 assert!(!pipeline.is_safe_query("CREATE (n:Person {name: 'Alice'})"));
154 assert!(!pipeline.is_safe_query("DELETE n"));
155 assert!(!pipeline.is_safe_query("SET n.name = 'Bob'"));
156 assert!(!pipeline.is_safe_query("MERGE (n:Person {name: 'Alice'})"));
157 assert!(!pipeline.is_safe_query("DROP INDEX my_index"));
158 assert!(!pipeline.is_safe_query("REMOVE n.age"));
159 }
160
161 #[test]
164 fn test_extract_cypher_plain_query() {
165 let input = "MATCH (n:Person) RETURN n.name";
166 let result = NLQPipeline::extract_cypher(input);
167 assert_eq!(result, "MATCH (n:Person) RETURN n.name");
168 }
169
170 #[test]
171 fn test_extract_cypher_markdown_fenced() {
172 let input =
173 "Here is the query:\n```cypher\nMATCH (n:Person) RETURN n.name\n```\nHope this helps!";
174 let result = NLQPipeline::extract_cypher(input);
175 assert_eq!(result, "MATCH (n:Person) RETURN n.name");
176 }
177
178 #[test]
179 fn test_extract_cypher_markdown_no_language_tag() {
180 let input = "```\nMATCH (n) RETURN n\n```";
181 let result = NLQPipeline::extract_cypher(input);
182 assert_eq!(result, "MATCH (n) RETURN n");
183 }
184
185 #[test]
186 fn test_extract_cypher_mixed_with_explanation() {
187 let input = "To find all people, use this:\nMATCH (n:Person)\nWHERE n.age > 30\nRETURN n.name\nThis returns names of people over 30.";
188 let result = NLQPipeline::extract_cypher(input);
189 assert!(result.contains("MATCH (n:Person)"));
190 assert!(result.contains("WHERE n.age > 30"));
191 assert!(result.contains("RETURN n.name"));
192 assert!(!result.contains("To find all people"));
193 }
194
195 #[test]
196 fn test_extract_cypher_with_optional_match() {
197 let input = "OPTIONAL MATCH (n:Person)-[:KNOWS]->(m)\nRETURN n, m";
198 let result = NLQPipeline::extract_cypher(input);
199 assert!(result.contains("OPTIONAL MATCH"));
200 assert!(result.contains("RETURN"));
201 }
202
203 #[test]
204 fn test_extract_cypher_with_order_and_limit() {
205 let input = "MATCH (n:Person)\nRETURN n.name\nORDER BY n.name\nLIMIT 10";
206 let result = NLQPipeline::extract_cypher(input);
207 assert!(result.contains("MATCH"));
208 assert!(result.contains("ORDER BY"));
209 assert!(result.contains("LIMIT 10"));
210 }
211
212 #[test]
213 fn test_extract_cypher_whitespace_trimming() {
214 let input = " \n MATCH (n) RETURN n \n ";
215 let result = NLQPipeline::extract_cypher(input);
216 assert_eq!(result, "MATCH (n) RETURN n");
217 }
218
219 #[test]
222 fn test_pipeline_creation_with_mock() {
223 let pipeline = make_pipeline();
224 assert!(pipeline.is_safe_query("MATCH (n) RETURN n"));
226 }
227
228 #[test]
229 fn test_is_safe_query_call_prefix() {
230 let pipeline = make_pipeline();
231 assert!(pipeline.is_safe_query("CALL algo.pageRank({}) YIELD node"));
232 assert!(pipeline.is_safe_query("CALL db.labels()"));
233 }
234
235 #[test]
236 fn test_is_safe_query_with_prefix() {
237 let pipeline = make_pipeline();
238 assert!(pipeline.is_safe_query("WITH 1 AS x MATCH (n) RETURN n"));
239 }
240
241 #[test]
242 fn test_is_safe_query_return_only() {
243 let pipeline = make_pipeline();
244 assert!(pipeline.is_safe_query("RETURN 42"));
245 assert!(pipeline.is_safe_query("RETURN datetime()"));
246 }
247
248 #[test]
249 fn test_is_safe_query_rejects_create() {
250 let pipeline = make_pipeline();
251 assert!(!pipeline.is_safe_query("CREATE (:Person {name: 'Eve'})"));
252 }
253
254 #[test]
255 fn test_is_safe_query_rejects_drop() {
256 let pipeline = make_pipeline();
257 assert!(!pipeline.is_safe_query("DROP INDEX myIdx"));
258 }
259
260 #[test]
261 fn test_is_safe_query_rejects_set_at_start() {
262 let pipeline = make_pipeline();
263 assert!(!pipeline.is_safe_query("SET n.name = 'test'"));
264 }
265
266 #[test]
267 fn test_is_safe_query_rejects_remove_at_start() {
268 let pipeline = make_pipeline();
269 assert!(!pipeline.is_safe_query("REMOVE n.age"));
270 }
271
272 #[test]
273 fn test_is_safe_query_whitespace_handling() {
274 let pipeline = make_pipeline();
275 assert!(pipeline.is_safe_query(" MATCH (n) RETURN n "));
276 assert!(pipeline.is_safe_query(" RETURN 1 "));
277 }
278
279 #[test]
280 fn test_is_safe_query_empty_string() {
281 let pipeline = make_pipeline();
282 assert!(!pipeline.is_safe_query(""));
283 }
284
285 #[test]
286 fn test_extract_cypher_multiple_fenced_blocks() {
287 let input = "First block:\n```cypher\nMATCH (a) RETURN a\n```\nSecond:\n```cypher\nMATCH (b) RETURN b\n```";
289 let result = NLQPipeline::extract_cypher(input);
290 assert_eq!(result, "MATCH (a) RETURN a");
291 }
292
293 #[test]
294 fn test_extract_cypher_fenced_without_closing() {
295 let input = "Here:\n```cypher\nMATCH (n) RETURN n";
297 let result = NLQPipeline::extract_cypher(input);
298 assert!(result.contains("MATCH"));
300 assert!(result.contains("RETURN"));
301 }
302
303 #[test]
304 fn test_extract_cypher_only_non_cypher_text() {
305 let input = "I think you should look at the data.";
307 let result = NLQPipeline::extract_cypher(input);
308 assert_eq!(result, "I think you should look at the data.");
309 }
310
311 #[test]
312 fn test_extract_cypher_unwind_at_start() {
313 let input = "UNWIND range(1, 10) AS i\nRETURN i";
314 let result = NLQPipeline::extract_cypher(input);
315 assert!(result.contains("UNWIND"));
316 assert!(result.contains("RETURN"));
317 }
318
319 #[test]
320 fn test_extract_cypher_call_at_start() {
321 let input = "CALL db.labels()";
322 let result = NLQPipeline::extract_cypher(input);
323 assert!(result.contains("CALL"));
324 }
325
326 #[test]
327 fn test_extract_cypher_with_clause_lines() {
328 let input = "MATCH (n:Person)\nWITH n.city AS city\nRETURN city";
329 let result = NLQPipeline::extract_cypher(input);
330 assert!(result.contains("MATCH"));
331 assert!(result.contains("WITH"));
332 assert!(result.contains("RETURN"));
333 }
334
335 #[tokio::test]
336 async fn test_text_to_cypher_with_mock() {
337 let pipeline = make_pipeline();
338 let schema = "Labels: Person, Company\nRelationships: WORKS_AT";
339 let result = pipeline.text_to_cypher("Find all people", schema).await;
340 assert!(result.is_ok());
342 let cypher = result.unwrap();
343 assert!(cypher.contains("MATCH"));
344 }
345
346 #[tokio::test]
347 async fn test_text_to_cypher_validates_safety() {
348 let pipeline = make_pipeline();
351 let result = pipeline.text_to_cypher("test question", "schema").await;
352 assert!(result.is_ok());
353 }
354
355 #[test]
356 fn test_extract_cypher_plain_fence_no_lang_tag() {
357 let input = "```\nRETURN 42\n```";
358 let result = NLQPipeline::extract_cypher(input);
359 assert_eq!(result, "RETURN 42");
360 }
361
362 #[test]
363 fn test_extract_cypher_mixed_case_keywords() {
364 let input = "match (n:Person)\nwhere n.age > 30\nreturn n.name";
365 let result = NLQPipeline::extract_cypher(input);
366 assert!(result.contains("match") || result.contains("MATCH"));
368 }
369
370 #[test]
371 fn test_nlq_pipeline_new_with_different_providers() {
372 let config = NLQConfig {
374 enabled: true,
375 provider: LLMProvider::OpenAI,
376 model: "gpt-4".to_string(),
377 api_key: Some("sk-test".to_string()),
378 api_base_url: None,
379 system_prompt: None,
380 };
381 let pipeline = NLQPipeline::new(config);
382 assert!(pipeline.is_ok());
383 }
384
385 #[test]
386 fn test_is_safe_query_unwind_prefix() {
387 let pipeline = make_pipeline();
388 assert!(pipeline.is_safe_query("UNWIND [1,2,3] AS x RETURN x"));
389 }
390}