1use heck::{ToPascalCase, ToSnakeCase};
19use serde::Deserialize;
20
21#[derive(thiserror::Error, Debug)]
23pub enum CodegenError {
24 #[error("failed to parse manifest: {0}")]
26 Parse(#[from] serde_json::Error),
27 #[error("invalid manifest: {0}")]
29 Invalid(String),
30}
31
32#[derive(Debug, Clone, Deserialize)]
34pub struct Manifest {
35 #[serde(default)]
37 pub service_id: String,
38 #[serde(default)]
40 pub cluster_id: String,
41 #[serde(default)]
43 pub bridge_version: String,
44 #[serde(default)]
46 pub schema_version: String,
47 #[serde(default)]
49 pub grains: Vec<GrainContract>,
50}
51
52#[derive(Debug, Clone, Deserialize)]
54pub struct GrainContract {
55 pub interface_name: String,
57 pub grain_type: String,
59 #[serde(default)]
61 pub methods: Vec<GrainMethod>,
62 #[serde(default)]
64 pub supported_key_kinds: Vec<String>,
65}
66
67#[derive(Debug, Clone, Deserialize)]
69pub struct MethodParameter {
70 pub name: String,
72 #[serde(rename = "type")]
74 pub ty: String,
75}
76
77#[derive(Debug, Clone, Deserialize)]
79pub struct GrainMethod {
80 pub name: String,
82 #[serde(default)]
85 pub request_type: String,
86 #[serde(default)]
89 pub parameters: Vec<MethodParameter>,
90 #[serde(default)]
92 pub response_type: String,
93 #[serde(default)]
95 pub payload_codec: String,
96}
97
98impl Manifest {
99 pub fn from_json_str(json: &str) -> Result<Self, CodegenError> {
104 Ok(serde_json::from_str(json)?)
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct CodegenOptions {
111 pub client_crate: String,
113 pub with_response_context: bool,
116}
117
118impl Default for CodegenOptions {
119 fn default() -> Self {
120 Self {
121 client_crate: "orleans_rust_client".to_owned(),
122 with_response_context: false,
123 }
124 }
125}
126
127pub fn generate(manifest: &Manifest, options: &CodegenOptions) -> Result<String, CodegenError> {
133 let mut out = String::new();
134 out.push_str("// @generated by orleans-rust-codegen. Do not edit by hand.\n");
135 out.push_str("// Include within a module annotated `#[allow(dead_code, clippy::all)]`.\n\n");
136 out.push_str(&format!(
137 "use {client}::{{GrainKey, GrainRef, OrleansClient, OrleansError}};\n\n",
138 client = options.client_crate
139 ));
140
141 for grain in &manifest.grains {
142 out.push_str(&generate_grain(grain, options)?);
143 out.push('\n');
144 }
145
146 Ok(out)
147}
148
149fn generate_grain(grain: &GrainContract, options: &CodegenOptions) -> Result<String, CodegenError> {
150 let struct_name = client_struct_name(&grain.interface_name)?;
151 let key = KeyStrategy::from_kinds(&grain.supported_key_kinds);
152
153 let mut s = String::new();
154 s.push_str(&format!(
155 "/// Typed client for `{}`.\n",
156 grain.interface_name
157 ));
158 s.push_str(&format!(
159 "pub struct {struct_name} {{\n inner: GrainRef,\n}}\n\n"
160 ));
161 s.push_str(&format!("impl {struct_name} {{\n"));
162 s.push_str(&format!(
163 " /// Construct a client bound to `key`.\n pub fn new(client: OrleansClient, key: {key_param}) -> Self {{\n Self {{\n inner: client.grain(\n \"{interface}\",\n \"{grain_type}\",\n {key_expr},\n ),\n }}\n }}\n",
164 key_param = key.param_type(),
165 interface = grain.interface_name,
166 grain_type = grain.grain_type,
167 key_expr = key.key_expr(),
168 ));
169
170 for method in &grain.methods {
171 s.push('\n');
172 s.push_str(&generate_method(method, options));
173 }
174
175 s.push_str("}\n");
176 Ok(s)
177}
178
179fn generate_method(method: &GrainMethod, options: &CodegenOptions) -> String {
180 let fn_name = sanitize_ident(&method.name.to_snake_case());
181 let response_ty = map_type(&method.response_type);
182
183 let args: Vec<(String, String)> = if !method.parameters.is_empty() {
186 method
187 .parameters
188 .iter()
189 .map(|p| (sanitize_ident(&p.name.to_snake_case()), map_type(&p.ty)))
190 .collect()
191 } else if map_type(&method.request_type) != "()" {
192 vec![("value".to_owned(), map_type(&method.request_type))]
193 } else {
194 Vec::new()
195 };
196
197 let signature_args: String = args
198 .iter()
199 .map(|(name, ty)| format!(", {name}: {ty}"))
200 .collect();
201
202 let call_arg = match args.as_slice() {
205 [] => "&()".to_owned(),
206 [(name, _)] => format!("&{name}"),
207 many => format!(
208 "&({})",
209 many.iter()
210 .map(|(name, _)| name.clone())
211 .collect::<Vec<_>>()
212 .join(", ")
213 ),
214 };
215
216 let mut out = format!(
217 " /// Invokes `{orig}`.\n pub async fn {fn_name}(&self{signature_args}) -> Result<{response_ty}, OrleansError> {{\n self.inner.invoke_json(\"{orig}\", {call_arg}).await\n }}\n",
218 orig = method.name,
219 );
220
221 if options.with_response_context {
222 out.push_str(&format!(
223 "\n /// Invokes `{orig}`, also returning the response context.\n pub async fn {fn_name}_with_context(&self{signature_args}) -> Result<({response_ty}, std::collections::HashMap<String, String>), OrleansError> {{\n self.inner.invoke_json_with_context(\"{orig}\", {call_arg}).await\n }}\n",
224 orig = method.name,
225 ));
226 }
227
228 out
229}
230
231#[derive(Debug, Clone, Copy)]
232enum KeyStrategy {
233 String,
234 Int64,
235 Guid,
236}
237
238impl KeyStrategy {
239 fn from_kinds(kinds: &[String]) -> Self {
240 for kind in kinds {
241 match kind.as_str() {
242 "int64" => return KeyStrategy::Int64,
243 "guid" => return KeyStrategy::Guid,
244 _ => {}
245 }
246 }
247 KeyStrategy::String
248 }
249
250 fn param_type(self) -> &'static str {
251 match self {
252 KeyStrategy::String => "impl Into<String>",
253 KeyStrategy::Int64 => "i64",
254 KeyStrategy::Guid => "uuid::Uuid",
255 }
256 }
257
258 fn key_expr(self) -> &'static str {
259 match self {
260 KeyStrategy::String => "GrainKey::String(key.into())",
261 KeyStrategy::Int64 => "GrainKey::Int64(key)",
262 KeyStrategy::Guid => "GrainKey::Guid(key)",
263 }
264 }
265}
266
267fn client_struct_name(interface_name: &str) -> Result<String, CodegenError> {
268 let last = interface_name.rsplit('.').next().unwrap_or(interface_name);
269 let trimmed = last
270 .strip_prefix('I')
271 .filter(|rest| rest.chars().next().is_some_and(char::is_uppercase))
272 .unwrap_or(last);
273 let base = trimmed.to_pascal_case();
274 if base.is_empty() {
275 return Err(CodegenError::Invalid(format!(
276 "cannot derive a client name from interface `{interface_name}`"
277 )));
278 }
279 Ok(format!("{base}Client"))
280}
281
282fn map_type(dotnet: &str) -> String {
286 let normalized = dotnet.trim();
287
288 if let Some(scalar) = map_scalar(normalized) {
291 return scalar;
292 }
293
294 if let Some(inner) = strip_nullable(normalized) {
296 return format!("Option<{}>", map_type(&inner));
297 }
298
299 if let Some(element) = normalized.strip_suffix("[]") {
301 return format!("Vec<{}>", map_type(element));
302 }
303
304 if let Some((base, args)) = parse_generic(normalized) {
306 match (base.as_str(), args.as_slice()) {
307 (
308 "System.Collections.Generic.List"
309 | "System.Collections.Generic.IList"
310 | "System.Collections.Generic.IReadOnlyList"
311 | "System.Collections.Generic.ICollection"
312 | "System.Collections.Generic.IEnumerable"
313 | "List"
314 | "IList"
315 | "IReadOnlyList"
316 | "IEnumerable",
317 [item],
318 ) => return format!("Vec<{}>", map_type(item)),
319 (
320 "System.Collections.Generic.Dictionary"
321 | "System.Collections.Generic.IDictionary"
322 | "System.Collections.Generic.IReadOnlyDictionary"
323 | "Dictionary"
324 | "IDictionary",
325 [key, value],
326 ) => {
327 return format!(
328 "std::collections::HashMap<{}, {}>",
329 map_type(key),
330 map_type(value)
331 );
332 }
333 ("System.Nullable" | "Nullable", [item]) => {
334 return format!("Option<{}>", map_type(item));
335 }
336 _ => {}
337 }
338 }
339
340 "serde_json::Value".to_owned()
341}
342
343fn map_scalar(normalized: &str) -> Option<String> {
344 let mapped = match normalized {
345 "" | "void" | "System.Void" | "System.Threading.Tasks.Task" => "()",
346 "System.String" | "string" => "String",
347 "System.Boolean" | "bool" => "bool",
348 "System.SByte" | "sbyte" => "i8",
349 "System.Byte" | "byte" => "u8",
350 "System.Int16" | "short" => "i16",
351 "System.UInt16" | "ushort" => "u16",
352 "System.Int32" | "int" => "i32",
353 "System.UInt32" | "uint" => "u32",
354 "System.Int64" | "long" => "i64",
355 "System.UInt64" | "ulong" => "u64",
356 "System.Single" | "float" => "f32",
357 "System.Double" | "double" => "f64",
358 "System.Guid" => "uuid::Uuid",
359 "System.DateTime"
360 | "System.DateTimeOffset"
361 | "System.TimeSpan"
362 | "System.Decimal"
363 | "decimal" => "String",
364 "System.Object" | "object" => "serde_json::Value",
365 _ => return None,
366 };
367 Some(mapped.to_owned())
368}
369
370fn strip_nullable(normalized: &str) -> Option<String> {
372 if let Some(inner) = normalized.strip_suffix('?') {
373 return Some(inner.trim().to_owned());
374 }
375 None
376}
377
378fn parse_generic(name: &str) -> Option<(String, Vec<String>)> {
382 if let Some(open) = name.find('<') {
383 if !name.ends_with('>') {
384 return None;
385 }
386 let base = name[..open].trim().to_owned();
387 let inner = &name[open + 1..name.len() - 1];
388 return Some((base, split_top_level(inner)));
389 }
390
391 if let Some(tick) = name.find('`') {
392 let base = name[..tick].trim().to_owned();
393 let rest = &name[tick..];
394 let outer_open = rest.find('[')?;
395 let outer = rest[outer_open..].trim();
396 let inner = outer.strip_prefix('[')?.strip_suffix(']')?;
397 let args = split_top_level(inner)
400 .into_iter()
401 .map(|group| {
402 let group = group.trim();
403 let group = group.strip_prefix('[').unwrap_or(group);
404 let group = group.strip_suffix(']').unwrap_or(group);
405 group.split(',').next().unwrap_or(group).trim().to_owned()
406 })
407 .collect();
408 return Some((base, args));
409 }
410
411 None
412}
413
414fn split_top_level(input: &str) -> Vec<String> {
416 let mut parts = Vec::new();
417 let mut depth = 0i32;
418 let mut current = String::new();
419 for ch in input.chars() {
420 match ch {
421 '<' | '[' => {
422 depth += 1;
423 current.push(ch);
424 }
425 '>' | ']' => {
426 depth -= 1;
427 current.push(ch);
428 }
429 ',' if depth == 0 => {
430 parts.push(current.trim().to_owned());
431 current.clear();
432 }
433 _ => current.push(ch),
434 }
435 }
436 if !current.trim().is_empty() {
437 parts.push(current.trim().to_owned());
438 }
439 parts
440}
441
442fn sanitize_ident(name: &str) -> String {
443 const RESERVED: &[&str] = &[
444 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
445 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
446 "mut", "pub", "ref", "return", "self", "static", "struct", "super", "trait", "true",
447 "type", "unsafe", "use", "where", "while",
448 ];
449 if RESERVED.contains(&name) {
450 format!("r#{name}")
451 } else {
452 name.to_owned()
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 fn method(name: &str, request: &str, response: &str) -> GrainMethod {
461 GrainMethod {
462 name: name.to_owned(),
463 request_type: request.to_owned(),
464 parameters: Vec::new(),
465 response_type: response.to_owned(),
466 payload_codec: "json".to_owned(),
467 }
468 }
469
470 fn grain(methods: Vec<GrainMethod>) -> Manifest {
471 Manifest {
472 service_id: "s".into(),
473 cluster_id: "c".into(),
474 bridge_version: "0.1.0".into(),
475 schema_version: "1".into(),
476 grains: vec![GrainContract {
477 interface_name: "Counter.Abstractions.ICounterGrain".into(),
478 grain_type: "counter".into(),
479 supported_key_kinds: vec!["string".into()],
480 methods,
481 }],
482 }
483 }
484
485 #[test]
486 fn derives_client_name() {
487 assert_eq!(
488 client_struct_name("Counter.Abstractions.ICounterGrain").unwrap(),
489 "CounterGrainClient"
490 );
491 assert_eq!(
492 client_struct_name("ICounterGrain").unwrap(),
493 "CounterGrainClient"
494 );
495 }
496
497 #[test]
498 fn maps_primitive_types() {
499 assert_eq!(map_type("System.Int64"), "i64");
500 assert_eq!(map_type(""), "()");
501 assert_eq!(map_type("Some.Custom.Type"), "serde_json::Value");
502 }
503
504 #[test]
505 fn maps_collections_and_options() {
506 assert_eq!(map_type("System.String?"), "Option<String>");
507 assert_eq!(map_type("System.Byte[]"), "Vec<u8>");
508 assert_eq!(map_type("System.Int32[]"), "Vec<i32>");
509 assert_eq!(map_type("List<System.Int64>"), "Vec<i64>");
510 assert_eq!(
511 map_type("Dictionary<System.String, System.Int32>"),
512 "std::collections::HashMap<String, i32>"
513 );
514 }
515
516 #[test]
517 fn maps_reflection_generic_names() {
518 assert_eq!(
519 map_type("System.Collections.Generic.List`1[[System.Int64, System.Private.CoreLib]]"),
520 "Vec<i64>"
521 );
522 assert_eq!(
523 map_type(
524 "System.Collections.Generic.Dictionary`2[[System.String, mscorlib],[System.Int32, mscorlib]]"
525 ),
526 "std::collections::HashMap<String, i32>"
527 );
528 }
529
530 #[test]
531 fn generates_counter_client() {
532 let manifest = grain(vec![
533 method("Get", "", "System.Int64"),
534 method("Add", "System.Int64", "System.Int64"),
535 ]);
536
537 let code = generate(&manifest, &CodegenOptions::default()).unwrap();
538 assert!(code.contains("pub struct CounterGrainClient"));
539 assert!(code.contains("pub async fn get(&self) -> Result<i64, OrleansError>"));
540 assert!(code.contains("pub async fn add(&self, value: i64) -> Result<i64, OrleansError>"));
541 }
542
543 #[test]
544 fn generates_multi_argument_method() {
545 let mut transfer = method("Transfer", "", "System.Boolean");
546 transfer.parameters = vec![
547 MethodParameter {
548 name: "destination".into(),
549 ty: "System.String".into(),
550 },
551 MethodParameter {
552 name: "amount".into(),
553 ty: "System.Int64".into(),
554 },
555 ];
556
557 let code = generate(&grain(vec![transfer]), &CodegenOptions::default()).unwrap();
558 assert!(code.contains(
559 "pub async fn transfer(&self, destination: String, amount: i64) -> Result<bool, OrleansError>"
560 ));
561 assert!(code.contains("invoke_json(\"Transfer\", &(destination, amount))"));
562 }
563
564 #[test]
565 fn generates_response_context_variant() {
566 let options = CodegenOptions {
567 with_response_context: true,
568 ..Default::default()
569 };
570 let code = generate(&grain(vec![method("Get", "", "System.Int64")]), &options).unwrap();
571 assert!(code.contains(
572 "pub async fn get_with_context(&self) -> Result<(i64, std::collections::HashMap<String, String>), OrleansError>"
573 ));
574 assert!(code.contains("invoke_json_with_context(\"Get\", &())"));
575 }
576
577 fn grain_with_keys(kinds: Vec<&str>, methods: Vec<GrainMethod>) -> Manifest {
578 Manifest {
579 service_id: "s".into(),
580 cluster_id: "c".into(),
581 bridge_version: "0.1.0".into(),
582 schema_version: "1".into(),
583 grains: vec![GrainContract {
584 interface_name: "Sample.IThingGrain".into(),
585 grain_type: "thing".into(),
586 supported_key_kinds: kinds.into_iter().map(str::to_owned).collect(),
587 methods,
588 }],
589 }
590 }
591
592 #[test]
593 fn generates_int64_key_constructor() {
594 let code = generate(
595 &grain_with_keys(vec!["int64"], vec![method("Get", "", "System.Int64")]),
596 &CodegenOptions::default(),
597 )
598 .unwrap();
599 assert!(code.contains("pub fn new(client: OrleansClient, key: i64) -> Self"));
600 assert!(code.contains("GrainKey::Int64(key)"));
601 }
602
603 #[test]
604 fn generates_guid_key_constructor() {
605 let code = generate(
606 &grain_with_keys(vec!["guid"], vec![method("Get", "", "System.Int64")]),
607 &CodegenOptions::default(),
608 )
609 .unwrap();
610 assert!(code.contains("pub fn new(client: OrleansClient, key: uuid::Uuid) -> Self"));
611 assert!(code.contains("GrainKey::Guid(key)"));
612 }
613
614 #[test]
615 fn sanitizes_reserved_method_names() {
616 let code = generate(
618 &grain(vec![method("Type", "", "System.String")]),
619 &CodegenOptions::default(),
620 )
621 .unwrap();
622 assert!(code.contains("pub async fn r#type(&self)"));
623 }
624
625 #[test]
626 fn empty_interface_name_is_an_error() {
627 let mut manifest = grain_with_keys(vec!["string"], vec![method("Get", "", "")]);
628 manifest.grains[0].interface_name = String::new();
629 let err = generate(&manifest, &CodegenOptions::default()).unwrap_err();
630 assert!(matches!(err, CodegenError::Invalid(_)));
631 }
632
633 #[test]
634 fn maps_additional_scalars() {
635 assert_eq!(map_type("System.DateTime"), "String");
636 assert_eq!(map_type("System.Decimal"), "String");
637 assert_eq!(map_type("System.Object"), "serde_json::Value");
638 assert_eq!(map_type("System.Boolean"), "bool");
639 assert_eq!(map_type("System.Guid"), "uuid::Uuid");
640 }
641
642 #[test]
643 fn maps_nullable_reflection_form() {
644 assert_eq!(
645 map_type("System.Nullable`1[[System.Int32, System.Private.CoreLib]]"),
646 "Option<i32>"
647 );
648 assert_eq!(
649 map_type("System.Collections.Generic.IReadOnlyList`1[[System.String, mscorlib]]"),
650 "Vec<String>"
651 );
652 }
653
654 #[test]
655 fn parses_manifest_from_json() {
656 let json = r#"{"service_id":"s","grains":[{"interface_name":"X.IY","grain_type":"y",
657 "supported_key_kinds":["string"],
658 "methods":[{"name":"Get","response_type":"System.Int64"}]}]}"#;
659 let manifest = Manifest::from_json_str(json).unwrap();
660 assert_eq!(manifest.grains.len(), 1);
661 let code = generate(&manifest, &CodegenOptions::default()).unwrap();
662 assert!(code.contains("pub struct YClient"));
663 }
664}