Skip to main content

apple_bindgen/
objc2.rs

1//! Post-process bindgen output to use objc2 instead of objc 0.2.
2//!
3//! bindgen generates ObjC bindings using `objc` 0.2 syntax.
4//! This module transforms the output to `objc2` 0.6 syntax.
5
6/// Apply all objc → objc2 transformations to generated binding code.
7pub fn migrate(content: &str) -> String {
8    let mut result = content.to_string();
9
10    // 1. Replace `use objc::{...};` import line (items may appear in any order)
11    result = replace_objc_import(&result);
12
13    // 1b. Convert `pub type BOOL = bool;` to a newtype struct.
14    // objc2 intentionally does NOT implement Encode/RefEncode for `bool`.
15    // A newtype struct gets Encode impls from `add_struct_encode_impls` below.
16    result = result.replace(
17        "pub type BOOL = bool;",
18        "#[repr(transparent)]\n#[derive(Debug, Copy, Clone, PartialEq, Eq)]\npub struct BOOL(pub bool);",
19    );
20
21    // 2. Replace type references
22    result = result.replace("objc::runtime::Object", "objc2::runtime::AnyObject");
23    result = result.replace("objc::runtime::", "objc2::runtime::");
24
25    // 3. Transform Message impls (before renaming objc::Message in where clauses)
26    result = transform_message_impls(&result);
27
28    // 4. Replace remaining objc::Message references (in where clauses etc.)
29    result = result.replace("objc::Message", "objc2::Message");
30
31    // 5. Replace class!(Name) → objc2::runtime::AnyClass::get(c"Name").unwrap()
32    result = replace_class_macro(&result);
33
34    // 6. Transform msg_send! invocations
35    result = transform_msg_send(&result);
36
37    // 7. Add RefEncode + Encode impls for all struct types
38    result = add_struct_encode_impls(&result);
39
40    result
41}
42
43/// Replace `use objc::{...};` import line with `use objc2::msg_send;`.
44///
45/// bindgen may emit the imports in any order (e.g. `{self, msg_send, sel, sel_impl, class}`
46/// vs `{self, class, msg_send, sel, sel_impl}`), so we match any line starting with `use objc::{`.
47fn replace_objc_import(content: &str) -> String {
48    let mut result = String::with_capacity(content.len());
49    for line in content.lines() {
50        let trimmed = line.trim();
51        if trimmed.starts_with("use objc::{") && trimmed.ends_with("};") {
52            result.push_str("#[allow(unused_imports)]\nuse objc2::msg_send;\n");
53        } else {
54            result.push_str(line);
55            result.push('\n');
56        }
57    }
58    result
59}
60
61/// Transform `unsafe impl objc::Message for Foo {}` → `unsafe impl objc2::Message for Foo {}`.
62///
63/// RefEncode + Encode impls are added separately by `add_struct_encode_impls` for ALL structs.
64fn transform_message_impls(content: &str) -> String {
65    let mut result = String::with_capacity(content.len());
66    for line in content.lines() {
67        let trimmed = line.trim();
68        if let Some(rest) = trimmed.strip_prefix("unsafe impl objc::Message for ") {
69            if let Some(name) = rest.strip_suffix(" {}") {
70                result.push_str(&format!("unsafe impl objc2::Message for {name} {{}}\n"));
71                continue;
72            }
73        }
74        result.push_str(line);
75        result.push('\n');
76    }
77    result
78}
79
80/// Add `RefEncode` + `Encode` impls for ALL `pub struct` types.
81///
82/// This satisfies objc2's msg_send! trait bounds for both pointer targets
83/// (`*mut T` requires `T: RefEncode`) and direct arguments (`T: Encode`).
84///
85/// For ObjC wrapper types (single `id` field like `NSObject(pub id)`),
86/// `Encoding::Object` is used. For C/CG structs with fields (e.g. CGRect),
87/// `Encoding::Struct(name, &[])` is used as an opaque struct encoding.
88/// objc2 treats opaque structs as compatible with any struct of the same name,
89/// which passes runtime verification without needing exact field encodings.
90fn add_struct_encode_impls(content: &str) -> String {
91    let mut names = Vec::new();
92    for line in content.lines() {
93        let trimmed = line.trim();
94        let rest_opt = trimmed
95            .strip_prefix("pub struct ")
96            .or_else(|| trimmed.strip_prefix("pub union "));
97        if let Some(rest) = rest_opt {
98            if rest.contains('<') {
99                continue;
100            }
101            let name_end = rest
102                .find(|c: char| c == '{' || c == '(' || c == ';' || c == ' ')
103                .unwrap_or(rest.len());
104            let name = rest[..name_end].trim();
105            if !name.is_empty() {
106                names.push(name.to_string());
107            }
108        }
109    }
110
111    if names.is_empty() {
112        return content.to_string();
113    }
114
115    // Detect which structs are ObjC wrapper types (single `pub id` field)
116    let objc_wrappers: std::collections::HashSet<&str> = {
117        let mut set = std::collections::HashSet::new();
118        for line in content.lines() {
119            let trimmed = line.trim();
120            let rest_opt = trimmed
121                .strip_prefix("pub struct ")
122                .or_else(|| trimmed.strip_prefix("pub union "));
123            if let Some(rest) = rest_opt {
124                if rest.ends_with("(pub id);") {
125                    let name_end = rest.find('(').unwrap_or(rest.len());
126                    let name = rest[..name_end].trim();
127                    set.insert(name);
128                }
129            }
130        }
131        set
132    };
133
134    let mut impls = String::from("\n");
135    for name in &names {
136        let encoding = if objc_wrappers.contains(name.as_str()) {
137            // ObjC wrapper type: use Object encoding
138            "objc2::encode::Encoding::Object".to_string()
139        } else {
140            // C struct: use opaque Struct encoding (name only, no fields)
141            format!("objc2::encode::Encoding::Struct(\"{name}\", &[])")
142        };
143        let ref_encoding =
144            format!("objc2::encode::Encoding::Pointer(&<Self as objc2::encode::Encode>::ENCODING)");
145        impls.push_str(&format!(
146            "unsafe impl objc2::encode::RefEncode for {name} {{\n    \
147             const ENCODING_REF: objc2::encode::Encoding = {ref_encoding};\n\
148             }}\n\
149             unsafe impl objc2::encode::Encode for {name} {{\n    \
150             const ENCODING: objc2::encode::Encoding = {encoding};\n\
151             }}\n"
152        ));
153    }
154
155    let mut result = content.to_string();
156    result.push_str(&impls);
157    result
158}
159
160/// Replace `class!(ClassName)` / `class ! (ClassName)` with
161/// `objc2::runtime::AnyClass::get(c"ClassName").unwrap()`.
162fn replace_class_macro(content: &str) -> String {
163    let mut result = String::with_capacity(content.len());
164    let mut remaining = content;
165    while let Some(pos) = remaining.find("class") {
166        let after_class = &remaining[pos + 5..];
167        let mut j = 0;
168        while j < after_class.len() && after_class.as_bytes()[j] == b' ' {
169            j += 1;
170        }
171        if j < after_class.len() && after_class.as_bytes()[j] == b'!' {
172            j += 1;
173            while j < after_class.len() && after_class.as_bytes()[j] == b' ' {
174                j += 1;
175            }
176            if j < after_class.len() && after_class.as_bytes()[j] == b'(' {
177                j += 1;
178                let args_start = j;
179                if let Some(close) = after_class[args_start..].find(')') {
180                    let name = after_class[args_start..args_start + close].trim();
181                    result.push_str(&remaining[..pos]);
182                    result.push_str(&format!(
183                        "objc2::runtime::AnyClass::get(c\"{name}\").unwrap()"
184                    ));
185                    remaining = &after_class[args_start + close + 1..];
186                    continue;
187                }
188            }
189        }
190        result.push_str(&remaining[..pos + 5]);
191        remaining = after_class;
192    }
193    result.push_str(remaining);
194    result
195}
196
197/// Transform msg_send! invocations:
198/// - Wrap receiver with `&*` for objc2's MessageReceiver trait
199/// - Add commas between multi-arg selector parts
200fn transform_msg_send(content: &str) -> String {
201    let mut result = String::with_capacity(content.len());
202    let bytes = content.as_bytes();
203    let mut i = 0;
204
205    while i < bytes.len() {
206        if i + 8 <= bytes.len() && &content[i..i + 8] == "msg_send" {
207            let mut j = i + 8;
208            while j < bytes.len() && (bytes[j] == b' ' || bytes[j] == b'!') {
209                j += 1;
210            }
211            if j < bytes.len() && (bytes[j] == b'(' || bytes[j] == b'[') {
212                let open = bytes[j];
213                let close = if open == b'(' { b')' } else { b']' };
214
215                let mut depth = 1i32;
216                let mut end = j + 1;
217                while end < bytes.len() && depth > 0 {
218                    if bytes[end] == open {
219                        depth += 1;
220                    }
221                    if bytes[end] == close {
222                        depth -= 1;
223                    }
224                    if depth > 0 {
225                        end += 1;
226                    }
227                }
228
229                let body = &content[j + 1..end];
230
231                let mut d = 0i32;
232                let mut recv_end = None;
233                for (k, &b) in body.as_bytes().iter().enumerate() {
234                    match b {
235                        b'(' | b'[' => d += 1,
236                        b')' | b']' => d -= 1,
237                        b',' if d == 0 => {
238                            recv_end = Some(k);
239                            break;
240                        }
241                        _ => {}
242                    }
243                }
244
245                if let Some(re) = recv_end {
246                    let receiver = body[..re].trim();
247                    let selector = body[re + 1..].trim();
248                    let new_receiver = format!("&*{receiver}");
249                    let new_selector = insert_selector_commas(selector);
250
251                    result.push_str("msg_send!");
252                    result.push(open as char);
253                    result.push_str(&new_receiver);
254                    result.push_str(", ");
255                    result.push_str(&new_selector);
256                    result.push(close as char);
257                } else {
258                    result.push_str("msg_send!");
259                    result.push(open as char);
260                    result.push_str(body);
261                    result.push(close as char);
262                }
263
264                i = end + 1;
265                continue;
266            }
267        }
268        result.push(bytes[i] as char);
269        i += 1;
270    }
271    result
272}
273
274/// Insert commas between multi-arg selector parts in a msg_send body.
275///
276/// `getValue : value size : size` → `getValue : value, size : size`
277fn insert_selector_commas(selector: &str) -> String {
278    let parts: Vec<&str> = selector.split(" : ").collect();
279    if parts.len() <= 2 {
280        return selector.to_string();
281    }
282
283    let mut result = String::new();
284    result.push_str(parts[0]);
285    result.push_str(" : ");
286
287    for i in 1..parts.len() - 1 {
288        let part = parts[i];
289        if let Some(last_space) = part.rfind(' ') {
290            let arg = &part[..last_space];
291            let next_sel = &part[last_space + 1..];
292            result.push_str(arg);
293            result.push_str(", ");
294            result.push_str(next_sel);
295        } else {
296            result.push_str(part);
297        }
298        result.push_str(" : ");
299    }
300
301    result.push_str(parts[parts.len() - 1]);
302    result
303}