use regex::Regex;
pub fn fix_handler_executor_calls(source: &str) -> String {
let mut result = source.to_string();
result = result.replace(" implements RustLibApi async {", " implements RustLibApi {");
result = rewrite_handler_calls_in_parameterized_functions(&result);
result = ensure_handler_closures_are_async(&result);
result
}
fn rewrite_handler_calls_in_parameterized_functions(source: &str) -> String {
let lines: Vec<&str> = source.lines().collect();
let mut result = String::new();
let mut i = 0;
while i < lines.len() {
let line = lines[i];
let is_function_start = is_likely_function_start(line);
if is_function_start {
let is_handler_parameterized = detect_handler_parameter(&lines, i);
let mut func_lines = vec![line];
i += 1;
let mut depth = count_brace_depth(line);
let mut saw_body = depth > 0;
while i < lines.len() && (!saw_body || depth > 0) {
let curr_line = lines[i];
func_lines.push(curr_line);
let line_depth = count_brace_depth(curr_line);
depth += line_depth;
saw_body = saw_body || line_depth > 0;
i += 1;
}
let func_text = func_lines.join("\n");
let rewritten = if is_handler_parameterized {
rewrite_handler_to_task_executor(&func_text)
} else {
func_text
};
result.push_str(&rewritten);
result.push('\n');
} else {
result.push_str(line);
result.push('\n');
i += 1;
}
}
if !source.ends_with('\n') && result.ends_with('\n') {
result.pop();
}
result
}
fn is_likely_function_start(line: &str) -> bool {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("//") {
return false;
}
if trimmed.starts_with("@") {
return false; }
if !line.contains('(') {
return false;
}
if trimmed.starts_with("}") || trimmed.starts_with("]") || trimmed.starts_with(")") {
return false;
}
true
}
fn count_brace_depth(line: &str) -> i32 {
let opens = line.chars().filter(|c| *c == '{').count() as i32;
let closes = line.chars().filter(|c| *c == '}').count() as i32;
opens - closes
}
fn detect_handler_parameter(lines: &[&str], idx: usize) -> bool {
if idx >= lines.len() {
return false;
}
let line = lines[idx];
if !line.contains('(') {
for l in lines.iter().take(std::cmp::min(idx + 20, lines.len())).skip(idx) {
if l.contains("handler") && l.contains("Function") {
return true;
}
if l.contains(')') && l.contains('{') {
break;
}
}
} else {
let mut sig = line.to_string();
let mut paren_depth = line.chars().filter(|c| *c == '(').count() - line.chars().filter(|c| *c == ')').count();
let mut j = idx + 1;
while j < lines.len() && paren_depth > 0 {
let l = lines[j];
sig.push(' ');
sig.push_str(l);
paren_depth += l.chars().filter(|c| *c == '(').count();
paren_depth -= l.chars().filter(|c| *c == ')').count();
j += 1;
}
if sig.contains("handler") && sig.contains("Function") {
return true;
}
}
false
}
fn rewrite_handler_to_task_executor(source: &str) -> String {
let mut result = rewrite_handler_executor_wrappers(source);
let orphaned_paren_sync =
Regex::new(r"(?s)\),\s*\)\.executeSync\(\)").expect("orphaned paren sync pattern must compile");
result = orphaned_paren_sync.replace_all(&result, ").executeSync()").into_owned();
let orphaned_paren_async =
Regex::new(r"(?s)\),\s*\)\.executeNormal\(\)").expect("orphaned paren async pattern must compile");
result = orphaned_paren_async
.replace_all(&result, ").executeNormal()")
.into_owned();
result
}
fn rewrite_handler_executor_wrappers(source: &str) -> String {
let mut out = String::with_capacity(source.len());
let mut cursor = 0;
while let Some((relative_start, method)) = find_next_handler_executor(&source[cursor..]) {
let start = cursor + relative_start;
let open_paren = start + format!("handler.{method}").len();
let Some(close_paren) = find_matching_paren(source, open_paren) else {
break;
};
out.push_str(&source[cursor..start]);
let task = source[open_paren + 1..close_paren].trim();
let task = task.strip_suffix(',').map(str::trim_end).unwrap_or(task);
out.push_str(task);
out.push('.');
out.push_str(method);
out.push_str("()");
cursor = close_paren + 1;
}
out.push_str(&source[cursor..]);
out
}
fn find_next_handler_executor(source: &str) -> Option<(usize, &'static str)> {
let sync = source.find("handler.executeSync(");
let normal = source.find("handler.executeNormal(");
match (sync, normal) {
(Some(sync), Some(normal)) if sync <= normal => Some((sync, "executeSync")),
(Some(_), Some(normal)) => Some((normal, "executeNormal")),
(Some(sync), None) => Some((sync, "executeSync")),
(None, Some(normal)) => Some((normal, "executeNormal")),
(None, None) => None,
}
}
fn find_matching_paren(source: &str, open_paren: usize) -> Option<usize> {
let mut depth = 0usize;
for (offset, ch) in source[open_paren..].char_indices() {
match ch {
'(' => depth += 1,
')' => {
depth = depth.checked_sub(1)?;
if depth == 0 {
return Some(open_paren + offset);
}
}
_ => {}
}
}
None
}
fn ensure_handler_closures_are_async(source: &str) -> String {
let lines: Vec<&str> = source.lines().collect();
let mut lines_to_fix: std::collections::HashSet<usize> = std::collections::HashSet::new();
let mut i = 0;
while i < lines.len() {
let line = lines[i];
let trimmed_line = line.trim();
if trimmed_line.starts_with("//")
|| line.contains("async")
|| trimmed_line.starts_with("class ")
|| trimmed_line.starts_with("abstract class ")
|| trimmed_line.starts_with("mixin ")
{
i += 1;
continue;
}
let contains_await_handler =
(i..std::cmp::min(i + 30, lines.len())).any(|j| lines[j].contains("await handler("));
if contains_await_handler {
let parens_balanced =
line.chars().filter(|c| *c == '(').count() == line.chars().filter(|c| *c == ')').count();
if parens_balanced && line.contains('{') {
lines_to_fix.insert(i);
}
else if !parens_balanced {
for (j, check_line) in lines
.iter()
.enumerate()
.take(std::cmp::min(i + 30, lines.len()))
.skip(i + 1)
{
if check_line.contains(')') && check_line.contains('{') && !check_line.trim().starts_with("//") {
if !check_line.contains("async") {
lines_to_fix.insert(j);
}
break;
}
}
}
}
i += 1;
}
let mut result = String::new();
for (i, line) in lines.iter().enumerate() {
if lines_to_fix.contains(&i) {
let fixed = if line.contains(") {") {
line.replace(") {", ") async {")
} else {
let trimmed = line.trim_end();
if trimmed.ends_with("{") {
format!("{} async {{", trimmed.trim_end_matches('{').trim_end())
} else {
line.to_string()
}
};
result.push_str(&fixed);
} else {
result.push_str(line);
}
result.push('\n');
}
if !source.ends_with('\n') && result.ends_with('\n') {
result.pop();
}
result
}
pub fn filter_excluded_functions(source: &str, exclude_functions: &std::collections::HashSet<&str>) -> String {
if exclude_functions.is_empty() {
return source.to_string();
}
let lines: Vec<&str> = source.lines().collect();
let mut result = String::with_capacity(source.len());
let mut i = 0;
let mut doc_buffer: Vec<&str> = Vec::new();
while i < lines.len() {
let line = lines[i];
let trimmed = line.trim_start();
if trimmed.starts_with("///")
|| trimmed.starts_with("//")
|| (trimmed.starts_with("*") && !trimmed.starts_with("**/"))
{
doc_buffer.push(line);
i += 1;
continue;
}
let mut should_skip_function = false;
if !trimmed.is_empty() && !trimmed.starts_with("class") && !trimmed.starts_with("enum") {
should_skip_function = exclude_functions.iter().any(|&excluded| {
let camel_excluded = snake_to_camel(excluded);
let pattern = format!(" {}(", camel_excluded);
line.contains(&pattern)
});
}
if should_skip_function {
doc_buffer.clear();
loop {
i += 1;
if i >= lines.len() {
break;
}
let check_line = lines[i];
if check_line.contains(';') {
i += 1;
break;
}
}
} else {
for doc_line in &doc_buffer {
result.push_str(doc_line);
result.push('\n');
}
doc_buffer.clear();
result.push_str(line);
result.push('\n');
i += 1;
}
}
for doc_line in &doc_buffer {
result.push_str(doc_line);
result.push('\n');
}
result
}
fn snake_to_camel(name: &str) -> String {
let mut result = String::new();
let mut capitalize_next = false;
for c in name.chars() {
if c == '_' {
capitalize_next = true;
} else if capitalize_next {
for upper_c in c.to_uppercase() {
result.push(upper_c);
}
capitalize_next = false;
} else if result.is_empty() {
for lower_c in c.to_lowercase() {
result.push(lower_c);
}
} else {
result.push(c);
}
}
result
}
pub fn make_struct_fields_with_defaults_optional(source: &str) -> String {
source.to_string()
}