use regex::Regex;
use std::collections::HashMap;
use crate::McpResult;
use turul_mcp_protocol::McpError;
#[derive(Debug, Clone)]
pub struct UriTemplate {
pattern: String,
regex: Regex,
variables: Vec<String>,
validators: HashMap<String, VariableValidator>,
mime_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct VariableValidator {
pattern: Regex,
description: String,
max_length: usize,
}
impl VariableValidator {
pub fn user_id() -> Self {
Self {
pattern: Regex::new(r"^[A-Za-z0-9_-]{1,128}$").unwrap(),
description: "alphanumeric characters, underscore, and hyphen (1-128 chars)"
.to_string(),
max_length: 128,
}
}
pub fn image_format() -> Self {
Self {
pattern: Regex::new(r"^(png|jpg|jpeg|webp|svg)$").unwrap(),
description: "valid image format: png, jpg, jpeg, webp, svg".to_string(),
max_length: 8,
}
}
pub fn document_format() -> Self {
Self {
pattern: Regex::new(r"^(pdf|txt|md|json|xml|html)$").unwrap(),
description: "valid document format: pdf, txt, md, json, xml, html".to_string(),
max_length: 8,
}
}
pub fn custom(pattern: &str, description: String, max_length: usize) -> McpResult<Self> {
let regex = Regex::new(pattern)
.map_err(|e| McpError::tool_execution(&format!("Invalid regex pattern: {}", e)))?;
Ok(Self {
pattern: regex,
description,
max_length,
})
}
pub fn validate(&self, value: &str) -> Result<(), String> {
if value.len() > self.max_length {
return Err(format!(
"Value too long: {} characters (max {})",
value.len(),
self.max_length
));
}
if !self.pattern.is_match(value) {
return Err(format!("Invalid format. Expected: {}", self.description));
}
Ok(())
}
}
impl UriTemplate {
pub fn new(pattern: &str) -> McpResult<Self> {
let mut template = Self {
pattern: pattern.to_string(),
regex: Regex::new("").unwrap(), variables: Vec::new(),
validators: HashMap::new(),
mime_type: Self::detect_mime_type(pattern),
};
template.compile()?;
Ok(template)
}
pub fn with_mime_type(pattern: &str, mime_type: &str) -> McpResult<Self> {
let mut template = Self::new(pattern)?;
template.mime_type = Some(mime_type.to_string());
Ok(template)
}
pub fn with_validator(mut self, variable: &str, validator: VariableValidator) -> Self {
self.validators.insert(variable.to_string(), validator);
self
}
fn compile(&mut self) -> McpResult<()> {
let var_regex = Regex::new(r"\{([^}]+)\}").unwrap();
let mut regex_pattern = regex::escape(&self.pattern);
for captures in var_regex.captures_iter(&self.pattern) {
let var_name = captures.get(1).unwrap().as_str();
self.variables.push(var_name.to_string());
let escaped_var = regex::escape(&format!("{{{}}}", var_name));
regex_pattern = regex_pattern.replace(&escaped_var, "([^/]+)");
}
regex_pattern = format!("^{}$", regex_pattern);
self.regex = Regex::new(®ex_pattern)
.map_err(|e| McpError::tool_execution(&format!("Failed to compile template: {}", e)))?;
Ok(())
}
fn detect_mime_type(pattern: &str) -> Option<String> {
if let Some(ext_start) = pattern.rfind('.') {
let ext = &pattern[ext_start + 1..];
let ext = ext.split('}').next().unwrap_or(ext);
match ext {
"json" => Some("application/json".to_string()),
"txt" => Some("text/plain".to_string()),
"md" => Some("text/markdown".to_string()),
"html" => Some("text/html".to_string()),
"xml" => Some("application/xml".to_string()),
"pdf" => Some("application/pdf".to_string()),
"png" => Some("image/png".to_string()),
"jpg" | "jpeg" => Some("image/jpeg".to_string()),
"webp" => Some("image/webp".to_string()),
"svg" => Some("image/svg+xml".to_string()),
_ => None,
}
} else {
None
}
}
pub fn resolve(&self, variables: &HashMap<String, String>) -> McpResult<String> {
let mut result = self.pattern.clone();
for var_name in &self.variables {
let value = variables
.get(var_name)
.ok_or_else(|| McpError::missing_param(var_name))?;
if let Some(validator) = self.validators.get(var_name) {
validator.validate(value).map_err(|e| {
McpError::invalid_param_type(var_name, &validator.description, &e)
})?;
}
result = result.replace(&format!("{{{}}}", var_name), value);
}
Ok(result)
}
pub fn extract(&self, uri: &str) -> McpResult<HashMap<String, String>> {
let captures = self
.regex
.captures(uri)
.ok_or_else(|| McpError::invalid_param_type("uri", "URI matching template", uri))?;
let mut variables = HashMap::new();
for (i, var_name) in self.variables.iter().enumerate() {
if let Some(value) = captures.get(i + 1) {
let value = urlencoding::decode(value.as_str())
.map_err(|e| {
McpError::invalid_param_type(
var_name,
"valid UTF-8 after percent-decoding",
&e.to_string(),
)
})?
.into_owned();
if let Some(validator) = self.validators.get(var_name) {
validator.validate(&value).map_err(|e| {
McpError::invalid_param_type(var_name, &validator.description, &e)
})?;
}
variables.insert(var_name.clone(), value);
}
}
Ok(variables)
}
pub fn matches(&self, uri: &str) -> bool {
self.regex.is_match(uri)
}
pub fn mime_type(&self) -> Option<&str> {
self.mime_type.as_deref()
}
pub fn pattern(&self) -> &str {
&self.pattern
}
pub fn variables(&self) -> &[String] {
&self.variables
}
}
#[derive(Debug, Default)]
pub struct UriTemplateRegistry {
templates: Vec<UriTemplate>,
}
impl UriTemplateRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, template: UriTemplate) {
self.templates.push(template);
}
pub fn find_matching(&self, uri: &str) -> Option<&UriTemplate> {
self.templates.iter().find(|t| t.matches(uri))
}
pub fn templates(&self) -> &[UriTemplate] {
&self.templates
}
pub fn resolve_pattern(
&self,
pattern: &str,
variables: &HashMap<String, String>,
) -> McpResult<String> {
let template = self
.templates
.iter()
.find(|t| t.pattern() == pattern)
.ok_or_else(|| {
McpError::invalid_param_type("pattern", "registered template pattern", pattern)
})?;
template.resolve(variables)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_user_id_validator() {
let validator = VariableValidator::user_id();
assert!(validator.validate("user123").is_ok());
assert!(validator.validate("user_id").is_ok());
assert!(validator.validate("user-name").is_ok());
assert!(validator.validate("ABC123").is_ok());
assert!(validator.validate("user@example.com").is_err()); assert!(validator.validate("user with spaces").is_err()); assert!(validator.validate("").is_err()); assert!(validator.validate(&"a".repeat(129)).is_err()); }
#[test]
fn test_image_format_validator() {
let validator = VariableValidator::image_format();
assert!(validator.validate("png").is_ok());
assert!(validator.validate("jpg").is_ok());
assert!(validator.validate("jpeg").is_ok());
assert!(validator.validate("webp").is_ok());
assert!(validator.validate("svg").is_ok());
assert!(validator.validate("gif").is_err()); assert!(validator.validate("PNG").is_err()); assert!(validator.validate("pdf").is_err()); }
#[test]
fn test_uri_template_creation() {
let template = UriTemplate::new("file:///user/{user_id}.json").unwrap();
assert_eq!(template.pattern(), "file:///user/{user_id}.json");
assert_eq!(template.variables(), &["user_id"]);
assert_eq!(template.mime_type(), Some("application/json"));
}
#[test]
fn test_uri_template_resolution() {
let template = UriTemplate::new("file:///user/{user_id}.json")
.unwrap()
.with_validator("user_id", VariableValidator::user_id());
let mut vars = HashMap::new();
vars.insert("user_id".to_string(), "alice123".to_string());
let resolved = template.resolve(&vars).unwrap();
assert_eq!(resolved, "file:///user/alice123.json");
}
#[test]
fn test_uri_template_extraction() {
let template = UriTemplate::new("file:///user/{user_id}.json")
.unwrap()
.with_validator("user_id", VariableValidator::user_id());
let vars = template.extract("file:///user/alice123.json").unwrap();
assert_eq!(vars.get("user_id"), Some(&"alice123".to_string()));
}
#[test]
fn test_uri_template_validation_failure() {
let template = UriTemplate::new("file:///user/{user_id}.json")
.unwrap()
.with_validator("user_id", VariableValidator::user_id());
let result = template.extract("file:///user/invalid@user.json");
assert!(result.is_err());
}
#[test]
fn test_multiple_variables() {
let template = UriTemplate::new("file:///user/{user_id}/avatar.{format}")
.unwrap()
.with_validator("user_id", VariableValidator::user_id())
.with_validator("format", VariableValidator::image_format());
let vars = template
.extract("file:///user/alice123/avatar.png")
.unwrap();
assert_eq!(vars.get("user_id"), Some(&"alice123".to_string()));
assert_eq!(vars.get("format"), Some(&"png".to_string()));
}
#[test]
fn test_registry() {
let mut registry = UriTemplateRegistry::new();
let template1 = UriTemplate::new("file:///user/{user_id}.json").unwrap();
let template2 = UriTemplate::new("file:///user/{user_id}/avatar.{format}").unwrap();
registry.register(template1);
registry.register(template2);
let found = registry.find_matching("file:///user/alice123.json");
assert!(found.is_some());
assert_eq!(found.unwrap().pattern(), "file:///user/{user_id}.json");
}
#[test]
fn test_mime_type_detection() {
assert_eq!(
UriTemplate::detect_mime_type("file.json"),
Some("application/json".to_string())
);
assert_eq!(
UriTemplate::detect_mime_type("file.pdf"),
Some("application/pdf".to_string())
);
assert_eq!(
UriTemplate::detect_mime_type("file.png"),
Some("image/png".to_string())
);
assert_eq!(
UriTemplate::detect_mime_type("file.txt"),
Some("text/plain".to_string())
);
assert_eq!(UriTemplate::detect_mime_type("file.unknown"), None);
assert_eq!(UriTemplate::detect_mime_type("file"), None);
}
#[test]
fn test_extract_percent_encoded_hash() {
let template = UriTemplate::new("custom://items/{item_id}").unwrap();
let vars = template
.extract("custom://items/PREFIX%23some-value")
.unwrap();
assert_eq!(
vars.get("item_id"),
Some(&"PREFIX#some-value".to_string()),
"Percent-encoded '#' should be decoded to '#'"
);
}
#[test]
fn test_extract_unencoded_values_unchanged() {
let template = UriTemplate::new("file:///user/{user_id}.json").unwrap();
let vars = template.extract("file:///user/alice123.json").unwrap();
assert_eq!(
vars.get("user_id"),
Some(&"alice123".to_string()),
"Plain values should pass through unchanged"
);
}
#[test]
fn test_extract_percent_encoded_space() {
let template = UriTemplate::new("file:///docs/{name}").unwrap();
let vars = template.extract("file:///docs/my%20document").unwrap();
assert_eq!(
vars.get("name"),
Some(&"my document".to_string()),
"Percent-encoded space should be decoded"
);
}
#[test]
fn test_extract_percent_encoded_special_chars() {
let template = UriTemplate::new("data://records/{record_id}").unwrap();
let vars = template
.extract("data://records/user%40host%26extra")
.unwrap();
assert_eq!(
vars.get("record_id"),
Some(&"user@host&extra".to_string()),
"Multiple percent-encoded chars should all be decoded"
);
}
}