1use std::time::Duration;
2
3use naia_socket_shared::{LinkConditionerConfig, SocketConfig};
4
5use crate::{
6 connection::compression_config::CompressionConfig,
7 messages::{
8 channels::{
9 channel::{Channel, ChannelDirection, ChannelMode, ChannelSettings},
10 channel_kinds::ChannelKinds,
11 default_channels::DefaultChannelsPlugin,
12 },
13 fragment::FragmentedMessage,
14 message::Message,
15 message_kinds::MessageKinds,
16 },
17 protocol_id::ProtocolId,
18 world::{
19 component::{component_kinds::ComponentKinds, replicate::Replicate},
20 resource::ResourceKinds,
21 },
22 Request, RequestOrResponse,
23};
24
25pub trait ProtocolPlugin {
27 fn build(&self, protocol: &mut Protocol);
29}
30
31#[derive(Clone)]
35pub struct Protocol {
36 pub channel_kinds: ChannelKinds,
38 pub message_kinds: MessageKinds,
40 pub component_kinds: ComponentKinds,
42 pub resource_kinds: ResourceKinds,
46 pub socket: SocketConfig,
48 pub tick_interval: Duration,
50 pub compression: Option<CompressionConfig>,
52 pub client_authoritative_entities: bool,
54 cached_protocol_id: Option<ProtocolId>,
56 locked: bool,
57}
58
59impl Default for Protocol {
60 fn default() -> Self {
61 let mut message_kinds = MessageKinds::new();
62 message_kinds.add_message::<FragmentedMessage>();
63 message_kinds.add_message::<RequestOrResponse>();
64
65 let channel_kinds = ChannelKinds::new();
66
67 Self {
68 channel_kinds,
69 message_kinds,
70 component_kinds: ComponentKinds::new(),
71 resource_kinds: ResourceKinds::new(),
72 socket: SocketConfig::new(None, None),
73 tick_interval: Duration::from_millis(50),
74 compression: None,
75 client_authoritative_entities: false,
76 cached_protocol_id: None,
77 locked: false,
78 }
79 }
80}
81
82impl Protocol {
83 pub fn builder() -> Self {
85 Self::default()
86 }
87
88 pub fn add_plugin<P: ProtocolPlugin>(&mut self, plugin: P) -> &mut Self {
90 self.check_lock();
91 plugin.build(self);
92 self
93 }
94
95 pub fn link_condition(&mut self, config: LinkConditionerConfig) -> &mut Self {
97 self.check_lock();
98 self.socket.link_condition = Some(config);
99 self
100 }
101
102 pub fn rtc_endpoint(&mut self, path: String) -> &mut Self {
104 self.check_lock();
105 self.socket.rtc_endpoint_path = path;
106 self
107 }
108
109 pub fn get_rtc_endpoint(&self) -> String {
111 self.socket.rtc_endpoint_path.clone()
112 }
113
114 pub fn tick_interval(&mut self, duration: Duration) -> &mut Self {
116 self.check_lock();
117 self.tick_interval = duration;
118 self
119 }
120
121 pub fn compression(&mut self, config: CompressionConfig) -> &mut Self {
123 self.check_lock();
124 self.compression = Some(config);
125 self
126 }
127
128 pub fn enable_client_authoritative_entities(&mut self) -> &mut Self {
130 self.check_lock();
131 self.client_authoritative_entities = true;
132 self
133 }
134
135 pub fn add_default_channels(&mut self) -> &mut Self {
137 self.check_lock();
138 let plugin = DefaultChannelsPlugin;
139 plugin.build(self);
140 self
141 }
142
143 pub fn add_channel<C: Channel>(
145 &mut self,
146 direction: ChannelDirection,
147 mode: ChannelMode,
148 ) -> &mut Self {
149 self.check_lock();
150 self.channel_kinds
151 .add_channel::<C>(ChannelSettings::new(mode, direction));
152 self
153 }
154
155 pub fn add_channel_settings<C: Channel>(&mut self, settings: ChannelSettings) -> &mut Self {
159 self.check_lock();
160 self.channel_kinds.add_channel::<C>(settings);
161 self
162 }
163
164 pub fn add_message<M: Message>(&mut self) -> &mut Self {
166 self.check_lock();
167 self.message_kinds.add_message::<M>();
168 self
169 }
170
171 pub fn add_request<Q: Request>(&mut self) -> &mut Self {
173 self.check_lock();
174 self.message_kinds.add_message::<Q>();
176 self.message_kinds.add_message::<Q::Response>();
177 self
178 }
179
180 pub fn add_component<C: Replicate>(&mut self) -> &mut Self {
182 self.check_lock();
183 self.component_kinds.add_component::<C>();
184 self
185 }
186
187 pub fn add_resource<R: Replicate>(&mut self) -> &mut Self {
202 self.check_lock();
203 self.component_kinds.add_component::<R>();
205 let kind = crate::ComponentKind::of::<R>();
207 self.resource_kinds.register::<R>(kind);
208 self
209 }
210
211 pub fn lock(&mut self) {
213 self.check_lock();
214 self.cached_protocol_id = Some(self.compute_protocol_id());
215 self.locked = true;
216 }
217
218 pub fn check_lock(&self) {
220 if self.locked {
221 panic!("Protocol already locked!");
222 }
223 }
224
225 pub fn build(&mut self) -> Self {
227 std::mem::take(self)
228 }
229
230 pub fn protocol_id(&self) -> ProtocolId {
232 self.cached_protocol_id
233 .expect("Protocol must be locked before calling protocol_id()")
234 }
235
236 fn compute_protocol_id(&self) -> ProtocolId {
238 let mut hasher = blake3::Hasher::new();
239
240 for name in self.channel_kinds.all_names() {
242 hasher.update(name.as_bytes());
243 }
244 for name in self.message_kinds.all_names() {
246 hasher.update(name.as_bytes());
247 }
248 for name in self.component_kinds.all_names() {
250 hasher.update(name.as_bytes());
251 }
252 hasher.update(b"naia:resources:");
258 let mut resource_count = 0u32;
259 for _ in self.resource_kinds.iter() {
260 resource_count += 1;
261 }
262 hasher.update(&resource_count.to_le_bytes());
263
264 let hash = hasher.finalize();
265 let mut bytes = [0u8; 8];
266 bytes.copy_from_slice(&hash.as_bytes()[..8]);
267 ProtocolId::new(u64::from_le_bytes(bytes))
268 }
269}