Skip to main content

aptos_sdk/codegen/
types.rs

1//! Type mapping between Move and Rust types.
2
3use std::collections::HashMap;
4
5/// A Rust type representation.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct RustType {
8    /// The full type path (e.g., "`Vec<u8>`", "`AccountAddress`").
9    pub path: String,
10    /// Whether this type requires BCS serialization as an argument.
11    pub needs_bcs: bool,
12    /// Whether this type is a reference.
13    pub is_ref: bool,
14    /// Documentation for this type.
15    pub doc: Option<String>,
16}
17
18impl RustType {
19    /// Creates a new Rust type.
20    #[must_use]
21    pub fn new(path: impl Into<String>) -> Self {
22        Self {
23            path: path.into(),
24            needs_bcs: true,
25            is_ref: false,
26            doc: None,
27        }
28    }
29
30    /// Creates a type that doesn't need BCS serialization.
31    #[must_use]
32    pub fn primitive(path: impl Into<String>) -> Self {
33        Self {
34            path: path.into(),
35            needs_bcs: false,
36            is_ref: false,
37            doc: None,
38        }
39    }
40
41    /// Creates a reference type.
42    #[must_use]
43    pub fn reference(mut self) -> Self {
44        self.is_ref = true;
45        self
46    }
47
48    /// Adds documentation.
49    #[must_use]
50    pub fn with_doc(mut self, doc: impl Into<String>) -> Self {
51        self.doc = Some(doc.into());
52        self
53    }
54
55    /// Returns the type as a function argument type.
56    pub fn as_arg_type(&self) -> String {
57        if self.is_ref {
58            format!("&{}", self.path)
59        } else {
60            self.path.clone()
61        }
62    }
63
64    /// Returns the type as a return type.
65    pub fn as_return_type(&self) -> String {
66        self.path.clone()
67    }
68}
69
70/// Maps Move types to Rust types.
71#[derive(Debug, Clone)]
72pub struct MoveTypeMapper {
73    /// Custom type mappings.
74    custom_mappings: HashMap<String, RustType>,
75}
76
77impl Default for MoveTypeMapper {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83impl MoveTypeMapper {
84    /// Creates a new type mapper with default mappings.
85    pub fn new() -> Self {
86        Self {
87            custom_mappings: HashMap::new(),
88        }
89    }
90
91    /// Adds a custom type mapping.
92    pub fn add_mapping(&mut self, move_type: impl Into<String>, rust_type: RustType) {
93        self.custom_mappings.insert(move_type.into(), rust_type);
94    }
95
96    /// Maps a Move type string to a Rust type.
97    pub fn map_type(&self, move_type: &str) -> RustType {
98        // Check custom mappings first
99        if let Some(rust_type) = self.custom_mappings.get(move_type) {
100            return rust_type.clone();
101        }
102
103        // Handle primitive types
104        match move_type {
105            "bool" => RustType::primitive("bool"),
106            "u8" => RustType::primitive("u8"),
107            "u16" => RustType::primitive("u16"),
108            "u32" => RustType::primitive("u32"),
109            "u64" => RustType::primitive("u64"),
110            "u128" => RustType::primitive("u128"),
111            "u256" => RustType::new("U256"),
112            "address" => RustType::new("AccountAddress"),
113            "signer" | "&signer" => RustType::new("AccountAddress")
114                .with_doc("Signer address (automatically set to sender)"),
115            _ => self.map_complex_type(move_type),
116        }
117    }
118
119    /// Maps complex Move types (vectors, structs, etc.)
120    fn map_complex_type(&self, move_type: &str) -> RustType {
121        // Handle vector types
122        if move_type.starts_with("vector<") && move_type.ends_with('>') {
123            let inner = &move_type[7..move_type.len() - 1];
124            let inner_type = self.map_type(inner);
125
126            // Special case: `vector<u8>` -> `Vec<u8>` (bytes)
127            if inner == "u8" {
128                return RustType::new("Vec<u8>").with_doc("Bytes");
129            }
130
131            return RustType::new(format!("Vec<{}>", inner_type.path));
132        }
133
134        // Handle Option types (0x1::option::Option<T>)
135        if move_type.contains("::option::Option<")
136            && let Some(start) = move_type.find("Option<")
137        {
138            let rest = &move_type[start + 7..];
139            if let Some(end) = rest.rfind('>') {
140                let inner = &rest[..end];
141                let inner_type = self.map_type(inner);
142                return RustType::new(format!("Option<{}>", inner_type.path));
143            }
144        }
145
146        // Handle String type
147        if move_type == "0x1::string::String" || move_type.ends_with("::string::String") {
148            return RustType::new("String");
149        }
150
151        // Handle Object types
152        if move_type.contains("::object::Object<") {
153            return RustType::new("AccountAddress").with_doc("Object address");
154        }
155
156        // Handle generic struct types (e.g., 0x1::coin::Coin<0x1::aptos_coin::AptosCoin>)
157        if move_type.contains("::") {
158            // Use rsplit to avoid collecting into Vec - we only need the last part
159            let part_count = move_type.matches("::").count() + 1;
160            if part_count >= 3 {
161                // Get the base struct name (without generics) using rsplit
162                if let Some(struct_name) = move_type.rsplit("::").next() {
163                    let base_name = struct_name.split('<').next().unwrap_or(struct_name);
164
165                    // Create a pascal case name
166                    let rust_name = to_pascal_case(base_name);
167                    return RustType::new(rust_name).with_doc(format!("Move type: {move_type}"));
168                }
169            }
170        }
171
172        // Default: use serde_json::Value for unknown types
173        RustType::new("serde_json::Value").with_doc(format!("Unknown Move type: {move_type}"))
174    }
175
176    /// Maps a Move type to a BCS argument encoding expression.
177    pub fn to_bcs_arg(&self, move_type: &str, var_name: &str) -> String {
178        let rust_type = self.map_type(move_type);
179
180        if !rust_type.needs_bcs {
181            // Primitives that don't need special handling
182            return format!("aptos_bcs::to_bytes(&{var_name}).unwrap()");
183        }
184
185        match move_type {
186            "address" => format!("aptos_bcs::to_bytes(&{var_name}).unwrap()"),
187            _ if move_type.starts_with("vector<u8>") => {
188                format!("aptos_bcs::to_bytes(&{var_name}).unwrap()")
189            }
190            _ if move_type.starts_with("vector<") => {
191                format!("aptos_bcs::to_bytes(&{var_name}).unwrap()")
192            }
193            "0x1::string::String" => format!("aptos_bcs::to_bytes(&{var_name}).unwrap()"),
194            _ if move_type.ends_with("::string::String") => {
195                format!("aptos_bcs::to_bytes(&{var_name}).unwrap()")
196            }
197            _ => format!("aptos_bcs::to_bytes(&{var_name}).unwrap()"),
198        }
199    }
200
201    /// Determines if a parameter should be excluded from the function signature.
202    /// (e.g., &signer is always the sender)
203    pub fn is_signer_param(&self, move_type: &str) -> bool {
204        move_type == "&signer" || move_type == "signer"
205    }
206}
207
208/// Converts a `snake_case` or other string to `PascalCase`.
209pub fn to_pascal_case(s: &str) -> String {
210    let mut result = String::new();
211    let mut capitalize_next = true;
212
213    for c in s.chars() {
214        if c == '_' || c == '-' || c == ' ' {
215            capitalize_next = true;
216        } else if capitalize_next {
217            result.push(c.to_ascii_uppercase());
218            capitalize_next = false;
219        } else {
220            result.push(c);
221        }
222    }
223
224    result
225}
226
227/// Converts a `PascalCase` or other string to `snake_case`.
228pub fn to_snake_case(s: &str) -> String {
229    let mut result = String::new();
230
231    for (i, c) in s.chars().enumerate() {
232        if c.is_ascii_uppercase() {
233            if i > 0 {
234                result.push('_');
235            }
236            result.push(c.to_ascii_lowercase());
237        } else {
238            result.push(c);
239        }
240    }
241
242    result
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_primitive_mapping() {
251        let mapper = MoveTypeMapper::new();
252
253        assert_eq!(mapper.map_type("bool").path, "bool");
254        assert_eq!(mapper.map_type("u8").path, "u8");
255        assert_eq!(mapper.map_type("u64").path, "u64");
256        assert_eq!(mapper.map_type("u128").path, "u128");
257        assert_eq!(mapper.map_type("address").path, "AccountAddress");
258    }
259
260    #[test]
261    fn test_vector_mapping() {
262        let mapper = MoveTypeMapper::new();
263
264        assert_eq!(mapper.map_type("vector<u8>").path, "Vec<u8>");
265        assert_eq!(
266            mapper.map_type("vector<address>").path,
267            "Vec<AccountAddress>"
268        );
269        assert_eq!(mapper.map_type("vector<u64>").path, "Vec<u64>");
270    }
271
272    #[test]
273    fn test_string_mapping() {
274        let mapper = MoveTypeMapper::new();
275
276        assert_eq!(mapper.map_type("0x1::string::String").path, "String");
277    }
278
279    #[test]
280    fn test_to_pascal_case() {
281        assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
282        assert_eq!(to_pascal_case("coin"), "Coin");
283        assert_eq!(to_pascal_case("aptos_coin"), "AptosCoin");
284    }
285
286    #[test]
287    fn test_to_snake_case() {
288        assert_eq!(to_snake_case("HelloWorld"), "hello_world");
289        assert_eq!(to_snake_case("Coin"), "coin");
290        assert_eq!(to_snake_case("AptosCoin"), "aptos_coin");
291    }
292
293    #[test]
294    fn test_rust_type_new() {
295        let rt = RustType::new("MyType");
296        assert_eq!(rt.path, "MyType");
297        assert!(rt.needs_bcs);
298        assert!(!rt.is_ref);
299        assert!(rt.doc.is_none());
300    }
301
302    #[test]
303    fn test_rust_type_primitive() {
304        let rt = RustType::primitive("u64");
305        assert_eq!(rt.path, "u64");
306        assert!(!rt.needs_bcs);
307    }
308
309    #[test]
310    fn test_rust_type_reference() {
311        let rt = RustType::new("MyType").reference();
312        assert!(rt.is_ref);
313        assert_eq!(rt.as_arg_type(), "&MyType");
314    }
315
316    #[test]
317    fn test_rust_type_with_doc() {
318        let rt = RustType::new("MyType").with_doc("My documentation");
319        assert_eq!(rt.doc, Some("My documentation".to_string()));
320    }
321
322    #[test]
323    fn test_rust_type_as_return_type() {
324        let rt = RustType::new("MyType").reference();
325        assert_eq!(rt.as_return_type(), "MyType"); // References don't affect return type
326    }
327
328    #[test]
329    fn test_mapper_default() {
330        let mapper = MoveTypeMapper::default();
331        assert_eq!(mapper.map_type("bool").path, "bool");
332    }
333
334    #[test]
335    fn test_mapper_custom_mapping() {
336        let mut mapper = MoveTypeMapper::new();
337        mapper.add_mapping("MyCustomType", RustType::new("CustomRustType"));
338        assert_eq!(mapper.map_type("MyCustomType").path, "CustomRustType");
339    }
340
341    #[test]
342    fn test_mapper_u256() {
343        let mapper = MoveTypeMapper::new();
344        assert_eq!(mapper.map_type("u256").path, "U256");
345    }
346
347    #[test]
348    fn test_mapper_signer() {
349        let mapper = MoveTypeMapper::new();
350        assert_eq!(mapper.map_type("&signer").path, "AccountAddress");
351        assert_eq!(mapper.map_type("signer").path, "AccountAddress");
352    }
353
354    #[test]
355    fn test_mapper_nested_vector() {
356        let mapper = MoveTypeMapper::new();
357        let result = mapper.map_type("vector<vector<u8>>");
358        assert_eq!(result.path, "Vec<Vec<u8>>");
359    }
360
361    #[test]
362    fn test_mapper_option_type() {
363        let mapper = MoveTypeMapper::new();
364        let result = mapper.map_type("0x1::option::Option<u64>");
365        assert_eq!(result.path, "Option<u64>");
366    }
367
368    #[test]
369    fn test_mapper_object_type() {
370        let mapper = MoveTypeMapper::new();
371        let result = mapper.map_type("0x1::object::Object<Token>");
372        assert_eq!(result.path, "AccountAddress");
373    }
374
375    #[test]
376    fn test_mapper_unknown_struct() {
377        let mapper = MoveTypeMapper::new();
378        let result = mapper.map_type("0x1::module::SomeStruct");
379        assert!(result.doc.is_some());
380    }
381
382    #[test]
383    fn test_mapper_unknown_type() {
384        let mapper = MoveTypeMapper::new();
385        let result = mapper.map_type("some_completely_unknown_thing");
386        assert_eq!(result.path, "serde_json::Value");
387    }
388
389    #[test]
390    fn test_to_bcs_arg_address() {
391        let mapper = MoveTypeMapper::new();
392        let result = mapper.to_bcs_arg("address", "my_addr");
393        assert!(result.contains("aptos_bcs::to_bytes"));
394        assert!(result.contains("my_addr"));
395    }
396
397    #[test]
398    fn test_to_bcs_arg_vector_u8() {
399        let mapper = MoveTypeMapper::new();
400        let result = mapper.to_bcs_arg("vector<u8>", "my_bytes");
401        assert!(result.contains("aptos_bcs::to_bytes"));
402    }
403
404    #[test]
405    fn test_to_bcs_arg_vector_other() {
406        let mapper = MoveTypeMapper::new();
407        let result = mapper.to_bcs_arg("vector<u64>", "my_vec");
408        assert!(result.contains("aptos_bcs::to_bytes"));
409    }
410
411    #[test]
412    fn test_to_bcs_arg_string() {
413        let mapper = MoveTypeMapper::new();
414        let result = mapper.to_bcs_arg("0x1::string::String", "my_string");
415        assert!(result.contains("aptos_bcs::to_bytes"));
416    }
417
418    #[test]
419    fn test_to_bcs_arg_other_string() {
420        let mapper = MoveTypeMapper::new();
421        let result = mapper.to_bcs_arg("0xabc::my_module::string::String", "s");
422        assert!(result.contains("aptos_bcs::to_bytes"));
423    }
424
425    #[test]
426    fn test_is_signer_param() {
427        let mapper = MoveTypeMapper::new();
428        assert!(mapper.is_signer_param("&signer"));
429        assert!(mapper.is_signer_param("signer"));
430        assert!(!mapper.is_signer_param("address"));
431        assert!(!mapper.is_signer_param("u64"));
432    }
433
434    #[test]
435    fn test_to_pascal_case_with_spaces() {
436        assert_eq!(to_pascal_case("hello world"), "HelloWorld");
437    }
438
439    #[test]
440    fn test_to_pascal_case_with_dashes() {
441        assert_eq!(to_pascal_case("hello-world"), "HelloWorld");
442    }
443
444    #[test]
445    fn test_to_snake_case_single_word() {
446        assert_eq!(to_snake_case("hello"), "hello");
447    }
448
449    #[test]
450    fn test_to_snake_case_already_lowercase() {
451        assert_eq!(to_snake_case("helloworld"), "helloworld");
452    }
453}