use std::ops::ControlFlow;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use tree_sitter::{Language, ParseOptions, ParseState, Parser, Tree};
use super::error::ParseError;
pub const DEFAULT_MAX_SIZE: usize = 10 * 1024 * 1024;
pub const MIN_MAX_SIZE: usize = 1024 * 1024;
pub const MAX_MAX_SIZE: usize = 32 * 1024 * 1024;
pub const DEFAULT_TIMEOUT_MICROS: u64 = 2_000_000;
pub const MIN_TIMEOUT_MICROS: u64 = 100_000;
pub const MAX_TIMEOUT_MICROS: u64 = 5_000_000;
#[derive(Debug, Clone)]
pub struct SafeParserConfig {
max_input_size: usize,
timeout_micros: u64,
}
impl Default for SafeParserConfig {
fn default() -> Self {
Self {
max_input_size: DEFAULT_MAX_SIZE,
timeout_micros: DEFAULT_TIMEOUT_MICROS,
}
}
}
impl SafeParserConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_input_size(mut self, size: usize) -> Self {
self.max_input_size = size.clamp(MIN_MAX_SIZE, MAX_MAX_SIZE);
self
}
#[must_use]
pub fn with_timeout_micros(mut self, timeout: u64) -> Self {
self.timeout_micros = timeout.clamp(MIN_TIMEOUT_MICROS, MAX_TIMEOUT_MICROS);
self
}
#[must_use]
pub fn max_input_size(&self) -> usize {
self.max_input_size
}
#[must_use]
pub fn timeout_micros(&self) -> u64 {
self.timeout_micros
}
}
#[derive(Debug, Clone, Default)]
pub struct CancellationFlag {
cancelled: Arc<AtomicBool>,
}
impl CancellationFlag {
#[must_use]
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
}
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Relaxed)
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
pub fn reset(&self) {
self.cancelled.store(false, Ordering::Relaxed);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TerminationReason {
None,
Cancelled,
TimedOut,
}
fn finalize_parse_result(
termination_reason: TerminationReason,
tree: Option<Tree>,
file: Option<&Path>,
timeout_micros: u64,
) -> Result<Tree, ParseError> {
match termination_reason {
TerminationReason::Cancelled => {
log::warn!(
"Parse cancelled{}",
file.map(|f| format!(" (file: {})", f.display()))
.unwrap_or_default()
);
return Err(ParseError::ParseCancelled {
reason: "cancelled during parsing".to_string(),
file: file.map(Path::to_path_buf),
});
}
TerminationReason::TimedOut => {
log::warn!(
"Parse timed out after {} ms{}",
timeout_micros / 1000,
file.map(|f| format!(" (file: {})", f.display()))
.unwrap_or_default()
);
return Err(ParseError::ParseTimedOut {
timeout_micros,
file: file.map(Path::to_path_buf),
});
}
TerminationReason::None => {
}
}
if let Some(t) = tree {
Ok(t)
} else {
log::warn!(
"Parse failed{}",
file.map(|f| format!(" (file: {})", f.display()))
.unwrap_or_default()
);
Err(ParseError::TreeSitterFailed)
}
}
#[derive(Debug, Clone)]
pub struct SafeParser {
config: SafeParserConfig,
cancellation_flag: Option<CancellationFlag>,
}
impl Default for SafeParser {
fn default() -> Self {
Self::new(SafeParserConfig::default())
}
}
impl SafeParser {
#[must_use]
pub fn new(config: SafeParserConfig) -> Self {
Self {
config,
cancellation_flag: None,
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::default()
}
#[must_use]
pub fn with_cancellation_flag(mut self, flag: CancellationFlag) -> Self {
self.cancellation_flag = Some(flag);
self
}
#[must_use]
pub fn config(&self) -> &SafeParserConfig {
&self.config
}
pub fn parse(
&self,
language: &Language,
content: &[u8],
file: Option<&Path>,
) -> Result<Tree, ParseError> {
if let Some(ref flag) = self.cancellation_flag
&& flag.is_cancelled()
{
return Err(ParseError::ParseCancelled {
reason: "cancelled before parse started".to_string(),
file: file.map(Path::to_path_buf),
});
}
if content.len() > self.config.max_input_size {
log::warn!(
"Input too large: {} bytes exceeds {} limit{}",
content.len(),
self.config.max_input_size,
file.map(|f| format!(" (file: {})", f.display()))
.unwrap_or_default()
);
return Err(ParseError::InputTooLarge {
size: content.len(),
max: self.config.max_input_size,
file: file.map(Path::to_path_buf),
});
}
let mut parser = Parser::new();
parser
.set_language(language)
.map_err(|e| ParseError::LanguageSetFailed(e.to_string()))?;
let start_time = Instant::now();
let timeout_micros = self.config.timeout_micros;
let cancellation_flag = self.cancellation_flag.clone();
let mut progress_fn = move |_: &ParseState| -> ControlFlow<()> {
if let Some(ref flag) = cancellation_flag
&& flag.is_cancelled()
{
return ControlFlow::Break(());
}
#[allow(clippy::cast_possible_truncation)]
if start_time.elapsed().as_micros() as u64 > timeout_micros {
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
};
let options = ParseOptions::new().progress_callback(&mut progress_fn);
let tree = parser.parse_with_options(
&mut |i, _| content.get(i..).unwrap_or_default(),
None,
Some(options),
);
#[allow(clippy::cast_possible_truncation)] let elapsed_micros = start_time.elapsed().as_micros() as u64;
let termination_reason = if let Some(ref flag) = self.cancellation_flag
&& flag.is_cancelled()
{
TerminationReason::Cancelled
} else if elapsed_micros > self.config.timeout_micros {
TerminationReason::TimedOut
} else if tree.is_none() && elapsed_micros >= self.config.timeout_micros {
TerminationReason::TimedOut
} else {
TerminationReason::None
};
finalize_parse_result(termination_reason, tree, file, self.config.timeout_micros)
}
pub fn parse_file(
&self,
language: &Language,
content: &[u8],
file: &Path,
) -> Result<Tree, ParseError> {
self.parse(language, content, Some(file))
}
#[allow(clippy::cast_precision_loss)] pub fn log_config(&self) {
log::info!(
"SafeParser configured: max_size={} bytes ({:.1} MiB), timeout={} ms",
self.config.max_input_size,
self.config.max_input_size as f64 / (1024.0 * 1024.0),
self.config.timeout_micros / 1000
);
}
}
pub fn parse_safe(
language: &Language,
content: &[u8],
file: Option<&Path>,
) -> Result<Tree, ParseError> {
SafeParser::default().parse(language, content, file)
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_config_defaults() {
let config = SafeParserConfig::default();
assert_eq!(config.max_input_size(), DEFAULT_MAX_SIZE);
assert_eq!(config.timeout_micros(), DEFAULT_TIMEOUT_MICROS);
}
#[test]
fn test_config_builder() {
let config = SafeParserConfig::new()
.with_max_input_size(20 * 1024 * 1024)
.with_timeout_micros(3_000_000);
assert_eq!(config.max_input_size(), 20 * 1024 * 1024);
assert_eq!(config.timeout_micros(), 3_000_000);
}
#[test]
fn test_config_clamping_min() {
let config = SafeParserConfig::new()
.with_max_input_size(100) .with_timeout_micros(1000);
assert_eq!(config.max_input_size(), MIN_MAX_SIZE);
assert_eq!(config.timeout_micros(), MIN_TIMEOUT_MICROS);
}
#[test]
fn test_config_clamping_max() {
let config = SafeParserConfig::new()
.with_max_input_size(100 * 1024 * 1024) .with_timeout_micros(10_000_000);
assert_eq!(config.max_input_size(), MAX_MAX_SIZE);
assert_eq!(config.timeout_micros(), MAX_TIMEOUT_MICROS);
}
#[test]
fn test_cancellation_flag() {
let flag = CancellationFlag::new();
assert!(!flag.is_cancelled());
flag.cancel();
assert!(flag.is_cancelled());
flag.reset();
assert!(!flag.is_cancelled());
}
#[test]
fn test_cancellation_flag_clone() {
let flag1 = CancellationFlag::new();
let flag2 = flag1.clone();
flag1.cancel();
assert!(flag2.is_cancelled()); }
#[test]
fn test_safe_parser_creation() {
let parser = SafeParser::with_defaults();
assert_eq!(parser.config().max_input_size(), DEFAULT_MAX_SIZE);
assert_eq!(parser.config().timeout_micros(), DEFAULT_TIMEOUT_MICROS);
}
#[test]
fn test_safe_parser_with_config() {
let config = SafeParserConfig::new().with_max_input_size(5 * 1024 * 1024);
let parser = SafeParser::new(config);
assert_eq!(parser.config().max_input_size(), 5 * 1024 * 1024);
}
#[test]
fn test_safe_parser_with_cancellation() {
let flag = CancellationFlag::new();
let parser = SafeParser::with_defaults().with_cancellation_flag(flag.clone());
assert!(parser.cancellation_flag.is_some());
}
#[test]
fn test_input_too_large_error() {
let config = SafeParserConfig::new().with_max_input_size(MIN_MAX_SIZE);
let parser = SafeParser::new(config);
let large_content = vec![b'x'; MIN_MAX_SIZE + 1];
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse(&language, &large_content, None);
match result {
Err(ParseError::InputTooLarge { size, max, file }) => {
assert_eq!(size, MIN_MAX_SIZE + 1);
assert_eq!(max, MIN_MAX_SIZE);
assert!(file.is_none());
}
_ => panic!("Expected InputTooLarge error"),
}
}
#[test]
fn test_input_too_large_with_file() {
let config = SafeParserConfig::new().with_max_input_size(MIN_MAX_SIZE);
let parser = SafeParser::new(config);
let large_content = vec![b'x'; MIN_MAX_SIZE + 1];
let file_path = PathBuf::from("/path/to/large.rs");
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse_file(&language, &large_content, &file_path);
match result {
Err(ParseError::InputTooLarge { file, .. }) => {
assert_eq!(file, Some(file_path));
}
_ => panic!("Expected InputTooLarge error with file path"),
}
}
#[test]
fn test_cancelled_before_parse() {
let flag = CancellationFlag::new();
flag.cancel();
let parser = SafeParser::with_defaults().with_cancellation_flag(flag);
let content = b"fn main() {}";
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse(&language, content, None);
match result {
Err(ParseError::ParseCancelled { reason, .. }) => {
assert!(reason.contains("before parse started"));
}
_ => panic!("Expected ParseCancelled error"),
}
}
#[test]
fn test_successful_parse() {
let parser = SafeParser::with_defaults();
let content = b"fn main() {}";
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse(&language, content, None);
assert!(result.is_ok());
let tree = result.unwrap();
assert_eq!(tree.root_node().kind(), "source_file");
}
#[test]
fn test_successful_parse_with_file() {
let parser = SafeParser::with_defaults();
let content = b"fn main() { let x = 42; }";
let file_path = PathBuf::from("test.rs");
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse_file(&language, content, &file_path);
assert!(result.is_ok());
}
#[test]
fn test_parse_safe_convenience() {
let content = b"fn foo() {}";
let language = tree_sitter_rust::LANGUAGE.into();
let result = parse_safe(&language, content, None);
assert!(result.is_ok());
}
#[test]
#[allow(clippy::assertions_on_constants)] fn test_constants_sanity() {
assert!(MIN_MAX_SIZE < DEFAULT_MAX_SIZE);
assert!(DEFAULT_MAX_SIZE < MAX_MAX_SIZE);
assert!(MIN_TIMEOUT_MICROS < DEFAULT_TIMEOUT_MICROS);
assert!(DEFAULT_TIMEOUT_MICROS < MAX_TIMEOUT_MICROS);
assert_eq!(MIN_MAX_SIZE, 1024 * 1024); assert_eq!(DEFAULT_MAX_SIZE, 10 * 1024 * 1024); assert_eq!(MAX_MAX_SIZE, 32 * 1024 * 1024); assert_eq!(MIN_TIMEOUT_MICROS, 100_000); assert_eq!(DEFAULT_TIMEOUT_MICROS, 2_000_000); assert_eq!(MAX_TIMEOUT_MICROS, 5_000_000); }
#[test]
fn test_termination_reason_enum() {
assert_ne!(TerminationReason::None, TerminationReason::Cancelled);
assert_ne!(TerminationReason::None, TerminationReason::TimedOut);
assert_ne!(TerminationReason::Cancelled, TerminationReason::TimedOut);
}
#[test]
fn test_timeout_returns_error_fail_closed() {
let config = SafeParserConfig::new().with_timeout_micros(MIN_TIMEOUT_MICROS);
let parser = SafeParser::new(config);
let content = br#"
fn complex_function() {
let x = vec![1, 2, 3, 4, 5];
for i in x.iter() {
if *i > 3 {
println!("{}", i);
}
}
}
"#;
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse(&language, content, None);
match result {
Ok(_tree) => {
}
Err(ParseError::ParseTimedOut { timeout_micros, .. }) => {
assert_eq!(timeout_micros, MIN_TIMEOUT_MICROS);
}
Err(ParseError::TreeSitterFailed) => {
}
Err(e) => {
panic!("Unexpected error type: {e:?}");
}
}
}
#[test]
fn test_cancellation_during_parse_fail_closed() {
use std::thread;
use std::time::Duration;
let flag = CancellationFlag::new();
let flag_clone = flag.clone();
let config = SafeParserConfig::new().with_timeout_micros(MIN_TIMEOUT_MICROS);
let parser = SafeParser::new(config).with_cancellation_flag(flag);
let handle = thread::spawn(move || {
thread::sleep(Duration::from_micros(10));
flag_clone.cancel();
});
let content = br"
fn foo() { let x = 1; }
fn bar() { let y = 2; }
fn baz() { let z = 3; }
";
let language = tree_sitter_rust::LANGUAGE.into();
let result = parser.parse(&language, content, None);
handle.join().unwrap();
match result {
Ok(_)
| Err(
ParseError::ParseCancelled { .. }
| ParseError::ParseTimedOut { .. }
| ParseError::TreeSitterFailed,
) => {
}
Err(e) => {
panic!("Unexpected error type: {e:?}");
}
}
}
fn create_test_tree() -> Tree {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_rust::LANGUAGE.into())
.unwrap();
parser.parse(b"fn main() {}", None).unwrap()
}
#[test]
fn test_finalize_timeout_with_tree_returns_error() {
let tree = create_test_tree();
let result =
finalize_parse_result(TerminationReason::TimedOut, Some(tree), None, 2_000_000);
match result {
Err(ParseError::ParseTimedOut {
timeout_micros,
file,
}) => {
assert_eq!(timeout_micros, 2_000_000);
assert!(file.is_none());
}
_ => panic!("Expected ParseTimedOut, got {result:?}"),
}
}
#[test]
fn test_finalize_cancelled_with_tree_returns_error() {
let tree = create_test_tree();
let result =
finalize_parse_result(TerminationReason::Cancelled, Some(tree), None, 2_000_000);
match result {
Err(ParseError::ParseCancelled { reason, file }) => {
assert!(reason.contains("cancelled"));
assert!(file.is_none());
}
_ => panic!("Expected ParseCancelled, got {result:?}"),
}
}
#[test]
fn test_finalize_timeout_without_tree_returns_error() {
let result = finalize_parse_result(TerminationReason::TimedOut, None, None, 2_000_000);
match result {
Err(ParseError::ParseTimedOut { .. }) => {}
_ => panic!("Expected ParseTimedOut, got {result:?}"),
}
}
#[test]
fn test_finalize_cancelled_without_tree_returns_error() {
let result = finalize_parse_result(TerminationReason::Cancelled, None, None, 2_000_000);
match result {
Err(ParseError::ParseCancelled { .. }) => {}
_ => panic!("Expected ParseCancelled, got {result:?}"),
}
}
#[test]
fn test_finalize_success_with_tree() {
let tree = create_test_tree();
let result = finalize_parse_result(TerminationReason::None, Some(tree), None, 2_000_000);
assert!(result.is_ok());
assert_eq!(result.unwrap().root_node().kind(), "source_file");
}
#[test]
fn test_finalize_failure_without_tree() {
let result = finalize_parse_result(TerminationReason::None, None, None, 2_000_000);
match result {
Err(ParseError::TreeSitterFailed) => {}
_ => panic!("Expected TreeSitterFailed, got {result:?}"),
}
}
#[test]
fn test_finalize_timeout_includes_file_path() {
let tree = create_test_tree();
let file_path = Path::new("/path/to/test.rs");
let result = finalize_parse_result(
TerminationReason::TimedOut,
Some(tree),
Some(file_path),
1_500_000,
);
match result {
Err(ParseError::ParseTimedOut {
timeout_micros,
file,
}) => {
assert_eq!(timeout_micros, 1_500_000);
assert_eq!(file, Some(PathBuf::from("/path/to/test.rs")));
}
_ => panic!("Expected ParseTimedOut with file path, got {result:?}"),
}
}
#[test]
fn test_finalize_cancelled_includes_file_path() {
let tree = create_test_tree();
let file_path = Path::new("/some/code.rs");
let result = finalize_parse_result(
TerminationReason::Cancelled,
Some(tree),
Some(file_path),
2_000_000,
);
match result {
Err(ParseError::ParseCancelled { file, .. }) => {
assert_eq!(file, Some(PathBuf::from("/some/code.rs")));
}
_ => panic!("Expected ParseCancelled with file path, got {result:?}"),
}
}
}