Skip to main content

rust2go_common/
common.rs

1// Copyright 2024 ihciah. All Rights Reserved.
2
3use crate::{g2r::G2RTraitRepr, r2g::R2GTraitRepr};
4use heck::{
5    ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTitleCase,
6    ToTrainCase, ToUpperCamelCase,
7};
8use proc_macro2::{Span, TokenStream};
9use quote::{format_ident, quote, ToTokens};
10use std::collections::HashMap;
11use syn::parse::Parser;
12use syn::{
13    Attribute, Error, Expr, ExprLit, File, Ident, Item, Lit, Meta, MetaNameValue, PathSegment,
14    Result, Type,
15};
16
17pub struct RawRsFile {
18    file: File,
19}
20
21impl RawRsFile {
22    pub fn new<S: AsRef<str>>(src: S) -> Self {
23        let src = src.as_ref();
24        let syntax = syn::parse_file(src).expect("Unable to parse file");
25        RawRsFile { file: syntax }
26    }
27
28    pub fn go_internal_drop() -> &'static str {
29        r#"
30const void c_rust2go_internal_drop(void*);
31"#
32    }
33
34    pub fn go_shm_include() -> &'static str {
35        r#"
36typedef struct QueueMeta {
37    uintptr_t buffer_ptr;
38    uintptr_t buffer_len;
39    uintptr_t head_ptr;
40    uintptr_t tail_ptr;
41    uintptr_t working_ptr;
42    uintptr_t stuck_ptr;
43    int32_t working_fd;
44    int32_t unstuck_fd;
45    } QueueMeta;
46"#
47    }
48
49    pub fn go_shm_ring_init() -> &'static str {
50        r#"
51        func ringsInit(crr, crw C.QueueMeta, fns []func(unsafe.Pointer, *ants.MultiPool, func(interface{}, []byte, uint))) {
52            const MULTIPOOL_SIZE = 8
53            const SIZE_PER_POOL = -1
54
55            type Storage struct {
56                resp   interface{}
57                buffer []byte
58            }
59
60            type Payload struct {
61                Ptr          uint
62                UserData     uint
63                NextUserData uint
64                CallId       uint32
65                Flag         uint32
66            }
67
68            const CALL = 0b0101
69            const REPLY = 0b1110
70            const DROP = 0b1000
71
72            queueMetaCvt := func(cq C.QueueMeta) mem_ring.QueueMeta {
73                return mem_ring.QueueMeta{
74                    BufferPtr:  uintptr(cq.buffer_ptr),
75                    BufferLen:  uintptr(cq.buffer_len),
76                    HeadPtr:    uintptr(cq.head_ptr),
77                    TailPtr:    uintptr(cq.tail_ptr),
78                    WorkingPtr: uintptr(cq.working_ptr),
79                    StuckPtr:   uintptr(cq.stuck_ptr),
80                    WorkingFd:  int32(cq.working_fd),
81                    UnstuckFd:  int32(cq.unstuck_fd),
82                }
83            }
84
85            rr := queueMetaCvt(crr)
86            rw := queueMetaCvt(crw)
87
88            rrq := mem_ring.NewQueue[Payload](rr)
89            rwq := mem_ring.NewQueue[Payload](rw)
90
91            gr := rwq.Read()
92            gw := rrq.Write()
93
94            slab := mem_ring.NewMultiSlab[Storage]()
95            pool, _ := ants.NewMultiPool(MULTIPOOL_SIZE, SIZE_PER_POOL, ants.RoundRobin)
96
97            gr.RunHandler(func(p Payload) {
98                if p.Flag == CALL {
99                    post_func := func(resp interface{}, buffer []byte, offset uint) {
100                        if resp == nil {
101                            payload := Payload{
102                                Ptr:          0,
103                                UserData:     p.UserData,
104                                NextUserData: 0,
105                                CallId:       p.CallId,
106                                Flag:         DROP,
107                            }
108                            gw.Push(payload)
109                            return
110                        }
111
112                        // Use slab to hold reference of resp and buffer
113                        sid := slab.Push(Storage{
114                            resp,
115                            buffer,
116                        })
117                        payload := Payload{
118                            Ptr:          uint(uintptr(unsafe.Pointer(&buffer[offset]))),
119                            UserData:     p.UserData,
120                            NextUserData: sid,
121                            CallId:       p.CallId,
122                            Flag:         REPLY,
123                        }
124                        gw.Push(payload)
125                    }
126                    fns[p.CallId](unsafe.Pointer(uintptr(p.Ptr)), pool, post_func)
127                } else if p.Flag == DROP {
128                    // drop memory instantly
129                    slab.Pop(p.UserData)
130                }
131            })
132        }
133        "#
134    }
135
136    // The returned mapping is struct OriginalType -> RefType.
137    pub fn convert_structs_to_ref(&self) -> Result<(HashMap<Ident, Ident>, TokenStream)> {
138        let mut name_mapping = HashMap::new();
139
140        // Add these to generated code to make golang have C structs of string.
141        let mut out = quote! {
142            #[repr(C)]
143            pub struct StringRef {
144                pub ptr: *const u8,
145                pub len: usize,
146            }
147            #[repr(C)]
148            pub struct ListRef {
149                pub ptr: *const (),
150                pub len: usize,
151            }
152        };
153        name_mapping.insert(
154            Ident::new("String", Span::call_site()),
155            Ident::new("StringRef", Span::call_site()),
156        );
157        name_mapping.insert(
158            Ident::new("Vec", Span::call_site()),
159            Ident::new("ListRef", Span::call_site()),
160        );
161
162        for item in self.file.items.iter() {
163            match item {
164                // for example, convert
165                // pub struct DemoRequest {
166                //     pub name: String,
167                //     pub age: u8,
168                // }
169                // to
170                // #[repr(C)]
171                // pub struct DemoRequestRef {
172                //    pub name: StringRef,
173                //    pub age: u8,
174                // }
175                Item::Struct(s) => {
176                    let struct_name = s.ident.clone();
177                    let struct_name_ref = format_ident!("{}Ref", struct_name);
178                    name_mapping.insert(struct_name, struct_name_ref.clone());
179                    let mut field_names = Vec::with_capacity(s.fields.len());
180                    let mut field_types = Vec::with_capacity(s.fields.len());
181                    for field in s.fields.iter() {
182                        let field_name = field
183                            .clone()
184                            .ident
185                            .ok_or_else(|| serr!("only named fields are supported"))?;
186                        let field_type = ParamType::try_from(&field.ty)?;
187                        field_names.push(field_name);
188                        field_types.push(field_type.to_rust_ref(None));
189                    }
190                    out.extend(quote! {
191                        #[repr(C)]
192                        pub struct #struct_name_ref {
193                            #(pub #field_names: #field_types,)*
194                        }
195                    });
196                }
197                _ => continue,
198            }
199        }
200        Ok((name_mapping, out))
201    }
202
203    // go structs define and newStruct/refStruct function impl.
204    pub fn convert_structs_to_go(
205        &self,
206        levels: &HashMap<Ident, u8>,
207        go118: bool,
208    ) -> Result<String> {
209        const GO118CODE: &str = r#"
210        // An alternative impl of unsafe.String for go1.18
211        func unsafeString(ptr *byte, length int) string {
212            sliceHeader := &reflect.SliceHeader{
213                Data: uintptr(unsafe.Pointer(ptr)),
214                Len:  length,
215                Cap:  length,
216            }
217            return *(*string)(unsafe.Pointer(sliceHeader))
218        }
219
220        // An alternative impl of unsafe.StringData for go1.18
221        func unsafeStringData(s string) *byte {
222            return (*byte)(unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data))
223        }
224        func newString(s_ref C.StringRef) string {
225            return unsafeString((*byte)(unsafe.Pointer(s_ref.ptr)), int(s_ref.len))
226        }
227        func refString(s *string, _ *[]byte) C.StringRef {
228            return C.StringRef{
229                ptr: (*C.uint8_t)(unsafeStringData(*s)),
230                len: C.uintptr_t(len(*s)),
231            }
232        }
233        "#;
234
235        const GO121CODE: &str = r#"
236        func newString(s_ref C.StringRef) string {
237            return unsafe.String((*byte)(unsafe.Pointer(s_ref.ptr)), s_ref.len)
238        }
239        func refString(s *string, _ *[]byte) C.StringRef {
240            return C.StringRef{
241                ptr: (*C.uint8_t)(unsafe.StringData(*s)),
242                len: C.uintptr_t(len(*s)),
243            }
244        }
245        "#;
246
247        let mut out = if go118 {
248            GO118CODE.to_string()
249        } else {
250            GO121CODE.to_string()
251        } + r#"
252        func ownString(s_ref C.StringRef) string {
253            return string(unsafe.Slice((*byte)(unsafe.Pointer(s_ref.ptr)), int(s_ref.len)))
254        }
255        func cntString(_ *string, _ *uint) [0]C.StringRef { return [0]C.StringRef{} }
256        func new_list_mapper[T1, T2 any](f func(T1) T2) func(C.ListRef) []T2 {
257            return func(x C.ListRef) []T2 {
258                input := unsafe.Slice((*T1)(unsafe.Pointer(x.ptr)), x.len)
259                output := make([]T2, len(input))
260                for i, v := range input {
261                    output[i] = f(v)
262                }
263                return output
264            }
265        }
266        func new_list_mapper_primitive[T1, T2 any](_ func(T1) T2) func(C.ListRef) []T2 {
267            return func(x C.ListRef) []T2 {
268                return unsafe.Slice((*T2)(unsafe.Pointer(x.ptr)), x.len)
269            }
270        }
271        // only handle non-primitive type T
272        func cnt_list_mapper[T, R any](f func(s *T, cnt *uint)[0]R) func(s *[]T, cnt *uint) [0]C.ListRef {
273            return func(s *[]T, cnt *uint) [0]C.ListRef {
274                for _, v := range *s {
275                    f(&v, cnt)
276                }
277                *cnt += uint(len(*s)) * size_of[R]()
278                return [0]C.ListRef{}
279            }
280        }
281
282        // only handle primitive type T
283        func cnt_list_mapper_primitive[T, R any](_ func(s *T, cnt *uint)[0]R) func(s *[]T, cnt *uint) [0]C.ListRef {
284            return func(s *[]T, cnt *uint) [0]C.ListRef {return [0]C.ListRef{}}
285        }
286        // only handle non-primitive type T
287        func ref_list_mapper[T, R any](f func(s *T, buffer *[]byte) R) func(s *[]T, buffer *[]byte) C.ListRef {
288            return func(s *[]T, buffer *[]byte) C.ListRef {
289                if len(*buffer) == 0 {
290                    return C.ListRef{
291                        ptr: unsafe.Pointer(nil),
292                        len: C.uintptr_t(len(*s)),
293                    }
294                }
295                ret := C.ListRef{
296                    ptr: unsafe.Pointer(&(*buffer)[0]),
297                    len: C.uintptr_t(len(*s)),
298                }
299                children_bytes := int(size_of[R]()) * len(*s)
300                children := (*buffer)[:children_bytes]
301                *buffer = (*buffer)[children_bytes:]
302                for _, v := range *s {
303                    child := f(&v, buffer)
304                    len := unsafe.Sizeof(child)
305                    copy(children, unsafe.Slice((*byte)(unsafe.Pointer(&child)), len))
306                    children = children[len:]
307                }
308                return ret
309            }
310        }
311        // only handle primitive type T
312        func ref_list_mapper_primitive[T, R any](_ func(s *T, buffer *[]byte) R) func(s *[]T, buffer *[]byte) C.ListRef {
313            return func(s *[]T, buffer *[]byte) C.ListRef {
314                if len(*s) == 0 {
315                    return C.ListRef{
316                        ptr: unsafe.Pointer(nil),
317                        len: C.uintptr_t(0),
318                    }
319                }
320                return C.ListRef{
321                    ptr: unsafe.Pointer(&(*s)[0]),
322                    len: C.uintptr_t(len(*s)),
323                }
324            }
325        }
326        func size_of[T any]() uint {
327            var t T
328            return uint(unsafe.Sizeof(t))
329        }
330        func cvt_ref[R, CR any](cnt_f func(s *R, cnt *uint) [0]CR, ref_f func(p *R, buffer *[]byte) CR) func(p *R) (CR, []byte) {
331            return func(p *R) (CR, []byte) {
332                var cnt uint
333                cnt_f(p, &cnt)
334                buffer := make([]byte, cnt)
335                return ref_f(p, &buffer), buffer
336            }
337        }
338        func cvt_ref_cap[R, CR any](cnt_f func(s *R, cnt *uint) [0]CR, ref_f func(p *R, buffer *[]byte) CR, add_cap uint) func(p *R) (CR, []byte) {
339            return func(p *R) (CR, []byte) {
340                var cnt uint
341                cnt_f(p, &cnt)
342                buffer := make([]byte, cnt, cnt + add_cap)
343                return ref_f(p, &buffer), buffer
344            }
345        }
346
347        func newC_uint8_t(n C.uint8_t) uint8    { return uint8(n) }
348        func newC_uint16_t(n C.uint16_t) uint16 { return uint16(n) }
349        func newC_uint32_t(n C.uint32_t) uint32 { return uint32(n) }
350        func newC_uint64_t(n C.uint64_t) uint64 { return uint64(n) }
351        func newC_int8_t(n C.int8_t) int8       { return int8(n) }
352        func newC_int16_t(n C.int16_t) int16    { return int16(n) }
353        func newC_int32_t(n C.int32_t) int32    { return int32(n) }
354        func newC_int64_t(n C.int64_t) int64    { return int64(n) }
355        func newC_bool(n C.bool) bool           { return bool(n) }
356        func newC_uintptr_t(n C.uintptr_t) uint { return uint(n) }
357        func newC_intptr_t(n C.intptr_t) int    { return int(n) }
358        func newC_float(n C.float) float32      { return float32(n) }
359        func newC_double(n C.double) float64    { return float64(n) }
360
361        func cntC_uint8_t(_ *uint8, _ *uint) [0]C.uint8_t    { return [0]C.uint8_t{} }
362        func cntC_uint16_t(_ *uint16, _ *uint) [0]C.uint16_t { return [0]C.uint16_t{} }
363        func cntC_uint32_t(_ *uint32, _ *uint) [0]C.uint32_t { return [0]C.uint32_t{} }
364        func cntC_uint64_t(_ *uint64, _ *uint) [0]C.uint64_t { return [0]C.uint64_t{} }
365        func cntC_int8_t(_ *int8, _ *uint) [0]C.int8_t       { return [0]C.int8_t{} }
366        func cntC_int16_t(_ *int16, _ *uint) [0]C.int16_t    { return [0]C.int16_t{} }
367        func cntC_int32_t(_ *int32, _ *uint) [0]C.int32_t    { return [0]C.int32_t{} }
368        func cntC_int64_t(_ *int64, _ *uint) [0]C.int64_t    { return [0]C.int64_t{} }
369        func cntC_bool(_ *bool, _ *uint) [0]C.bool           { return [0]C.bool{} }
370        func cntC_uintptr_t(_ *uint, _ *uint) [0]C.uintptr_t { return [0]C.uintptr_t{} }
371        func cntC_intptr_t(_ *int, _ *uint) [0]C.intptr_t    { return [0]C.intptr_t{} }
372        func cntC_float(_ *float32, _ *uint) [0]C.float      { return [0]C.float{} }
373        func cntC_double(_ *float64, _ *uint) [0]C.double    { return [0]C.double{} }
374
375        func refC_uint8_t(p *uint8, _ *[]byte) C.uint8_t    { return C.uint8_t(*p) }
376        func refC_uint16_t(p *uint16, _ *[]byte) C.uint16_t { return C.uint16_t(*p) }
377        func refC_uint32_t(p *uint32, _ *[]byte) C.uint32_t { return C.uint32_t(*p) }
378        func refC_uint64_t(p *uint64, _ *[]byte) C.uint64_t { return C.uint64_t(*p) }
379        func refC_int8_t(p *int8, _ *[]byte) C.int8_t       { return C.int8_t(*p) }
380        func refC_int16_t(p *int16, _ *[]byte) C.int16_t    { return C.int16_t(*p) }
381        func refC_int32_t(p *int32, _ *[]byte) C.int32_t    { return C.int32_t(*p) }
382        func refC_int64_t(p *int64, _ *[]byte) C.int64_t    { return C.int64_t(*p) }
383        func refC_bool(p *bool, _ *[]byte) C.bool           { return C.bool(*p) }
384        func refC_uintptr_t(p *uint, _ *[]byte) C.uintptr_t { return C.uintptr_t(*p) }
385        func refC_intptr_t(p *int, _ *[]byte) C.intptr_t    { return C.intptr_t(*p) }
386        func refC_float(p *float32, _ *[]byte) C.float      { return C.float(*p) }
387        func refC_double(p *float64, _ *[]byte) C.double    { return C.double(*p) }
388        "#;
389        for item in self.file.items.iter() {
390            match item {
391                // for example, convert
392                // pub struct DemoRequest {
393                //     pub name: String,
394                //     pub age: u8,
395                // }
396                // to
397                // type DemoRequest struct {
398                //     name String
399                //     age uint8
400                // }
401                // func newDemoRequest(p C.DemoRequestRef) DemoRequest {
402                //     return DemoRequest {
403                //         name: newString(p.name),
404                //         age: uint8(p.age),
405                //     }
406                // }
407                // func refDemoRequest(p DemoRequest) C.DemoRequestRef {
408                //     return C.DemoRequestRef {
409                //         name: refString(p.name),
410                //         age: C.uint8_t(p.age),
411                //     }
412                // }
413                Item::Struct(s) => {
414                    let go_struct_tag = Self::go_struct_tag(&s.attrs)?;
415                    let struct_name = s.ident.to_string();
416                    out.push_str(&format!("type {struct_name} struct {{\n"));
417                    for field in s.fields.iter() {
418                        let field_name = field
419                            .ident
420                            .as_ref()
421                            .ok_or_else(|| serr!("only named fields are supported"))?
422                            .to_string();
423                        let field_type = ParamType::try_from(&field.ty)?;
424                        out.push_str(&format!(
425                            "    {} {} {}\n",
426                            field_name,
427                            field_type.to_go(),
428                            Self::gen_tag(&field_name, &go_struct_tag)
429                        ));
430                    }
431                    out.push_str("}\n");
432
433                    // newStruct
434                    out.push_str(&format!(
435                        "func new{struct_name}(p C.{struct_name}Ref) {struct_name}{{\nreturn {struct_name}{{\n"
436                    ));
437                    for field in s.fields.iter() {
438                        let field_name = field.ident.as_ref().unwrap().to_string();
439                        let field_type = ParamType::try_from(&field.ty)?;
440                        let (new_f, _) = field_type.c_to_go_field_converter(levels);
441                        out.push_str(&format!("{field_name}: {new_f}(p.{field_name}),\n",));
442                    }
443                    out.push_str("}\n}\n");
444
445                    // ownStruct
446                    out.push_str(&format!(
447                        "func own{struct_name}(p C.{struct_name}Ref) {struct_name}{{\nreturn {struct_name}{{\n"
448                    ));
449                    for field in s.fields.iter() {
450                        let field_name = field.ident.as_ref().unwrap().to_string();
451                        let field_type = ParamType::try_from(&field.ty)?;
452                        let own_f = field_type.c_to_go_field_converter_owned();
453                        out.push_str(&format!("{field_name}: {own_f}(p.{field_name}),\n",));
454                    }
455                    out.push_str("}\n}\n");
456
457                    // cntStruct
458                    let level = *levels.get(&s.ident).unwrap();
459                    out.push_str(&format!(
460                        "func cnt{struct_name}(s *{struct_name}, cnt *uint) [0]C.{struct_name}Ref {{\n"
461                    ));
462                    let mut used = false;
463                    if level == 2 {
464                        for field in s.fields.iter() {
465                            let field_name = field.ident.as_ref().unwrap().to_string();
466                            let field_type = ParamType::try_from(&field.ty)?;
467                            let (counter_f, level) = field_type.go_to_c_field_counter(levels);
468                            if level == 2 {
469                                out.push_str(&format!("{counter_f}(&s.{field_name}, cnt)\n"));
470                                used = true;
471                            }
472                        }
473                    }
474                    if !used {
475                        out.push_str("_ = s\n_ = cnt\n");
476                    }
477                    out.push_str(&format!("return [0]C.{struct_name}Ref{{}}\n"));
478                    out.push_str("}\n");
479
480                    // refStruct
481                    out.push_str(&format!(
482                        "func ref{struct_name}(p *{struct_name}, buffer *[]byte) C.{struct_name}Ref{{\nreturn C.{struct_name}Ref{{\n"
483                    ));
484                    for field in s.fields.iter() {
485                        let field_name = field.ident.as_ref().unwrap().to_string();
486                        let field_type = ParamType::try_from(&field.ty)?;
487                        let (ref_f, _) = field_type.go_to_c_field_converter(levels);
488                        out.push_str(&format!(
489                            "{field_name}: {ref_f}(&p.{field_name}, buffer),\n",
490                        ));
491                    }
492                    out.push_str("}\n}\n");
493                }
494                _ => continue,
495            }
496        }
497        Ok(out)
498    }
499
500    pub fn convert_r2g_trait(&self) -> Result<Vec<R2GTraitRepr>> {
501        let out: Vec<R2GTraitRepr> = self
502            .file
503            .items
504            .iter()
505            .filter_map(|item| match item {
506                Item::Trait(t)
507                    if t.attrs
508                        .iter()
509                        .any(|attr| attr.meta.path().segments.last().unwrap().ident == "r2g") =>
510                {
511                    Some(t)
512                }
513                _ => None,
514            })
515            .map(|trat| trat.try_into())
516            .collect::<Result<Vec<R2GTraitRepr>>>()?;
517        Ok(out)
518    }
519
520    pub fn convert_g2r_trait(&self) -> Result<Vec<G2RTraitRepr>> {
521        let out: Vec<G2RTraitRepr> = self
522            .file
523            .items
524            .iter()
525            .filter_map(|item| match item {
526                Item::Trait(t)
527                    if t.attrs
528                        .iter()
529                        .any(|attr| attr.meta.path().segments.last().unwrap().ident == "g2r") =>
530                {
531                    Some(t)
532                }
533                _ => None,
534            })
535            .map(|trat| trat.try_into())
536            .collect::<Result<Vec<G2RTraitRepr>>>()?;
537        Ok(out)
538    }
539
540    // 0->Primitive
541    // 1->SimpleWrapper
542    // 2->Complex
543    pub fn convert_structs_levels(&self) -> Result<HashMap<Ident, u8>> {
544        enum Node {
545            List(Box<Node>),
546            NamedStruct(Ident),
547            Primitive,
548        }
549        fn type_to_node(ty: &Type) -> Result<Node> {
550            let seg = type_to_segment(ty)?;
551            match seg.ident.to_string().as_str() {
552                "Vec" | "Option" => {
553                    let inside = match &seg.arguments {
554                        syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
555                            syn::GenericArgument::Type(ty) => ty,
556                            _ => panic!("list generic must be a type"),
557                        },
558                        _ => panic!("list type must have angle bracketed arguments"),
559                    };
560                    Ok(Node::List(Box::new(type_to_node(inside)?)))
561                }
562                "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64" | "isize"
563                | "bool" | "char" | "f32" | "f64" => Ok(Node::Primitive),
564                _ => Ok(Node::NamedStruct(seg.ident.clone())),
565            }
566        }
567        fn node_level(
568            node: &Node,
569            items: &HashMap<Ident, Vec<Node>>,
570            out: &mut HashMap<Ident, u8>,
571        ) -> u8 {
572            match node {
573                Node::List(inner) => (1 + node_level(inner, items, out)).min(2),
574                Node::NamedStruct(ident) if ident.to_string().as_str() == "String" => 1,
575                Node::NamedStruct(name) => {
576                    if let Some(lv) = out.get(name) {
577                        return *lv;
578                    }
579                    let lv = items
580                        .get(name)
581                        .map(|nodes| {
582                            nodes
583                                .iter()
584                                .map(|n| node_level(n, items, out))
585                                .max()
586                                .unwrap_or(0)
587                        })
588                        .unwrap();
589                    out.insert(name.clone(), lv);
590                    lv
591                }
592                Node::Primitive => 0,
593            }
594        }
595        let mut items = HashMap::<Ident, Vec<Node>>::new();
596        for item in self.file.items.iter() {
597            match item {
598                Item::Struct(s) => {
599                    let mut fields = Vec::new();
600                    for field in &s.fields {
601                        fields.push(type_to_node(&field.ty)?);
602                    }
603                    items.insert(s.ident.clone(), fields);
604                }
605                _ => continue,
606            }
607        }
608
609        let mut out = HashMap::new();
610        for name in items.keys() {
611            let lv = node_level(&Node::NamedStruct(name.clone()), &items, &mut out);
612            out.insert(name.clone(), lv);
613        }
614        out.insert(Ident::new("String", Span::call_site()), 1);
615        Ok(out)
616    }
617
618    fn is_r2g_struct_tag(attr: &Attribute) -> bool {
619        if attr.path().is_ident("r2g_struct_tag") {
620            return true;
621        }
622
623        let segments: Vec<_> = attr
624            .path()
625            .segments
626            .iter()
627            .map(|seg| seg.ident.to_string())
628            .collect();
629
630        if segments.len() == 2 && segments[0] == "rust2go" && segments[1] == "r2g_struct_tag" {
631            return true;
632        }
633
634        false
635    }
636    fn go_struct_tag(attrs: &[Attribute]) -> Result<Vec<(String, String)>> {
637        let mut hash_set = vec![];
638
639        for attr in attrs {
640            if Self::is_r2g_struct_tag(attr) {
641                let meta_list = match &attr.meta {
642                    Meta::List(meta_list) => meta_list,
643                    _ => continue,
644                };
645
646                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
647                let metas = parser.parse2(meta_list.tokens.clone())?;
648
649                for meta in metas {
650                    if let Meta::NameValue(MetaNameValue {
651                        path,
652                        value:
653                            Expr::Lit(ExprLit {
654                                lit: Lit::Str(lit_str),
655                                ..
656                            }),
657                        ..
658                    }) = meta
659                    {
660                        if let Some(ident) = path.get_ident() {
661                            let key = ident.to_string();
662                            let value = lit_str.value();
663                            hash_set.push((key, value));
664                        }
665                    }
666                }
667            }
668        }
669
670        Ok(hash_set)
671    }
672
673    fn gen_tag(field_name: &str, tag_list: &[(String, String)]) -> String {
674        let mut tags = vec![];
675        for (key, heck_type) in tag_list {
676            tags.push(format!(
677                "{}:{:?}",
678                key,
679                Self::heck_field_name(field_name, heck_type)
680            ));
681        }
682        if tags.is_empty() {
683            return String::new();
684        }
685        format!("`{}`", tags.join(" "))
686    }
687
688    fn heck_field_name(field_name: &str, heck_type: &str) -> String {
689        match heck_type {
690            "snake_case" => field_name.to_snake_case(),
691            "lowerCamelCase" => field_name.to_lower_camel_case(),
692            "UpperCamelCase" => field_name.to_upper_camel_case(),
693            "kebab-case" => field_name.to_kebab_case(),
694            "SHOUTY_SNAKE_CASE" => field_name.to_shouty_snake_case(),
695            "SHOUTY-KEBAB-CASE" => field_name.to_shouty_kebab_case(),
696            "Title Case" => field_name.to_title_case(),
697            "Train-Case" => field_name.to_train_case(),
698            _ => panic!("unknown heck type"),
699        }
700    }
701}
702
703pub struct Param {
704    pub name: Ident,
705    pub ty: ParamType,
706}
707
708impl Param {
709    pub fn ty(&self) -> &ParamType {
710        &self.ty
711    }
712}
713
714pub struct ParamType {
715    pub inner: ParamTypeInner,
716    pub is_reference: bool,
717}
718
719pub enum ParamTypeInner {
720    Primitive(Ident),
721    Custom(Ident),
722    List(Type),
723}
724
725impl ToTokens for ParamType {
726    fn to_tokens(&self, tokens: &mut TokenStream) {
727        if self.is_reference {
728            tokens.extend(quote! {&});
729        }
730        match &self.inner {
731            ParamTypeInner::Primitive(ty) => ty.to_tokens(tokens),
732            ParamTypeInner::Custom(ty) => ty.to_tokens(tokens),
733            ParamTypeInner::List(ty) => ty.to_tokens(tokens),
734        }
735    }
736}
737
738impl TryFrom<&Type> for ParamType {
739    type Error = Error;
740
741    fn try_from(mut ty: &Type) -> Result<Self> {
742        let mut is_reference = false;
743        if let Type::Reference(r) = ty {
744            is_reference = true;
745            ty = &r.elem;
746        }
747
748        // TypePath -> ParamType
749        let seg = type_to_segment(ty)?;
750        let param_type_inner = match seg.ident.to_string().as_str() {
751            "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "usize" | "isize"
752            | "bool" | "char" | "f32" | "f64" => {
753                if !seg.arguments.is_none() {
754                    sbail!("primitive types with arguments are not supported")
755                }
756                ParamTypeInner::Primitive(seg.ident.clone())
757            }
758            "Vec" | "Option" => ParamTypeInner::List(ty.clone()),
759            _ => {
760                if !seg.arguments.is_none() {
761                    sbail!("custom types with arguments are not supported")
762                }
763                ParamTypeInner::Custom(seg.ident.clone())
764            }
765        };
766        Ok(ParamType {
767            inner: param_type_inner,
768            is_reference,
769        })
770    }
771}
772
773impl ParamType {
774    pub fn to_c(&self, with_struct: bool) -> String {
775        let struct_ = if with_struct { "struct " } else { "" };
776        match &self.inner {
777            ParamTypeInner::Primitive(name) => match name.to_string().as_str() {
778                "u8" => "uint8_t",
779                "u16" => "uint16_t",
780                "u32" => "uint32_t",
781                "u64" => "uint64_t",
782                "i8" => "int8_t",
783                "i16" => "int16_t",
784                "i32" => "int32_t",
785                "i64" => "int64_t",
786                "bool" => "bool",
787                "char" => "uint32_t",
788                "usize" => "uintptr_t",
789                "isize" => "intptr_t",
790                "f32" => "float",
791                "f64" => "double",
792                _ => panic!("unreconigzed rust primitive type {name}"),
793            }
794            .to_string(),
795            ParamTypeInner::Custom(c) => format!("{struct_}{c}Ref"),
796            ParamTypeInner::List(_) => format!("{struct_}ListRef"),
797        }
798    }
799
800    pub fn to_go(&self) -> String {
801        match &self.inner {
802            ParamTypeInner::Primitive(name) => match name.to_string().as_str() {
803                "u8" => "uint8",
804                "u16" => "uint16",
805                "u32" => "uint32",
806                "u64" => "uint64",
807                "i8" => "int8",
808                "i16" => "int16",
809                "i32" => "int32",
810                "i64" => "int64",
811                "bool" => "bool",
812                "char" => "rune",
813                "usize" => "uint",
814                "isize" => "int",
815                "f32" => "float32",
816                "f64" => "float64",
817                _ => panic!("unreconigzed rust primitive type {name}"),
818            }
819            .to_string(),
820            ParamTypeInner::Custom(c) => {
821                let s = c.to_string();
822                match s.as_str() {
823                    "String" => "string".to_string(),
824                    _ => s,
825                }
826            }
827            ParamTypeInner::List(inner) => {
828                let seg = type_to_segment(inner).unwrap();
829                let inside = match &seg.arguments {
830                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
831                        syn::GenericArgument::Type(ty) => ty,
832                        _ => panic!("list generic must be a type"),
833                    },
834                    _ => panic!("list type must have angle bracketed arguments"),
835                };
836                format!(
837                    "[]{}",
838                    ParamType::try_from(inside)
839                        .expect("unable to convert list type")
840                        .to_go()
841                )
842            }
843        }
844    }
845
846    // f: StructRef -> Struct
847    pub fn c_to_go_field_converter(&self, mapping: &HashMap<Ident, u8>) -> (String, u8) {
848        match &self.inner {
849            ParamTypeInner::Primitive(name) => (
850                match name.to_string().as_str() {
851                    "u8" => "newC_uint8_t",
852                    "u16" => "newC_uint16_t",
853                    "u32" => "newC_uint32_t",
854                    "u64" => "newC_uint64_t",
855                    "i8" => "newC_int8_t",
856                    "i16" => "newC_int16_t",
857                    "i32" => "newC_int32_t",
858                    "i64" => "newC_int64_t",
859                    "bool" => "newC_bool",
860                    "usize" => "newC_uintptr_t",
861                    "isize" => "newC_intptr_t",
862                    "f32" => "newC_float",
863                    "f64" => "newC_double",
864                    _ => panic!("unrecognized rust primitive type {name}"),
865                }
866                .to_string(),
867                0,
868            ),
869            ParamTypeInner::Custom(c) => (
870                format!("new{}", c.to_string().as_str()),
871                *mapping.get(c).unwrap(),
872            ),
873            ParamTypeInner::List(inner) => {
874                let seg = type_to_segment(inner).unwrap();
875                let inside = match &seg.arguments {
876                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
877                        syn::GenericArgument::Type(ty) => ty,
878                        _ => panic!("list generic must be a type"),
879                    },
880                    _ => panic!("list type must have angle bracketed arguments"),
881                };
882                let (inner, inner_level) = ParamType::try_from(inside)
883                    .expect("unable to convert list type")
884                    .c_to_go_field_converter(mapping);
885                if inner_level == 0 {
886                    (format!("new_list_mapper_primitive({inner})"), 1)
887                } else {
888                    (format!("new_list_mapper({inner})"), 2.min(inner_level + 1))
889                }
890            }
891        }
892    }
893
894    // f: StructRef -> Struct with fully ownership
895    pub fn c_to_go_field_converter_owned(&self) -> String {
896        match &self.inner {
897            ParamTypeInner::Primitive(name) => match name.to_string().as_str() {
898                "u8" => "newC_uint8_t",
899                "u16" => "newC_uint16_t",
900                "u32" => "newC_uint32_t",
901                "u64" => "newC_uint64_t",
902                "i8" => "newC_int8_t",
903                "i16" => "newC_int16_t",
904                "i32" => "newC_int32_t",
905                "i64" => "newC_int64_t",
906                "bool" => "newC_bool",
907                "usize" => "newC_uintptr_t",
908                "isize" => "newC_intptr_t",
909                "f32" => "newC_float",
910                "f64" => "newC_double",
911                _ => panic!("unrecognized rust primitive type {name}"),
912            }
913            .to_string(),
914            ParamTypeInner::Custom(c) => format!("own{}", c.to_string().as_str()),
915            ParamTypeInner::List(inner) => {
916                let seg = type_to_segment(inner).unwrap();
917                let inside = match &seg.arguments {
918                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
919                        syn::GenericArgument::Type(ty) => ty,
920                        _ => panic!("list generic must be a type"),
921                    },
922                    _ => panic!("list type must have angle bracketed arguments"),
923                };
924                let inner = ParamType::try_from(inside)
925                    .expect("unable to convert list type")
926                    .c_to_go_field_converter_owned();
927                format!("new_list_mapper({inner})")
928            }
929        }
930    }
931
932    pub fn go_to_c_field_counter(&self, mapping: &HashMap<Ident, u8>) -> (String, u8) {
933        match &self.inner {
934            ParamTypeInner::Primitive(name) => (
935                match name.to_string().as_str() {
936                    "u8" => "cntC_uint8_t",
937                    "u16" => "cntC_uint16_t",
938                    "u32" => "cntC_uint32_t",
939                    "u64" => "cntC_uint64_t",
940                    "i8" => "cntC_int8_t",
941                    "i16" => "cntC_int16_t",
942                    "i32" => "cntC_int32_t",
943                    "i64" => "cntC_int64_t",
944                    "bool" => "cntC_bool",
945                    "usize" => "cntC_uintptr_t",
946                    "isize" => "cntC_intptr_t",
947                    "f32" => "cntC_float",
948                    "f64" => "cntC_double",
949                    _ => panic!("unrecognized rust primitive type {name}"),
950                }
951                .to_string(),
952                0,
953            ),
954            ParamTypeInner::Custom(c) => (
955                format!("cnt{}", c.to_string().as_str()),
956                *mapping.get(c).unwrap(),
957            ),
958            ParamTypeInner::List(inner) => {
959                let seg = type_to_segment(inner).unwrap();
960                let inside = match &seg.arguments {
961                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
962                        syn::GenericArgument::Type(ty) => ty,
963                        _ => panic!("list generic must be a type"),
964                    },
965                    _ => panic!("list type must have angle bracketed arguments"),
966                };
967                let (inner, inner_level) = ParamType::try_from(inside)
968                    .expect("unable to convert list type")
969                    .go_to_c_field_counter(mapping);
970                if inner_level == 0 {
971                    (format!("cnt_list_mapper_primitive({inner})"), 1)
972                } else {
973                    (format!("cnt_list_mapper({inner})"), 2.min(inner_level + 1))
974                }
975            }
976        }
977    }
978
979    // f: Struct -> StructRef
980    pub fn go_to_c_field_converter(&self, mapping: &HashMap<Ident, u8>) -> (String, u8) {
981        match &self.inner {
982            ParamTypeInner::Primitive(name) => (
983                match name.to_string().as_str() {
984                    "u8" => "refC_uint8_t",
985                    "u16" => "refC_uint16_t",
986                    "u32" => "refC_uint32_t",
987                    "u64" => "refC_uint64_t",
988                    "i8" => "refC_int8_t",
989                    "i16" => "refC_int16_t",
990                    "i32" => "refC_int32_t",
991                    "i64" => "refC_int64_t",
992                    "bool" => "refC_bool",
993                    "usize" => "refC_uintptr_t",
994                    "isize" => "refC_intptr_t",
995                    "f32" => "refC_float",
996                    "f64" => "refC_double",
997                    _ => panic!("unreconigzed rust primitive type {name}"),
998                }
999                .to_string(),
1000                0,
1001            ),
1002            ParamTypeInner::Custom(c) => (
1003                format!("ref{}", c.to_string().as_str()),
1004                *mapping.get(c).unwrap(),
1005            ),
1006            ParamTypeInner::List(inner) => {
1007                let seg = type_to_segment(inner).unwrap();
1008                let inside = match &seg.arguments {
1009                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
1010                        syn::GenericArgument::Type(ty) => ty,
1011                        _ => panic!("list generic must be a type"),
1012                    },
1013                    _ => panic!("list type must have angle bracketed arguments"),
1014                };
1015                let (inner, inner_level) = ParamType::try_from(inside)
1016                    .expect("unable to convert list type")
1017                    .go_to_c_field_converter(mapping);
1018                if inner_level == 0 {
1019                    (format!("ref_list_mapper_primitive({inner})"), 1)
1020                } else {
1021                    (format!("ref_list_mapper({inner})"), 2.min(inner_level + 1))
1022                }
1023            }
1024        }
1025    }
1026
1027    pub fn to_rust_ref(&self, prefix: Option<&TokenStream>) -> TokenStream {
1028        match &self.inner {
1029            ParamTypeInner::Primitive(name) => quote!(#name),
1030            ParamTypeInner::Custom(name) => {
1031                let ident = format_ident!("{}Ref", name);
1032                quote!(#prefix #ident)
1033            }
1034            ParamTypeInner::List(_) => {
1035                let ident = format_ident!("ListRef");
1036                quote!(#prefix #ident)
1037            }
1038        }
1039    }
1040}
1041
1042pub(crate) fn type_to_segment(ty: &Type) -> Result<&PathSegment> {
1043    let field_type = match ty {
1044        Type::Path(p) => p,
1045        _ => sbail!("only path types are supported"),
1046    };
1047    let path = &field_type.path;
1048    // Leading colon is not allow
1049    if path.leading_colon.is_some() {
1050        sbail!("types with leading colons are not supported");
1051    }
1052    // We only accept single-segment path
1053    if path.segments.len() != 1 {
1054        sbail!("types with multiple segments are not supported");
1055    }
1056    Ok(path.segments.first().unwrap())
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061    #[test]
1062    fn it_works() {
1063        let raw = r#"
1064        pub struct DemoRequest {
1065            pub name: String,
1066            pub age: u8,
1067        }
1068        pub struct DemoResponse {
1069            pub pass: bool,
1070        }
1071        pub trait DemoCall {
1072            fn demo_check(req: DemoRequest) -> DemoResponse;
1073            fn demo_check_async(req: DemoRequest) -> impl std::future::Future<Output = DemoResponse>;
1074        }
1075        "#;
1076        let raw_file = super::RawRsFile::new(raw);
1077        let traits = raw_file.convert_r2g_trait().unwrap();
1078        let levels = raw_file.convert_structs_levels().unwrap();
1079
1080        println!(
1081            "structs gen: {}",
1082            raw_file.convert_structs_to_go(&levels, false).unwrap()
1083        );
1084        for trait_ in traits {
1085            println!("if gen: {}", trait_.generate_go_interface());
1086            println!("go export gen: {}", trait_.generate_go_exports(&levels));
1087        }
1088        let levels = raw_file.convert_structs_levels().unwrap();
1089        levels.iter().for_each(|f| println!("{}: {}", f.0, f.1));
1090    }
1091}