1use regex::Regex;
48use serde::{Deserialize, Serialize};
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum RoutingMode {
53 #[serde(rename = "exclusive")]
55 Exclusive,
56 #[serde(rename = "additional")]
58 #[default]
59 Additional,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryRouterRule {
65 pub pattern: String,
67 pub substitution: String,
70 pub target_field: String,
72 #[serde(default)]
74 pub mode: RoutingMode,
75}
76
77#[derive(Debug, Clone)]
79pub struct RoutedQuery {
80 pub query: String,
82 pub target_field: String,
84 pub mode: RoutingMode,
86}
87
88mod template {
92 use regex::Captures;
93
94 pub fn evaluate(template: &str, captures: &Captures) -> String {
96 let mut result = String::new();
97 let mut chars = template.chars().peekable();
98
99 while let Some(c) = chars.next() {
100 if c == '{' {
101 let mut expr = String::new();
103 let mut brace_depth = 1;
104
105 for c in chars.by_ref() {
106 if c == '{' {
107 brace_depth += 1;
108 expr.push(c);
109 } else if c == '}' {
110 brace_depth -= 1;
111 if brace_depth == 0 {
112 break;
113 }
114 expr.push(c);
115 } else {
116 expr.push(c);
117 }
118 }
119
120 let value = evaluate_expr(&expr, captures);
122 result.push_str(&value);
123 } else {
124 result.push(c);
125 }
126 }
127
128 result
129 }
130
131 fn evaluate_expr(expr: &str, captures: &Captures) -> String {
133 let expr = expr.trim();
134
135 if let Ok(group_num) = expr.parse::<usize>() {
137 return captures
138 .get(group_num)
139 .map(|m| m.as_str().to_string())
140 .unwrap_or_default();
141 }
142
143 parse_and_evaluate(expr, captures)
145 }
146
147 fn parse_and_evaluate(expr: &str, captures: &Captures) -> String {
149 let mut chars = expr.chars().peekable();
150 let mut value = String::new();
151
152 while chars.peek() == Some(&' ') {
154 chars.next();
155 }
156
157 if expr.starts_with("g(") {
159 chars.next(); chars.next(); let mut num_str = String::new();
164 while let Some(&c) = chars.peek() {
165 if c == ')' {
166 chars.next();
167 break;
168 }
169 num_str.push(c);
170 chars.next();
171 }
172
173 if let Ok(group_num) = num_str.trim().parse::<usize>() {
174 value = captures
175 .get(group_num)
176 .map(|m| m.as_str().to_string())
177 .unwrap_or_default();
178 }
179 } else {
180 return expr.to_string();
182 }
183
184 while chars.peek().is_some() {
186 while chars.peek() == Some(&' ') {
188 chars.next();
189 }
190
191 if chars.peek() != Some(&'.') {
193 break;
194 }
195 chars.next(); let mut method_name = String::new();
199 while let Some(&c) = chars.peek() {
200 if c == '(' || c == ' ' {
201 break;
202 }
203 method_name.push(c);
204 chars.next();
205 }
206
207 while chars.peek() == Some(&' ') {
209 chars.next();
210 }
211
212 let args = if chars.peek() == Some(&'(') {
214 chars.next(); parse_args(&mut chars)
216 } else {
217 vec![]
218 };
219
220 value = apply_method(&value, &method_name, &args);
222 }
223
224 value
225 }
226
227 fn parse_args(chars: &mut std::iter::Peekable<std::str::Chars>) -> Vec<String> {
229 let mut args = Vec::new();
230 let mut current_arg = String::new();
231 let mut in_string = false;
232 let mut string_char = '"';
233
234 for c in chars.by_ref() {
235 if c == ')' && !in_string {
236 let arg = current_arg.trim().to_string();
238 if !arg.is_empty() {
239 args.push(parse_string_literal(&arg));
240 }
241 break;
242 } else if (c == '"' || c == '\'') && !in_string {
243 in_string = true;
244 string_char = c;
245 current_arg.push(c);
246 } else if c == string_char && in_string {
247 in_string = false;
248 current_arg.push(c);
249 } else if c == ',' && !in_string {
250 let arg = current_arg.trim().to_string();
251 if !arg.is_empty() {
252 args.push(parse_string_literal(&arg));
253 }
254 current_arg.clear();
255 } else {
256 current_arg.push(c);
257 }
258 }
259
260 args
261 }
262
263 fn parse_string_literal(s: &str) -> String {
265 let s = s.trim();
266 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
267 s[1..s.len() - 1].to_string()
268 } else {
269 s.to_string()
270 }
271 }
272
273 fn apply_method(value: &str, method: &str, args: &[String]) -> String {
275 match method {
276 "replace" => {
277 if args.len() >= 2 {
278 value.replace(&args[0], &args[1])
279 } else if args.len() == 1 {
280 value.replace(&args[0], "")
281 } else {
282 value.to_string()
283 }
284 }
285 "lower" | "lowercase" => value.to_lowercase(),
286 "upper" | "uppercase" => value.to_uppercase(),
287 "trim" => value.trim().to_string(),
288 "trim_start" | "ltrim" => value.trim_start().to_string(),
289 "trim_end" | "rtrim" => value.trim_end().to_string(),
290 _ => value.to_string(),
291 }
292 }
293
294 #[cfg(test)]
295 mod tests {
296 use super::*;
297 use regex::Regex;
298
299 fn make_captures<'a>(pattern: &str, text: &'a str) -> Option<Captures<'a>> {
300 Regex::new(pattern).ok()?.captures(text)
301 }
302
303 #[test]
304 fn test_simple_substitution() {
305 let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
306 assert_eq!(evaluate("value: {1}", &caps), "value: 123");
307 }
308
309 #[test]
310 fn test_g_function() {
311 let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
312 assert_eq!(evaluate("{g(1)}", &caps), "123");
313 assert_eq!(evaluate("{g(0)}", &caps), "123");
314 }
315
316 #[test]
317 fn test_replace_function() {
318 let caps = make_captures(r"([\d\-]+)", "isbn:978-3-16-148410-0").unwrap();
319 assert_eq!(evaluate("{g(1).replace('-', '')}", &caps), "9783161484100");
320 }
321
322 #[test]
323 fn test_lower_function() {
324 let caps = make_captures(r"(\w+)", "HELLO").unwrap();
325 assert_eq!(evaluate("{g(1).lower()}", &caps), "hello");
326 }
327
328 #[test]
329 fn test_upper_function() {
330 let caps = make_captures(r"(\w+)", "hello").unwrap();
331 assert_eq!(evaluate("{g(1).upper()}", &caps), "HELLO");
332 }
333
334 #[test]
335 fn test_trim_function() {
336 let caps = make_captures(r"(.+)", " hello ").unwrap();
337 assert_eq!(evaluate("{g(1).trim()}", &caps), "hello");
338 }
339
340 #[test]
341 fn test_chained_functions() {
342 let caps = make_captures(r"([\d\-]+)", "978-3-16").unwrap();
343 assert_eq!(evaluate("{g(1).replace('-', '').lower()}", &caps), "978316");
344 }
345
346 #[test]
347 fn test_mixed_template() {
348 let caps = make_captures(r"isbn:([\d\-]+)", "isbn:978-3-16").unwrap();
349 assert_eq!(
350 evaluate("isbn://{g(1).replace('-', '')}", &caps),
351 "isbn://978316"
352 );
353 }
354
355 #[test]
356 fn test_multiple_expressions() {
357 let caps = make_captures(r"(\w+):(\w+)", "key:VALUE").unwrap();
358 assert_eq!(
359 evaluate("{g(1).upper()}={g(2).lower()}", &caps),
360 "KEY=value"
361 );
362 }
363 }
364}
365
366#[derive(Debug, Clone)]
368pub struct CompiledRouterRule {
369 regex: Regex,
370 substitution: String,
371 target_field: String,
372 mode: RoutingMode,
373}
374
375impl CompiledRouterRule {
376 pub fn new(rule: &QueryRouterRule) -> Result<Self, String> {
378 let regex = Regex::new(&rule.pattern)
379 .map_err(|e| format!("Invalid regex pattern '{}': {}", rule.pattern, e))?;
380
381 Ok(Self {
382 regex,
383 substitution: rule.substitution.clone(),
384 target_field: rule.target_field.clone(),
385 mode: rule.mode,
386 })
387 }
388
389 pub fn try_match(&self, query: &str) -> Option<RoutedQuery> {
391 let captures = self.regex.captures(query)?;
392
393 let result = template::evaluate(&self.substitution, &captures);
395
396 Some(RoutedQuery {
397 query: result,
398 target_field: self.target_field.clone(),
399 mode: self.mode,
400 })
401 }
402
403 pub fn target_field(&self) -> &str {
405 &self.target_field
406 }
407
408 pub fn mode(&self) -> RoutingMode {
410 self.mode
411 }
412}
413
414#[derive(Debug, Clone, Default)]
416pub struct QueryFieldRouter {
417 rules: Vec<CompiledRouterRule>,
418}
419
420impl QueryFieldRouter {
421 pub fn new() -> Self {
423 Self { rules: Vec::new() }
424 }
425
426 pub fn from_rules(rules: &[QueryRouterRule]) -> Result<Self, String> {
428 let compiled: Result<Vec<_>, _> = rules.iter().map(CompiledRouterRule::new).collect();
429 Ok(Self { rules: compiled? })
430 }
431
432 pub fn add_rule(&mut self, rule: &QueryRouterRule) -> Result<(), String> {
434 self.rules.push(CompiledRouterRule::new(rule)?);
435 Ok(())
436 }
437
438 pub fn is_empty(&self) -> bool {
440 self.rules.is_empty()
441 }
442
443 pub fn len(&self) -> usize {
445 self.rules.len()
446 }
447
448 pub fn route(&self, query: &str) -> Option<RoutedQuery> {
450 for rule in &self.rules {
451 if let Some(routed) = rule.try_match(query) {
452 return Some(routed);
453 }
454 }
455 None
456 }
457
458 pub fn route_all(&self, query: &str) -> Vec<RoutedQuery> {
460 self.rules
461 .iter()
462 .filter_map(|rule| rule.try_match(query))
463 .collect()
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn test_doi_routing() {
473 let rule = QueryRouterRule {
474 pattern: r"(10\.\d{4,}/[^\s]+)".to_string(),
475 substitution: "doi://{1}".to_string(),
476 target_field: "uri".to_string(),
477 mode: RoutingMode::Exclusive,
478 };
479
480 let compiled = CompiledRouterRule::new(&rule).unwrap();
481
482 let result = compiled.try_match("10.1234/abc.123").unwrap();
484 assert_eq!(result.query, "doi://10.1234/abc.123");
485 assert_eq!(result.target_field, "uri");
486 assert_eq!(result.mode, RoutingMode::Exclusive);
487
488 assert!(compiled.try_match("hello world").is_none());
490 }
491
492 #[test]
493 fn test_full_match_substitution() {
494 let rule = QueryRouterRule {
495 pattern: r"^#(\d+)$".to_string(),
496 substitution: "{1}".to_string(),
497 target_field: "issue_number".to_string(),
498 mode: RoutingMode::Exclusive,
499 };
500
501 let compiled = CompiledRouterRule::new(&rule).unwrap();
502
503 let result = compiled.try_match("#42").unwrap();
504 assert_eq!(result.query, "42");
505 assert_eq!(result.target_field, "issue_number");
506 }
507
508 #[test]
509 fn test_multiple_capture_groups() {
510 let rule = QueryRouterRule {
511 pattern: r"(\w+):(\w+)".to_string(),
512 substitution: "field={1} value={2}".to_string(),
513 target_field: "custom".to_string(),
514 mode: RoutingMode::Additional,
515 };
516
517 let compiled = CompiledRouterRule::new(&rule).unwrap();
518
519 let result = compiled.try_match("author:smith").unwrap();
520 assert_eq!(result.query, "field=author value=smith");
521 assert_eq!(result.mode, RoutingMode::Additional);
522 }
523
524 #[test]
525 fn test_router_with_multiple_rules() {
526 let rules = vec![
527 QueryRouterRule {
528 pattern: r"^doi:(10\.\d{4,}/[^\s]+)$".to_string(),
529 substitution: "doi://{1}".to_string(),
530 target_field: "uri".to_string(),
531 mode: RoutingMode::Exclusive,
532 },
533 QueryRouterRule {
534 pattern: r"^pmid:(\d+)$".to_string(),
535 substitution: "pubmed://{1}".to_string(),
536 target_field: "uri".to_string(),
537 mode: RoutingMode::Exclusive,
538 },
539 ];
540
541 let router = QueryFieldRouter::from_rules(&rules).unwrap();
542
543 let result = router.route("doi:10.1234/test").unwrap();
545 assert_eq!(result.query, "doi://10.1234/test");
546
547 let result = router.route("pmid:12345678").unwrap();
549 assert_eq!(result.query, "pubmed://12345678");
550
551 assert!(router.route("random query").is_none());
553 }
554
555 #[test]
556 fn test_invalid_regex() {
557 let rule = QueryRouterRule {
558 pattern: r"[invalid".to_string(),
559 substitution: "{0}".to_string(),
560 target_field: "test".to_string(),
561 mode: RoutingMode::Exclusive,
562 };
563
564 assert!(CompiledRouterRule::new(&rule).is_err());
565 }
566
567 #[test]
568 fn test_routing_mode_default() {
569 let mode: RoutingMode = Default::default();
570 assert_eq!(mode, RoutingMode::Additional);
571 }
572}