Skip to main content

hermes_core/dsl/
query_field_router.rs

1//! Query Field Router - Routes queries to specific fields based on regex patterns
2//!
3//! This module provides functionality to detect if a query matches a regex pattern,
4//! extract capture groups, substitute them into a template, and route the result
5//! to a specific field instead of (or in addition to) default fields.
6//!
7//! # Template Language
8//!
9//! The substitution template supports:
10//! - Simple group references: `{0}`, `{1}`, `{2}` etc.
11//! - Expression syntax with functions: `{g(1).replace('-', '').lower()}`
12//!
13//! ## Available Functions
14//!
15//! - `g(n)` - Get capture group n (0 = entire match, 1+ = capture groups)
16//! - `.replace(from, to)` - Replace all occurrences of `from` with `to`
17//! - `.lower()` - Convert to lowercase
18//! - `.upper()` - Convert to uppercase
19//! - `.trim()` - Remove leading/trailing whitespace
20//!
21//! # Example
22//!
23//! ```text
24//! # In SDL:
25//! index documents {
26//!     field title: text [indexed, stored]
27//!     field uri: text [indexed, stored]
28//!
29//!     # Route DOI queries to uri field exclusively
30//!     query_router {
31//!         pattern: r"10\.\d{4,}/[^\s]+"
32//!         substitution: "doi://{0}"
33//!         target_field: uri
34//!         mode: exclusive
35//!     }
36//!
37//!     # Route ISBN with hyphen removal
38//!     query_router {
39//!         pattern: r"^isbn:([\d\-]+)$"
40//!         substitution: "isbn://{g(1).replace('-', '')}"
41//!         target_field: uri
42//!         mode: exclusive
43//!     }
44//! }
45//! ```
46
47use regex::Regex;
48use serde::{Deserialize, Serialize};
49
50/// Routing mode for matched queries
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum RoutingMode {
53    /// Query only the target field (replace default fields)
54    #[serde(rename = "exclusive")]
55    Exclusive,
56    /// Query both target field and default fields
57    #[serde(rename = "additional")]
58    #[default]
59    Additional,
60}
61
62/// A single query routing rule
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryRouterRule {
65    /// Regex pattern to match against the query
66    pub pattern: String,
67    /// Substitution template using {0}, {1}, etc. for capture groups
68    /// {0} is the entire match, {1} is first capture group, etc.
69    pub substitution: String,
70    /// Target field name to route the substituted query to
71    pub target_field: String,
72    /// Whether this is exclusive (replaces default) or additional
73    #[serde(default)]
74    pub mode: RoutingMode,
75}
76
77/// Result of applying a routing rule
78#[derive(Debug, Clone)]
79pub struct RoutedQuery {
80    /// The transformed query string
81    pub query: String,
82    /// Target field name
83    pub target_field: String,
84    /// Routing mode
85    pub mode: RoutingMode,
86}
87
88/// Template expression evaluator
89///
90/// Evaluates expressions like `{g(1).replace('-', '').lower()}`
91mod template {
92    use regex::Captures;
93
94    /// Evaluate a substitution template with the given regex captures
95    pub fn evaluate(template: &str, captures: &Captures) -> String {
96        let mut result = String::new();
97        let mut chars = template.chars().peekable();
98
99        while let Some(c) = chars.next() {
100            if c == '{' {
101                // Parse expression until closing brace
102                let mut expr = String::new();
103                let mut brace_depth = 1;
104
105                for c in chars.by_ref() {
106                    if c == '{' {
107                        brace_depth += 1;
108                        expr.push(c);
109                    } else if c == '}' {
110                        brace_depth -= 1;
111                        if brace_depth == 0 {
112                            break;
113                        }
114                        expr.push(c);
115                    } else {
116                        expr.push(c);
117                    }
118                }
119
120                // Evaluate the expression
121                let value = evaluate_expr(&expr, captures);
122                result.push_str(&value);
123            } else {
124                result.push(c);
125            }
126        }
127
128        result
129    }
130
131    /// Evaluate a single expression (content inside {})
132    fn evaluate_expr(expr: &str, captures: &Captures) -> String {
133        let expr = expr.trim();
134
135        // Check for simple numeric reference like "0", "1", "2"
136        if let Ok(group_num) = expr.parse::<usize>() {
137            return captures
138                .get(group_num)
139                .map(|m| m.as_str().to_string())
140                .unwrap_or_default();
141        }
142
143        // Parse expression with function calls
144        parse_and_evaluate(expr, captures)
145    }
146
147    /// Parse and evaluate an expression like `g(1).replace('-', '').lower()`
148    fn parse_and_evaluate(expr: &str, captures: &Captures) -> String {
149        let mut chars = expr.chars().peekable();
150        let mut value = String::new();
151
152        // Skip whitespace
153        while chars.peek() == Some(&' ') {
154            chars.next();
155        }
156
157        // Parse initial value (must start with g(n))
158        if expr.starts_with("g(") {
159            // Parse g(n)
160            chars.next(); // 'g'
161            chars.next(); // '('
162
163            let mut num_str = String::new();
164            while let Some(&c) = chars.peek() {
165                if c == ')' {
166                    chars.next();
167                    break;
168                }
169                num_str.push(c);
170                chars.next();
171            }
172
173            if let Ok(group_num) = num_str.trim().parse::<usize>() {
174                value = captures
175                    .get(group_num)
176                    .map(|m| m.as_str().to_string())
177                    .unwrap_or_default();
178            }
179        } else {
180            // Unknown expression start
181            return expr.to_string();
182        }
183
184        // Parse method chain
185        while chars.peek().is_some() {
186            // Skip whitespace
187            while chars.peek() == Some(&' ') {
188                chars.next();
189            }
190
191            // Expect '.'
192            if chars.peek() != Some(&'.') {
193                break;
194            }
195            chars.next(); // consume '.'
196
197            // Parse method name
198            let mut method_name = String::new();
199            while let Some(&c) = chars.peek() {
200                if c == '(' || c == ' ' {
201                    break;
202                }
203                method_name.push(c);
204                chars.next();
205            }
206
207            // Skip whitespace
208            while chars.peek() == Some(&' ') {
209                chars.next();
210            }
211
212            // Parse arguments if present
213            let args = if chars.peek() == Some(&'(') {
214                chars.next(); // consume '('
215                parse_args(&mut chars)
216            } else {
217                vec![]
218            };
219
220            // Apply method
221            value = apply_method(&value, &method_name, &args);
222        }
223
224        value
225    }
226
227    /// Parse function arguments from the char iterator
228    fn parse_args(chars: &mut std::iter::Peekable<std::str::Chars>) -> Vec<String> {
229        let mut args = Vec::new();
230        let mut current_arg = String::new();
231        let mut in_string = false;
232        let mut string_char = '"';
233
234        for c in chars.by_ref() {
235            if c == ')' && !in_string {
236                // End of arguments
237                let arg = current_arg.trim().to_string();
238                if !arg.is_empty() {
239                    args.push(parse_string_literal(&arg));
240                }
241                break;
242            } else if (c == '"' || c == '\'') && !in_string {
243                in_string = true;
244                string_char = c;
245                current_arg.push(c);
246            } else if c == string_char && in_string {
247                in_string = false;
248                current_arg.push(c);
249            } else if c == ',' && !in_string {
250                let arg = current_arg.trim().to_string();
251                if !arg.is_empty() {
252                    args.push(parse_string_literal(&arg));
253                }
254                current_arg.clear();
255            } else {
256                current_arg.push(c);
257            }
258        }
259
260        args
261    }
262
263    /// Parse a string literal, removing quotes
264    fn parse_string_literal(s: &str) -> String {
265        let s = s.trim();
266        if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
267            s[1..s.len() - 1].to_string()
268        } else {
269            s.to_string()
270        }
271    }
272
273    /// Apply a method to a value
274    fn apply_method(value: &str, method: &str, args: &[String]) -> String {
275        match method {
276            "replace" => {
277                if args.len() >= 2 {
278                    value.replace(&args[0], &args[1])
279                } else if args.len() == 1 {
280                    value.replace(&args[0], "")
281                } else {
282                    value.to_string()
283                }
284            }
285            "lower" | "lowercase" => value.to_lowercase(),
286            "upper" | "uppercase" => value.to_uppercase(),
287            "trim" => value.trim().to_string(),
288            "trim_start" | "ltrim" => value.trim_start().to_string(),
289            "trim_end" | "rtrim" => value.trim_end().to_string(),
290            _ => value.to_string(),
291        }
292    }
293
294    #[cfg(test)]
295    mod tests {
296        use super::*;
297        use regex::Regex;
298
299        fn make_captures<'a>(pattern: &str, text: &'a str) -> Option<Captures<'a>> {
300            Regex::new(pattern).ok()?.captures(text)
301        }
302
303        #[test]
304        fn test_simple_substitution() {
305            let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
306            assert_eq!(evaluate("value: {1}", &caps), "value: 123");
307        }
308
309        #[test]
310        fn test_g_function() {
311            let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
312            assert_eq!(evaluate("{g(1)}", &caps), "123");
313            assert_eq!(evaluate("{g(0)}", &caps), "123");
314        }
315
316        #[test]
317        fn test_replace_function() {
318            let caps = make_captures(r"([\d\-]+)", "isbn:978-3-16-148410-0").unwrap();
319            assert_eq!(evaluate("{g(1).replace('-', '')}", &caps), "9783161484100");
320        }
321
322        #[test]
323        fn test_lower_function() {
324            let caps = make_captures(r"(\w+)", "HELLO").unwrap();
325            assert_eq!(evaluate("{g(1).lower()}", &caps), "hello");
326        }
327
328        #[test]
329        fn test_upper_function() {
330            let caps = make_captures(r"(\w+)", "hello").unwrap();
331            assert_eq!(evaluate("{g(1).upper()}", &caps), "HELLO");
332        }
333
334        #[test]
335        fn test_trim_function() {
336            let caps = make_captures(r"(.+)", "  hello  ").unwrap();
337            assert_eq!(evaluate("{g(1).trim()}", &caps), "hello");
338        }
339
340        #[test]
341        fn test_chained_functions() {
342            let caps = make_captures(r"([\d\-]+)", "978-3-16").unwrap();
343            assert_eq!(evaluate("{g(1).replace('-', '').lower()}", &caps), "978316");
344        }
345
346        #[test]
347        fn test_mixed_template() {
348            let caps = make_captures(r"isbn:([\d\-]+)", "isbn:978-3-16").unwrap();
349            assert_eq!(
350                evaluate("isbn://{g(1).replace('-', '')}", &caps),
351                "isbn://978316"
352            );
353        }
354
355        #[test]
356        fn test_multiple_expressions() {
357            let caps = make_captures(r"(\w+):(\w+)", "key:VALUE").unwrap();
358            assert_eq!(
359                evaluate("{g(1).upper()}={g(2).lower()}", &caps),
360                "KEY=value"
361            );
362        }
363    }
364}
365
366/// Compiled query router rule with pre-compiled regex
367#[derive(Debug, Clone)]
368pub struct CompiledRouterRule {
369    regex: Regex,
370    substitution: String,
371    target_field: String,
372    mode: RoutingMode,
373}
374
375impl CompiledRouterRule {
376    /// Create a new compiled router rule
377    pub fn new(rule: &QueryRouterRule) -> Result<Self, String> {
378        let regex = Regex::new(&rule.pattern)
379            .map_err(|e| format!("Invalid regex pattern '{}': {}", rule.pattern, e))?;
380
381        Ok(Self {
382            regex,
383            substitution: rule.substitution.clone(),
384            target_field: rule.target_field.clone(),
385            mode: rule.mode,
386        })
387    }
388
389    /// Try to match and transform a query
390    pub fn try_match(&self, query: &str) -> Option<RoutedQuery> {
391        let captures = self.regex.captures(query)?;
392
393        // Use the template evaluator for substitution
394        let result = template::evaluate(&self.substitution, &captures);
395
396        Some(RoutedQuery {
397            query: result,
398            target_field: self.target_field.clone(),
399            mode: self.mode,
400        })
401    }
402
403    /// Get the target field name
404    pub fn target_field(&self) -> &str {
405        &self.target_field
406    }
407
408    /// Get the routing mode
409    pub fn mode(&self) -> RoutingMode {
410        self.mode
411    }
412}
413
414/// Query field router that holds multiple routing rules
415#[derive(Debug, Clone, Default)]
416pub struct QueryFieldRouter {
417    rules: Vec<CompiledRouterRule>,
418}
419
420impl QueryFieldRouter {
421    /// Create a new empty router
422    pub fn new() -> Self {
423        Self { rules: Vec::new() }
424    }
425
426    /// Create a router from a list of rules
427    pub fn from_rules(rules: &[QueryRouterRule]) -> Result<Self, String> {
428        let compiled: Result<Vec<_>, _> = rules.iter().map(CompiledRouterRule::new).collect();
429        Ok(Self { rules: compiled? })
430    }
431
432    /// Add a rule to the router
433    pub fn add_rule(&mut self, rule: &QueryRouterRule) -> Result<(), String> {
434        self.rules.push(CompiledRouterRule::new(rule)?);
435        Ok(())
436    }
437
438    /// Check if router has any rules
439    pub fn is_empty(&self) -> bool {
440        self.rules.is_empty()
441    }
442
443    /// Get the number of rules
444    pub fn len(&self) -> usize {
445        self.rules.len()
446    }
447
448    /// Try to route a query, returning the first matching rule's result
449    pub fn route(&self, query: &str) -> Option<RoutedQuery> {
450        for rule in &self.rules {
451            if let Some(routed) = rule.try_match(query) {
452                return Some(routed);
453            }
454        }
455        None
456    }
457
458    /// Try to route a query, returning all matching rules' results
459    pub fn route_all(&self, query: &str) -> Vec<RoutedQuery> {
460        self.rules
461            .iter()
462            .filter_map(|rule| rule.try_match(query))
463            .collect()
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_doi_routing() {
473        let rule = QueryRouterRule {
474            pattern: r"(10\.\d{4,}/[^\s]+)".to_string(),
475            substitution: "doi://{1}".to_string(),
476            target_field: "uri".to_string(),
477            mode: RoutingMode::Exclusive,
478        };
479
480        let compiled = CompiledRouterRule::new(&rule).unwrap();
481
482        // Should match DOI
483        let result = compiled.try_match("10.1234/abc.123").unwrap();
484        assert_eq!(result.query, "doi://10.1234/abc.123");
485        assert_eq!(result.target_field, "uri");
486        assert_eq!(result.mode, RoutingMode::Exclusive);
487
488        // Should not match non-DOI
489        assert!(compiled.try_match("hello world").is_none());
490    }
491
492    #[test]
493    fn test_full_match_substitution() {
494        let rule = QueryRouterRule {
495            pattern: r"^#(\d+)$".to_string(),
496            substitution: "{1}".to_string(),
497            target_field: "issue_number".to_string(),
498            mode: RoutingMode::Exclusive,
499        };
500
501        let compiled = CompiledRouterRule::new(&rule).unwrap();
502
503        let result = compiled.try_match("#42").unwrap();
504        assert_eq!(result.query, "42");
505        assert_eq!(result.target_field, "issue_number");
506    }
507
508    #[test]
509    fn test_multiple_capture_groups() {
510        let rule = QueryRouterRule {
511            pattern: r"(\w+):(\w+)".to_string(),
512            substitution: "field={1} value={2}".to_string(),
513            target_field: "custom".to_string(),
514            mode: RoutingMode::Additional,
515        };
516
517        let compiled = CompiledRouterRule::new(&rule).unwrap();
518
519        let result = compiled.try_match("author:smith").unwrap();
520        assert_eq!(result.query, "field=author value=smith");
521        assert_eq!(result.mode, RoutingMode::Additional);
522    }
523
524    #[test]
525    fn test_router_with_multiple_rules() {
526        let rules = vec![
527            QueryRouterRule {
528                pattern: r"^doi:(10\.\d{4,}/[^\s]+)$".to_string(),
529                substitution: "doi://{1}".to_string(),
530                target_field: "uri".to_string(),
531                mode: RoutingMode::Exclusive,
532            },
533            QueryRouterRule {
534                pattern: r"^pmid:(\d+)$".to_string(),
535                substitution: "pubmed://{1}".to_string(),
536                target_field: "uri".to_string(),
537                mode: RoutingMode::Exclusive,
538            },
539        ];
540
541        let router = QueryFieldRouter::from_rules(&rules).unwrap();
542
543        // Match first rule
544        let result = router.route("doi:10.1234/test").unwrap();
545        assert_eq!(result.query, "doi://10.1234/test");
546
547        // Match second rule
548        let result = router.route("pmid:12345678").unwrap();
549        assert_eq!(result.query, "pubmed://12345678");
550
551        // No match
552        assert!(router.route("random query").is_none());
553    }
554
555    #[test]
556    fn test_invalid_regex() {
557        let rule = QueryRouterRule {
558            pattern: r"[invalid".to_string(),
559            substitution: "{0}".to_string(),
560            target_field: "test".to_string(),
561            mode: RoutingMode::Exclusive,
562        };
563
564        assert!(CompiledRouterRule::new(&rule).is_err());
565    }
566
567    #[test]
568    fn test_routing_mode_default() {
569        let mode: RoutingMode = Default::default();
570        assert_eq!(mode, RoutingMode::Additional);
571    }
572}