1use facet::{Facet, Shape};
2use facet_path::{Path, walk_shape};
3use std::collections::HashMap;
4use std::sync::{Mutex, OnceLock};
5
6use crate::channel;
7
8pub struct RpcPlan {
13 pub shape: &'static Shape,
15
16 pub channel_locations: &'static [ChannelLocation],
18}
19
20pub struct ChannelLocation {
22 pub path: Path,
24
25 pub kind: ChannelKind,
27
28 pub initial_credit: u32,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum ChannelKind {
35 Rx,
36 Tx,
37}
38
39impl RpcPlan {
40 fn from_shape(shape: &'static Shape) -> Self {
41 let mut visitor = ChannelDiscovery {
42 locations: Vec::new(),
43 };
44 walk_shape(shape, &mut visitor);
45
46 RpcPlan {
47 shape,
48 channel_locations: visitor.locations.leak(),
49 }
50 }
51
52 pub fn for_shape(shape: &'static Shape) -> &'static Self {
54 static CACHE: OnceLock<Mutex<HashMap<usize, &'static RpcPlan>>> = OnceLock::new();
55 let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
56
57 let key = shape as *const Shape as usize;
58
59 let mut guard = cache
60 .lock()
61 .expect("rpc plan cache mutex should not be poisoned");
62 if let Some(plan) = guard.get(&key) {
63 return plan;
64 }
65
66 let plan = Box::leak(Box::new(Self::from_shape(shape)));
67 guard.insert(key, plan);
68 plan
69 }
70
71 pub fn for_type<T: Facet<'static>>() -> &'static Self {
73 Self::for_shape(T::SHAPE)
74 }
75}
76
77fn extract_initial_credit(shape: &'static Shape) -> u32 {
79 shape
80 .const_params
81 .iter()
82 .find(|cp| cp.name == "N")
83 .map(|cp| cp.value as u32)
84 .unwrap_or(16)
85}
86
87struct ChannelDiscovery {
91 locations: Vec<ChannelLocation>,
92}
93
94impl facet_path::ShapeVisitor for ChannelDiscovery {
95 fn enter(&mut self, path: &Path, shape: &'static Shape) -> facet_path::VisitDecision {
96 if channel::is_tx(shape) {
97 self.locations.push(ChannelLocation {
98 path: path.clone(),
99 kind: ChannelKind::Tx,
100 initial_credit: extract_initial_credit(shape),
101 });
102 return facet_path::VisitDecision::SkipChildren;
103 }
104
105 if channel::is_rx(shape) {
106 self.locations.push(ChannelLocation {
107 path: path.clone(),
108 kind: ChannelKind::Rx,
109 initial_credit: extract_initial_credit(shape),
110 });
111 return facet_path::VisitDecision::SkipChildren;
112 }
113
114 if matches!(
116 shape.def,
117 facet::Def::List(_) | facet::Def::Array(_) | facet::Def::Map(_) | facet::Def::Set(_)
118 ) {
119 return facet_path::VisitDecision::SkipChildren;
120 }
121
122 facet_path::VisitDecision::Recurse
123 }
124}