1use bus_controller::BusController;
2pub use errors::{Error, JoinError, RecvError, SendError};
3pub use ipmb_derive::MessageBox;
4pub use label::{Label, LabelOp};
5pub use memory_registry::MemoryRegistry;
6pub use message::{BytesMessage, Message, MessageBox};
7use once_cell::sync::Lazy;
8pub use options::Options;
9use platform::{look_up, register, EncodedMessage, IoHub, IoMultiplexing, Remote};
10pub use platform::{MemoryRegion, Object};
11use serde::{Deserialize, Serialize};
12use std::{
13 fmt::{Display, Formatter},
14 marker::PhantomData,
15 sync::{
16 mpsc,
17 mpsc::{Receiver, RecvTimeoutError, Sender},
18 Arc, Mutex, RwLock,
19 },
20 thread,
21 time::{Duration, Instant},
22};
23use type_uuid::Bytes;
24use util::EndpointID;
25
26mod bus_controller;
27mod errors;
28mod label;
29mod memory_registry;
30mod message;
31mod options;
32pub mod platform;
33mod util;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Selector {
38 pub label_op: LabelOp,
39 pub mode: SelectorMode,
40 uuid: Bytes,
41 memory_region_count: u16,
42 pub ttl: Duration,
44}
45
46impl Selector {
47 pub fn unicast(label_op: impl Into<LabelOp>) -> Self {
48 Self {
49 label_op: label_op.into(),
50 mode: SelectorMode::Unicast,
51 uuid: [0; 16],
52 memory_region_count: 0,
53 ttl: Duration::ZERO,
54 }
55 }
56
57 pub fn multicast(label_op: impl Into<LabelOp>) -> Self {
58 Self {
59 label_op: label_op.into(),
60 mode: SelectorMode::Multicast,
61 uuid: [0; 16],
62 memory_region_count: 0,
63 ttl: Duration::ZERO,
64 }
65 }
66}
67
68#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
69pub enum SelectorMode {
70 Unicast,
72 Multicast,
74}
75
76pub fn decode<'de, T: Deserialize<'de>>(data: &'de [u8]) -> Result<T, Error> {
77 let (d, _): (T, _) =
78 bincode::serde::borrow_decode_from_slice(data, bincode::config::standard())
79 .map_err(Error::Decode)?;
80 Ok(d)
81}
82
83pub fn encode<T: Serialize>(t: T) -> Result<Vec<u8>, Error> {
84 let data = bincode::serde::encode_to_vec(t, bincode::config::standard())?;
85 Ok(data)
86}
87
88pub fn join<'de, T: MessageBox, R: MessageBox>(
89 options: Options,
90 timeout: Option<Duration>,
91) -> Result<(EndpointSender<T>, EndpointReceiver<R>), JoinError> {
92 let rule = Arc::new(RwLock::new(Rule::join(
93 options,
94 0,
95 Arc::new(IoMultiplexing::new()),
96 timeout,
97 )?));
98
99 Ok((
100 EndpointSender {
101 rule: rule.clone(),
102 _marker: PhantomData,
103 },
104 EndpointReceiver {
105 rule,
106 _maker: PhantomData,
107 },
108 ))
109}
110
111pub struct EndpointSender<T> {
113 rule: Arc<RwLock<Rule>>,
114 _marker: PhantomData<T>,
115}
116
117impl<T> Clone for EndpointSender<T> {
118 fn clone(&self) -> Self {
119 Self {
120 rule: self.rule.clone(),
121 _marker: PhantomData,
122 }
123 }
124}
125
126impl<T: MessageBox> EndpointSender<T> {
127 pub fn send(&self, mut msg: Message<T>) -> Result<(), SendError> {
128 msg.selector.memory_region_count = msg.memory_regions.len() as _;
129 let mut msg = msg.into_encoded();
130
131 loop {
132 let rule = self.rule.read().unwrap();
133 match &*rule {
134 Rule::Client {
135 endpoint_id: _,
136 options: _,
137 remote,
138 io_hub: _,
139 reader_closed: _,
140 im: _,
141 epoch,
142 } => match msg.send(remote) {
143 Err(Error::Disconnect) => {
144 let epoch = *epoch;
145 drop(rule);
146
147 let mut rule = self.rule.write().unwrap();
148 match &mut *rule {
149 Rule::Client {
150 endpoint_id: _,
151 options,
152 remote: _,
153 io_hub,
154 reader_closed,
155 im,
156 epoch: epoch1,
157 } => {
158 if epoch == *epoch1 {
159 let reader_closed = *reader_closed;
160
161 drop(io_hub.take());
163
164 *rule = Rule::join(
165 options.clone(),
166 epoch.overflowing_add(1).0,
167 im.clone(),
168 None,
169 )?;
170
171 if reader_closed {
172 rule.reader_close();
173 }
174 }
175 }
176 Rule::Server { .. } => {}
177 }
178 }
179 Err(_) => unreachable!(),
180 Ok(_) => break Ok(()),
181 },
182 Rule::Server {
183 endpoint_id: _,
184 bus_sender,
185 receiver: _,
186 im,
187 } => {
188 bus_sender.lock().unwrap().send(msg).unwrap();
189 im.wake();
190 break Ok(());
191 }
192 }
193 }
194 }
195}
196
197pub struct EndpointReceiver<R> {
200 rule: Arc<RwLock<Rule>>,
201 _maker: PhantomData<R>,
202}
203
204impl<'de, R: MessageBox> EndpointReceiver<R> {
205 pub fn recv(&mut self, timeout: Option<Duration>) -> Result<Message<R>, RecvError> {
206 loop {
207 let rule = self.rule.read().unwrap();
208 match &*rule {
209 Rule::Client {
210 endpoint_id: _,
211 options,
212 remote,
213 io_hub,
214 reader_closed,
215 im: _,
216 epoch,
217 } => {
218 if !*reader_closed && io_hub.is_none() {
219 let epoch = *epoch;
220 drop(rule);
221
222 let mut rule = self.rule.write().unwrap();
223 match &mut *rule {
224 Rule::Client {
225 endpoint_id: _,
226 options,
227 remote: _,
228 io_hub,
229 reader_closed,
230 im,
231 epoch: epoch1,
232 } => {
233 if epoch == *epoch1 {
234 let reader_closed = *reader_closed;
235
236 drop(io_hub.take());
238
239 *rule = Rule::join(
240 options.clone(),
241 epoch.overflowing_add(1).0,
242 im.clone(),
243 timeout,
244 )?;
245
246 if reader_closed {
247 rule.reader_close();
248 }
249 }
250
251 continue;
252 }
253 Rule::Server { .. } => continue,
254 }
255 }
256
257 let mut io_hub_guard = io_hub.as_ref().expect("reader closed").lock().unwrap();
258
259 match io_hub_guard.recv(timeout, Some(remote)) {
260 Ok(encoded_msg) => {
261 if encoded_msg.selector.label_op.validate(&options.label) {
262 match R::decode(encoded_msg.selector.uuid, encoded_msg.payload_data)
263 {
264 Ok(payload) => {
265 let mut msg = Message::new(encoded_msg.selector, payload);
266 msg.objects = encoded_msg.objects;
267 msg.memory_regions = encoded_msg.memory_regions;
268 break Ok(msg);
269 }
270 Err(Error::TypeUuidNotFound) => {
271 continue;
272 }
273 Err(Error::Decode(err)) => {
274 break Err(RecvError::Decode(err));
275 }
276 Err(_) => unreachable!(),
277 }
278 } else {
279 log::warn!(
280 "Unexpected message label_op: {:?}",
281 encoded_msg.selector.label_op
282 );
283 continue;
284 }
285 }
286 Err(Error::Disconnect) => {
287 let epoch = *epoch;
288 drop(io_hub_guard);
289 drop(rule);
290
291 let mut rule = self.rule.write().unwrap();
292 match &mut *rule {
293 Rule::Client {
294 endpoint_id: _,
295 options,
296 remote: _,
297 io_hub,
298 reader_closed,
299 im,
300 epoch: epoch1,
301 } => {
302 if epoch == *epoch1 {
303 let reader_closed = *reader_closed;
304
305 drop(io_hub.take());
307
308 *rule = Rule::join(
309 options.clone(),
310 epoch.overflowing_add(1).0,
311 im.clone(),
312 timeout,
313 )?;
314
315 if reader_closed {
316 rule.reader_close();
317 }
318 }
319
320 continue;
321 }
322 Rule::Server { .. } => continue,
323 }
324 }
325 Err(Error::Timeout) => {
326 break Err(RecvError::Timeout);
327 }
328 Err(_) => unreachable!(),
329 }
330 }
331 Rule::Server {
332 endpoint_id: _,
333 bus_sender: _,
334 receiver,
335 im: _,
336 } => {
337 let receiver = receiver.as_ref().expect("reader closed").lock().unwrap();
338 break match timeout {
339 Some(timeout) => match receiver.recv_timeout(timeout) {
340 Ok(encoded_msg) => {
341 match R::decode(encoded_msg.selector.uuid, encoded_msg.payload_data)
342 {
343 Ok(payload) => {
344 let mut msg = Message::new(encoded_msg.selector, payload);
345 msg.objects = encoded_msg.objects;
346 msg.memory_regions = encoded_msg.memory_regions;
347 Ok(msg)
348 }
349 Err(Error::TypeUuidNotFound) => {
350 continue;
351 }
352 Err(Error::Decode(err)) => Err(RecvError::Decode(err)),
353 Err(_) => unreachable!(),
354 }
355 }
356 Err(RecvTimeoutError::Timeout) => Err(RecvError::Timeout),
357 Err(_) => unreachable!(),
358 },
359 None => {
360 let encoded_msg = receiver.recv().unwrap();
361 match R::decode(encoded_msg.selector.uuid, encoded_msg.payload_data) {
362 Ok(payload) => {
363 let mut msg = Message::new(encoded_msg.selector, payload);
364 msg.objects = encoded_msg.objects;
365 msg.memory_regions = encoded_msg.memory_regions;
366 Ok(msg)
367 }
368 Err(Error::TypeUuidNotFound) => {
369 continue;
370 }
371 Err(Error::Decode(err)) => Err(RecvError::Decode(err)),
372 Err(_) => unreachable!(),
373 }
374 }
375 };
376 }
377 }
378 }
379 }
380}
381
382impl<R> Drop for EndpointReceiver<R> {
383 fn drop(&mut self) {
384 let mut rule = self.rule.write().unwrap();
385 rule.reader_close();
386 }
387}
388
389enum Rule {
390 Client {
391 #[allow(dead_code)]
392 endpoint_id: EndpointID,
393 options: Options,
394 remote: Remote,
395 io_hub: Option<Mutex<IoHub>>,
396 reader_closed: bool,
397 im: Arc<IoMultiplexing>,
398 epoch: u32,
399 },
400 Server {
401 #[allow(dead_code)]
402 endpoint_id: EndpointID,
403 bus_sender: Mutex<Sender<EncodedMessage>>,
404 receiver: Option<Mutex<Receiver<EncodedMessage>>>,
405 im: Arc<IoMultiplexing>,
406 },
407}
408
409impl Rule {
410 fn join(
411 options: Options,
412 epoch: u32,
413 im: Arc<IoMultiplexing>,
414 timeout: Option<Duration>,
415 ) -> Result<Self, JoinError> {
416 let end = timeout.map(|timeout| Instant::now() + timeout);
417
418 macro_rules! wait {
419 () => {
420 let mut wait = Duration::from_secs(2);
421 if let Some(end) = end {
422 let remain = end.saturating_duration_since(Instant::now());
423 if remain.is_zero() {
424 return Err(JoinError::Timeout);
425 }
426 wait = wait.min(remain);
427 }
428 thread::sleep(wait);
429 };
430 }
431
432 let mut timeout_count = 0;
433 let mut permission_denied_count = 0;
435
436 let rule = loop {
437 let r = look_up(
438 &options.identifier,
439 options.label.clone(),
440 options.token.clone(),
441 im.clone(),
442 );
443
444 match r {
445 Ok((io_hub, remote, endpoint_id)) => {
446 let rule = Rule::Client {
447 endpoint_id,
448 options,
449 remote,
450 io_hub: Some(Mutex::new(io_hub)),
451 reader_closed: false,
452 im,
453 epoch,
454 };
455 break rule;
456 }
457 Err(Error::IdentifierNotInUse) => {
458 if !options.controller_affinity {
459 log::error!("lookup: controller not found");
460 wait!();
461 continue;
462 }
463
464 let r = register(&options.identifier, im.clone());
465
466 match r {
467 Ok((io_hub, bus_sender, endpoint_id)) => {
468 let (sender, receiver) = mpsc::channel::<EncodedMessage>();
469
470 let im = io_hub.io_multiplexing();
471
472 let bus_controller = BusController::new(
473 endpoint_id,
474 options.label,
475 options.token,
476 sender,
477 io_hub,
478 );
479 bus_controller.run();
480
481 let rule = Rule::Server {
482 endpoint_id,
483 bus_sender: Mutex::new(bus_sender),
484 receiver: Some(Mutex::new(receiver)),
485 im,
486 };
487 break rule;
488 }
489 Err(Error::IdentifierInUse) => {}
490 Err(Error::PermissonDenied) => {
491 permission_denied_count += 1;
492 if permission_denied_count > 5 {
493 return Err(JoinError::PermissonDenied);
494 }
495 wait!();
496 }
497 Err(err) => {
498 log::error!("register: {:?}", err);
499 wait!();
500 }
501 }
502 }
503 Err(Error::VersionMismatch(v, _)) => {
504 return Err(JoinError::VersionMismatch(v));
505 }
506 Err(Error::TokenMismatch) => {
507 return Err(JoinError::TokenMismatch);
508 }
509 Err(Error::PermissonDenied) => {
510 permission_denied_count += 1;
511 if permission_denied_count > 5 {
512 return Err(JoinError::PermissonDenied);
513 }
514 wait!();
515 }
516 Err(Error::Timeout) => {
517 timeout_count += 1;
518
519 if timeout_count > 5 {
520 return Err(JoinError::VersionMismatch(Version((0, 0, 0))));
521 }
522
523 wait!();
524 }
525 Err(err) => {
526 log::error!("look_up: {:?}", err);
527 wait!();
528 }
529 }
530 };
531
532 Ok(rule)
533 }
534}
535
536impl Rule {
537 fn reader_close(&mut self) {
538 match self {
539 Rule::Client {
540 io_hub,
541 reader_closed,
542 ..
543 } => {
544 let _ = io_hub.take();
545 *reader_closed = true;
546 }
547 Rule::Server { receiver, .. } => {
548 let _ = receiver.take();
549 }
550 }
551 }
552}
553
554#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
556pub struct Version((u8, u8, u8));
557
558impl Version {
559 fn compatible(&self, rhs: Self) -> bool {
560 if self.major() == 0 && rhs.major() == 0 {
561 self.minor() == rhs.minor()
562 } else {
563 self.major() == rhs.major()
564 }
565 }
566
567 pub fn major(&self) -> u8 {
568 self.0 .0
569 }
570
571 pub fn minor(&self) -> u8 {
572 self.0 .1
573 }
574
575 pub fn patch(&self) -> u8 {
576 self.0 .2
577 }
578}
579
580impl Display for Version {
581 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
582 write!(f, "{}.{}.{}", self.0 .0, self.0 .1, self.0 .2)
583 }
584}
585
586static VERSION: Lazy<Version> = Lazy::new(|| {
587 let v_major = env!("CARGO_PKG_VERSION_MAJOR");
588 let v_minor = env!("CARGO_PKG_VERSION_MINOR");
589 let v_patch = env!("CARGO_PKG_VERSION_PATCH");
590 Version((
591 v_major.parse().unwrap(),
592 v_minor.parse().unwrap(),
593 v_patch.parse().unwrap(),
594 ))
595});
596static VERSION_PRE: Lazy<&'static str> = Lazy::new(|| env!("CARGO_PKG_VERSION_PRE"));
597
598pub fn version() -> Version {
599 *VERSION
600}
601
602pub fn version_pre() -> String {
603 VERSION_PRE.to_owned()
604}