use crate::error::IonError;
use crate::lexer::Lexer;
use crate::token::{SpannedToken, Token};
#[derive(Debug, Clone)]
pub enum RewriteError {
Lex(IonError),
NotFound(String),
Malformed(String),
InvalidReplacement(IonError),
}
impl std::fmt::Display for RewriteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RewriteError::Lex(e) => write!(f, "lex error: {}", e.message),
RewriteError::NotFound(name) => {
write!(f, "no top-level `let {}` binding found", name)
}
RewriteError::Malformed(msg) => write!(f, "malformed let binding: {}", msg),
RewriteError::InvalidReplacement(e) => {
write!(f, "rewritten source is invalid: {}", e.message)
}
}
}
}
impl std::error::Error for RewriteError {}
pub fn replace_global(
source: &str,
name: &str,
new_value_src: &str,
) -> Result<String, RewriteError> {
let tokens = Lexer::new(source)
.tokenize()
.map_err(RewriteError::Lex)?;
let line_starts = line_start_offsets(source);
let (value_start, value_end) = find_global_value_span(&tokens, name)?;
let start_byte = byte_offset_of(&tokens[value_start], &line_starts);
let end_byte = byte_offset_of(&tokens[value_end], &line_starts);
let mut out = String::with_capacity(source.len() + new_value_src.len());
out.push_str(&source[..start_byte]);
out.push_str(new_value_src);
out.push_str(&source[end_byte..]);
let new_tokens = Lexer::new(&out)
.tokenize()
.map_err(RewriteError::InvalidReplacement)?;
crate::parser::Parser::new(new_tokens)
.parse_program()
.map_err(RewriteError::InvalidReplacement)?;
Ok(out)
}
fn find_global_value_span(
tokens: &[SpannedToken],
name: &str,
) -> Result<(usize, usize), RewriteError> {
let mut i = 0;
let mut depth: i32 = 0;
while i < tokens.len() {
match &tokens[i].token {
Token::LBrace | Token::LBracket | Token::LParen | Token::HashBrace => {
depth += 1;
}
Token::RBrace | Token::RBracket | Token::RParen => {
depth -= 1;
}
Token::Let if depth == 0 => {
if let Some(span) = try_match_let(tokens, i, name)? {
return Ok(span);
}
}
_ => {}
}
i += 1;
}
Err(RewriteError::NotFound(name.to_string()))
}
fn try_match_let(
tokens: &[SpannedToken],
let_idx: usize,
name: &str,
) -> Result<Option<(usize, usize)>, RewriteError> {
let mut j = let_idx + 1;
if matches!(tokens.get(j).map(|t| &t.token), Some(Token::Mut)) {
j += 1;
}
let ident_matches = match tokens.get(j).map(|t| &t.token) {
Some(Token::Ident(n)) => n == name,
_ => return Ok(None),
};
j += 1;
if matches!(tokens.get(j).map(|t| &t.token), Some(Token::Colon)) {
j += 1;
while let Some(tok) = tokens.get(j) {
match tok.token {
Token::Eq | Token::Semicolon | Token::Eof => break,
_ => j += 1,
}
}
}
if !matches!(tokens.get(j).map(|t| &t.token), Some(Token::Eq)) {
return Err(RewriteError::Malformed(format!(
"expected `=` after `let {}`",
name
)));
}
let value_start = j + 1;
if value_start >= tokens.len()
|| matches!(tokens[value_start].token, Token::Eof | Token::Semicolon)
{
return Err(RewriteError::Malformed(
"expected expression after `=`".to_string(),
));
}
let mut k = value_start;
let mut local: i32 = 0;
loop {
match tokens.get(k).map(|t| &t.token) {
Some(Token::LBrace)
| Some(Token::LBracket)
| Some(Token::LParen)
| Some(Token::HashBrace) => local += 1,
Some(Token::RBrace) | Some(Token::RBracket) | Some(Token::RParen) => local -= 1,
Some(Token::Semicolon) if local == 0 => {
if !ident_matches {
return Ok(None);
}
return Ok(Some((value_start, k)));
}
Some(Token::Eof) | None => {
return Err(RewriteError::Malformed(
"unterminated `let` binding (missing `;`)".to_string(),
));
}
_ => {}
}
k += 1;
}
}
fn line_start_offsets(source: &str) -> Vec<usize> {
let mut starts = Vec::with_capacity(source.len() / 40 + 1);
starts.push(0);
for (i, b) in source.as_bytes().iter().enumerate() {
if *b == b'\n' {
starts.push(i + 1);
}
}
starts
}
fn byte_offset_of(tok: &SpannedToken, line_starts: &[usize]) -> usize {
let line = tok.line.max(1);
let idx = (line - 1).min(line_starts.len() - 1);
line_starts[idx] + tok.col.saturating_sub(1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn replaces_simple_int() {
let src = "let x = 1;";
let out = replace_global(src, "x", "42").unwrap();
assert_eq!(out, "let x = 42;");
}
#[test]
fn preserves_surrounding_code() {
let src = "fn pre() { 1 }\nlet threshold = 10;\nfn post() { threshold }\n";
let out = replace_global(src, "threshold", "99").unwrap();
assert_eq!(
out,
"fn pre() { 1 }\nlet threshold = 10;\nfn post() { threshold }\n"
.replace("= 10", "= 99")
);
}
#[test]
fn handles_mutable_global() {
let src = "let mut counter = 0;";
let out = replace_global(src, "counter", "100").unwrap();
assert_eq!(out, "let mut counter = 100;");
}
#[test]
fn handles_type_annotation() {
let src = "let name: string = \"old\";";
let out = replace_global(src, "name", "\"new\"").unwrap();
assert_eq!(out, "let name: string = \"new\";");
}
#[test]
fn handles_list_value() {
let src = "let xs = [1, 2, 3];";
let out = replace_global(src, "xs", "[4, 5, 6, 7]").unwrap();
assert_eq!(out, "let xs = [4, 5, 6, 7];");
}
#[test]
fn handles_dict_value_with_nested_semicolons_impossible_but_braces_ok() {
let src = "let cfg = #{\"a\": 1, \"b\": [2, 3]};";
let out = replace_global(src, "cfg", "#{\"a\": 9}").unwrap();
assert_eq!(out, "let cfg = #{\"a\": 9};");
}
#[test]
fn skips_bindings_inside_function_bodies() {
let src = "fn f() { let x = 1; x }\nlet x = 99;";
let out = replace_global(src, "x", "7").unwrap();
assert_eq!(out, "fn f() { let x = 1; x }\nlet x = 7;");
}
#[test]
fn not_found_returns_error() {
let src = "let y = 1;";
let err = replace_global(src, "missing", "0").unwrap_err();
assert!(matches!(err, RewriteError::NotFound(_)));
}
#[test]
fn rejects_invalid_replacement() {
let src = "let x = 1;";
let err = replace_global(src, "x", "}{ not valid").unwrap_err();
assert!(matches!(err, RewriteError::InvalidReplacement(_)));
}
#[test]
fn preserves_trailing_newline_and_comments() {
let src = "// config\nlet port = 8080; // default\n";
let out = replace_global(src, "port", "9090").unwrap();
assert_eq!(out, "// config\nlet port = 9090; // default\n");
}
#[test]
fn first_top_level_binding_wins() {
let src = "let x = 1;\nlet x = 2;";
let out = replace_global(src, "x", "9").unwrap();
assert_eq!(out, "let x = 9;\nlet x = 2;");
}
}