Skip to main content

aptos_sdk/codegen/
types.rs

1//! Type mapping between Move and Rust types.
2
3use std::collections::HashMap;
4
5/// A Rust type representation.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct RustType {
8    /// The full type path (e.g., "`Vec<u8>`", "`AccountAddress`").
9    pub path: String,
10    /// Whether this type requires BCS serialization as an argument.
11    pub needs_bcs: bool,
12    /// Whether this type is a reference.
13    pub is_ref: bool,
14    /// Documentation for this type.
15    pub doc: Option<String>,
16}
17
18impl RustType {
19    /// Creates a new Rust type.
20    #[must_use]
21    pub fn new(path: impl Into<String>) -> Self {
22        Self {
23            path: path.into(),
24            needs_bcs: true,
25            is_ref: false,
26            doc: None,
27        }
28    }
29
30    /// Creates a type that doesn't need BCS serialization.
31    #[must_use]
32    pub fn primitive(path: impl Into<String>) -> Self {
33        Self {
34            path: path.into(),
35            needs_bcs: false,
36            is_ref: false,
37            doc: None,
38        }
39    }
40
41    /// Creates a reference type.
42    #[must_use]
43    pub fn reference(mut self) -> Self {
44        self.is_ref = true;
45        self
46    }
47
48    /// Adds documentation.
49    #[must_use]
50    pub fn with_doc(mut self, doc: impl Into<String>) -> Self {
51        self.doc = Some(doc.into());
52        self
53    }
54
55    /// Returns the type as a function argument type.
56    pub fn as_arg_type(&self) -> String {
57        if self.is_ref {
58            format!("&{}", self.path)
59        } else {
60            self.path.clone()
61        }
62    }
63
64    /// Returns the type as a return type.
65    pub fn as_return_type(&self) -> String {
66        self.path.clone()
67    }
68}
69
70/// Maps Move types to Rust types.
71#[derive(Debug, Clone)]
72pub struct MoveTypeMapper {
73    /// Custom type mappings.
74    custom_mappings: HashMap<String, RustType>,
75}
76
77impl Default for MoveTypeMapper {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83impl MoveTypeMapper {
84    /// Creates a new type mapper with default mappings.
85    pub fn new() -> Self {
86        Self {
87            custom_mappings: HashMap::new(),
88        }
89    }
90
91    /// Adds a custom type mapping.
92    pub fn add_mapping(&mut self, move_type: impl Into<String>, rust_type: RustType) {
93        self.custom_mappings.insert(move_type.into(), rust_type);
94    }
95
96    /// Maps a Move type string to a Rust type.
97    pub fn map_type(&self, move_type: &str) -> RustType {
98        // Check custom mappings first
99        if let Some(rust_type) = self.custom_mappings.get(move_type) {
100            return rust_type.clone();
101        }
102
103        // Handle primitive types
104        match move_type {
105            "bool" => RustType::primitive("bool"),
106            "u8" => RustType::primitive("u8"),
107            "u16" => RustType::primitive("u16"),
108            "u32" => RustType::primitive("u32"),
109            "u64" => RustType::primitive("u64"),
110            "u128" => RustType::primitive("u128"),
111            "u256" => RustType::new("U256"),
112            "address" => RustType::new("AccountAddress"),
113            "signer" | "&signer" => RustType::new("AccountAddress")
114                .with_doc("Signer address (automatically set to sender)"),
115            _ => self.map_complex_type(move_type),
116        }
117    }
118
119    /// Maps complex Move types (vectors, structs, etc.)
120    fn map_complex_type(&self, move_type: &str) -> RustType {
121        // Handle vector types
122        if move_type.starts_with("vector<") && move_type.ends_with('>') {
123            let inner = &move_type[7..move_type.len() - 1];
124            let inner_type = self.map_type(inner);
125
126            // Special case: `vector<u8>` -> `Vec<u8>` (bytes)
127            if inner == "u8" {
128                return RustType::new("Vec<u8>").with_doc("Bytes");
129            }
130
131            return RustType::new(format!("Vec<{}>", inner_type.path));
132        }
133
134        // Handle Option types (0x1::option::Option<T>)
135        if move_type.contains("::option::Option<")
136            && let Some(start) = move_type.find("Option<")
137        {
138            let rest = &move_type[start + 7..];
139            if let Some(end) = rest.rfind('>') {
140                let inner = &rest[..end];
141                let inner_type = self.map_type(inner);
142                return RustType::new(format!("Option<{}>", inner_type.path));
143            }
144        }
145
146        // Handle String type
147        if move_type == "0x1::string::String" || move_type.ends_with("::string::String") {
148            return RustType::new("String");
149        }
150
151        // Handle Object types
152        if move_type.contains("::object::Object<") {
153            return RustType::new("AccountAddress").with_doc("Object address");
154        }
155
156        // Handle generic struct types (e.g., 0x1::coin::Coin<0x1::aptos_coin::AptosCoin>)
157        if move_type.contains("::") {
158            // Use rsplit to avoid collecting into Vec - we only need the last part
159            let part_count = move_type.matches("::").count() + 1;
160            if part_count >= 3 {
161                // Get the base struct name (without generics) using rsplit
162                if let Some(struct_name) = move_type.rsplit("::").next() {
163                    let base_name = struct_name.split('<').next().unwrap_or(struct_name);
164
165                    // Create a pascal case name
166                    let rust_name = to_pascal_case(base_name);
167                    return RustType::new(rust_name).with_doc(format!("Move type: {move_type}"));
168                }
169            }
170        }
171
172        // Default: use serde_json::Value for unknown types
173        RustType::new("serde_json::Value").with_doc(format!("Unknown Move type: {move_type}"))
174    }
175
176    /// Maps a Move type to a BCS argument encoding expression.
177    pub fn to_bcs_arg(&self, move_type: &str, var_name: &str) -> String {
178        let rust_type = self.map_type(move_type);
179
180        if !rust_type.needs_bcs {
181            // Primitives that don't need special handling
182            return format!(
183                "aptos_bcs::to_bytes(&{var_name}).map_err(|e| AptosError::Bcs(e.to_string()))?"
184            );
185        }
186
187        // SECURITY: Use error propagation instead of .unwrap() to prevent
188        // panics in generated code if BCS serialization fails.
189        format!("aptos_bcs::to_bytes(&{var_name}).map_err(|e| AptosError::Bcs(e.to_string()))?")
190    }
191
192    /// Determines if a parameter should be excluded from the function signature.
193    /// (e.g., &signer is always the sender)
194    pub fn is_signer_param(&self, move_type: &str) -> bool {
195        move_type == "&signer" || move_type == "signer"
196    }
197}
198
199/// Converts a `snake_case` or other string to `PascalCase`.
200pub fn to_pascal_case(s: &str) -> String {
201    let mut result = String::new();
202    let mut capitalize_next = true;
203
204    for c in s.chars() {
205        if c == '_' || c == '-' || c == ' ' {
206            capitalize_next = true;
207        } else if capitalize_next {
208            result.push(c.to_ascii_uppercase());
209            capitalize_next = false;
210        } else {
211            result.push(c);
212        }
213    }
214
215    result
216}
217
218/// Converts a `PascalCase` or other string to `snake_case`.
219pub fn to_snake_case(s: &str) -> String {
220    let mut result = String::new();
221
222    for (i, c) in s.chars().enumerate() {
223        if c.is_ascii_uppercase() {
224            if i > 0 {
225                result.push('_');
226            }
227            result.push(c.to_ascii_lowercase());
228        } else {
229            result.push(c);
230        }
231    }
232
233    result
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_primitive_mapping() {
242        let mapper = MoveTypeMapper::new();
243
244        assert_eq!(mapper.map_type("bool").path, "bool");
245        assert_eq!(mapper.map_type("u8").path, "u8");
246        assert_eq!(mapper.map_type("u64").path, "u64");
247        assert_eq!(mapper.map_type("u128").path, "u128");
248        assert_eq!(mapper.map_type("address").path, "AccountAddress");
249    }
250
251    #[test]
252    fn test_vector_mapping() {
253        let mapper = MoveTypeMapper::new();
254
255        assert_eq!(mapper.map_type("vector<u8>").path, "Vec<u8>");
256        assert_eq!(
257            mapper.map_type("vector<address>").path,
258            "Vec<AccountAddress>"
259        );
260        assert_eq!(mapper.map_type("vector<u64>").path, "Vec<u64>");
261    }
262
263    #[test]
264    fn test_string_mapping() {
265        let mapper = MoveTypeMapper::new();
266
267        assert_eq!(mapper.map_type("0x1::string::String").path, "String");
268    }
269
270    #[test]
271    fn test_to_pascal_case() {
272        assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
273        assert_eq!(to_pascal_case("coin"), "Coin");
274        assert_eq!(to_pascal_case("aptos_coin"), "AptosCoin");
275    }
276
277    #[test]
278    fn test_to_snake_case() {
279        assert_eq!(to_snake_case("HelloWorld"), "hello_world");
280        assert_eq!(to_snake_case("Coin"), "coin");
281        assert_eq!(to_snake_case("AptosCoin"), "aptos_coin");
282    }
283
284    #[test]
285    fn test_rust_type_new() {
286        let rt = RustType::new("MyType");
287        assert_eq!(rt.path, "MyType");
288        assert!(rt.needs_bcs);
289        assert!(!rt.is_ref);
290        assert!(rt.doc.is_none());
291    }
292
293    #[test]
294    fn test_rust_type_primitive() {
295        let rt = RustType::primitive("u64");
296        assert_eq!(rt.path, "u64");
297        assert!(!rt.needs_bcs);
298    }
299
300    #[test]
301    fn test_rust_type_reference() {
302        let rt = RustType::new("MyType").reference();
303        assert!(rt.is_ref);
304        assert_eq!(rt.as_arg_type(), "&MyType");
305    }
306
307    #[test]
308    fn test_rust_type_with_doc() {
309        let rt = RustType::new("MyType").with_doc("My documentation");
310        assert_eq!(rt.doc, Some("My documentation".to_string()));
311    }
312
313    #[test]
314    fn test_rust_type_as_return_type() {
315        let rt = RustType::new("MyType").reference();
316        assert_eq!(rt.as_return_type(), "MyType"); // References don't affect return type
317    }
318
319    #[test]
320    fn test_mapper_default() {
321        let mapper = MoveTypeMapper::default();
322        assert_eq!(mapper.map_type("bool").path, "bool");
323    }
324
325    #[test]
326    fn test_mapper_custom_mapping() {
327        let mut mapper = MoveTypeMapper::new();
328        mapper.add_mapping("MyCustomType", RustType::new("CustomRustType"));
329        assert_eq!(mapper.map_type("MyCustomType").path, "CustomRustType");
330    }
331
332    #[test]
333    fn test_mapper_u256() {
334        let mapper = MoveTypeMapper::new();
335        assert_eq!(mapper.map_type("u256").path, "U256");
336    }
337
338    #[test]
339    fn test_mapper_signer() {
340        let mapper = MoveTypeMapper::new();
341        assert_eq!(mapper.map_type("&signer").path, "AccountAddress");
342        assert_eq!(mapper.map_type("signer").path, "AccountAddress");
343    }
344
345    #[test]
346    fn test_mapper_nested_vector() {
347        let mapper = MoveTypeMapper::new();
348        let result = mapper.map_type("vector<vector<u8>>");
349        assert_eq!(result.path, "Vec<Vec<u8>>");
350    }
351
352    #[test]
353    fn test_mapper_option_type() {
354        let mapper = MoveTypeMapper::new();
355        let result = mapper.map_type("0x1::option::Option<u64>");
356        assert_eq!(result.path, "Option<u64>");
357    }
358
359    #[test]
360    fn test_mapper_object_type() {
361        let mapper = MoveTypeMapper::new();
362        let result = mapper.map_type("0x1::object::Object<Token>");
363        assert_eq!(result.path, "AccountAddress");
364    }
365
366    #[test]
367    fn test_mapper_unknown_struct() {
368        let mapper = MoveTypeMapper::new();
369        let result = mapper.map_type("0x1::module::SomeStruct");
370        assert!(result.doc.is_some());
371    }
372
373    #[test]
374    fn test_mapper_unknown_type() {
375        let mapper = MoveTypeMapper::new();
376        let result = mapper.map_type("some_completely_unknown_thing");
377        assert_eq!(result.path, "serde_json::Value");
378    }
379
380    #[test]
381    fn test_to_bcs_arg_address() {
382        let mapper = MoveTypeMapper::new();
383        let result = mapper.to_bcs_arg("address", "my_addr");
384        assert!(result.contains("aptos_bcs::to_bytes"));
385        assert!(result.contains("my_addr"));
386    }
387
388    #[test]
389    fn test_to_bcs_arg_vector_u8() {
390        let mapper = MoveTypeMapper::new();
391        let result = mapper.to_bcs_arg("vector<u8>", "my_bytes");
392        assert!(result.contains("aptos_bcs::to_bytes"));
393    }
394
395    #[test]
396    fn test_to_bcs_arg_vector_other() {
397        let mapper = MoveTypeMapper::new();
398        let result = mapper.to_bcs_arg("vector<u64>", "my_vec");
399        assert!(result.contains("aptos_bcs::to_bytes"));
400    }
401
402    #[test]
403    fn test_to_bcs_arg_string() {
404        let mapper = MoveTypeMapper::new();
405        let result = mapper.to_bcs_arg("0x1::string::String", "my_string");
406        assert!(result.contains("aptos_bcs::to_bytes"));
407    }
408
409    #[test]
410    fn test_to_bcs_arg_other_string() {
411        let mapper = MoveTypeMapper::new();
412        let result = mapper.to_bcs_arg("0xabc::my_module::string::String", "s");
413        assert!(result.contains("aptos_bcs::to_bytes"));
414    }
415
416    #[test]
417    fn test_is_signer_param() {
418        let mapper = MoveTypeMapper::new();
419        assert!(mapper.is_signer_param("&signer"));
420        assert!(mapper.is_signer_param("signer"));
421        assert!(!mapper.is_signer_param("address"));
422        assert!(!mapper.is_signer_param("u64"));
423    }
424
425    #[test]
426    fn test_to_pascal_case_with_spaces() {
427        assert_eq!(to_pascal_case("hello world"), "HelloWorld");
428    }
429
430    #[test]
431    fn test_to_pascal_case_with_dashes() {
432        assert_eq!(to_pascal_case("hello-world"), "HelloWorld");
433    }
434
435    #[test]
436    fn test_to_snake_case_single_word() {
437        assert_eq!(to_snake_case("hello"), "hello");
438    }
439
440    #[test]
441    fn test_to_snake_case_already_lowercase() {
442        assert_eq!(to_snake_case("helloworld"), "helloworld");
443    }
444}