use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use super::CompressionError;
const MAX_COMPILE_TASKS: usize = 4;
static ACTIVE_COMPILE_TASKS: AtomicUsize = AtomicUsize::new(0);
pub async fn safe_compile(pat: &str, timeout_ms: u64) -> Result<regex::Regex, CompressionError> {
let prev = ACTIVE_COMPILE_TASKS.fetch_add(1, Ordering::Relaxed);
if prev >= MAX_COMPILE_TASKS {
ACTIVE_COMPILE_TASKS.fetch_sub(1, Ordering::Relaxed);
return Err(CompressionError::CompileTimeout);
}
let pat = pat.to_owned();
let join = tokio::task::spawn_blocking(move || {
let result = regex::RegexBuilder::new(&pat)
.size_limit(64 * 1024)
.dfa_size_limit(1024 * 1024)
.build();
ACTIVE_COMPILE_TASKS.fetch_sub(1, Ordering::Relaxed);
result
});
match tokio::time::timeout(Duration::from_millis(timeout_ms), join).await {
Err(_elapsed) => Err(CompressionError::CompileTimeout),
Ok(Err(_join_err)) => Err(CompressionError::BadPattern("compile task panicked".into())),
Ok(Ok(Err(regex_err))) => Err(CompressionError::BadPattern(regex_err.to_string())),
Ok(Ok(Ok(re))) => Ok(re),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn compiles_simple_pattern() {
let re = safe_compile(r"\d+", 500).await.unwrap();
assert!(re.is_match("123"));
}
#[tokio::test]
async fn rejects_invalid_pattern() {
let err = safe_compile(r"[invalid", 500).await.unwrap_err();
assert!(matches!(err, CompressionError::BadPattern(_)));
}
}