1extern crate alloc;
50
51use alloc::format;
52use alloc::string::{String, ToString};
53use alloc::vec::Vec;
54
55use crate::error::{RpcError, RpcResult};
56use crate::service_mapping::{MethodDef, ParamDirection, ServiceDef, TypeRef};
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum ServiceLayout {
61 Basic,
63 Enhanced,
65}
66
67#[derive(Debug, Clone, PartialEq)]
69pub struct StructMember {
70 pub name: String,
72 pub type_ref: MemberType,
74}
75
76impl StructMember {
77 #[must_use]
79 pub fn new(name: impl Into<String>, type_ref: MemberType) -> Self {
80 Self {
81 name: name.into(),
82 type_ref,
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq)]
91pub enum MemberType {
92 RequestHeader,
94 ReplyHeader,
96 CallUnion(CallUnionDef),
99 Idl(TypeRef),
101}
102
103#[derive(Debug, Clone, PartialEq)]
105pub struct CallUnionCase {
106 pub method: String,
108 pub discriminator: u32,
110 pub case_type_name: String,
112 pub members: Vec<StructMember>,
114}
115
116#[derive(Debug, Clone, PartialEq)]
118pub struct CallUnionDef {
119 pub name: String,
121 pub cases: Vec<CallUnionCase>,
124}
125
126#[derive(Debug, Clone, PartialEq)]
133pub struct RequestType {
134 pub name: String,
137 pub topic_name: String,
139 pub layout: ServiceLayout,
141 pub method: Option<String>,
143 pub members: Vec<StructMember>,
145}
146
147#[derive(Debug, Clone, PartialEq)]
149pub struct ReplyType {
150 pub name: String,
152 pub topic_name: String,
154 pub layout: ServiceLayout,
156 pub method: Option<String>,
158 pub members: Vec<StructMember>,
160}
161
162#[derive(Debug, Clone, PartialEq)]
164pub struct MethodPair {
165 pub request: RequestType,
167 pub reply: Option<ReplyType>,
169}
170
171pub fn build_basic_pair(svc: &ServiceDef) -> RpcResult<(RequestType, ReplyType)> {
185 if svc.name.is_empty() {
186 return Err(RpcError::InvalidServiceName(String::new()));
187 }
188 let topics = svc.topic_names()?;
189 let req_union = build_call_union(svc, true, false)?;
190 let rep_union = build_call_union(svc, false, true)?;
191
192 let request = RequestType {
193 name: format!("{}_Request", svc.name),
194 topic_name: topics.request.clone(),
195 layout: ServiceLayout::Basic,
196 method: None,
197 members: alloc::vec![
198 StructMember::new("header", MemberType::RequestHeader),
199 StructMember::new("call", MemberType::CallUnion(req_union)),
200 ],
201 };
202
203 let reply = ReplyType {
204 name: format!("{}_Reply", svc.name),
205 topic_name: topics.reply,
206 layout: ServiceLayout::Basic,
207 method: None,
208 members: alloc::vec![
209 StructMember::new("header", MemberType::ReplyHeader),
210 StructMember::new("result", MemberType::CallUnion(rep_union)),
211 ],
212 };
213
214 Ok((request, reply))
215}
216
217fn build_call_union(
218 svc: &ServiceDef,
219 include_oneway: bool,
220 reply: bool,
221) -> RpcResult<CallUnionDef> {
222 let union_name = if reply {
223 format!("{}_Result", svc.name)
224 } else {
225 format!("{}_Call", svc.name)
226 };
227 let mut cases = Vec::with_capacity(svc.methods.len());
228 let mut discr: u32 = 1;
229 for m in &svc.methods {
230 if m.oneway && !include_oneway {
231 continue;
232 }
233 let case_type = method_struct_name(svc, m, reply);
234 let members = if reply {
235 method_out_members(m)
236 } else {
237 method_in_members(m)
238 };
239 cases.push(CallUnionCase {
240 method: m.name.clone(),
241 discriminator: discr,
242 case_type_name: case_type,
243 members,
244 });
245 discr = discr
246 .checked_add(1)
247 .ok_or_else(|| RpcError::Codec("more than u32::MAX methods in service".to_string()))?;
248 }
249 Ok(CallUnionDef {
250 name: union_name,
251 cases,
252 })
253}
254
255pub fn build_enhanced_pair(svc: &ServiceDef, method: &MethodDef) -> RpcResult<MethodPair> {
267 if svc.name.is_empty() {
268 return Err(RpcError::InvalidServiceName(String::new()));
269 }
270 if method.name.is_empty() {
271 return Err(RpcError::InvalidMethodName(String::new()));
272 }
273 let request_topic = format!(
277 "{}_{}{}",
278 svc.name,
279 method.name,
280 crate::topic_naming::REQUEST_SUFFIX
281 );
282 let reply_topic = format!(
283 "{}_{}{}",
284 svc.name,
285 method.name,
286 crate::topic_naming::REPLY_SUFFIX
287 );
288
289 let mut req_members = alloc::vec![StructMember::new("header", MemberType::RequestHeader)];
290 req_members.extend(method_in_members(method));
291 let request = RequestType {
292 name: format!("{}_{}_Request", svc.name, method.name),
293 topic_name: request_topic,
294 layout: ServiceLayout::Enhanced,
295 method: Some(method.name.clone()),
296 members: req_members,
297 };
298
299 let reply = if method.oneway {
300 None
301 } else {
302 let mut rep_members = alloc::vec![StructMember::new("header", MemberType::ReplyHeader)];
303 rep_members.extend(method_out_members(method));
304 Some(ReplyType {
305 name: format!("{}_{}_Reply", svc.name, method.name),
306 topic_name: reply_topic,
307 layout: ServiceLayout::Enhanced,
308 method: Some(method.name.clone()),
309 members: rep_members,
310 })
311 };
312
313 Ok(MethodPair { request, reply })
314}
315
316pub fn build_enhanced_all(svc: &ServiceDef) -> RpcResult<Vec<MethodPair>> {
321 let mut out = Vec::with_capacity(svc.methods.len());
322 for m in &svc.methods {
323 out.push(build_enhanced_pair(svc, m)?);
324 }
325 Ok(out)
326}
327
328fn method_struct_name(svc: &ServiceDef, m: &MethodDef, reply: bool) -> String {
333 if reply {
334 format!("{}_{}_Out", svc.name, m.name)
335 } else {
336 format!("{}_{}_In", svc.name, m.name)
337 }
338}
339
340fn method_in_members(m: &MethodDef) -> Vec<StructMember> {
341 m.params
342 .iter()
343 .filter(|p| p.direction.is_in())
344 .map(|p| StructMember::new(p.name.clone(), MemberType::Idl(p.type_ref.clone())))
345 .collect()
346}
347
348fn method_out_members(m: &MethodDef) -> Vec<StructMember> {
349 let mut out = Vec::new();
350 if let Some(ret) = &m.return_type {
351 out.push(StructMember::new("_return", MemberType::Idl(ret.clone())));
352 }
353 for p in m.params.iter().filter(|p| p.direction.is_out()) {
354 let _ = ParamDirection::Out;
356 out.push(StructMember::new(
357 p.name.clone(),
358 MemberType::Idl(p.type_ref.clone()),
359 ));
360 }
361 out
362}
363
364#[cfg(test)]
365#[allow(
366 clippy::unwrap_used,
367 clippy::expect_used,
368 clippy::panic,
369 clippy::unreachable
370)]
371mod tests {
372 use super::*;
373 use crate::annotations::lower_rpc_annotations;
374 use crate::service_mapping::{ParamDef, lower_service};
375 use zerodds_idl::ast::{
376 Annotation, AnnotationParams, Export, Identifier, IntegerType, InterfaceDef, InterfaceKind,
377 OpDecl, ParamAttribute, ParamDecl, PrimitiveType, ScopedName, StringType, TypeSpec,
378 };
379 use zerodds_idl::errors::Span;
380
381 fn sp() -> Span {
382 Span::SYNTHETIC
383 }
384
385 fn ident(t: &str) -> Identifier {
386 Identifier::new(t, sp())
387 }
388
389 fn long_t() -> TypeSpec {
390 TypeSpec::Primitive(PrimitiveType::Integer(IntegerType::Long))
391 }
392
393 fn string_t() -> TypeSpec {
394 TypeSpec::String(StringType {
395 wide: false,
396 bound: None,
397 span: sp(),
398 })
399 }
400
401 fn op(name: &str, oneway: bool, ret: Option<TypeSpec>, params: Vec<ParamDecl>) -> OpDecl {
402 OpDecl {
403 name: ident(name),
404 oneway,
405 return_type: ret,
406 params,
407 raises: Vec::new(),
408 annotations: Vec::new(),
409 span: sp(),
410 }
411 }
412
413 fn param(name: &str, attr: ParamAttribute, ty: TypeSpec) -> ParamDecl {
414 ParamDecl {
415 attribute: attr,
416 type_spec: ty,
417 name: ident(name),
418 annotations: Vec::new(),
419 span: sp(),
420 }
421 }
422
423 fn ann_simple(name: &str) -> Annotation {
424 Annotation {
425 name: ScopedName {
426 absolute: false,
427 parts: alloc::vec![ident(name)],
428 span: sp(),
429 },
430 params: AnnotationParams::None,
431 span: sp(),
432 }
433 }
434
435 fn calc_service() -> ServiceDef {
436 let add = op(
437 "add",
438 false,
439 Some(long_t()),
440 alloc::vec![
441 param("a", ParamAttribute::In, long_t()),
442 param("b", ParamAttribute::In, long_t()),
443 ],
444 );
445 let log = op(
446 "log",
447 true, None,
449 alloc::vec![param("msg", ParamAttribute::In, string_t())],
450 );
451 let i = InterfaceDef {
452 kind: InterfaceKind::Plain,
453 name: ident("Calculator"),
454 bases: Vec::new(),
455 exports: alloc::vec![Export::Op(add), Export::Op(log)],
456 annotations: alloc::vec![ann_simple("service")],
457 span: sp(),
458 };
459 let lowered = lower_rpc_annotations(&i.annotations);
460 lower_service(&i, &lowered).unwrap()
461 }
462
463 #[test]
464 fn basic_pair_topic_names() {
465 let svc = calc_service();
466 let (req, rep) = build_basic_pair(&svc).unwrap();
467 assert_eq!(req.topic_name, "Calculator_Request");
468 assert_eq!(rep.topic_name, "Calculator_Reply");
469 }
470
471 #[test]
472 fn basic_pair_layout_marker() {
473 let svc = calc_service();
474 let (req, rep) = build_basic_pair(&svc).unwrap();
475 assert_eq!(req.layout, ServiceLayout::Basic);
476 assert_eq!(rep.layout, ServiceLayout::Basic);
477 assert_eq!(req.method, None);
478 assert_eq!(rep.method, None);
479 }
480
481 #[test]
482 fn basic_request_has_header_and_call_union() {
483 let svc = calc_service();
484 let (req, _) = build_basic_pair(&svc).unwrap();
485 assert_eq!(req.members.len(), 2);
486 assert_eq!(req.members[0].name, "header");
487 assert!(matches!(req.members[0].type_ref, MemberType::RequestHeader));
488 assert_eq!(req.members[1].name, "call");
489 let call_union = match &req.members[1].type_ref {
490 MemberType::CallUnion(u) => u,
491 _ => panic!("expected CallUnion"),
492 };
493 assert_eq!(call_union.name, "Calculator_Call");
494 assert_eq!(call_union.cases.len(), 2);
496 assert_eq!(call_union.cases[0].method, "add");
497 assert_eq!(call_union.cases[0].discriminator, 1);
498 assert_eq!(call_union.cases[0].case_type_name, "Calculator_add_In");
499 assert_eq!(call_union.cases[1].method, "log");
500 assert_eq!(call_union.cases[1].discriminator, 2);
501 }
502
503 #[test]
504 fn basic_reply_excludes_oneway_methods() {
505 let svc = calc_service();
506 let (_, rep) = build_basic_pair(&svc).unwrap();
507 let result_union = match &rep.members[1].type_ref {
508 MemberType::CallUnion(u) => u,
509 _ => panic!("expected CallUnion"),
510 };
511 assert_eq!(result_union.name, "Calculator_Result");
512 assert_eq!(result_union.cases.len(), 1);
514 assert_eq!(result_union.cases[0].method, "add");
515 assert_eq!(result_union.cases[0].case_type_name, "Calculator_add_Out");
516 }
517
518 #[test]
519 fn basic_request_in_params_become_case_members() {
520 let svc = calc_service();
521 let (req, _) = build_basic_pair(&svc).unwrap();
522 let call_union = match &req.members[1].type_ref {
523 MemberType::CallUnion(u) => u,
524 _ => unreachable!(),
525 };
526 let add_case = &call_union.cases[0];
527 assert_eq!(add_case.members.len(), 2);
528 assert_eq!(add_case.members[0].name, "a");
529 assert_eq!(add_case.members[1].name, "b");
530 }
531
532 #[test]
533 fn basic_reply_return_value_first_member() {
534 let svc = calc_service();
535 let (_, rep) = build_basic_pair(&svc).unwrap();
536 let result_union = match &rep.members[1].type_ref {
537 MemberType::CallUnion(u) => u,
538 _ => unreachable!(),
539 };
540 let add_case = &result_union.cases[0];
541 assert_eq!(add_case.members.len(), 1);
542 assert_eq!(add_case.members[0].name, "_return");
543 }
544
545 #[test]
546 fn enhanced_pair_topic_names() {
547 let svc = calc_service();
548 let pair = build_enhanced_pair(&svc, &svc.methods[0]).unwrap();
549 assert_eq!(pair.request.topic_name, "Calculator_add_Request");
550 assert_eq!(
551 pair.reply.as_ref().unwrap().topic_name,
552 "Calculator_add_Reply"
553 );
554 }
555
556 #[test]
557 fn enhanced_pair_layout_marker() {
558 let svc = calc_service();
559 let pair = build_enhanced_pair(&svc, &svc.methods[0]).unwrap();
560 assert_eq!(pair.request.layout, ServiceLayout::Enhanced);
561 assert_eq!(pair.request.method, Some("add".to_string()));
562 }
563
564 #[test]
565 fn enhanced_oneway_has_no_reply() {
566 let svc = calc_service();
567 let log = svc.methods.iter().find(|m| m.oneway).unwrap();
568 let pair = build_enhanced_pair(&svc, log).unwrap();
569 assert!(pair.reply.is_none());
570 assert_eq!(pair.request.members.len(), 2);
572 assert_eq!(pair.request.members[0].name, "header");
573 assert_eq!(pair.request.members[1].name, "msg");
574 }
575
576 #[test]
577 fn enhanced_pair_request_in_params() {
578 let svc = calc_service();
579 let pair = build_enhanced_pair(&svc, &svc.methods[0]).unwrap();
580 assert_eq!(pair.request.members.len(), 3);
582 assert_eq!(pair.request.members[0].name, "header");
583 assert_eq!(pair.request.members[1].name, "a");
584 assert_eq!(pair.request.members[2].name, "b");
585 }
586
587 #[test]
588 fn enhanced_pair_reply_return_only() {
589 let svc = calc_service();
590 let pair = build_enhanced_pair(&svc, &svc.methods[0]).unwrap();
591 let rep = pair.reply.as_ref().unwrap();
592 assert_eq!(rep.members.len(), 2);
594 assert_eq!(rep.members[0].name, "header");
595 assert_eq!(rep.members[1].name, "_return");
596 }
597
598 #[test]
599 fn enhanced_inout_param_appears_in_both_request_and_reply() {
600 let m = op(
601 "swap",
602 false,
603 None,
604 alloc::vec![param("v", ParamAttribute::InOut, long_t())],
605 );
606 let svc = ServiceDef {
607 name: "Swap".into(),
608 methods: alloc::vec![MethodDef {
609 name: "swap".into(),
610 params: alloc::vec![ParamDef {
611 name: "v".into(),
612 direction: ParamDirection::InOut,
613 type_ref: long_t(),
614 }],
615 return_type: None,
616 oneway: false,
617 }],
618 };
619 let _ = m;
620 let pair = build_enhanced_pair(&svc, &svc.methods[0]).unwrap();
621 assert!(pair.request.members.iter().any(|m| m.name == "v"));
622 let rep = pair.reply.as_ref().unwrap();
623 assert!(rep.members.iter().any(|m| m.name == "v"));
624 }
625
626 #[test]
627 fn enhanced_all_skips_no_method() {
628 let svc = calc_service();
629 let pairs = build_enhanced_all(&svc).unwrap();
630 assert_eq!(pairs.len(), 2);
631 assert_eq!(pairs[0].request.method, Some("add".to_string()));
632 assert_eq!(pairs[1].request.method, Some("log".to_string()));
633 assert!(pairs[1].reply.is_none()); }
635
636 #[test]
637 fn empty_service_yields_empty_unions_in_basic() {
638 let svc = ServiceDef {
639 name: "Empty".into(),
640 methods: Vec::new(),
641 };
642 let (req, rep) = build_basic_pair(&svc).unwrap();
643 let req_u = match &req.members[1].type_ref {
644 MemberType::CallUnion(u) => u,
645 _ => unreachable!(),
646 };
647 let rep_u = match &rep.members[1].type_ref {
648 MemberType::CallUnion(u) => u,
649 _ => unreachable!(),
650 };
651 assert_eq!(req_u.cases.len(), 0);
652 assert_eq!(rep_u.cases.len(), 0);
653 }
654
655 #[test]
656 fn invalid_service_name_is_error_in_codegen() {
657 let svc = ServiceDef {
658 name: String::new(),
659 methods: Vec::new(),
660 };
661 let err = build_basic_pair(&svc).unwrap_err();
662 assert!(matches!(err, RpcError::InvalidServiceName(_)));
663 }
664
665 #[test]
666 fn enhanced_method_with_invalid_name_is_error() {
667 let svc = ServiceDef {
668 name: "S".into(),
669 methods: alloc::vec![MethodDef {
670 name: String::new(),
671 params: Vec::new(),
672 return_type: None,
673 oneway: false,
674 }],
675 };
676 let err = build_enhanced_pair(&svc, &svc.methods[0]).unwrap_err();
677 assert!(matches!(err, RpcError::InvalidMethodName(_)));
678 }
679}