pub fn migrate(content: &str) -> String {
let mut result = content.to_string();
result = replace_objc_import(&result);
result = result.replace(
"pub type BOOL = bool;",
"#[repr(transparent)]\n#[derive(Debug, Copy, Clone, PartialEq, Eq)]\npub struct BOOL(pub bool);",
);
result = result.replace("objc::runtime::Object", "objc2::runtime::AnyObject");
result = result.replace("objc::runtime::", "objc2::runtime::");
result = transform_message_impls(&result);
result = result.replace("objc::Message", "objc2::Message");
result = replace_class_macro(&result);
result = transform_msg_send(&result);
result = add_struct_encode_impls(&result);
result
}
fn replace_objc_import(content: &str) -> String {
let mut result = String::with_capacity(content.len());
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("use objc::{") && trimmed.ends_with("};") {
result.push_str("#[allow(unused_imports)]\nuse objc2::msg_send;\n");
} else {
result.push_str(line);
result.push('\n');
}
}
result
}
fn transform_message_impls(content: &str) -> String {
let mut result = String::with_capacity(content.len());
for line in content.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed.strip_prefix("unsafe impl objc::Message for ") {
if let Some(name) = rest.strip_suffix(" {}") {
result.push_str(&format!("unsafe impl objc2::Message for {name} {{}}\n"));
continue;
}
}
result.push_str(line);
result.push('\n');
}
result
}
fn add_struct_encode_impls(content: &str) -> String {
let mut names = Vec::new();
for line in content.lines() {
let trimmed = line.trim();
let rest_opt = trimmed
.strip_prefix("pub struct ")
.or_else(|| trimmed.strip_prefix("pub union "));
if let Some(rest) = rest_opt {
if rest.contains('<') {
continue;
}
let name_end = rest
.find(|c: char| c == '{' || c == '(' || c == ';' || c == ' ')
.unwrap_or(rest.len());
let name = rest[..name_end].trim();
if !name.is_empty() {
names.push(name.to_string());
}
}
}
if names.is_empty() {
return content.to_string();
}
let objc_wrappers: std::collections::HashSet<&str> = {
let mut set = std::collections::HashSet::new();
for line in content.lines() {
let trimmed = line.trim();
let rest_opt = trimmed
.strip_prefix("pub struct ")
.or_else(|| trimmed.strip_prefix("pub union "));
if let Some(rest) = rest_opt {
if rest.ends_with("(pub id);") {
let name_end = rest.find('(').unwrap_or(rest.len());
let name = rest[..name_end].trim();
set.insert(name);
}
}
}
set
};
let mut impls = String::from("\n");
for name in &names {
let encoding = if objc_wrappers.contains(name.as_str()) {
"objc2::encode::Encoding::Object".to_string()
} else {
format!("objc2::encode::Encoding::Struct(\"{name}\", &[])")
};
let ref_encoding =
format!("objc2::encode::Encoding::Pointer(&<Self as objc2::encode::Encode>::ENCODING)");
impls.push_str(&format!(
"unsafe impl objc2::encode::RefEncode for {name} {{\n \
const ENCODING_REF: objc2::encode::Encoding = {ref_encoding};\n\
}}\n\
unsafe impl objc2::encode::Encode for {name} {{\n \
const ENCODING: objc2::encode::Encoding = {encoding};\n\
}}\n"
));
}
let mut result = content.to_string();
result.push_str(&impls);
result
}
fn replace_class_macro(content: &str) -> String {
let mut result = String::with_capacity(content.len());
let mut remaining = content;
while let Some(pos) = remaining.find("class") {
let after_class = &remaining[pos + 5..];
let mut j = 0;
while j < after_class.len() && after_class.as_bytes()[j] == b' ' {
j += 1;
}
if j < after_class.len() && after_class.as_bytes()[j] == b'!' {
j += 1;
while j < after_class.len() && after_class.as_bytes()[j] == b' ' {
j += 1;
}
if j < after_class.len() && after_class.as_bytes()[j] == b'(' {
j += 1;
let args_start = j;
if let Some(close) = after_class[args_start..].find(')') {
let name = after_class[args_start..args_start + close].trim();
result.push_str(&remaining[..pos]);
result.push_str(&format!(
"objc2::runtime::AnyClass::get(c\"{name}\").unwrap()"
));
remaining = &after_class[args_start + close + 1..];
continue;
}
}
}
result.push_str(&remaining[..pos + 5]);
remaining = after_class;
}
result.push_str(remaining);
result
}
fn transform_msg_send(content: &str) -> String {
let mut result = String::with_capacity(content.len());
let bytes = content.as_bytes();
let mut i = 0;
while i < bytes.len() {
if i + 8 <= bytes.len() && &content[i..i + 8] == "msg_send" {
let mut j = i + 8;
while j < bytes.len() && (bytes[j] == b' ' || bytes[j] == b'!') {
j += 1;
}
if j < bytes.len() && (bytes[j] == b'(' || bytes[j] == b'[') {
let open = bytes[j];
let close = if open == b'(' { b')' } else { b']' };
let mut depth = 1i32;
let mut end = j + 1;
while end < bytes.len() && depth > 0 {
if bytes[end] == open {
depth += 1;
}
if bytes[end] == close {
depth -= 1;
}
if depth > 0 {
end += 1;
}
}
let body = &content[j + 1..end];
let mut d = 0i32;
let mut recv_end = None;
for (k, &b) in body.as_bytes().iter().enumerate() {
match b {
b'(' | b'[' => d += 1,
b')' | b']' => d -= 1,
b',' if d == 0 => {
recv_end = Some(k);
break;
}
_ => {}
}
}
if let Some(re) = recv_end {
let receiver = body[..re].trim();
let selector = body[re + 1..].trim();
let new_receiver = format!("&*{receiver}");
let new_selector = insert_selector_commas(selector);
result.push_str("msg_send!");
result.push(open as char);
result.push_str(&new_receiver);
result.push_str(", ");
result.push_str(&new_selector);
result.push(close as char);
} else {
result.push_str("msg_send!");
result.push(open as char);
result.push_str(body);
result.push(close as char);
}
i = end + 1;
continue;
}
}
result.push(bytes[i] as char);
i += 1;
}
result
}
fn insert_selector_commas(selector: &str) -> String {
let parts: Vec<&str> = selector.split(" : ").collect();
if parts.len() <= 2 {
return selector.to_string();
}
let mut result = String::new();
result.push_str(parts[0]);
result.push_str(" : ");
for i in 1..parts.len() - 1 {
let part = parts[i];
if let Some(last_space) = part.rfind(' ') {
let arg = &part[..last_space];
let next_sel = &part[last_space + 1..];
result.push_str(arg);
result.push_str(", ");
result.push_str(next_sel);
} else {
result.push_str(part);
}
result.push_str(" : ");
}
result.push_str(parts[parts.len() - 1]);
result
}