use super::refactor::csharp_unit;
use super::resolve::{resolve, DeclKind};
use crate::v2::ast::{Node, NodeKind, StringKind};
use crate::v2::edit::TextEdit;
use crate::v2::span::Span;
pub fn rename_type(scope: &Node, src: &str, from: &str, to: &str) -> Vec<TextEdit> {
let mut spans = Vec::new();
each_csharp_unit(scope, src, &mut |unit, _owner| {
let r = resolve(unit);
for id in r.find(from, Some(DeclKind::Type)) {
spans.extend(r.references_of(id));
}
});
ps_type_refs(scope, src, from, &mut spans);
super::refactor::edits_from_spans(spans, to)
}
pub fn rename_member(
scope: &Node,
src: &str,
type_name: &str,
from: &str,
to: &str,
) -> Vec<TextEdit> {
let mut spans = Vec::new();
each_csharp_unit(scope, src, &mut |unit, owner| {
let r = resolve(unit);
for id in r.find(from, None) {
let Some(d) = r.decl(id) else { continue };
let is_member = matches!(
d.kind,
DeclKind::Field | DeclKind::Method | DeclKind::Property | DeclKind::EnumMember
);
let owned = r.enclosing_type(id) == Some(type_name)
|| (r.enclosing_type(id).is_none() && owner == Some(type_name));
if is_member && owned {
spans.extend(r.references_of(id));
}
}
});
ps_static_member_refs(scope, src, type_name, from, &mut spans);
super::refactor::edits_from_spans(spans, to)
}
fn ps_type_refs(scope: &Node, src: &str, from: &str, out: &mut Vec<Span>) {
scope.walk(&mut |n| match &n.kind {
NodeKind::TypeExpression(_) => {
find_in_interior(src, n.span, from, out);
}
NodeKind::Cast { .. } => {
let rest = &src[n.span.start..n.span.end];
if let Some(close) = matching_bracket(rest) {
find_in_byte_range(src, n.span.start + 1, n.span.start + close, from, out);
}
}
NodeKind::Command { name, elements, .. } => {
let cmd = match &name.kind {
NodeKind::BareWord(w) => w.as_str(),
_ => "",
};
if cmd.eq_ignore_ascii_case("new-object") {
if let Some(node) = new_object_type_arg(elements) {
push_type_name_from_arg(src, node, from, out);
}
} else if cmd.eq_ignore_ascii_case("add-type") {
if let Some(node) = named_parameter_value(elements, "name") {
push_type_name_from_arg(src, node, from, out);
}
}
}
NodeKind::ClassDefinition {
name, name_span, ..
} => {
if name.eq_ignore_ascii_case(from) {
out.push(*name_span);
}
let base_start = name_span.end;
let base_end = type_header_end(src, n.span).max(base_start);
find_type_name_spans_in_header(
safe_inner(src, base_start, base_end),
base_start,
from,
out,
);
}
NodeKind::EnumDefinition {
name, name_span, ..
} if name.eq_ignore_ascii_case(from) => {
out.push(*name_span);
}
_ => {}
});
}
fn type_header_end(src: &str, node: Span) -> usize {
let bytes = src.as_bytes();
let mut i = node.start;
while i < node.end {
if bytes[i] == b'{' {
return i;
}
i += 1;
}
node.end
}
fn find_type_name_spans_in_header(text: &str, base: usize, from: &str, out: &mut Vec<Span>) {
if !text.as_bytes().contains(&b'#') {
find_type_name_spans(text, base, from, out);
return;
}
let mut masked = String::with_capacity(text.len());
let bytes = text.as_bytes();
let mut i = 0;
while i < bytes.len() {
if i + 1 < bytes.len() && bytes[i] == b'<' && bytes[i + 1] == b'#' {
let mut j = i + 2;
while j + 1 < bytes.len() && !(bytes[j] == b'#' && bytes[j + 1] == b'>') {
j += 1;
}
let stop = (j + 2).min(bytes.len());
for _ in i..stop {
masked.push(' ');
}
i = stop;
} else if bytes[i] == b'#' {
let mut j = i;
while j < bytes.len() && bytes[j] != b'\n' {
j += 1;
}
for _ in i..j {
masked.push(' ');
}
i = j;
} else {
let ch_len = utf8_len(bytes[i]);
let end = (i + ch_len).min(bytes.len());
masked.push_str(&text[i..end]);
i = end;
}
}
find_type_name_spans(&masked, base, from, out);
}
fn find_in_byte_range(src: &str, start: usize, end: usize, from: &str, out: &mut Vec<Span>) {
if end <= start {
return;
}
find_type_name_spans(safe_inner(src, start, end), start, from, out);
}
fn find_in_interior(src: &str, span: Span, from: &str, out: &mut Vec<Span>) {
if span.end <= span.start + 1 {
return;
}
find_in_byte_range(src, span.start + 1, span.end - 1, from, out);
}
fn is_plain_quoted(node: &Node) -> bool {
matches!(
&node.kind,
NodeKind::StringLiteral { kind, parts, .. }
if parts.is_empty()
&& matches!(kind, StringKind::Single | StringKind::Double)
&& node.span.end >= node.span.start + 2
)
}
fn safe_inner(src: &str, mut start: usize, mut end: usize) -> &str {
if start > end || end > src.len() {
return "";
}
while start < end && !src.is_char_boundary(start) {
start += 1;
}
while end > start && !src.is_char_boundary(end) {
end -= 1;
}
&src[start..end]
}
fn utf8_len(b: u8) -> usize {
match b {
0x00..=0x7F => 1,
0xC0..=0xDF => 2,
0xE0..=0xEF => 3,
_ => 4,
}
}
fn matching_bracket(s: &str) -> Option<usize> {
let mut depth = 0i32;
for (i, b) in s.bytes().enumerate() {
match b {
b'[' => depth += 1,
b']' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
None
}
fn ps_static_member_refs(
scope: &Node,
src: &str,
type_name: &str,
member: &str,
out: &mut Vec<Span>,
) {
scope.walk(&mut |n| {
let (target, m, is_static) = match &n.kind {
NodeKind::MemberAccess {
target,
member,
is_static,
} => (target, member, *is_static),
NodeKind::InvokeMember {
target,
member,
is_static,
..
} => (target, member, *is_static),
_ => return,
};
if !is_static || !m.eq_ignore_ascii_case(member) {
return;
}
let NodeKind::TypeExpression(_) = &target.kind else {
return;
};
if target.span.end <= target.span.start + 1 {
return;
}
let t_inner = safe_inner(src, target.span.start + 1, target.span.end - 1);
if !type_text_matches(t_inner, type_name) {
return;
}
if let Some(s) = member_span_after(src, target.span.end, member) {
out.push(s);
}
});
}
fn named_parameter_value<'a>(elements: &'a [Node], param: &str) -> Option<&'a Node> {
let mut i = 0;
while i < elements.len() {
if let NodeKind::CommandParameter { name, argument } = &elements[i].kind {
if name.eq_ignore_ascii_case(param) {
if let Some(arg) = argument {
return Some(arg);
}
if let Some(next) = elements.get(i + 1) {
if !matches!(next.kind, NodeKind::CommandParameter { .. }) {
return Some(next);
}
}
}
}
i += 1;
}
None
}
fn push_type_name_from_arg(src: &str, node: &Node, from: &str, out: &mut Vec<Span>) {
match &node.kind {
NodeKind::BareWord(w) => {
if w.eq_ignore_ascii_case(from) {
out.push(node.span);
}
}
_ if is_plain_quoted(node) => find_in_interior(src, node.span, from, out),
_ => {}
}
}
fn new_object_type_arg(elements: &[Node]) -> Option<&Node> {
let mut i = 0;
while i < elements.len() {
match &elements[i].kind {
NodeKind::CommandParameter { name, argument }
if is_prefix_ignore_ascii_case(name, "typename") =>
{
if let Some(arg) = argument {
return Some(arg);
}
if let Some(next) = elements.get(i + 1) {
if !matches!(next.kind, NodeKind::CommandParameter { .. }) {
return Some(next);
}
}
}
NodeKind::BareWord(_) | NodeKind::StringLiteral { .. } => return Some(&elements[i]),
_ => {}
}
i += 1;
}
None
}
fn is_prefix_ignore_ascii_case(prefix: &str, full: &str) -> bool {
!prefix.is_empty()
&& prefix.len() <= full.len()
&& full.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes())
}
fn type_text_matches(inner: &str, target: &str) -> bool {
let t = inner.trim();
t.eq_ignore_ascii_case(target)
|| t.rsplit('.')
.next()
.is_some_and(|seg| seg.eq_ignore_ascii_case(target))
}
fn find_type_name_spans(text: &str, base: usize, from: &str, out: &mut Vec<Span>) {
let bytes = text.as_bytes();
let is_word = |c: u8| c.is_ascii_alphanumeric() || c == b'_' || c >= 0x80;
let mut i = 0;
while i < bytes.len() {
if !is_word(bytes[i]) {
i += 1;
continue;
}
let run_start = i;
while i < bytes.len() && (is_word(bytes[i]) || bytes[i] == b'.') {
i += 1;
}
let run = &text[run_start..i];
let dot = run.rfind('.').map_or(0, |d| d + 1);
let name = &run[dot..];
if name.eq_ignore_ascii_case(from) {
let s = run_start + dot;
out.push(Span::new(base + s, base + s + name.len()));
}
}
}
fn member_span_after(src: &str, offset: usize, member: &str) -> Option<Span> {
let b = src.as_bytes();
let n = b.len();
let mut i = offset.min(n);
while i < n && (b[i] == b':' || b[i] == b'.' || b[i].is_ascii_whitespace()) {
i += 1;
}
let start = i;
while i < n && (b[i] == b'_' || b[i].is_ascii_alphanumeric()) {
i += 1;
}
if i > start && src[start..i].eq_ignore_ascii_case(member) {
Some(Span::new(start, i))
} else {
None
}
}
fn each_csharp_unit(
scope: &Node,
src: &str,
f: &mut impl FnMut(&crate::v2::csharp::ast::CsUnit, Option<&str>),
) {
scope.walk(&mut |n| {
if let NodeKind::Command {
elements, csharp, ..
} = &n.kind
{
let Some(node) = csharp else { return };
let NodeKind::CSharpMemberDef(def) = &node.kind else {
return;
};
let owner = if def.parameter.eq_ignore_ascii_case("memberdefinition") {
named_parameter_value(elements, "name").and_then(|arg| arg_name_text(arg, src))
} else {
None
};
if let Some(unit) = csharp_unit(node, src) {
f(&unit, owner.as_deref());
}
}
});
}
fn arg_name_text(node: &Node, src: &str) -> Option<String> {
match &node.kind {
NodeKind::BareWord(w) => Some(w.clone()),
_ if is_plain_quoted(node) => {
Some(safe_inner(src, node.span.start + 1, node.span.end - 1).to_string())
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::v2::{apply_edits, parse};
const PS: &str = "Add-Type -TypeDefinition @'\npublic class Win32 {\n [DllImport(\"user32.dll\")]\n public static extern int MessageBox(IntPtr h, string t, string c, uint ty);\n}\n'@\n[Win32]::MessageBox(0, 'hi', 'title', 0)\n$inst = New-Object Win32\n[Win32]$casted = $inst\n";
#[test]
fn rename_type_updates_csharp_and_all_powershell_sites() {
let out = parse(PS);
let edits = rename_type(&out.script, PS, "Win32", "NativeApi");
let result = apply_edits(PS, &edits).unwrap();
assert!(result.contains("public class NativeApi {"));
assert!(result.contains("[NativeApi]::MessageBox"));
assert!(result.contains("New-Object NativeApi"));
assert!(result.contains("[NativeApi]$casted"));
assert!(!result.contains("Win32"));
}
#[test]
fn rename_type_handles_array_generic_and_multiple_occurrences() {
let src = "[Logger[]]::new()\n\
[System.Collections.Generic.List[Logger]]$x = $null\n\
[System.Collections.Generic.Dictionary[Logger,Logger]]::new()\n\
[LoggerHelper]::Init()\n\
[My.Logger]::X()\n";
let out = parse(src);
let edits = rename_type(&out.script, src, "Logger", "Tracer");
let result = apply_edits(src, &edits).unwrap();
assert!(result.contains("[Tracer[]]::new()"));
assert!(result.contains("List[Tracer]"));
assert!(result.contains("Dictionary[Tracer,Tracer]"));
assert!(result.contains("[My.Tracer]::X()"));
assert!(result.contains("[LoggerHelper]::Init()"));
assert!(!result.contains("[Logger]"));
assert!(!result.contains("[Logger["));
}
#[test]
fn rename_type_reaches_constructors_and_new_expressions() {
let src = "Add-Type -TypeDefinition @\"\n\
public class Logger {\n\
public Logger() { }\n\
public Logger(int n) { }\n\
void M() { var a = new Logger(); var b = new Logger(5); }\n\
}\n\
\"@\n[Logger]::M()\n";
let out = parse(src);
let result = apply_edits(src, &rename_type(&out.script, src, "Logger", "Tracer")).unwrap();
assert!(result.contains("public Tracer()"));
assert!(result.contains("public Tracer(int n)"));
assert!(result.contains("new Tracer()"));
assert!(result.contains("new Tracer(5)"));
assert!(result.contains("class Tracer"));
assert!(!result.contains("Logger"), "no old name should remain");
}
#[test]
fn rename_member_reaches_member_definition_csharp_side() {
let src = "Add-Type -MemberDefinition @\"\n\
[DllImport(\"user32.dll\")]\n\
public static extern int MessageBox(int h, string m, string c, int t);\n\
\"@ -Name Win32 -Namespace Native\n\
[Native.Win32]::MessageBox(0, 'm', 'c', 0)\n";
let out = parse(src);
let result = apply_edits(
src,
&rename_member(&out.script, src, "Win32", "MessageBox", "Show"),
)
.unwrap();
assert!(result.contains("extern int Show("), "C# declaration");
assert!(
result.contains("[Native.Win32]::Show("),
"PowerShell call site"
);
let none = rename_member(&out.script, src, "Other", "MessageBox", "Show");
assert!(none.is_empty());
}
#[test]
fn rename_type_rewrites_add_type_name_argument() {
let src = "Add-Type -MemberDefinition @\"\npublic static int F() { return 0; }\n\"@ -Name Win32 -Namespace Native\n[Native.Win32]::F()\n";
let out = parse(src);
let result = apply_edits(src, &rename_type(&out.script, src, "Win32", "WinApi")).unwrap();
assert!(result.contains("-Name WinApi"));
assert!(result.contains("[Native.WinApi]::"));
assert!(result.contains("-Namespace Native"), "namespace untouched");
let other = "Get-Process -Name Win32\nAdd-Type -MemberDefinition $s -Name Win32\n";
let o2 = parse(other);
let r2 = apply_edits(other, &rename_type(&o2.script, other, "Win32", "WinApi")).unwrap();
assert!(r2.contains("Get-Process -Name Win32"));
}
#[test]
fn class_and_enum_name_span_is_populated() {
let out = parse("class Logger { }\nenum Color { Red }\n");
let mut found = Vec::new();
out.script.walk(&mut |n| match &n.kind {
NodeKind::ClassDefinition {
name, name_span, ..
}
| NodeKind::EnumDefinition {
name, name_span, ..
} => {
found.push((
name.clone(),
name_span
.slice("class Logger { }\nenum Color { Red }\n")
.to_string(),
));
}
_ => {}
});
assert_eq!(
found,
vec![
("Logger".into(), "Logger".into()),
("Color".into(), "Color".into())
]
);
}
#[test]
fn rename_type_rewrites_powershell_class_declarations() {
let src = "class Base { }\nclass Logger : Base { [int]$X }\n[Logger]::M()\nenum Color { Red }\n[Color]::Red\n";
let out = parse(src);
let renamed = apply_edits(src, &rename_type(&out.script, src, "Logger", "Tracer")).unwrap();
assert!(renamed.contains("class Tracer : Base"), "declaration");
assert!(renamed.contains("[Tracer]::M()"), "reference");
let base = apply_edits(src, &rename_type(&out.script, src, "Base", "Root")).unwrap();
assert!(base.contains("class Root"), "base declaration");
assert!(base.contains(": Root"), "base reference in header");
let en = apply_edits(src, &rename_type(&out.script, src, "Color", "Hue")).unwrap();
assert!(en.contains("enum Hue") && en.contains("[Hue]::Red"));
let cs = "class <# Logger #> Logger { }\n";
let oc = parse(cs);
let rc = apply_edits(cs, &rename_type(&oc.script, cs, "Logger", "Tracer")).unwrap();
assert!(rc.contains("<# Logger #> Tracer"));
}
#[test]
fn rename_paths_do_not_panic_on_multibyte_adjacent_to_delimiters() {
for src in [
"[文]::M()\n",
"[a文]$x = 1\n",
"New-Object '文'\n",
"return 0; -d \"@\nq -e \"@$x]=w\n",
"Add-Type -MemberDefinition @\"\nint F();\n\"@ -Name 文\n",
] {
let out = parse(src);
let _ = rename_type(&out.script, src, "文", "X");
let _ = rename_type(&out.script, src, "Logger", "X");
let _ = rename_member(&out.script, src, "文", "F", "G");
}
}
#[test]
fn rename_type_handles_non_ascii_identifiers() {
let src = "class \u{141}ogger { }\n[\u{141}ogger]::M()\nclass \u{141}oggerHelper { }\n";
let out = parse(src);
let renamed = apply_edits(
src,
&rename_type(&out.script, src, "\u{141}ogger", "Tracer"),
)
.unwrap();
assert!(renamed.contains("class Tracer"));
assert!(renamed.contains("[Tracer]::M()"));
assert!(
renamed.contains("\u{141}oggerHelper"),
"prefix-only name untouched"
);
}
#[test]
fn rename_type_handles_quoted_new_object_names() {
let src = "New-Object -TypeName 'Logger'\n\
New-Object \"Logger\"\n\
New-Object 'My.Logger'\n\
New-Object \"$prefix.Logger\"\n\
New-Object 'LoggerHelper'\n";
let out = parse(src);
let edits = rename_type(&out.script, src, "Logger", "Tracer");
let result = apply_edits(src, &edits).unwrap();
assert!(result.contains("-TypeName 'Tracer'"));
assert!(result.contains("New-Object \"Tracer\""));
assert!(result.contains("'My.Tracer'"));
assert!(
result.contains("\"$prefix.Logger\""),
"interpolated untouched"
);
assert!(result.contains("'LoggerHelper'"), "prefix-only untouched");
}
#[test]
fn rename_member_updates_csharp_decl_and_static_call_site() {
let out = parse(PS);
let edits = rename_member(&out.script, PS, "Win32", "MessageBox", "ShowMessage");
let result = apply_edits(PS, &edits).unwrap();
assert!(result.contains("extern int ShowMessage("));
assert!(result.contains("[Win32]::ShowMessage(0, 'hi', 'title', 0)"));
assert!(result.contains("public class Win32"));
assert!(!result.contains("MessageBox"));
}
#[test]
fn rename_member_leaves_unrelated_receiver_alone() {
let src = "Add-Type -TypeDefinition @'\npublic class C {\n public int Length;\n public int Measure(string s) { return s.Length + this.Length; }\n}\n'@\n";
let out = parse(src);
let edits = rename_member(&out.script, src, "C", "Length", "Size");
let result = apply_edits(src, &edits).unwrap();
assert!(result.contains("public int Size;"));
assert!(result.contains("return s.Length + this.Size;"));
}
}