hermes_core/dsl/
query_field_router.rs

1//! Query Field Router - Routes queries to specific fields based on regex patterns
2//!
3//! This module provides functionality to detect if a query matches a regex pattern,
4//! extract capture groups, substitute them into a template, and route the result
5//! to a specific field instead of (or in addition to) default fields.
6//!
7//! # Template Language
8//!
9//! The substitution template supports:
10//! - Simple group references: `{0}`, `{1}`, `{2}` etc.
11//! - Expression syntax with functions: `{g(1).replace('-', '').lower()}`
12//!
13//! ## Available Functions
14//!
15//! - `g(n)` - Get capture group n (0 = entire match, 1+ = capture groups)
16//! - `.replace(from, to)` - Replace all occurrences of `from` with `to`
17//! - `.lower()` - Convert to lowercase
18//! - `.upper()` - Convert to uppercase
19//! - `.trim()` - Remove leading/trailing whitespace
20//!
21//! # Example
22//!
23//! ```text
24//! # In SDL:
25//! index documents {
26//!     field title: text [indexed, stored]
27//!     field uri: text [indexed, stored]
28//!     
29//!     # Route DOI queries to uri field exclusively
30//!     query_router {
31//!         pattern: r"10\.\d{4,}/[^\s]+"
32//!         substitution: "doi://{0}"
33//!         target_field: uri
34//!         mode: exclusive
35//!     }
36//!     
37//!     # Route ISBN with hyphen removal
38//!     query_router {
39//!         pattern: r"^isbn:([\d\-]+)$"
40//!         substitution: "isbn://{g(1).replace('-', '')}"
41//!         target_field: uri
42//!         mode: exclusive
43//!     }
44//! }
45//! ```
46
47use regex::Regex;
48use serde::{Deserialize, Serialize};
49
50/// Routing mode for matched queries
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum RoutingMode {
53    /// Query only the target field (replace default fields)
54    #[serde(rename = "exclusive")]
55    Exclusive,
56    /// Query both target field and default fields
57    #[serde(rename = "additional")]
58    #[default]
59    Additional,
60}
61
62/// A single query routing rule
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryRouterRule {
65    /// Regex pattern to match against the query
66    pub pattern: String,
67    /// Substitution template using {0}, {1}, etc. for capture groups
68    /// {0} is the entire match, {1} is first capture group, etc.
69    pub substitution: String,
70    /// Target field name to route the substituted query to
71    pub target_field: String,
72    /// Whether this is exclusive (replaces default) or additional
73    #[serde(default)]
74    pub mode: RoutingMode,
75}
76
77/// Result of applying a routing rule
78#[derive(Debug, Clone)]
79pub struct RoutedQuery {
80    /// The transformed query string
81    pub query: String,
82    /// Target field name
83    pub target_field: String,
84    /// Routing mode
85    pub mode: RoutingMode,
86}
87
88/// Template expression evaluator
89/// 
90/// Evaluates expressions like `{g(1).replace('-', '').lower()}`
91mod template {
92    use regex::Captures;
93
94    /// Evaluate a substitution template with the given regex captures
95    pub fn evaluate(template: &str, captures: &Captures) -> String {
96        let mut result = String::new();
97        let mut chars = template.chars().peekable();
98
99        while let Some(c) = chars.next() {
100            if c == '{' {
101                // Parse expression until closing brace
102                let mut expr = String::new();
103                let mut brace_depth = 1;
104                
105                while let Some(c) = chars.next() {
106                    if c == '{' {
107                        brace_depth += 1;
108                        expr.push(c);
109                    } else if c == '}' {
110                        brace_depth -= 1;
111                        if brace_depth == 0 {
112                            break;
113                        }
114                        expr.push(c);
115                    } else {
116                        expr.push(c);
117                    }
118                }
119
120                // Evaluate the expression
121                let value = evaluate_expr(&expr, captures);
122                result.push_str(&value);
123            } else {
124                result.push(c);
125            }
126        }
127
128        result
129    }
130
131    /// Evaluate a single expression (content inside {})
132    fn evaluate_expr(expr: &str, captures: &Captures) -> String {
133        let expr = expr.trim();
134
135        // Check for simple numeric reference like "0", "1", "2"
136        if let Ok(group_num) = expr.parse::<usize>() {
137            return captures
138                .get(group_num)
139                .map(|m| m.as_str().to_string())
140                .unwrap_or_default();
141        }
142
143        // Parse expression with function calls
144        parse_and_evaluate(expr, captures)
145    }
146
147    /// Parse and evaluate an expression like `g(1).replace('-', '').lower()`
148    fn parse_and_evaluate(expr: &str, captures: &Captures) -> String {
149        let mut chars = expr.chars().peekable();
150        let mut value = String::new();
151
152        // Skip whitespace
153        while chars.peek() == Some(&' ') {
154            chars.next();
155        }
156
157        // Parse initial value (must start with g(n))
158        if expr.starts_with("g(") {
159            // Parse g(n)
160            chars.next(); // 'g'
161            chars.next(); // '('
162            
163            let mut num_str = String::new();
164            while let Some(&c) = chars.peek() {
165                if c == ')' {
166                    chars.next();
167                    break;
168                }
169                num_str.push(c);
170                chars.next();
171            }
172
173            if let Ok(group_num) = num_str.trim().parse::<usize>() {
174                value = captures
175                    .get(group_num)
176                    .map(|m| m.as_str().to_string())
177                    .unwrap_or_default();
178            }
179        } else {
180            // Unknown expression start
181            return expr.to_string();
182        }
183
184        // Parse method chain
185        while chars.peek().is_some() {
186            // Skip whitespace
187            while chars.peek() == Some(&' ') {
188                chars.next();
189            }
190
191            // Expect '.'
192            if chars.peek() != Some(&'.') {
193                break;
194            }
195            chars.next(); // consume '.'
196
197            // Parse method name
198            let mut method_name = String::new();
199            while let Some(&c) = chars.peek() {
200                if c == '(' || c == ' ' {
201                    break;
202                }
203                method_name.push(c);
204                chars.next();
205            }
206
207            // Skip whitespace
208            while chars.peek() == Some(&' ') {
209                chars.next();
210            }
211
212            // Parse arguments if present
213            let args = if chars.peek() == Some(&'(') {
214                chars.next(); // consume '('
215                parse_args(&mut chars)
216            } else {
217                vec![]
218            };
219
220            // Apply method
221            value = apply_method(&value, &method_name, &args);
222        }
223
224        value
225    }
226
227    /// Parse function arguments from the char iterator
228    fn parse_args(chars: &mut std::iter::Peekable<std::str::Chars>) -> Vec<String> {
229        let mut args = Vec::new();
230        let mut current_arg = String::new();
231        let mut in_string = false;
232        let mut string_char = '"';
233
234        while let Some(c) = chars.next() {
235            if c == ')' && !in_string {
236                // End of arguments
237                let arg = current_arg.trim().to_string();
238                if !arg.is_empty() {
239                    args.push(parse_string_literal(&arg));
240                }
241                break;
242            } else if (c == '"' || c == '\'') && !in_string {
243                in_string = true;
244                string_char = c;
245                current_arg.push(c);
246            } else if c == string_char && in_string {
247                in_string = false;
248                current_arg.push(c);
249            } else if c == ',' && !in_string {
250                let arg = current_arg.trim().to_string();
251                if !arg.is_empty() {
252                    args.push(parse_string_literal(&arg));
253                }
254                current_arg.clear();
255            } else {
256                current_arg.push(c);
257            }
258        }
259
260        args
261    }
262
263    /// Parse a string literal, removing quotes
264    fn parse_string_literal(s: &str) -> String {
265        let s = s.trim();
266        if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
267            s[1..s.len() - 1].to_string()
268        } else {
269            s.to_string()
270        }
271    }
272
273    /// Apply a method to a value
274    fn apply_method(value: &str, method: &str, args: &[String]) -> String {
275        match method {
276            "replace" => {
277                if args.len() >= 2 {
278                    value.replace(&args[0], &args[1])
279                } else if args.len() == 1 {
280                    value.replace(&args[0], "")
281                } else {
282                    value.to_string()
283                }
284            }
285            "lower" | "lowercase" => value.to_lowercase(),
286            "upper" | "uppercase" => value.to_uppercase(),
287            "trim" => value.trim().to_string(),
288            "trim_start" | "ltrim" => value.trim_start().to_string(),
289            "trim_end" | "rtrim" => value.trim_end().to_string(),
290            _ => value.to_string(),
291        }
292    }
293
294    #[cfg(test)]
295    mod tests {
296        use super::*;
297        use regex::Regex;
298
299        fn make_captures<'a>(pattern: &str, text: &'a str) -> Option<Captures<'a>> {
300            Regex::new(pattern).ok()?.captures(text)
301        }
302
303        #[test]
304        fn test_simple_substitution() {
305            let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
306            assert_eq!(evaluate("value: {1}", &caps), "value: 123");
307        }
308
309        #[test]
310        fn test_g_function() {
311            let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
312            assert_eq!(evaluate("{g(1)}", &caps), "123");
313            assert_eq!(evaluate("{g(0)}", &caps), "123");
314        }
315
316        #[test]
317        fn test_replace_function() {
318            let caps = make_captures(r"([\d\-]+)", "isbn:978-3-16-148410-0").unwrap();
319            assert_eq!(evaluate("{g(1).replace('-', '')}", &caps), "9783161484100");
320        }
321
322        #[test]
323        fn test_lower_function() {
324            let caps = make_captures(r"(\w+)", "HELLO").unwrap();
325            assert_eq!(evaluate("{g(1).lower()}", &caps), "hello");
326        }
327
328        #[test]
329        fn test_upper_function() {
330            let caps = make_captures(r"(\w+)", "hello").unwrap();
331            assert_eq!(evaluate("{g(1).upper()}", &caps), "HELLO");
332        }
333
334        #[test]
335        fn test_trim_function() {
336            let caps = make_captures(r"(.+)", "  hello  ").unwrap();
337            assert_eq!(evaluate("{g(1).trim()}", &caps), "hello");
338        }
339
340        #[test]
341        fn test_chained_functions() {
342            let caps = make_captures(r"([\d\-]+)", "978-3-16").unwrap();
343            assert_eq!(
344                evaluate("{g(1).replace('-', '').lower()}", &caps),
345                "978316"
346            );
347        }
348
349        #[test]
350        fn test_mixed_template() {
351            let caps = make_captures(r"isbn:([\d\-]+)", "isbn:978-3-16").unwrap();
352            assert_eq!(
353                evaluate("isbn://{g(1).replace('-', '')}", &caps),
354                "isbn://978316"
355            );
356        }
357
358        #[test]
359        fn test_multiple_expressions() {
360            let caps = make_captures(r"(\w+):(\w+)", "key:VALUE").unwrap();
361            assert_eq!(
362                evaluate("{g(1).upper()}={g(2).lower()}", &caps),
363                "KEY=value"
364            );
365        }
366    }
367}
368
369/// Compiled query router rule with pre-compiled regex
370#[derive(Debug, Clone)]
371pub struct CompiledRouterRule {
372    regex: Regex,
373    substitution: String,
374    target_field: String,
375    mode: RoutingMode,
376}
377
378impl CompiledRouterRule {
379    /// Create a new compiled router rule
380    pub fn new(rule: &QueryRouterRule) -> Result<Self, String> {
381        let regex = Regex::new(&rule.pattern)
382            .map_err(|e| format!("Invalid regex pattern '{}': {}", rule.pattern, e))?;
383
384        Ok(Self {
385            regex,
386            substitution: rule.substitution.clone(),
387            target_field: rule.target_field.clone(),
388            mode: rule.mode,
389        })
390    }
391
392    /// Try to match and transform a query
393    pub fn try_match(&self, query: &str) -> Option<RoutedQuery> {
394        let captures = self.regex.captures(query)?;
395
396        // Use the template evaluator for substitution
397        let result = template::evaluate(&self.substitution, &captures);
398
399        Some(RoutedQuery {
400            query: result,
401            target_field: self.target_field.clone(),
402            mode: self.mode,
403        })
404    }
405
406    /// Get the target field name
407    pub fn target_field(&self) -> &str {
408        &self.target_field
409    }
410
411    /// Get the routing mode
412    pub fn mode(&self) -> RoutingMode {
413        self.mode
414    }
415}
416
417/// Query field router that holds multiple routing rules
418#[derive(Debug, Clone, Default)]
419pub struct QueryFieldRouter {
420    rules: Vec<CompiledRouterRule>,
421}
422
423impl QueryFieldRouter {
424    /// Create a new empty router
425    pub fn new() -> Self {
426        Self { rules: Vec::new() }
427    }
428
429    /// Create a router from a list of rules
430    pub fn from_rules(rules: &[QueryRouterRule]) -> Result<Self, String> {
431        let compiled: Result<Vec<_>, _> = rules.iter().map(CompiledRouterRule::new).collect();
432        Ok(Self { rules: compiled? })
433    }
434
435    /// Add a rule to the router
436    pub fn add_rule(&mut self, rule: &QueryRouterRule) -> Result<(), String> {
437        self.rules.push(CompiledRouterRule::new(rule)?);
438        Ok(())
439    }
440
441    /// Check if router has any rules
442    pub fn is_empty(&self) -> bool {
443        self.rules.is_empty()
444    }
445
446    /// Get the number of rules
447    pub fn len(&self) -> usize {
448        self.rules.len()
449    }
450
451    /// Try to route a query, returning the first matching rule's result
452    pub fn route(&self, query: &str) -> Option<RoutedQuery> {
453        for rule in &self.rules {
454            if let Some(routed) = rule.try_match(query) {
455                return Some(routed);
456            }
457        }
458        None
459    }
460
461    /// Try to route a query, returning all matching rules' results
462    pub fn route_all(&self, query: &str) -> Vec<RoutedQuery> {
463        self.rules
464            .iter()
465            .filter_map(|rule| rule.try_match(query))
466            .collect()
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_doi_routing() {
476        let rule = QueryRouterRule {
477            pattern: r"(10\.\d{4,}/[^\s]+)".to_string(),
478            substitution: "doi://{1}".to_string(),
479            target_field: "uri".to_string(),
480            mode: RoutingMode::Exclusive,
481        };
482
483        let compiled = CompiledRouterRule::new(&rule).unwrap();
484
485        // Should match DOI
486        let result = compiled.try_match("10.1234/abc.123").unwrap();
487        assert_eq!(result.query, "doi://10.1234/abc.123");
488        assert_eq!(result.target_field, "uri");
489        assert_eq!(result.mode, RoutingMode::Exclusive);
490
491        // Should not match non-DOI
492        assert!(compiled.try_match("hello world").is_none());
493    }
494
495    #[test]
496    fn test_full_match_substitution() {
497        let rule = QueryRouterRule {
498            pattern: r"^#(\d+)$".to_string(),
499            substitution: "{1}".to_string(),
500            target_field: "issue_number".to_string(),
501            mode: RoutingMode::Exclusive,
502        };
503
504        let compiled = CompiledRouterRule::new(&rule).unwrap();
505
506        let result = compiled.try_match("#42").unwrap();
507        assert_eq!(result.query, "42");
508        assert_eq!(result.target_field, "issue_number");
509    }
510
511    #[test]
512    fn test_multiple_capture_groups() {
513        let rule = QueryRouterRule {
514            pattern: r"(\w+):(\w+)".to_string(),
515            substitution: "field={1} value={2}".to_string(),
516            target_field: "custom".to_string(),
517            mode: RoutingMode::Additional,
518        };
519
520        let compiled = CompiledRouterRule::new(&rule).unwrap();
521
522        let result = compiled.try_match("author:smith").unwrap();
523        assert_eq!(result.query, "field=author value=smith");
524        assert_eq!(result.mode, RoutingMode::Additional);
525    }
526
527    #[test]
528    fn test_router_with_multiple_rules() {
529        let rules = vec![
530            QueryRouterRule {
531                pattern: r"^doi:(10\.\d{4,}/[^\s]+)$".to_string(),
532                substitution: "doi://{1}".to_string(),
533                target_field: "uri".to_string(),
534                mode: RoutingMode::Exclusive,
535            },
536            QueryRouterRule {
537                pattern: r"^pmid:(\d+)$".to_string(),
538                substitution: "pubmed://{1}".to_string(),
539                target_field: "uri".to_string(),
540                mode: RoutingMode::Exclusive,
541            },
542        ];
543
544        let router = QueryFieldRouter::from_rules(&rules).unwrap();
545
546        // Match first rule
547        let result = router.route("doi:10.1234/test").unwrap();
548        assert_eq!(result.query, "doi://10.1234/test");
549
550        // Match second rule
551        let result = router.route("pmid:12345678").unwrap();
552        assert_eq!(result.query, "pubmed://12345678");
553
554        // No match
555        assert!(router.route("random query").is_none());
556    }
557
558    #[test]
559    fn test_invalid_regex() {
560        let rule = QueryRouterRule {
561            pattern: r"[invalid".to_string(),
562            substitution: "{0}".to_string(),
563            target_field: "test".to_string(),
564            mode: RoutingMode::Exclusive,
565        };
566
567        assert!(CompiledRouterRule::new(&rule).is_err());
568    }
569
570    #[test]
571    fn test_routing_mode_default() {
572        let mode: RoutingMode = Default::default();
573        assert_eq!(mode, RoutingMode::Additional);
574    }
575}