1use regex::Regex;
48use serde::{Deserialize, Serialize};
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum RoutingMode {
53 #[serde(rename = "exclusive")]
55 Exclusive,
56 #[serde(rename = "additional")]
58 #[default]
59 Additional,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryRouterRule {
65 pub pattern: String,
67 pub substitution: String,
70 pub target_field: String,
72 #[serde(default)]
74 pub mode: RoutingMode,
75}
76
77#[derive(Debug, Clone)]
79pub struct RoutedQuery {
80 pub query: String,
82 pub target_field: String,
84 pub mode: RoutingMode,
86}
87
88mod template {
92 use regex::Captures;
93
94 pub fn evaluate(template: &str, captures: &Captures) -> String {
96 let mut result = String::new();
97 let mut chars = template.chars().peekable();
98
99 while let Some(c) = chars.next() {
100 if c == '{' {
101 let mut expr = String::new();
103 let mut brace_depth = 1;
104
105 while let Some(c) = chars.next() {
106 if c == '{' {
107 brace_depth += 1;
108 expr.push(c);
109 } else if c == '}' {
110 brace_depth -= 1;
111 if brace_depth == 0 {
112 break;
113 }
114 expr.push(c);
115 } else {
116 expr.push(c);
117 }
118 }
119
120 let value = evaluate_expr(&expr, captures);
122 result.push_str(&value);
123 } else {
124 result.push(c);
125 }
126 }
127
128 result
129 }
130
131 fn evaluate_expr(expr: &str, captures: &Captures) -> String {
133 let expr = expr.trim();
134
135 if let Ok(group_num) = expr.parse::<usize>() {
137 return captures
138 .get(group_num)
139 .map(|m| m.as_str().to_string())
140 .unwrap_or_default();
141 }
142
143 parse_and_evaluate(expr, captures)
145 }
146
147 fn parse_and_evaluate(expr: &str, captures: &Captures) -> String {
149 let mut chars = expr.chars().peekable();
150 let mut value = String::new();
151
152 while chars.peek() == Some(&' ') {
154 chars.next();
155 }
156
157 if expr.starts_with("g(") {
159 chars.next(); chars.next(); let mut num_str = String::new();
164 while let Some(&c) = chars.peek() {
165 if c == ')' {
166 chars.next();
167 break;
168 }
169 num_str.push(c);
170 chars.next();
171 }
172
173 if let Ok(group_num) = num_str.trim().parse::<usize>() {
174 value = captures
175 .get(group_num)
176 .map(|m| m.as_str().to_string())
177 .unwrap_or_default();
178 }
179 } else {
180 return expr.to_string();
182 }
183
184 while chars.peek().is_some() {
186 while chars.peek() == Some(&' ') {
188 chars.next();
189 }
190
191 if chars.peek() != Some(&'.') {
193 break;
194 }
195 chars.next(); let mut method_name = String::new();
199 while let Some(&c) = chars.peek() {
200 if c == '(' || c == ' ' {
201 break;
202 }
203 method_name.push(c);
204 chars.next();
205 }
206
207 while chars.peek() == Some(&' ') {
209 chars.next();
210 }
211
212 let args = if chars.peek() == Some(&'(') {
214 chars.next(); parse_args(&mut chars)
216 } else {
217 vec![]
218 };
219
220 value = apply_method(&value, &method_name, &args);
222 }
223
224 value
225 }
226
227 fn parse_args(chars: &mut std::iter::Peekable<std::str::Chars>) -> Vec<String> {
229 let mut args = Vec::new();
230 let mut current_arg = String::new();
231 let mut in_string = false;
232 let mut string_char = '"';
233
234 while let Some(c) = chars.next() {
235 if c == ')' && !in_string {
236 let arg = current_arg.trim().to_string();
238 if !arg.is_empty() {
239 args.push(parse_string_literal(&arg));
240 }
241 break;
242 } else if (c == '"' || c == '\'') && !in_string {
243 in_string = true;
244 string_char = c;
245 current_arg.push(c);
246 } else if c == string_char && in_string {
247 in_string = false;
248 current_arg.push(c);
249 } else if c == ',' && !in_string {
250 let arg = current_arg.trim().to_string();
251 if !arg.is_empty() {
252 args.push(parse_string_literal(&arg));
253 }
254 current_arg.clear();
255 } else {
256 current_arg.push(c);
257 }
258 }
259
260 args
261 }
262
263 fn parse_string_literal(s: &str) -> String {
265 let s = s.trim();
266 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
267 s[1..s.len() - 1].to_string()
268 } else {
269 s.to_string()
270 }
271 }
272
273 fn apply_method(value: &str, method: &str, args: &[String]) -> String {
275 match method {
276 "replace" => {
277 if args.len() >= 2 {
278 value.replace(&args[0], &args[1])
279 } else if args.len() == 1 {
280 value.replace(&args[0], "")
281 } else {
282 value.to_string()
283 }
284 }
285 "lower" | "lowercase" => value.to_lowercase(),
286 "upper" | "uppercase" => value.to_uppercase(),
287 "trim" => value.trim().to_string(),
288 "trim_start" | "ltrim" => value.trim_start().to_string(),
289 "trim_end" | "rtrim" => value.trim_end().to_string(),
290 _ => value.to_string(),
291 }
292 }
293
294 #[cfg(test)]
295 mod tests {
296 use super::*;
297 use regex::Regex;
298
299 fn make_captures<'a>(pattern: &str, text: &'a str) -> Option<Captures<'a>> {
300 Regex::new(pattern).ok()?.captures(text)
301 }
302
303 #[test]
304 fn test_simple_substitution() {
305 let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
306 assert_eq!(evaluate("value: {1}", &caps), "value: 123");
307 }
308
309 #[test]
310 fn test_g_function() {
311 let caps = make_captures(r"(\d+)", "hello 123 world").unwrap();
312 assert_eq!(evaluate("{g(1)}", &caps), "123");
313 assert_eq!(evaluate("{g(0)}", &caps), "123");
314 }
315
316 #[test]
317 fn test_replace_function() {
318 let caps = make_captures(r"([\d\-]+)", "isbn:978-3-16-148410-0").unwrap();
319 assert_eq!(evaluate("{g(1).replace('-', '')}", &caps), "9783161484100");
320 }
321
322 #[test]
323 fn test_lower_function() {
324 let caps = make_captures(r"(\w+)", "HELLO").unwrap();
325 assert_eq!(evaluate("{g(1).lower()}", &caps), "hello");
326 }
327
328 #[test]
329 fn test_upper_function() {
330 let caps = make_captures(r"(\w+)", "hello").unwrap();
331 assert_eq!(evaluate("{g(1).upper()}", &caps), "HELLO");
332 }
333
334 #[test]
335 fn test_trim_function() {
336 let caps = make_captures(r"(.+)", " hello ").unwrap();
337 assert_eq!(evaluate("{g(1).trim()}", &caps), "hello");
338 }
339
340 #[test]
341 fn test_chained_functions() {
342 let caps = make_captures(r"([\d\-]+)", "978-3-16").unwrap();
343 assert_eq!(
344 evaluate("{g(1).replace('-', '').lower()}", &caps),
345 "978316"
346 );
347 }
348
349 #[test]
350 fn test_mixed_template() {
351 let caps = make_captures(r"isbn:([\d\-]+)", "isbn:978-3-16").unwrap();
352 assert_eq!(
353 evaluate("isbn://{g(1).replace('-', '')}", &caps),
354 "isbn://978316"
355 );
356 }
357
358 #[test]
359 fn test_multiple_expressions() {
360 let caps = make_captures(r"(\w+):(\w+)", "key:VALUE").unwrap();
361 assert_eq!(
362 evaluate("{g(1).upper()}={g(2).lower()}", &caps),
363 "KEY=value"
364 );
365 }
366 }
367}
368
369#[derive(Debug, Clone)]
371pub struct CompiledRouterRule {
372 regex: Regex,
373 substitution: String,
374 target_field: String,
375 mode: RoutingMode,
376}
377
378impl CompiledRouterRule {
379 pub fn new(rule: &QueryRouterRule) -> Result<Self, String> {
381 let regex = Regex::new(&rule.pattern)
382 .map_err(|e| format!("Invalid regex pattern '{}': {}", rule.pattern, e))?;
383
384 Ok(Self {
385 regex,
386 substitution: rule.substitution.clone(),
387 target_field: rule.target_field.clone(),
388 mode: rule.mode,
389 })
390 }
391
392 pub fn try_match(&self, query: &str) -> Option<RoutedQuery> {
394 let captures = self.regex.captures(query)?;
395
396 let result = template::evaluate(&self.substitution, &captures);
398
399 Some(RoutedQuery {
400 query: result,
401 target_field: self.target_field.clone(),
402 mode: self.mode,
403 })
404 }
405
406 pub fn target_field(&self) -> &str {
408 &self.target_field
409 }
410
411 pub fn mode(&self) -> RoutingMode {
413 self.mode
414 }
415}
416
417#[derive(Debug, Clone, Default)]
419pub struct QueryFieldRouter {
420 rules: Vec<CompiledRouterRule>,
421}
422
423impl QueryFieldRouter {
424 pub fn new() -> Self {
426 Self { rules: Vec::new() }
427 }
428
429 pub fn from_rules(rules: &[QueryRouterRule]) -> Result<Self, String> {
431 let compiled: Result<Vec<_>, _> = rules.iter().map(CompiledRouterRule::new).collect();
432 Ok(Self { rules: compiled? })
433 }
434
435 pub fn add_rule(&mut self, rule: &QueryRouterRule) -> Result<(), String> {
437 self.rules.push(CompiledRouterRule::new(rule)?);
438 Ok(())
439 }
440
441 pub fn is_empty(&self) -> bool {
443 self.rules.is_empty()
444 }
445
446 pub fn len(&self) -> usize {
448 self.rules.len()
449 }
450
451 pub fn route(&self, query: &str) -> Option<RoutedQuery> {
453 for rule in &self.rules {
454 if let Some(routed) = rule.try_match(query) {
455 return Some(routed);
456 }
457 }
458 None
459 }
460
461 pub fn route_all(&self, query: &str) -> Vec<RoutedQuery> {
463 self.rules
464 .iter()
465 .filter_map(|rule| rule.try_match(query))
466 .collect()
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_doi_routing() {
476 let rule = QueryRouterRule {
477 pattern: r"(10\.\d{4,}/[^\s]+)".to_string(),
478 substitution: "doi://{1}".to_string(),
479 target_field: "uri".to_string(),
480 mode: RoutingMode::Exclusive,
481 };
482
483 let compiled = CompiledRouterRule::new(&rule).unwrap();
484
485 let result = compiled.try_match("10.1234/abc.123").unwrap();
487 assert_eq!(result.query, "doi://10.1234/abc.123");
488 assert_eq!(result.target_field, "uri");
489 assert_eq!(result.mode, RoutingMode::Exclusive);
490
491 assert!(compiled.try_match("hello world").is_none());
493 }
494
495 #[test]
496 fn test_full_match_substitution() {
497 let rule = QueryRouterRule {
498 pattern: r"^#(\d+)$".to_string(),
499 substitution: "{1}".to_string(),
500 target_field: "issue_number".to_string(),
501 mode: RoutingMode::Exclusive,
502 };
503
504 let compiled = CompiledRouterRule::new(&rule).unwrap();
505
506 let result = compiled.try_match("#42").unwrap();
507 assert_eq!(result.query, "42");
508 assert_eq!(result.target_field, "issue_number");
509 }
510
511 #[test]
512 fn test_multiple_capture_groups() {
513 let rule = QueryRouterRule {
514 pattern: r"(\w+):(\w+)".to_string(),
515 substitution: "field={1} value={2}".to_string(),
516 target_field: "custom".to_string(),
517 mode: RoutingMode::Additional,
518 };
519
520 let compiled = CompiledRouterRule::new(&rule).unwrap();
521
522 let result = compiled.try_match("author:smith").unwrap();
523 assert_eq!(result.query, "field=author value=smith");
524 assert_eq!(result.mode, RoutingMode::Additional);
525 }
526
527 #[test]
528 fn test_router_with_multiple_rules() {
529 let rules = vec![
530 QueryRouterRule {
531 pattern: r"^doi:(10\.\d{4,}/[^\s]+)$".to_string(),
532 substitution: "doi://{1}".to_string(),
533 target_field: "uri".to_string(),
534 mode: RoutingMode::Exclusive,
535 },
536 QueryRouterRule {
537 pattern: r"^pmid:(\d+)$".to_string(),
538 substitution: "pubmed://{1}".to_string(),
539 target_field: "uri".to_string(),
540 mode: RoutingMode::Exclusive,
541 },
542 ];
543
544 let router = QueryFieldRouter::from_rules(&rules).unwrap();
545
546 let result = router.route("doi:10.1234/test").unwrap();
548 assert_eq!(result.query, "doi://10.1234/test");
549
550 let result = router.route("pmid:12345678").unwrap();
552 assert_eq!(result.query, "pubmed://12345678");
553
554 assert!(router.route("random query").is_none());
556 }
557
558 #[test]
559 fn test_invalid_regex() {
560 let rule = QueryRouterRule {
561 pattern: r"[invalid".to_string(),
562 substitution: "{0}".to_string(),
563 target_field: "test".to_string(),
564 mode: RoutingMode::Exclusive,
565 };
566
567 assert!(CompiledRouterRule::new(&rule).is_err());
568 }
569
570 #[test]
571 fn test_routing_mode_default() {
572 let mode: RoutingMode = Default::default();
573 assert_eq!(mode, RoutingMode::Additional);
574 }
575}