use super::types::{RoutingConfig, RoutingRule};
#[derive(Debug, Clone)]
pub struct Router {
config: RoutingConfig,
}
impl Router {
pub fn new(config: RoutingConfig) -> Self {
Self { config }
}
pub fn resolve(&self, task_name: &str) -> &str {
for rule in &self.config.rules {
if Self::matches_pattern(&rule.task_pattern, task_name) {
return &rule.executor;
}
}
&self.config.default_executor
}
fn matches_pattern(pattern: &str, task_name: &str) -> bool {
if pattern == task_name {
return true;
}
if pattern == "**" {
return true;
}
let pattern_parts: Vec<&str> = pattern.split("::").collect();
let name_parts: Vec<&str> = task_name.split("::").collect();
Self::match_segments(&pattern_parts, &name_parts)
}
fn match_segments(pattern_parts: &[&str], name_parts: &[&str]) -> bool {
match (pattern_parts.first(), name_parts.first()) {
(None, None) => true,
(None, Some(_)) => false,
(Some(&"**"), None) => pattern_parts.len() == 1,
(Some(_), None) => false,
(Some(&"**"), Some(_)) => {
if Self::match_segments(&pattern_parts[1..], name_parts) {
return true;
}
Self::match_segments(pattern_parts, &name_parts[1..])
}
(Some(pattern_seg), Some(name_seg)) => {
if Self::match_glob(pattern_seg, name_seg) {
Self::match_segments(&pattern_parts[1..], &name_parts[1..])
} else {
false
}
}
}
}
fn match_glob(pattern: &str, text: &str) -> bool {
if pattern == text {
return true;
}
if pattern == "*" {
return true;
}
if pattern.contains('*') {
return Self::match_wildcard(pattern, text);
}
false
}
fn match_wildcard(pattern: &str, text: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 1 {
return pattern == text;
}
let mut text_pos = 0;
let text_bytes = text.as_bytes();
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
let part_bytes = part.as_bytes();
if i == 0 {
if !text.starts_with(part) {
return false;
}
text_pos = part.len();
} else if i == parts.len() - 1 {
if !text.ends_with(part) {
return false;
}
} else {
if let Some(pos) = Self::find_substring(&text_bytes[text_pos..], part_bytes) {
text_pos += pos + part.len();
} else {
return false;
}
}
}
true
}
fn find_substring(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
pub fn config(&self) -> &RoutingConfig {
&self.config
}
pub fn add_rule(&mut self, rule: RoutingRule) {
self.config.rules.push(rule);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let config = RoutingConfig::new("default").with_rule(RoutingRule::new("ml::train", "gpu"));
let router = Router::new(config);
assert_eq!(router.resolve("ml::train"), "gpu");
assert_eq!(router.resolve("ml::predict"), "default");
}
#[test]
fn test_wildcard_match() {
let config = RoutingConfig::new("default").with_rule(RoutingRule::new("ml::*", "gpu"));
let router = Router::new(config);
assert_eq!(router.resolve("ml::train"), "gpu");
assert_eq!(router.resolve("ml::predict"), "gpu");
assert_eq!(router.resolve("etl::extract"), "default");
}
#[test]
fn test_double_wildcard() {
let config = RoutingConfig::new("default").with_rule(RoutingRule::new("**", "catch_all"));
let router = Router::new(config);
assert_eq!(router.resolve("anything"), "catch_all");
assert_eq!(router.resolve("ml::deep::nested"), "catch_all");
}
#[test]
fn test_prefix_wildcard() {
let config =
RoutingConfig::new("default").with_rule(RoutingRule::new("heavy_*", "high_memory"));
let router = Router::new(config);
assert_eq!(router.resolve("heavy_compute"), "high_memory");
assert_eq!(router.resolve("light_compute"), "default");
}
#[test]
fn test_suffix_wildcard() {
let config =
RoutingConfig::new("default").with_rule(RoutingRule::new("*_gpu", "gpu_executor"));
let router = Router::new(config);
assert_eq!(router.resolve("train_gpu"), "gpu_executor");
assert_eq!(router.resolve("train_cpu"), "default");
}
#[test]
fn test_rule_order_priority() {
let config = RoutingConfig::new("default")
.with_rule(RoutingRule::new("ml::train", "specific"))
.with_rule(RoutingRule::new("ml::*", "general"));
let router = Router::new(config);
assert_eq!(router.resolve("ml::train"), "specific");
assert_eq!(router.resolve("ml::predict"), "general");
}
#[test]
fn test_namespace_wildcard() {
let config = RoutingConfig::new("default")
.with_rule(RoutingRule::new("**::heavy_*", "high_compute"));
let router = Router::new(config);
assert_eq!(router.resolve("ml::heavy_train"), "high_compute");
assert_eq!(router.resolve("etl::data::heavy_load"), "high_compute");
assert_eq!(router.resolve("ml::light_train"), "default");
}
}