c2rust-refactor 0.15.0

C2Rust refactoring tool implementation
require "utils"

Variable = {}

function Variable.new(used, id, locl, binding, ident)
    self = {}
    self.used = used
    self.id = id
    self.locl = locl
    self.binding = binding
    self.ident = ident
    self.shadowed = false

    setmetatable(self, Variable)
    Variable.__index = Variable

    return self
end

Visitor = {}

function Visitor.new()
    self = {}
    self.args = {}
    self.locals = {}
    self.variables = {}

    setmetatable(self, Visitor)
    Visitor.__index = Visitor

    return self
end

function Visitor:visit_stmt(stmt)
    if stmt.kind == "Local" then
        debug("Found local in visitor")
        if stmt.pat.kind == "Ident" then
            used = false
            locl = true
            id = stmt.pat.id
            ident = stmt.pat.ident
            binding = "ByValueImmutable"

            -- Find a shadowed variable
            self:find_variable(ident,
                function(var)
                    var.shadowed = true
                end
            )

            self.variables[id] = Variable.new(used, id, locl, binding, ident)

            table.insert(self.locals, stmt)
        else
            debug("Skipping unsupported local type")
        end
    end

    return true
end

function is_simple_expr_path(expr)
    if expr.kind ~= "Path" or #expr.segments ~= 1 then
        return false
    end

    return true
end

function set_by_value_mutable(var)
    var.binding = "ByValueMutable"
end

function set_used(var)
    var.used = true
end

function Visitor:visit_expr(expr)
    debug("Visiting Expr: " .. expr.kind)
    if is_simple_expr_path(expr) then
        self:find_variable(expr.segments[1].ident, set_used)
    elseif expr.kind == "Assign" or expr.kind == "AssignOp" then
        ident = nil

        -- Here we might find a path, ie `a += 1` or an index path, ie `a[1] += 1`
        if is_simple_expr_path(expr.lhs) then
            ident = expr.lhs.segments[1].ident
        elseif expr.lhs.kind == "Index" then
            if is_simple_expr_path(expr.lhs.indexed) then
                ident = expr.lhs.indexed.segments[1].ident
            end
        end

        -- Find the variable and mark it as mutable
        if ident then
            self:find_variable(ident, set_by_value_mutable)
        end
    elseif expr.kind == "InlineAsm" then
        for _, input in ipairs(expr.inputs) do
            if is_simple_expr_path(input.expr) then
                self:find_variable(input.expr.segments[1].ident, set_used)
            end
        end

        for _, output in ipairs(expr.outputs) do
            if is_simple_expr_path(output.expr) then
                self:find_variable(output.expr.segments[1].ident, set_by_value_mutable)
            end
        end
    elseif expr.kind == "MethodCall" then
        -- This may need to be more complex. What if self isn't an ident but say an
        -- index? ie `x[1].as_mut_ptr()` x still needs to be mutable?
        if is_simple_expr_path(expr.args[1]) and expr.caller_is == "ref_mut" then
            self:find_variable(expr.args[1].segments[1].ident, set_by_value_mutable)
        end
    end

    return true
end

function Visitor:find_variable(ident, mutator)
    for _, variable in pairs(self.variables) do
        if variable.ident == ident then
            mutator(variable)
        end
    end
end

function Visitor:visit_fn_like(fn_like)
    -- Skip foreign functions - we only want functions with bodies
    if fn_like.kind == "Foreign" then
        return
    end

    -- Most trait methods don't have default impls, though they can
    if fn_like.kind == "TraitMethod" and not fn_like.block then
        return
    end

    debug("FnLike name: " .. fn_like.ident)

    args = fn_like.decl.args
    stmts = fn_like.block.stmts

    for _, arg in ipairs(args) do
        -- TODO: Pattern might not be an ident
        used = false
        id = arg.id
        locl = false
        binding = "ByValueImmutable"
        ident = arg.pat.ident

        self.variables[id] = Variable.new(used, id, locl, binding, ident)

        table.insert(self.args, arg)

        debug("Arg: " .. ident .. " of id " .. id .. " binding " .. binding)
    end

    return true
end

function update_pattern(pattern, variable)
    -- TODO: Pattern might not be an ident
    debug("variable " .. pattern.ident .. " used: " .. tostring(variable.used))
    if not variable.used then
        pattern.binding = "ByValueImmutable"

        -- If the argument doesn't already have an underscore
        -- prefix, we should add one as it is idomatic rust
        if not starts_with(pattern.ident, '_') then
            pattern.ident = '_' .. pattern.ident
        end
    else
        pattern.binding = variable.binding
    end
end

function Visitor:finish()
    -- Iterate over args
    for _, arg in ipairs(self.args) do
        variable = self.variables[arg.id]

        update_pattern(arg.pat, variable)
    end

    -- Iterate over locals
    for _, stmt in ipairs(self.locals) do
        variable = self.variables[stmt.pat.id]

        update_pattern(stmt.pat, variable)
    end
end

refactor:transform(
    function(transform_ctx)
        return transform_ctx:visit_fn_like(Visitor.new())
    end
)
refactor:save_crate()

print("Finished cleanup_params_locals.lua")