use std::collections::HashMap;
const MAX_BODY_LINES: usize = 50;
#[derive(Debug, Clone)]
pub enum InlineError {
SubNotFound {
name: String,
},
Recursive {
name: String,
},
TooLarge {
name: String,
line_count: usize,
},
MultipleReturns {
name: String,
count: usize,
},
CallSiteParseFailed {
message: String,
},
}
impl std::fmt::Display for InlineError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InlineError::SubNotFound { name } => {
write!(f, "subroutine '{}' not found in source", name)
}
InlineError::Recursive { name } => {
write!(f, "cannot inline recursive subroutine '{}'", name)
}
InlineError::TooLarge { name, line_count } => {
write!(
f,
"subroutine '{}' is too large to inline ({} lines, max {})",
name, line_count, MAX_BODY_LINES
)
}
InlineError::MultipleReturns { name, count } => {
write!(
f,
"subroutine '{}' has {} return points; only single-return subs can be inlined",
name, count
)
}
InlineError::CallSiteParseFailed { message } => {
write!(f, "failed to parse call site: {}", message)
}
}
}
}
impl std::error::Error for InlineError {}
#[derive(Debug, Clone)]
pub enum InlineAbility {
Ok {
params: Vec<String>,
body: String,
has_side_effects: bool,
},
}
pub fn analyze_sub_for_inlining(
source: &str,
sub_name: &str,
) -> Result<InlineAbility, InlineError> {
let parsed = parse_sub_definition(source, sub_name)
.ok_or_else(|| InlineError::SubNotFound { name: sub_name.to_string() })?;
if body_calls_self(&parsed.body, sub_name) {
return Err(InlineError::Recursive { name: sub_name.to_string() });
}
let body_line_count = parsed.body.lines().count();
if body_line_count > MAX_BODY_LINES {
return Err(InlineError::TooLarge {
name: sub_name.to_string(),
line_count: body_line_count,
});
}
let return_count = count_return_statements(&parsed.body);
if return_count > 1 {
return Err(InlineError::MultipleReturns {
name: sub_name.to_string(),
count: return_count,
});
}
let side_effects = has_side_effects(&parsed.body);
Ok(InlineAbility::Ok {
params: parsed.params,
body: parsed.body,
has_side_effects: side_effects,
})
}
pub struct SubInliner {
source: String,
}
impl SubInliner {
pub fn new(source: &str) -> Self {
Self { source: source.to_string() }
}
pub fn inline_call(&self, sub_name: &str, call_expr: &str) -> Result<String, InlineError> {
let (inlined, _warnings) = self.inline_call_inner(sub_name, call_expr, &[])?;
Ok(inlined)
}
pub fn inline_call_with_warnings(
&self,
sub_name: &str,
call_expr: &str,
) -> Result<(String, Vec<String>), InlineError> {
self.inline_call_inner(sub_name, call_expr, &[])
}
pub fn inline_call_with_outer_vars(
&self,
sub_name: &str,
call_expr: &str,
outer_vars: &[String],
) -> Result<String, InlineError> {
let (inlined, _warnings) = self.inline_call_inner(sub_name, call_expr, outer_vars)?;
Ok(inlined)
}
fn inline_call_inner(
&self,
sub_name: &str,
call_expr: &str,
outer_vars: &[String],
) -> Result<(String, Vec<String>), InlineError> {
let ability = analyze_sub_for_inlining(&self.source, sub_name)?;
let InlineAbility::Ok { params, body, has_side_effects } = ability;
let mut warnings = Vec::new();
if has_side_effects {
warnings.push(format!(
"subroutine '{}' contains side effects (print/warn/die/I/O); \
inlining preserves them but may change semantics",
sub_name
));
}
let args = extract_call_args(call_expr, sub_name)?;
let mut sub_map: HashMap<String, String> = HashMap::new();
for (i, param) in params.iter().enumerate() {
let arg = args.get(i).cloned().unwrap_or_default();
sub_map.insert(param.clone(), arg);
}
let body = rename_collisions(&body, outer_vars);
let substituted = substitute_params(&body, &sub_map);
let expr = extract_return_expr(&substituted);
Ok((expr, warnings))
}
}
struct ParsedSub {
params: Vec<String>,
body: String,
}
fn parse_sub_definition(source: &str, sub_name: &str) -> Option<ParsedSub> {
let start = find_sub_start(source, sub_name)?;
let body_start = source[start..].find('{').map(|i| start + i + 1)?;
let body_raw = extract_balanced_braces(source, body_start)?;
let (params, body_without_params) = extract_params_line(&body_raw);
Some(ParsedSub { params, body: body_without_params })
}
fn find_sub_start(source: &str, sub_name: &str) -> Option<usize> {
let mut pos = 0;
while pos < source.len() {
let rest = &source[pos..];
if let Some(idx) = rest.find("sub ") {
let after_sub = &rest[idx + 4..];
let trimmed = after_sub.trim_start();
if let Some(after_name) = trimmed.strip_prefix(sub_name) {
let boundary_ok =
after_name.chars().next().is_none_or(|c| !c.is_alphanumeric() && c != '_');
if boundary_ok && after_name.trim_start().starts_with('{') {
return Some(pos + idx);
}
}
pos += idx + 4;
} else {
break;
}
}
None
}
fn extract_balanced_braces(source: &str, open_pos: usize) -> Option<String> {
let mut depth = 1usize;
let chars: Vec<char> = source[open_pos..].chars().collect();
let mut end = 0;
let mut found = false;
let mut i = 0;
while i < chars.len() {
match chars[i] {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
end = i;
found = true;
break;
}
}
_ => {}
}
i += 1;
}
if !found {
return None;
}
Some(chars[..end].iter().collect())
}
fn extract_params_line(body: &str) -> (Vec<String>, String) {
for (i, line) in body.lines().enumerate() {
let trimmed = line.trim();
if trimmed.starts_with("my (") && trimmed.contains("= @_") {
let params = parse_param_names(trimmed);
let remaining: String = body
.lines()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, l)| l)
.collect::<Vec<_>>()
.join("\n");
return (params, remaining);
}
}
(vec![], body.to_string())
}
fn parse_param_names(line: &str) -> Vec<String> {
let open = match line.find('(') {
Some(i) => i,
None => return vec![],
};
let close = match line.rfind(')') {
Some(i) => i,
None => return vec![],
};
if close <= open {
return vec![];
}
let inner = &line[open + 1..close];
inner
.split(',')
.map(|s| s.trim().trim_start_matches(['$', '@', '%']).to_string())
.filter(|s| !s.is_empty())
.collect()
}
fn count_return_statements(body: &str) -> usize {
let mut count = 0usize;
let mut pos = 0;
let mut in_single_quote = false;
let mut in_double_quote = false;
let bytes = body.as_bytes();
while pos < body.len() {
let b = bytes[pos];
match b {
b'\\' if in_single_quote || in_double_quote => {
pos += 2;
continue;
}
b'\'' if !in_double_quote => {
in_single_quote = !in_single_quote;
pos += 1;
continue;
}
b'"' if !in_single_quote => {
in_double_quote = !in_double_quote;
pos += 1;
continue;
}
_ => {}
}
if !in_single_quote && !in_double_quote {
let rest = &body[pos..];
if rest.starts_with("return") {
let before_ok = if pos > 0 {
let prev = bytes[pos - 1];
!prev.is_ascii_alphanumeric() && prev != b'_'
} else {
true
};
let after_pos = pos + 6;
let after_ok = if after_pos < body.len() {
let next = bytes[after_pos];
!next.is_ascii_alphanumeric() && next != b'_'
} else {
true
};
if before_ok && after_ok {
count += 1;
}
pos += 6;
continue;
}
}
pos += body[pos..].chars().next().map_or(1, |c| c.len_utf8());
}
count
}
fn has_side_effects(body: &str) -> bool {
const SIDE_EFFECT_KEYWORDS: &[&str] = &[
"print ", "warn ", "die ", "open ", "close ", "read ", "write ", "seek ", "sysread",
"syswrite", "printf", "say ",
];
SIDE_EFFECT_KEYWORDS.iter().any(|kw| body.contains(kw))
}
fn body_calls_self(body: &str, sub_name: &str) -> bool {
let call_pattern = format!("{}(", sub_name);
let bytes = body.as_bytes();
let mut pos = 0;
let mut in_single_quote = false;
let mut in_double_quote = false;
while pos < body.len() {
let b = bytes[pos];
match b {
b'\\' if in_single_quote || in_double_quote => {
pos += 2;
continue;
}
b'\'' if !in_double_quote => {
in_single_quote = !in_single_quote;
pos += 1;
continue;
}
b'"' if !in_single_quote => {
in_double_quote = !in_double_quote;
pos += 1;
continue;
}
_ => {}
}
if !in_single_quote && !in_double_quote && body[pos..].starts_with(&call_pattern) {
return true;
}
pos += body[pos..].chars().next().map_or(1, |c| c.len_utf8());
}
false
}
fn extract_call_args(call_expr: &str, sub_name: &str) -> Result<Vec<String>, InlineError> {
let sub_pos = call_expr.find(sub_name).ok_or_else(|| InlineError::CallSiteParseFailed {
message: format!("call expression does not contain sub name '{}'", sub_name),
})?;
let after_name_pos = sub_pos + sub_name.len();
let rest = call_expr[after_name_pos..].trim_start();
if !rest.starts_with('(') {
return Ok(vec![]);
}
let paren_offset = call_expr[after_name_pos..].find('(').unwrap_or(0);
let open_abs = after_name_pos + paren_offset;
let close_abs = find_matching_paren(call_expr, open_abs).ok_or_else(|| {
InlineError::CallSiteParseFailed {
message: "unmatched parenthesis in call expression".to_string(),
}
})?;
let args_str = &call_expr[open_abs + 1..close_abs];
if args_str.trim().is_empty() {
return Ok(vec![]);
}
Ok(split_args(args_str))
}
fn find_matching_paren(s: &str, open: usize) -> Option<usize> {
let bytes = s.as_bytes();
if bytes.get(open) != Some(&b'(') {
return None;
}
let mut depth = 0usize;
for (i, &b) in bytes.iter().enumerate().skip(open) {
match b {
b'(' => depth += 1,
b')' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
None
}
fn split_args(args_str: &str) -> Vec<String> {
let mut result = Vec::new();
let mut current = String::new();
let mut depth = 0usize;
let mut in_single_quote = false;
let mut in_double_quote = false;
let chars: Vec<char> = args_str.chars().collect();
let mut i = 0;
while i < chars.len() {
let c = chars[i];
match c {
'\\' if in_double_quote || in_single_quote => {
current.push(c);
i += 1;
if i < chars.len() {
current.push(chars[i]);
}
}
'\'' if !in_double_quote => {
in_single_quote = !in_single_quote;
current.push(c);
}
'"' if !in_single_quote => {
in_double_quote = !in_double_quote;
current.push(c);
}
'(' | '[' | '{' if !in_single_quote && !in_double_quote => {
depth += 1;
current.push(c);
}
')' | ']' | '}' if !in_single_quote && !in_double_quote => {
depth = depth.saturating_sub(1);
current.push(c);
}
',' if depth == 0 && !in_single_quote && !in_double_quote => {
result.push(current.trim().to_string());
current = String::new();
}
_ => current.push(c),
}
i += 1;
}
if !current.trim().is_empty() {
result.push(current.trim().to_string());
}
result
}
fn substitute_params(body: &str, sub_map: &HashMap<String, String>) -> String {
let mut result = body.to_string();
let mut pairs: Vec<(&String, &String)> = sub_map.iter().collect();
pairs.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
for (param, arg) in pairs {
let var = format!("${}", param);
result = replace_whole_var(&result, &var, arg);
}
result
}
fn rename_collisions(body: &str, outer_vars: &[String]) -> String {
let mut result = body.to_string();
for outer in outer_vars {
let bare = outer.trim_start_matches(['$', '@', '%']);
let my_decl = format!("my ${}", bare);
if result.contains(&my_decl) {
let renamed_bare = format!("{}_inlined", bare);
let renamed_decl = format!("my ${}", renamed_bare);
result = replace_whole_var(&result, &my_decl, &renamed_decl);
let var = format!("${}", bare);
let renamed_var = format!("${}", renamed_bare);
result = replace_whole_var(&result, &var, &renamed_var);
}
}
result
}
fn replace_whole_var(text: &str, var: &str, replacement: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut pos = 0;
while pos < text.len() {
if text[pos..].starts_with(var) {
let after = pos + var.len();
let next_is_alphanum =
text[after..].chars().next().is_some_and(|c| c.is_alphanumeric() || c == '_');
if !next_is_alphanum {
result.push_str(replacement);
pos = after;
continue;
}
}
let c = text[pos..].chars().next().unwrap_or('\0');
result.push(c);
pos += c.len_utf8();
}
result
}
fn extract_return_expr(body: &str) -> String {
for line in body.lines() {
let trimmed = line.trim();
if trimmed.starts_with("return ") {
let expr = trimmed.trim_start_matches("return ").trim_end_matches(';').trim();
return format!("({})", expr);
}
}
body.trim().to_string()
}