use heck::{ToPascalCase, ToSnakeCase};
use serde::Deserialize;
#[derive(thiserror::Error, Debug)]
pub enum CodegenError {
#[error("failed to parse manifest: {0}")]
Parse(#[from] serde_json::Error),
#[error("invalid manifest: {0}")]
Invalid(String),
}
#[derive(Debug, Clone, Deserialize)]
pub struct Manifest {
#[serde(default)]
pub service_id: String,
#[serde(default)]
pub cluster_id: String,
#[serde(default)]
pub bridge_version: String,
#[serde(default)]
pub schema_version: String,
#[serde(default)]
pub grains: Vec<GrainContract>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GrainContract {
pub interface_name: String,
pub grain_type: String,
#[serde(default)]
pub methods: Vec<GrainMethod>,
#[serde(default)]
pub supported_key_kinds: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MethodParameter {
pub name: String,
#[serde(rename = "type")]
pub ty: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GrainMethod {
pub name: String,
#[serde(default)]
pub request_type: String,
#[serde(default)]
pub parameters: Vec<MethodParameter>,
#[serde(default)]
pub response_type: String,
#[serde(default)]
pub payload_codec: String,
}
impl Manifest {
pub fn from_json_str(json: &str) -> Result<Self, CodegenError> {
Ok(serde_json::from_str(json)?)
}
}
#[derive(Debug, Clone)]
pub struct CodegenOptions {
pub client_crate: String,
pub with_response_context: bool,
}
impl Default for CodegenOptions {
fn default() -> Self {
Self {
client_crate: "orleans_rust_client".to_owned(),
with_response_context: false,
}
}
}
pub fn generate(manifest: &Manifest, options: &CodegenOptions) -> Result<String, CodegenError> {
let mut out = String::new();
out.push_str("// @generated by orleans-rust-codegen. Do not edit by hand.\n");
out.push_str("// Include within a module annotated `#[allow(dead_code, clippy::all)]`.\n\n");
out.push_str(&format!(
"use {client}::{{GrainKey, GrainRef, OrleansClient, OrleansError}};\n\n",
client = options.client_crate
));
for grain in &manifest.grains {
out.push_str(&generate_grain(grain, options)?);
out.push('\n');
}
Ok(out)
}
fn generate_grain(grain: &GrainContract, options: &CodegenOptions) -> Result<String, CodegenError> {
let struct_name = client_struct_name(&grain.interface_name)?;
let key = KeyStrategy::from_kinds(&grain.supported_key_kinds);
let mut s = String::new();
s.push_str(&format!(
"/// Typed client for `{}`.\n",
grain.interface_name
));
s.push_str(&format!(
"pub struct {struct_name} {{\n inner: GrainRef,\n}}\n\n"
));
s.push_str(&format!("impl {struct_name} {{\n"));
s.push_str(&format!(
" /// Construct a client bound to `key`.\n pub fn new(client: OrleansClient, key: {key_param}) -> Self {{\n Self {{\n inner: client.grain(\n \"{interface}\",\n \"{grain_type}\",\n {key_expr},\n ),\n }}\n }}\n",
key_param = key.param_type(),
interface = grain.interface_name,
grain_type = grain.grain_type,
key_expr = key.key_expr(),
));
for method in &grain.methods {
s.push('\n');
s.push_str(&generate_method(method, options));
}
s.push_str("}\n");
Ok(s)
}
fn generate_method(method: &GrainMethod, options: &CodegenOptions) -> String {
let fn_name = sanitize_ident(&method.name.to_snake_case());
let response_ty = map_type(&method.response_type);
let args: Vec<(String, String)> = if !method.parameters.is_empty() {
method
.parameters
.iter()
.map(|p| (sanitize_ident(&p.name.to_snake_case()), map_type(&p.ty)))
.collect()
} else if map_type(&method.request_type) != "()" {
vec![("value".to_owned(), map_type(&method.request_type))]
} else {
Vec::new()
};
let signature_args: String = args
.iter()
.map(|(name, ty)| format!(", {name}: {ty}"))
.collect();
let call_arg = match args.as_slice() {
[] => "&()".to_owned(),
[(name, _)] => format!("&{name}"),
many => format!(
"&({})",
many.iter()
.map(|(name, _)| name.clone())
.collect::<Vec<_>>()
.join(", ")
),
};
let mut out = format!(
" /// Invokes `{orig}`.\n pub async fn {fn_name}(&self{signature_args}) -> Result<{response_ty}, OrleansError> {{\n self.inner.invoke_json(\"{orig}\", {call_arg}).await\n }}\n",
orig = method.name,
);
if options.with_response_context {
out.push_str(&format!(
"\n /// Invokes `{orig}`, also returning the response context.\n pub async fn {fn_name}_with_context(&self{signature_args}) -> Result<({response_ty}, std::collections::HashMap<String, String>), OrleansError> {{\n self.inner.invoke_json_with_context(\"{orig}\", {call_arg}).await\n }}\n",
orig = method.name,
));
}
out
}
#[derive(Debug, Clone, Copy)]
enum KeyStrategy {
String,
Int64,
Guid,
}
impl KeyStrategy {
fn from_kinds(kinds: &[String]) -> Self {
for kind in kinds {
match kind.as_str() {
"int64" => return KeyStrategy::Int64,
"guid" => return KeyStrategy::Guid,
_ => {}
}
}
KeyStrategy::String
}
fn param_type(self) -> &'static str {
match self {
KeyStrategy::String => "impl Into<String>",
KeyStrategy::Int64 => "i64",
KeyStrategy::Guid => "uuid::Uuid",
}
}
fn key_expr(self) -> &'static str {
match self {
KeyStrategy::String => "GrainKey::String(key.into())",
KeyStrategy::Int64 => "GrainKey::Int64(key)",
KeyStrategy::Guid => "GrainKey::Guid(key)",
}
}
}
fn client_struct_name(interface_name: &str) -> Result<String, CodegenError> {
let last = interface_name.rsplit('.').next().unwrap_or(interface_name);
let trimmed = last
.strip_prefix('I')
.filter(|rest| rest.chars().next().is_some_and(char::is_uppercase))
.unwrap_or(last);
let base = trimmed.to_pascal_case();
if base.is_empty() {
return Err(CodegenError::Invalid(format!(
"cannot derive a client name from interface `{interface_name}`"
)));
}
Ok(format!("{base}Client"))
}
fn map_type(dotnet: &str) -> String {
let normalized = dotnet.trim();
if let Some(scalar) = map_scalar(normalized) {
return scalar;
}
if let Some(inner) = strip_nullable(normalized) {
return format!("Option<{}>", map_type(&inner));
}
if let Some(element) = normalized.strip_suffix("[]") {
return format!("Vec<{}>", map_type(element));
}
if let Some((base, args)) = parse_generic(normalized) {
match (base.as_str(), args.as_slice()) {
(
"System.Collections.Generic.List"
| "System.Collections.Generic.IList"
| "System.Collections.Generic.IReadOnlyList"
| "System.Collections.Generic.ICollection"
| "System.Collections.Generic.IEnumerable"
| "List"
| "IList"
| "IReadOnlyList"
| "IEnumerable",
[item],
) => return format!("Vec<{}>", map_type(item)),
(
"System.Collections.Generic.Dictionary"
| "System.Collections.Generic.IDictionary"
| "System.Collections.Generic.IReadOnlyDictionary"
| "Dictionary"
| "IDictionary",
[key, value],
) => {
return format!(
"std::collections::HashMap<{}, {}>",
map_type(key),
map_type(value)
);
}
("System.Nullable" | "Nullable", [item]) => {
return format!("Option<{}>", map_type(item));
}
_ => {}
}
}
"serde_json::Value".to_owned()
}
fn map_scalar(normalized: &str) -> Option<String> {
let mapped = match normalized {
"" | "void" | "System.Void" | "System.Threading.Tasks.Task" => "()",
"System.String" | "string" => "String",
"System.Boolean" | "bool" => "bool",
"System.SByte" | "sbyte" => "i8",
"System.Byte" | "byte" => "u8",
"System.Int16" | "short" => "i16",
"System.UInt16" | "ushort" => "u16",
"System.Int32" | "int" => "i32",
"System.UInt32" | "uint" => "u32",
"System.Int64" | "long" => "i64",
"System.UInt64" | "ulong" => "u64",
"System.Single" | "float" => "f32",
"System.Double" | "double" => "f64",
"System.Guid" => "uuid::Uuid",
"System.DateTime"
| "System.DateTimeOffset"
| "System.TimeSpan"
| "System.Decimal"
| "decimal" => "String",
"System.Object" | "object" => "serde_json::Value",
_ => return None,
};
Some(mapped.to_owned())
}
fn strip_nullable(normalized: &str) -> Option<String> {
if let Some(inner) = normalized.strip_suffix('?') {
return Some(inner.trim().to_owned());
}
None
}
fn parse_generic(name: &str) -> Option<(String, Vec<String>)> {
if let Some(open) = name.find('<') {
if !name.ends_with('>') {
return None;
}
let base = name[..open].trim().to_owned();
let inner = &name[open + 1..name.len() - 1];
return Some((base, split_top_level(inner)));
}
if let Some(tick) = name.find('`') {
let base = name[..tick].trim().to_owned();
let rest = &name[tick..];
let outer_open = rest.find('[')?;
let outer = rest[outer_open..].trim();
let inner = outer.strip_prefix('[')?.strip_suffix(']')?;
let args = split_top_level(inner)
.into_iter()
.map(|group| {
let group = group.trim();
let group = group.strip_prefix('[').unwrap_or(group);
let group = group.strip_suffix(']').unwrap_or(group);
group.split(',').next().unwrap_or(group).trim().to_owned()
})
.collect();
return Some((base, args));
}
None
}
fn split_top_level(input: &str) -> Vec<String> {
let mut parts = Vec::new();
let mut depth = 0i32;
let mut current = String::new();
for ch in input.chars() {
match ch {
'<' | '[' => {
depth += 1;
current.push(ch);
}
'>' | ']' => {
depth -= 1;
current.push(ch);
}
',' if depth == 0 => {
parts.push(current.trim().to_owned());
current.clear();
}
_ => current.push(ch),
}
}
if !current.trim().is_empty() {
parts.push(current.trim().to_owned());
}
parts
}
fn sanitize_ident(name: &str) -> String {
const RESERVED: &[&str] = &[
"as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
"extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
"mut", "pub", "ref", "return", "self", "static", "struct", "super", "trait", "true",
"type", "unsafe", "use", "where", "while",
];
if RESERVED.contains(&name) {
format!("r#{name}")
} else {
name.to_owned()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn method(name: &str, request: &str, response: &str) -> GrainMethod {
GrainMethod {
name: name.to_owned(),
request_type: request.to_owned(),
parameters: Vec::new(),
response_type: response.to_owned(),
payload_codec: "json".to_owned(),
}
}
fn grain(methods: Vec<GrainMethod>) -> Manifest {
Manifest {
service_id: "s".into(),
cluster_id: "c".into(),
bridge_version: "0.1.0".into(),
schema_version: "1".into(),
grains: vec![GrainContract {
interface_name: "Counter.Abstractions.ICounterGrain".into(),
grain_type: "counter".into(),
supported_key_kinds: vec!["string".into()],
methods,
}],
}
}
#[test]
fn derives_client_name() {
assert_eq!(
client_struct_name("Counter.Abstractions.ICounterGrain").unwrap(),
"CounterGrainClient"
);
assert_eq!(
client_struct_name("ICounterGrain").unwrap(),
"CounterGrainClient"
);
}
#[test]
fn maps_primitive_types() {
assert_eq!(map_type("System.Int64"), "i64");
assert_eq!(map_type(""), "()");
assert_eq!(map_type("Some.Custom.Type"), "serde_json::Value");
}
#[test]
fn maps_collections_and_options() {
assert_eq!(map_type("System.String?"), "Option<String>");
assert_eq!(map_type("System.Byte[]"), "Vec<u8>");
assert_eq!(map_type("System.Int32[]"), "Vec<i32>");
assert_eq!(map_type("List<System.Int64>"), "Vec<i64>");
assert_eq!(
map_type("Dictionary<System.String, System.Int32>"),
"std::collections::HashMap<String, i32>"
);
}
#[test]
fn maps_reflection_generic_names() {
assert_eq!(
map_type("System.Collections.Generic.List`1[[System.Int64, System.Private.CoreLib]]"),
"Vec<i64>"
);
assert_eq!(
map_type(
"System.Collections.Generic.Dictionary`2[[System.String, mscorlib],[System.Int32, mscorlib]]"
),
"std::collections::HashMap<String, i32>"
);
}
#[test]
fn generates_counter_client() {
let manifest = grain(vec![
method("Get", "", "System.Int64"),
method("Add", "System.Int64", "System.Int64"),
]);
let code = generate(&manifest, &CodegenOptions::default()).unwrap();
assert!(code.contains("pub struct CounterGrainClient"));
assert!(code.contains("pub async fn get(&self) -> Result<i64, OrleansError>"));
assert!(code.contains("pub async fn add(&self, value: i64) -> Result<i64, OrleansError>"));
}
#[test]
fn generates_multi_argument_method() {
let mut transfer = method("Transfer", "", "System.Boolean");
transfer.parameters = vec![
MethodParameter {
name: "destination".into(),
ty: "System.String".into(),
},
MethodParameter {
name: "amount".into(),
ty: "System.Int64".into(),
},
];
let code = generate(&grain(vec![transfer]), &CodegenOptions::default()).unwrap();
assert!(code.contains(
"pub async fn transfer(&self, destination: String, amount: i64) -> Result<bool, OrleansError>"
));
assert!(code.contains("invoke_json(\"Transfer\", &(destination, amount))"));
}
#[test]
fn generates_response_context_variant() {
let options = CodegenOptions {
with_response_context: true,
..Default::default()
};
let code = generate(&grain(vec![method("Get", "", "System.Int64")]), &options).unwrap();
assert!(code.contains(
"pub async fn get_with_context(&self) -> Result<(i64, std::collections::HashMap<String, String>), OrleansError>"
));
assert!(code.contains("invoke_json_with_context(\"Get\", &())"));
}
fn grain_with_keys(kinds: Vec<&str>, methods: Vec<GrainMethod>) -> Manifest {
Manifest {
service_id: "s".into(),
cluster_id: "c".into(),
bridge_version: "0.1.0".into(),
schema_version: "1".into(),
grains: vec![GrainContract {
interface_name: "Sample.IThingGrain".into(),
grain_type: "thing".into(),
supported_key_kinds: kinds.into_iter().map(str::to_owned).collect(),
methods,
}],
}
}
#[test]
fn generates_int64_key_constructor() {
let code = generate(
&grain_with_keys(vec!["int64"], vec![method("Get", "", "System.Int64")]),
&CodegenOptions::default(),
)
.unwrap();
assert!(code.contains("pub fn new(client: OrleansClient, key: i64) -> Self"));
assert!(code.contains("GrainKey::Int64(key)"));
}
#[test]
fn generates_guid_key_constructor() {
let code = generate(
&grain_with_keys(vec!["guid"], vec![method("Get", "", "System.Int64")]),
&CodegenOptions::default(),
)
.unwrap();
assert!(code.contains("pub fn new(client: OrleansClient, key: uuid::Uuid) -> Self"));
assert!(code.contains("GrainKey::Guid(key)"));
}
#[test]
fn sanitizes_reserved_method_names() {
let code = generate(
&grain(vec![method("Type", "", "System.String")]),
&CodegenOptions::default(),
)
.unwrap();
assert!(code.contains("pub async fn r#type(&self)"));
}
#[test]
fn empty_interface_name_is_an_error() {
let mut manifest = grain_with_keys(vec!["string"], vec![method("Get", "", "")]);
manifest.grains[0].interface_name = String::new();
let err = generate(&manifest, &CodegenOptions::default()).unwrap_err();
assert!(matches!(err, CodegenError::Invalid(_)));
}
#[test]
fn maps_additional_scalars() {
assert_eq!(map_type("System.DateTime"), "String");
assert_eq!(map_type("System.Decimal"), "String");
assert_eq!(map_type("System.Object"), "serde_json::Value");
assert_eq!(map_type("System.Boolean"), "bool");
assert_eq!(map_type("System.Guid"), "uuid::Uuid");
}
#[test]
fn maps_nullable_reflection_form() {
assert_eq!(
map_type("System.Nullable`1[[System.Int32, System.Private.CoreLib]]"),
"Option<i32>"
);
assert_eq!(
map_type("System.Collections.Generic.IReadOnlyList`1[[System.String, mscorlib]]"),
"Vec<String>"
);
}
#[test]
fn parses_manifest_from_json() {
let json = r#"{"service_id":"s","grains":[{"interface_name":"X.IY","grain_type":"y",
"supported_key_kinds":["string"],
"methods":[{"name":"Get","response_type":"System.Int64"}]}]}"#;
let manifest = Manifest::from_json_str(json).unwrap();
assert_eq!(manifest.grains.len(), 1);
let code = generate(&manifest, &CodegenOptions::default()).unwrap();
assert!(code.contains("pub struct YClient"));
}
}