Skip to main content

orleans_rust_codegen/
lib.rs

1//! Manifest-driven code generation for typed `orleans-rust-client` grain
2//! clients.
3//!
4//! This is intentionally limited (see the repository roadmap). It consumes a
5//! manifest emitted by the .NET bridge (`GetManifest` / `OrleansRustBridge.Tools`)
6//! and produces one Rust struct per grain contract that wraps a
7//! `orleans_rust_client::GrainRef` with typed methods.
8//!
9//! Type mapping covers the common primitive .NET types plus nullable types,
10//! arrays, and the standard generic collections (`List<T>` → `Vec<T>`,
11//! `Dictionary<K, V>` → `HashMap<K, V>`, ...); anything unrecognised falls back
12//! to `serde_json::Value`, keeping the generator robust against manifests it
13//! does not fully understand. Methods with multiple parameters generate
14//! multi-argument functions (serialized as a JSON array), and an opt-in mode
15//! emits `<method>_with_context` variants that also return the response
16//! context.
17
18use heck::{ToPascalCase, ToSnakeCase};
19use serde::Deserialize;
20
21/// Errors produced while generating client code.
22#[derive(thiserror::Error, Debug)]
23pub enum CodegenError {
24    /// The manifest JSON could not be parsed.
25    #[error("failed to parse manifest: {0}")]
26    Parse(#[from] serde_json::Error),
27    /// The manifest was structurally valid but unusable.
28    #[error("invalid manifest: {0}")]
29    Invalid(String),
30}
31
32/// A contract manifest, matching the JSON shape emitted by the bridge.
33#[derive(Debug, Clone, Deserialize)]
34pub struct Manifest {
35    /// Orleans service id the bridge connects to.
36    #[serde(default)]
37    pub service_id: String,
38    /// Orleans cluster id the bridge connects to.
39    #[serde(default)]
40    pub cluster_id: String,
41    /// Bridge version that produced the manifest.
42    #[serde(default)]
43    pub bridge_version: String,
44    /// Manifest schema version.
45    #[serde(default)]
46    pub schema_version: String,
47    /// Grain contracts.
48    #[serde(default)]
49    pub grains: Vec<GrainContract>,
50}
51
52/// A single grain contract.
53#[derive(Debug, Clone, Deserialize)]
54pub struct GrainContract {
55    /// Fully-qualified grain interface name.
56    pub interface_name: String,
57    /// Grain type alias used for dispatch.
58    pub grain_type: String,
59    /// Methods exposed by the grain.
60    #[serde(default)]
61    pub methods: Vec<GrainMethod>,
62    /// Key kinds the grain supports (`string`, `int64`, `guid`).
63    #[serde(default)]
64    pub supported_key_kinds: Vec<String>,
65}
66
67/// A single named method parameter.
68#[derive(Debug, Clone, Deserialize)]
69pub struct MethodParameter {
70    /// Parameter name.
71    pub name: String,
72    /// .NET type name.
73    #[serde(rename = "type")]
74    pub ty: String,
75}
76
77/// A single grain method.
78#[derive(Debug, Clone, Deserialize)]
79pub struct GrainMethod {
80    /// Method name as exposed on the grain interface.
81    pub name: String,
82    /// Request (single-argument) .NET type name, or empty for no argument.
83    /// Ignored when `parameters` is present.
84    #[serde(default)]
85    pub request_type: String,
86    /// Full parameter list. When present, takes precedence over `request_type`
87    /// and enables multi-argument methods.
88    #[serde(default)]
89    pub parameters: Vec<MethodParameter>,
90    /// Response .NET type name, or empty for no return value.
91    #[serde(default)]
92    pub response_type: String,
93    /// Payload codec; only `json` is supported by the generator.
94    #[serde(default)]
95    pub payload_codec: String,
96}
97
98impl Manifest {
99    /// Parse a manifest from a JSON string.
100    ///
101    /// # Errors
102    /// Returns [`CodegenError::Parse`] if the JSON is malformed.
103    pub fn from_json_str(json: &str) -> Result<Self, CodegenError> {
104        Ok(serde_json::from_str(json)?)
105    }
106}
107
108/// Options controlling generation.
109#[derive(Debug, Clone)]
110pub struct CodegenOptions {
111    /// Crate path used to reference the runtime client.
112    pub client_crate: String,
113    /// Also generate `<method>_with_context` variants that return the
114    /// response-context map alongside the value.
115    pub with_response_context: bool,
116}
117
118impl Default for CodegenOptions {
119    fn default() -> Self {
120        Self {
121            client_crate: "orleans_rust_client".to_owned(),
122            with_response_context: false,
123        }
124    }
125}
126
127/// Generate Rust source for every grain in `manifest`.
128///
129/// # Errors
130/// Returns [`CodegenError::Invalid`] if a grain contract cannot be turned into
131/// a valid Rust identifier.
132pub fn generate(manifest: &Manifest, options: &CodegenOptions) -> Result<String, CodegenError> {
133    let mut out = String::new();
134    out.push_str("// @generated by orleans-rust-codegen. Do not edit by hand.\n");
135    out.push_str("// Include within a module annotated `#[allow(dead_code, clippy::all)]`.\n\n");
136    out.push_str(&format!(
137        "use {client}::{{GrainKey, GrainRef, OrleansClient, OrleansError}};\n\n",
138        client = options.client_crate
139    ));
140
141    for grain in &manifest.grains {
142        out.push_str(&generate_grain(grain, options)?);
143        out.push('\n');
144    }
145
146    Ok(out)
147}
148
149fn generate_grain(grain: &GrainContract, options: &CodegenOptions) -> Result<String, CodegenError> {
150    let struct_name = client_struct_name(&grain.interface_name)?;
151    let key = KeyStrategy::from_kinds(&grain.supported_key_kinds);
152
153    let mut s = String::new();
154    s.push_str(&format!(
155        "/// Typed client for `{}`.\n",
156        grain.interface_name
157    ));
158    s.push_str(&format!(
159        "pub struct {struct_name} {{\n    inner: GrainRef,\n}}\n\n"
160    ));
161    s.push_str(&format!("impl {struct_name} {{\n"));
162    s.push_str(&format!(
163        "    /// Construct a client bound to `key`.\n    pub fn new(client: OrleansClient, key: {key_param}) -> Self {{\n        Self {{\n            inner: client.grain(\n                \"{interface}\",\n                \"{grain_type}\",\n                {key_expr},\n            ),\n        }}\n    }}\n",
164        key_param = key.param_type(),
165        interface = grain.interface_name,
166        grain_type = grain.grain_type,
167        key_expr = key.key_expr(),
168    ));
169
170    for method in &grain.methods {
171        s.push('\n');
172        s.push_str(&generate_method(method, options));
173    }
174
175    s.push_str("}\n");
176    Ok(s)
177}
178
179fn generate_method(method: &GrainMethod, options: &CodegenOptions) -> String {
180    let fn_name = sanitize_ident(&method.name.to_snake_case());
181    let response_ty = map_type(&method.response_type);
182
183    // Resolve the argument list: an explicit `parameters` list wins, otherwise
184    // fall back to the single `request_type`.
185    let args: Vec<(String, String)> = if !method.parameters.is_empty() {
186        method
187            .parameters
188            .iter()
189            .map(|p| (sanitize_ident(&p.name.to_snake_case()), map_type(&p.ty)))
190            .collect()
191    } else if map_type(&method.request_type) != "()" {
192        vec![("value".to_owned(), map_type(&method.request_type))]
193    } else {
194        Vec::new()
195    };
196
197    let signature_args: String = args
198        .iter()
199        .map(|(name, ty)| format!(", {name}: {ty}"))
200        .collect();
201
202    // Serialize 0 args as `&()`, 1 as `&name`, N as a tuple `&(a, b, ...)`
203    // which serde encodes as a JSON array the bridge invoker can decode.
204    let call_arg = match args.as_slice() {
205        [] => "&()".to_owned(),
206        [(name, _)] => format!("&{name}"),
207        many => format!(
208            "&({})",
209            many.iter()
210                .map(|(name, _)| name.clone())
211                .collect::<Vec<_>>()
212                .join(", ")
213        ),
214    };
215
216    let mut out = format!(
217        "    /// Invokes `{orig}`.\n    pub async fn {fn_name}(&self{signature_args}) -> Result<{response_ty}, OrleansError> {{\n        self.inner.invoke_json(\"{orig}\", {call_arg}).await\n    }}\n",
218        orig = method.name,
219    );
220
221    if options.with_response_context {
222        out.push_str(&format!(
223            "\n    /// Invokes `{orig}`, also returning the response context.\n    pub async fn {fn_name}_with_context(&self{signature_args}) -> Result<({response_ty}, std::collections::HashMap<String, String>), OrleansError> {{\n        self.inner.invoke_json_with_context(\"{orig}\", {call_arg}).await\n    }}\n",
224            orig = method.name,
225        ));
226    }
227
228    out
229}
230
231#[derive(Debug, Clone, Copy)]
232enum KeyStrategy {
233    String,
234    Int64,
235    Guid,
236}
237
238impl KeyStrategy {
239    fn from_kinds(kinds: &[String]) -> Self {
240        for kind in kinds {
241            match kind.as_str() {
242                "int64" => return KeyStrategy::Int64,
243                "guid" => return KeyStrategy::Guid,
244                _ => {}
245            }
246        }
247        KeyStrategy::String
248    }
249
250    fn param_type(self) -> &'static str {
251        match self {
252            KeyStrategy::String => "impl Into<String>",
253            KeyStrategy::Int64 => "i64",
254            KeyStrategy::Guid => "uuid::Uuid",
255        }
256    }
257
258    fn key_expr(self) -> &'static str {
259        match self {
260            KeyStrategy::String => "GrainKey::String(key.into())",
261            KeyStrategy::Int64 => "GrainKey::Int64(key)",
262            KeyStrategy::Guid => "GrainKey::Guid(key)",
263        }
264    }
265}
266
267fn client_struct_name(interface_name: &str) -> Result<String, CodegenError> {
268    let last = interface_name.rsplit('.').next().unwrap_or(interface_name);
269    let trimmed = last
270        .strip_prefix('I')
271        .filter(|rest| rest.chars().next().is_some_and(char::is_uppercase))
272        .unwrap_or(last);
273    let base = trimmed.to_pascal_case();
274    if base.is_empty() {
275        return Err(CodegenError::Invalid(format!(
276            "cannot derive a client name from interface `{interface_name}`"
277        )));
278    }
279    Ok(format!("{base}Client"))
280}
281
282/// Map a .NET type name to a Rust type. Handles primitives, nullable types,
283/// arrays, and the common generic collections; unknown types fall back to
284/// `serde_json::Value` so generation never fails on an unfamiliar type.
285fn map_type(dotnet: &str) -> String {
286    let normalized = dotnet.trim();
287
288    // Reflection FullName uses a trailing assembly-qualified suffix on generic
289    // arguments (e.g. `[[System.Int64, mscorlib, ...]]`); strip it for matching.
290    if let Some(scalar) = map_scalar(normalized) {
291        return scalar;
292    }
293
294    // Nullable<T> / `T?` -> Option<T>
295    if let Some(inner) = strip_nullable(normalized) {
296        return format!("Option<{}>", map_type(&inner));
297    }
298
299    // byte[] -> Vec<u8>; T[] -> Vec<T>
300    if let Some(element) = normalized.strip_suffix("[]") {
301        return format!("Vec<{}>", map_type(element));
302    }
303
304    // Generic collections.
305    if let Some((base, args)) = parse_generic(normalized) {
306        match (base.as_str(), args.as_slice()) {
307            (
308                "System.Collections.Generic.List"
309                | "System.Collections.Generic.IList"
310                | "System.Collections.Generic.IReadOnlyList"
311                | "System.Collections.Generic.ICollection"
312                | "System.Collections.Generic.IEnumerable"
313                | "List"
314                | "IList"
315                | "IReadOnlyList"
316                | "IEnumerable",
317                [item],
318            ) => return format!("Vec<{}>", map_type(item)),
319            (
320                "System.Collections.Generic.Dictionary"
321                | "System.Collections.Generic.IDictionary"
322                | "System.Collections.Generic.IReadOnlyDictionary"
323                | "Dictionary"
324                | "IDictionary",
325                [key, value],
326            ) => {
327                return format!(
328                    "std::collections::HashMap<{}, {}>",
329                    map_type(key),
330                    map_type(value)
331                );
332            }
333            ("System.Nullable" | "Nullable", [item]) => {
334                return format!("Option<{}>", map_type(item));
335            }
336            _ => {}
337        }
338    }
339
340    "serde_json::Value".to_owned()
341}
342
343fn map_scalar(normalized: &str) -> Option<String> {
344    let mapped = match normalized {
345        "" | "void" | "System.Void" | "System.Threading.Tasks.Task" => "()",
346        "System.String" | "string" => "String",
347        "System.Boolean" | "bool" => "bool",
348        "System.SByte" | "sbyte" => "i8",
349        "System.Byte" | "byte" => "u8",
350        "System.Int16" | "short" => "i16",
351        "System.UInt16" | "ushort" => "u16",
352        "System.Int32" | "int" => "i32",
353        "System.UInt32" | "uint" => "u32",
354        "System.Int64" | "long" => "i64",
355        "System.UInt64" | "ulong" => "u64",
356        "System.Single" | "float" => "f32",
357        "System.Double" | "double" => "f64",
358        "System.Guid" => "uuid::Uuid",
359        "System.DateTime"
360        | "System.DateTimeOffset"
361        | "System.TimeSpan"
362        | "System.Decimal"
363        | "decimal" => "String",
364        "System.Object" | "object" => "serde_json::Value",
365        _ => return None,
366    };
367    Some(mapped.to_owned())
368}
369
370/// Strip a `Nullable<T>` / `T?` wrapper, returning the inner type name.
371fn strip_nullable(normalized: &str) -> Option<String> {
372    if let Some(inner) = normalized.strip_suffix('?') {
373        return Some(inner.trim().to_owned());
374    }
375    None
376}
377
378/// Parse a generic type name into `(base, [arg, ...])`, supporting both C#
379/// source form (`List<System.Int64>`) and reflection form
380/// (`System.Collections.Generic.List`1[[System.Int64, mscorlib, ...]]`).
381fn parse_generic(name: &str) -> Option<(String, Vec<String>)> {
382    if let Some(open) = name.find('<') {
383        if !name.ends_with('>') {
384            return None;
385        }
386        let base = name[..open].trim().to_owned();
387        let inner = &name[open + 1..name.len() - 1];
388        return Some((base, split_top_level(inner)));
389    }
390
391    if let Some(tick) = name.find('`') {
392        let base = name[..tick].trim().to_owned();
393        let rest = &name[tick..];
394        let outer_open = rest.find('[')?;
395        let outer = rest[outer_open..].trim();
396        let inner = outer.strip_prefix('[')?.strip_suffix(']')?;
397        // `inner` is `[Type, asm, ...],[Type, asm, ...]`; each top-level group is
398        // an assembly-qualified type — take the type name before its first comma.
399        let args = split_top_level(inner)
400            .into_iter()
401            .map(|group| {
402                let group = group.trim();
403                let group = group.strip_prefix('[').unwrap_or(group);
404                let group = group.strip_suffix(']').unwrap_or(group);
405                group.split(',').next().unwrap_or(group).trim().to_owned()
406            })
407            .collect();
408        return Some((base, args));
409    }
410
411    None
412}
413
414/// Split a comma-separated generic argument list, respecting nested brackets.
415fn split_top_level(input: &str) -> Vec<String> {
416    let mut parts = Vec::new();
417    let mut depth = 0i32;
418    let mut current = String::new();
419    for ch in input.chars() {
420        match ch {
421            '<' | '[' => {
422                depth += 1;
423                current.push(ch);
424            }
425            '>' | ']' => {
426                depth -= 1;
427                current.push(ch);
428            }
429            ',' if depth == 0 => {
430                parts.push(current.trim().to_owned());
431                current.clear();
432            }
433            _ => current.push(ch),
434        }
435    }
436    if !current.trim().is_empty() {
437        parts.push(current.trim().to_owned());
438    }
439    parts
440}
441
442fn sanitize_ident(name: &str) -> String {
443    const RESERVED: &[&str] = &[
444        "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
445        "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
446        "mut", "pub", "ref", "return", "self", "static", "struct", "super", "trait", "true",
447        "type", "unsafe", "use", "where", "while",
448    ];
449    if RESERVED.contains(&name) {
450        format!("r#{name}")
451    } else {
452        name.to_owned()
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    fn method(name: &str, request: &str, response: &str) -> GrainMethod {
461        GrainMethod {
462            name: name.to_owned(),
463            request_type: request.to_owned(),
464            parameters: Vec::new(),
465            response_type: response.to_owned(),
466            payload_codec: "json".to_owned(),
467        }
468    }
469
470    fn grain(methods: Vec<GrainMethod>) -> Manifest {
471        Manifest {
472            service_id: "s".into(),
473            cluster_id: "c".into(),
474            bridge_version: "0.1.0".into(),
475            schema_version: "1".into(),
476            grains: vec![GrainContract {
477                interface_name: "Counter.Abstractions.ICounterGrain".into(),
478                grain_type: "counter".into(),
479                supported_key_kinds: vec!["string".into()],
480                methods,
481            }],
482        }
483    }
484
485    #[test]
486    fn derives_client_name() {
487        assert_eq!(
488            client_struct_name("Counter.Abstractions.ICounterGrain").unwrap(),
489            "CounterGrainClient"
490        );
491        assert_eq!(
492            client_struct_name("ICounterGrain").unwrap(),
493            "CounterGrainClient"
494        );
495    }
496
497    #[test]
498    fn maps_primitive_types() {
499        assert_eq!(map_type("System.Int64"), "i64");
500        assert_eq!(map_type(""), "()");
501        assert_eq!(map_type("Some.Custom.Type"), "serde_json::Value");
502    }
503
504    #[test]
505    fn maps_collections_and_options() {
506        assert_eq!(map_type("System.String?"), "Option<String>");
507        assert_eq!(map_type("System.Byte[]"), "Vec<u8>");
508        assert_eq!(map_type("System.Int32[]"), "Vec<i32>");
509        assert_eq!(map_type("List<System.Int64>"), "Vec<i64>");
510        assert_eq!(
511            map_type("Dictionary<System.String, System.Int32>"),
512            "std::collections::HashMap<String, i32>"
513        );
514    }
515
516    #[test]
517    fn maps_reflection_generic_names() {
518        assert_eq!(
519            map_type("System.Collections.Generic.List`1[[System.Int64, System.Private.CoreLib]]"),
520            "Vec<i64>"
521        );
522        assert_eq!(
523            map_type(
524                "System.Collections.Generic.Dictionary`2[[System.String, mscorlib],[System.Int32, mscorlib]]"
525            ),
526            "std::collections::HashMap<String, i32>"
527        );
528    }
529
530    #[test]
531    fn generates_counter_client() {
532        let manifest = grain(vec![
533            method("Get", "", "System.Int64"),
534            method("Add", "System.Int64", "System.Int64"),
535        ]);
536
537        let code = generate(&manifest, &CodegenOptions::default()).unwrap();
538        assert!(code.contains("pub struct CounterGrainClient"));
539        assert!(code.contains("pub async fn get(&self) -> Result<i64, OrleansError>"));
540        assert!(code.contains("pub async fn add(&self, value: i64) -> Result<i64, OrleansError>"));
541    }
542
543    #[test]
544    fn generates_multi_argument_method() {
545        let mut transfer = method("Transfer", "", "System.Boolean");
546        transfer.parameters = vec![
547            MethodParameter {
548                name: "destination".into(),
549                ty: "System.String".into(),
550            },
551            MethodParameter {
552                name: "amount".into(),
553                ty: "System.Int64".into(),
554            },
555        ];
556
557        let code = generate(&grain(vec![transfer]), &CodegenOptions::default()).unwrap();
558        assert!(code.contains(
559            "pub async fn transfer(&self, destination: String, amount: i64) -> Result<bool, OrleansError>"
560        ));
561        assert!(code.contains("invoke_json(\"Transfer\", &(destination, amount))"));
562    }
563
564    #[test]
565    fn generates_response_context_variant() {
566        let options = CodegenOptions {
567            with_response_context: true,
568            ..Default::default()
569        };
570        let code = generate(&grain(vec![method("Get", "", "System.Int64")]), &options).unwrap();
571        assert!(code.contains(
572            "pub async fn get_with_context(&self) -> Result<(i64, std::collections::HashMap<String, String>), OrleansError>"
573        ));
574        assert!(code.contains("invoke_json_with_context(\"Get\", &())"));
575    }
576
577    fn grain_with_keys(kinds: Vec<&str>, methods: Vec<GrainMethod>) -> Manifest {
578        Manifest {
579            service_id: "s".into(),
580            cluster_id: "c".into(),
581            bridge_version: "0.1.0".into(),
582            schema_version: "1".into(),
583            grains: vec![GrainContract {
584                interface_name: "Sample.IThingGrain".into(),
585                grain_type: "thing".into(),
586                supported_key_kinds: kinds.into_iter().map(str::to_owned).collect(),
587                methods,
588            }],
589        }
590    }
591
592    #[test]
593    fn generates_int64_key_constructor() {
594        let code = generate(
595            &grain_with_keys(vec!["int64"], vec![method("Get", "", "System.Int64")]),
596            &CodegenOptions::default(),
597        )
598        .unwrap();
599        assert!(code.contains("pub fn new(client: OrleansClient, key: i64) -> Self"));
600        assert!(code.contains("GrainKey::Int64(key)"));
601    }
602
603    #[test]
604    fn generates_guid_key_constructor() {
605        let code = generate(
606            &grain_with_keys(vec!["guid"], vec![method("Get", "", "System.Int64")]),
607            &CodegenOptions::default(),
608        )
609        .unwrap();
610        assert!(code.contains("pub fn new(client: OrleansClient, key: uuid::Uuid) -> Self"));
611        assert!(code.contains("GrainKey::Guid(key)"));
612    }
613
614    #[test]
615    fn sanitizes_reserved_method_names() {
616        // "Type" -> snake_case "type" (a Rust keyword) -> raw identifier.
617        let code = generate(
618            &grain(vec![method("Type", "", "System.String")]),
619            &CodegenOptions::default(),
620        )
621        .unwrap();
622        assert!(code.contains("pub async fn r#type(&self)"));
623    }
624
625    #[test]
626    fn empty_interface_name_is_an_error() {
627        let mut manifest = grain_with_keys(vec!["string"], vec![method("Get", "", "")]);
628        manifest.grains[0].interface_name = String::new();
629        let err = generate(&manifest, &CodegenOptions::default()).unwrap_err();
630        assert!(matches!(err, CodegenError::Invalid(_)));
631    }
632
633    #[test]
634    fn maps_additional_scalars() {
635        assert_eq!(map_type("System.DateTime"), "String");
636        assert_eq!(map_type("System.Decimal"), "String");
637        assert_eq!(map_type("System.Object"), "serde_json::Value");
638        assert_eq!(map_type("System.Boolean"), "bool");
639        assert_eq!(map_type("System.Guid"), "uuid::Uuid");
640    }
641
642    #[test]
643    fn maps_nullable_reflection_form() {
644        assert_eq!(
645            map_type("System.Nullable`1[[System.Int32, System.Private.CoreLib]]"),
646            "Option<i32>"
647        );
648        assert_eq!(
649            map_type("System.Collections.Generic.IReadOnlyList`1[[System.String, mscorlib]]"),
650            "Vec<String>"
651        );
652    }
653
654    #[test]
655    fn parses_manifest_from_json() {
656        let json = r#"{"service_id":"s","grains":[{"interface_name":"X.IY","grain_type":"y",
657            "supported_key_kinds":["string"],
658            "methods":[{"name":"Get","response_type":"System.Int64"}]}]}"#;
659        let manifest = Manifest::from_json_str(json).unwrap();
660        assert_eq!(manifest.grains.len(), 1);
661        let code = generate(&manifest, &CodegenOptions::default()).unwrap();
662        assert!(code.contains("pub struct YClient"));
663    }
664}