use super::refactor::csharp_unit;
use super::resolve::{resolve, DeclKind};
use crate::v2::ast::{Node, NodeKind, StringKind};
use crate::v2::edit::TextEdit;
use crate::v2::span::Span;
pub fn rename_type(scope: &Node, src: &str, from: &str, to: &str) -> Vec<TextEdit> {
let mut spans = Vec::new();
each_csharp_unit(scope, src, &mut |unit| {
let r = resolve(unit);
for id in r.find(from, Some(DeclKind::Type)) {
spans.extend(r.references_of(id));
}
});
ps_type_refs(scope, src, from, &mut spans);
super::refactor::edits_from_spans(spans, to)
}
pub fn rename_member(
scope: &Node,
src: &str,
type_name: &str,
from: &str,
to: &str,
) -> Vec<TextEdit> {
let mut spans = Vec::new();
each_csharp_unit(scope, src, &mut |unit| {
let r = resolve(unit);
for id in r.find(from, None) {
let Some(d) = r.decl(id) else { continue };
let is_member = matches!(
d.kind,
DeclKind::Field | DeclKind::Method | DeclKind::Property | DeclKind::EnumMember
);
if is_member && r.enclosing_type(id) == Some(type_name) {
spans.extend(r.references_of(id));
}
}
});
ps_static_member_refs(scope, src, type_name, from, &mut spans);
super::refactor::edits_from_spans(spans, to)
}
fn ps_type_refs(scope: &Node, src: &str, from: &str, out: &mut Vec<Span>) {
scope.walk(&mut |n| match &n.kind {
NodeKind::TypeExpression(_) if n.span.end > n.span.start + 1 => {
let inner_start = n.span.start + 1;
let inner_end = n.span.end - 1;
find_type_name_spans(&src[inner_start..inner_end], inner_start, from, out);
}
NodeKind::Cast { .. } => {
let rest = &src[n.span.start..n.span.end];
if let Some(close) = matching_bracket(rest) {
let inner_start = n.span.start + 1;
let inner_end = n.span.start + close;
if inner_end > inner_start {
find_type_name_spans(&src[inner_start..inner_end], inner_start, from, out);
}
}
}
NodeKind::Command { name, elements, .. } => {
if matches!(&name.kind, NodeKind::BareWord(w) if w.eq_ignore_ascii_case("new-object")) {
if let Some(node) = new_object_type_arg(elements) {
match &node.kind {
NodeKind::BareWord(w) => {
if w.eq_ignore_ascii_case(from) {
out.push(node.span);
}
}
NodeKind::StringLiteral { kind, parts, .. }
if parts.is_empty()
&& matches!(kind, StringKind::Single | StringKind::Double)
&& node.span.end >= node.span.start + 2 =>
{
let inner_start = node.span.start + 1;
let inner_end = node.span.end - 1;
find_type_name_spans(
&src[inner_start..inner_end],
inner_start,
from,
out,
);
}
_ => {}
}
}
}
}
_ => {}
});
}
fn matching_bracket(s: &str) -> Option<usize> {
let mut depth = 0i32;
for (i, b) in s.bytes().enumerate() {
match b {
b'[' => depth += 1,
b']' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
None
}
fn ps_static_member_refs(
scope: &Node,
src: &str,
type_name: &str,
member: &str,
out: &mut Vec<Span>,
) {
scope.walk(&mut |n| {
let (target, m, is_static) = match &n.kind {
NodeKind::MemberAccess {
target,
member,
is_static,
} => (target, member, *is_static),
NodeKind::InvokeMember {
target,
member,
is_static,
..
} => (target, member, *is_static),
_ => return,
};
if !is_static || !m.eq_ignore_ascii_case(member) {
return;
}
let NodeKind::TypeExpression(_) = &target.kind else {
return;
};
if target.span.end <= target.span.start + 1 {
return;
}
let t_inner = &src[target.span.start + 1..target.span.end - 1];
if !type_text_matches(t_inner, type_name) {
return;
}
if let Some(s) = member_span_after(src, target.span.end, member) {
out.push(s);
}
});
}
fn new_object_type_arg(elements: &[Node]) -> Option<&Node> {
let mut i = 0;
while i < elements.len() {
match &elements[i].kind {
NodeKind::CommandParameter { name, argument }
if "typename".starts_with(&name.to_ascii_lowercase()) =>
{
if let Some(arg) = argument {
return Some(arg);
}
if let Some(next) = elements.get(i + 1) {
if !matches!(next.kind, NodeKind::CommandParameter { .. }) {
return Some(next);
}
}
}
NodeKind::BareWord(_) | NodeKind::StringLiteral { .. } => return Some(&elements[i]),
_ => {}
}
i += 1;
}
None
}
fn type_text_matches(inner: &str, target: &str) -> bool {
let t = inner.trim();
t.eq_ignore_ascii_case(target)
|| t.rsplit('.')
.next()
.is_some_and(|seg| seg.eq_ignore_ascii_case(target))
}
fn find_type_name_spans(text: &str, base: usize, from: &str, out: &mut Vec<Span>) {
let bytes = text.as_bytes();
let is_word = |c: u8| c.is_ascii_alphanumeric() || c == b'_';
let mut i = 0;
while i < bytes.len() {
if !is_word(bytes[i]) {
i += 1;
continue;
}
let run_start = i;
while i < bytes.len() && (is_word(bytes[i]) || bytes[i] == b'.') {
i += 1;
}
let run = &text[run_start..i];
let dot = run.rfind('.').map_or(0, |d| d + 1);
let name = &run[dot..];
if name.eq_ignore_ascii_case(from) {
let s = run_start + dot;
out.push(Span::new(base + s, base + s + name.len()));
}
}
}
fn member_span_after(src: &str, offset: usize, member: &str) -> Option<Span> {
let b = src.as_bytes();
let n = b.len();
let mut i = offset.min(n);
while i < n && (b[i] == b':' || b[i] == b'.' || b[i].is_ascii_whitespace()) {
i += 1;
}
let start = i;
while i < n && (b[i] == b'_' || b[i].is_ascii_alphanumeric()) {
i += 1;
}
if i > start && src[start..i].eq_ignore_ascii_case(member) {
Some(Span::new(start, i))
} else {
None
}
}
fn each_csharp_unit(scope: &Node, src: &str, f: &mut impl FnMut(&crate::v2::csharp::ast::CsUnit)) {
scope.walk(&mut |n| {
if matches!(n.kind, NodeKind::CSharpMemberDef(_)) {
if let Some(unit) = csharp_unit(n, src) {
f(&unit);
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use crate::v2::{apply_edits, parse};
const PS: &str = "Add-Type -TypeDefinition @'\npublic class Win32 {\n [DllImport(\"user32.dll\")]\n public static extern int MessageBox(IntPtr h, string t, string c, uint ty);\n}\n'@\n[Win32]::MessageBox(0, 'hi', 'title', 0)\n$inst = New-Object Win32\n[Win32]$casted = $inst\n";
#[test]
fn rename_type_updates_csharp_and_all_powershell_sites() {
let out = parse(PS);
let edits = rename_type(&out.script, PS, "Win32", "NativeApi");
let result = apply_edits(PS, &edits).unwrap();
assert!(result.contains("public class NativeApi {"));
assert!(result.contains("[NativeApi]::MessageBox"));
assert!(result.contains("New-Object NativeApi"));
assert!(result.contains("[NativeApi]$casted"));
assert!(!result.contains("Win32"));
}
#[test]
fn rename_type_handles_array_generic_and_multiple_occurrences() {
let src = "[Logger[]]::new()\n\
[System.Collections.Generic.List[Logger]]$x = $null\n\
[System.Collections.Generic.Dictionary[Logger,Logger]]::new()\n\
[LoggerHelper]::Init()\n\
[My.Logger]::X()\n";
let out = parse(src);
let edits = rename_type(&out.script, src, "Logger", "Tracer");
let result = apply_edits(src, &edits).unwrap();
assert!(result.contains("[Tracer[]]::new()"));
assert!(result.contains("List[Tracer]"));
assert!(result.contains("Dictionary[Tracer,Tracer]"));
assert!(result.contains("[My.Tracer]::X()"));
assert!(result.contains("[LoggerHelper]::Init()"));
assert!(!result.contains("[Logger]"));
assert!(!result.contains("[Logger["));
}
#[test]
fn rename_type_handles_quoted_new_object_names() {
let src = "New-Object -TypeName 'Logger'\n\
New-Object \"Logger\"\n\
New-Object 'My.Logger'\n\
New-Object \"$prefix.Logger\"\n\
New-Object 'LoggerHelper'\n";
let out = parse(src);
let edits = rename_type(&out.script, src, "Logger", "Tracer");
let result = apply_edits(src, &edits).unwrap();
assert!(result.contains("-TypeName 'Tracer'"));
assert!(result.contains("New-Object \"Tracer\""));
assert!(result.contains("'My.Tracer'"));
assert!(
result.contains("\"$prefix.Logger\""),
"interpolated untouched"
);
assert!(result.contains("'LoggerHelper'"), "prefix-only untouched");
}
#[test]
fn rename_member_updates_csharp_decl_and_static_call_site() {
let out = parse(PS);
let edits = rename_member(&out.script, PS, "Win32", "MessageBox", "ShowMessage");
let result = apply_edits(PS, &edits).unwrap();
assert!(result.contains("extern int ShowMessage("));
assert!(result.contains("[Win32]::ShowMessage(0, 'hi', 'title', 0)"));
assert!(result.contains("public class Win32"));
assert!(!result.contains("MessageBox"));
}
#[test]
fn rename_member_leaves_unrelated_receiver_alone() {
let src = "Add-Type -TypeDefinition @'\npublic class C {\n public int Length;\n public int Measure(string s) { return s.Length + this.Length; }\n}\n'@\n";
let out = parse(src);
let edits = rename_member(&out.script, src, "C", "Length", "Size");
let result = apply_edits(src, &edits).unwrap();
assert!(result.contains("public int Size;"));
assert!(result.contains("return s.Length + this.Size;"));
}
}