ipmb/
lib.rs

1use bus_controller::BusController;
2pub use errors::{Error, JoinError, RecvError, SendError};
3pub use ipmb_derive::MessageBox;
4pub use label::{Label, LabelOp};
5pub use memory_registry::MemoryRegistry;
6pub use message::{BytesMessage, Message, MessageBox};
7use once_cell::sync::Lazy;
8pub use options::Options;
9use platform::{look_up, register, EncodedMessage, IoHub, IoMultiplexing, Remote};
10pub use platform::{MemoryRegion, Object};
11use serde::{Deserialize, Serialize};
12use std::{
13    fmt::{Display, Formatter},
14    marker::PhantomData,
15    sync::{
16        mpsc,
17        mpsc::{Receiver, RecvTimeoutError, Sender},
18        Arc, Mutex, RwLock,
19    },
20    thread,
21    time::{Duration, Instant},
22};
23use type_uuid::Bytes;
24use util::EndpointID;
25
26mod bus_controller;
27mod errors;
28mod label;
29mod memory_registry;
30mod message;
31mod options;
32pub mod platform;
33mod util;
34
35/// Describe how a messages is routed.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Selector {
38    pub label_op: LabelOp,
39    pub mode: SelectorMode,
40    uuid: Bytes,
41    memory_region_count: u16,
42    /// The time to live when a message cannot be routed to any endpoint.
43    pub ttl: Duration,
44}
45
46impl Selector {
47    pub fn unicast(label_op: impl Into<LabelOp>) -> Self {
48        Self {
49            label_op: label_op.into(),
50            mode: SelectorMode::Unicast,
51            uuid: [0; 16],
52            memory_region_count: 0,
53            ttl: Duration::ZERO,
54        }
55    }
56
57    pub fn multicast(label_op: impl Into<LabelOp>) -> Self {
58        Self {
59            label_op: label_op.into(),
60            mode: SelectorMode::Multicast,
61            uuid: [0; 16],
62            memory_region_count: 0,
63            ttl: Duration::ZERO,
64        }
65    }
66}
67
68#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
69pub enum SelectorMode {
70    /// The message can only be consumed by one endpoint.
71    Unicast,
72    /// The message can be consumed by multiple endpoints.
73    Multicast,
74}
75
76pub fn decode<'de, T: Deserialize<'de>>(data: &'de [u8]) -> Result<T, Error> {
77    let (d, _): (T, _) =
78        bincode::serde::borrow_decode_from_slice(data, bincode::config::standard())
79            .map_err(Error::Decode)?;
80    Ok(d)
81}
82
83pub fn encode<T: Serialize>(t: T) -> Result<Vec<u8>, Error> {
84    let data = bincode::serde::encode_to_vec(t, bincode::config::standard())?;
85    Ok(data)
86}
87
88pub fn join<'de, T: MessageBox, R: MessageBox>(
89    options: Options,
90    timeout: Option<Duration>,
91) -> Result<(EndpointSender<T>, EndpointReceiver<R>), JoinError> {
92    let rule = Arc::new(RwLock::new(Rule::join(
93        options,
94        0,
95        Arc::new(IoMultiplexing::new()),
96        timeout,
97    )?));
98
99    Ok((
100        EndpointSender {
101            rule: rule.clone(),
102            _marker: PhantomData,
103        },
104        EndpointReceiver {
105            rule,
106            _maker: PhantomData,
107        },
108    ))
109}
110
111/// The sending half of endpoint, messages can be sent with [`send`](EndpointSender::send).
112pub struct EndpointSender<T> {
113    rule: Arc<RwLock<Rule>>,
114    _marker: PhantomData<T>,
115}
116
117impl<T> Clone for EndpointSender<T> {
118    fn clone(&self) -> Self {
119        Self {
120            rule: self.rule.clone(),
121            _marker: PhantomData,
122        }
123    }
124}
125
126impl<T: MessageBox> EndpointSender<T> {
127    pub fn send(&self, mut msg: Message<T>) -> Result<(), SendError> {
128        msg.selector.memory_region_count = msg.memory_regions.len() as _;
129        let mut msg = msg.into_encoded();
130
131        loop {
132            let rule = self.rule.read().unwrap();
133            match &*rule {
134                Rule::Client {
135                    endpoint_id: _,
136                    options: _,
137                    remote,
138                    io_hub: _,
139                    reader_closed: _,
140                    im: _,
141                    epoch,
142                } => match msg.send(remote) {
143                    Err(Error::Disconnect) => {
144                        let epoch = *epoch;
145                        drop(rule);
146
147                        let mut rule = self.rule.write().unwrap();
148                        match &mut *rule {
149                            Rule::Client {
150                                endpoint_id: _,
151                                options,
152                                remote: _,
153                                io_hub,
154                                reader_closed,
155                                im,
156                                epoch: epoch1,
157                            } => {
158                                if epoch == *epoch1 {
159                                    let reader_closed = *reader_closed;
160
161                                    // Close reader
162                                    drop(io_hub.take());
163
164                                    *rule = Rule::join(
165                                        options.clone(),
166                                        epoch.overflowing_add(1).0,
167                                        im.clone(),
168                                        None,
169                                    )?;
170
171                                    if reader_closed {
172                                        rule.reader_close();
173                                    }
174                                }
175                            }
176                            Rule::Server { .. } => {}
177                        }
178                    }
179                    Err(_) => unreachable!(),
180                    Ok(_) => break Ok(()),
181                },
182                Rule::Server {
183                    endpoint_id: _,
184                    bus_sender,
185                    receiver: _,
186                    im,
187                } => {
188                    bus_sender.lock().unwrap().send(msg).unwrap();
189                    im.wake();
190                    break Ok(());
191                }
192            }
193        }
194    }
195}
196
197/// The receiving half of endpoint, messages sent to the endpoint can be retrieved using [`recv`](EndpointReceiver::recv), dropping receiver will close underly receving kernel buffer.
198// Don't impl Clone
199pub struct EndpointReceiver<R> {
200    rule: Arc<RwLock<Rule>>,
201    _maker: PhantomData<R>,
202}
203
204impl<'de, R: MessageBox> EndpointReceiver<R> {
205    pub fn recv(&mut self, timeout: Option<Duration>) -> Result<Message<R>, RecvError> {
206        loop {
207            let rule = self.rule.read().unwrap();
208            match &*rule {
209                Rule::Client {
210                    endpoint_id: _,
211                    options,
212                    remote,
213                    io_hub,
214                    reader_closed,
215                    im: _,
216                    epoch,
217                } => {
218                    if !*reader_closed && io_hub.is_none() {
219                        let epoch = *epoch;
220                        drop(rule);
221
222                        let mut rule = self.rule.write().unwrap();
223                        match &mut *rule {
224                            Rule::Client {
225                                endpoint_id: _,
226                                options,
227                                remote: _,
228                                io_hub,
229                                reader_closed,
230                                im,
231                                epoch: epoch1,
232                            } => {
233                                if epoch == *epoch1 {
234                                    let reader_closed = *reader_closed;
235
236                                    // Close reader
237                                    drop(io_hub.take());
238
239                                    *rule = Rule::join(
240                                        options.clone(),
241                                        epoch.overflowing_add(1).0,
242                                        im.clone(),
243                                        timeout,
244                                    )?;
245
246                                    if reader_closed {
247                                        rule.reader_close();
248                                    }
249                                }
250
251                                continue;
252                            }
253                            Rule::Server { .. } => continue,
254                        }
255                    }
256
257                    let mut io_hub_guard = io_hub.as_ref().expect("reader closed").lock().unwrap();
258
259                    match io_hub_guard.recv(timeout, Some(remote)) {
260                        Ok(encoded_msg) => {
261                            if encoded_msg.selector.label_op.validate(&options.label) {
262                                match R::decode(encoded_msg.selector.uuid, encoded_msg.payload_data)
263                                {
264                                    Ok(payload) => {
265                                        let mut msg = Message::new(encoded_msg.selector, payload);
266                                        msg.objects = encoded_msg.objects;
267                                        msg.memory_regions = encoded_msg.memory_regions;
268                                        break Ok(msg);
269                                    }
270                                    Err(Error::TypeUuidNotFound) => {
271                                        continue;
272                                    }
273                                    Err(Error::Decode(err)) => {
274                                        break Err(RecvError::Decode(err));
275                                    }
276                                    Err(_) => unreachable!(),
277                                }
278                            } else {
279                                log::warn!(
280                                    "Unexpected message label_op: {:?}",
281                                    encoded_msg.selector.label_op
282                                );
283                                continue;
284                            }
285                        }
286                        Err(Error::Disconnect) => {
287                            let epoch = *epoch;
288                            drop(io_hub_guard);
289                            drop(rule);
290
291                            let mut rule = self.rule.write().unwrap();
292                            match &mut *rule {
293                                Rule::Client {
294                                    endpoint_id: _,
295                                    options,
296                                    remote: _,
297                                    io_hub,
298                                    reader_closed,
299                                    im,
300                                    epoch: epoch1,
301                                } => {
302                                    if epoch == *epoch1 {
303                                        let reader_closed = *reader_closed;
304
305                                        // Close reader
306                                        drop(io_hub.take());
307
308                                        *rule = Rule::join(
309                                            options.clone(),
310                                            epoch.overflowing_add(1).0,
311                                            im.clone(),
312                                            timeout,
313                                        )?;
314
315                                        if reader_closed {
316                                            rule.reader_close();
317                                        }
318                                    }
319
320                                    continue;
321                                }
322                                Rule::Server { .. } => continue,
323                            }
324                        }
325                        Err(Error::Timeout) => {
326                            break Err(RecvError::Timeout);
327                        }
328                        Err(_) => unreachable!(),
329                    }
330                }
331                Rule::Server {
332                    endpoint_id: _,
333                    bus_sender: _,
334                    receiver,
335                    im: _,
336                } => {
337                    let receiver = receiver.as_ref().expect("reader closed").lock().unwrap();
338                    break match timeout {
339                        Some(timeout) => match receiver.recv_timeout(timeout) {
340                            Ok(encoded_msg) => {
341                                match R::decode(encoded_msg.selector.uuid, encoded_msg.payload_data)
342                                {
343                                    Ok(payload) => {
344                                        let mut msg = Message::new(encoded_msg.selector, payload);
345                                        msg.objects = encoded_msg.objects;
346                                        msg.memory_regions = encoded_msg.memory_regions;
347                                        Ok(msg)
348                                    }
349                                    Err(Error::TypeUuidNotFound) => {
350                                        continue;
351                                    }
352                                    Err(Error::Decode(err)) => Err(RecvError::Decode(err)),
353                                    Err(_) => unreachable!(),
354                                }
355                            }
356                            Err(RecvTimeoutError::Timeout) => Err(RecvError::Timeout),
357                            Err(_) => unreachable!(),
358                        },
359                        None => {
360                            let encoded_msg = receiver.recv().unwrap();
361                            match R::decode(encoded_msg.selector.uuid, encoded_msg.payload_data) {
362                                Ok(payload) => {
363                                    let mut msg = Message::new(encoded_msg.selector, payload);
364                                    msg.objects = encoded_msg.objects;
365                                    msg.memory_regions = encoded_msg.memory_regions;
366                                    Ok(msg)
367                                }
368                                Err(Error::TypeUuidNotFound) => {
369                                    continue;
370                                }
371                                Err(Error::Decode(err)) => Err(RecvError::Decode(err)),
372                                Err(_) => unreachable!(),
373                            }
374                        }
375                    };
376                }
377            }
378        }
379    }
380}
381
382impl<R> Drop for EndpointReceiver<R> {
383    fn drop(&mut self) {
384        let mut rule = self.rule.write().unwrap();
385        rule.reader_close();
386    }
387}
388
389enum Rule {
390    Client {
391        #[allow(dead_code)]
392        endpoint_id: EndpointID,
393        options: Options,
394        remote: Remote,
395        io_hub: Option<Mutex<IoHub>>,
396        reader_closed: bool,
397        im: Arc<IoMultiplexing>,
398        epoch: u32,
399    },
400    Server {
401        #[allow(dead_code)]
402        endpoint_id: EndpointID,
403        bus_sender: Mutex<Sender<EncodedMessage>>,
404        receiver: Option<Mutex<Receiver<EncodedMessage>>>,
405        im: Arc<IoMultiplexing>,
406    },
407}
408
409impl Rule {
410    fn join(
411        options: Options,
412        epoch: u32,
413        im: Arc<IoMultiplexing>,
414        timeout: Option<Duration>,
415    ) -> Result<Self, JoinError> {
416        let end = timeout.map(|timeout| Instant::now() + timeout);
417
418        macro_rules! wait {
419            () => {
420                let mut wait = Duration::from_secs(2);
421                if let Some(end) = end {
422                    let remain = end.saturating_duration_since(Instant::now());
423                    if remain.is_zero() {
424                        return Err(JoinError::Timeout);
425                    }
426                    wait = wait.min(remain);
427                }
428                thread::sleep(wait);
429            };
430        }
431
432        let mut timeout_count = 0;
433        // When the service is in the process of registration, other endpoints may return PermissonDenied
434        let mut permission_denied_count = 0;
435
436        let rule = loop {
437            let r = look_up(
438                &options.identifier,
439                options.label.clone(),
440                options.token.clone(),
441                im.clone(),
442            );
443
444            match r {
445                Ok((io_hub, remote, endpoint_id)) => {
446                    let rule = Rule::Client {
447                        endpoint_id,
448                        options,
449                        remote,
450                        io_hub: Some(Mutex::new(io_hub)),
451                        reader_closed: false,
452                        im,
453                        epoch,
454                    };
455                    break rule;
456                }
457                Err(Error::IdentifierNotInUse) => {
458                    if !options.controller_affinity {
459                        log::error!("lookup: controller not found");
460                        wait!();
461                        continue;
462                    }
463
464                    let r = register(&options.identifier, im.clone());
465
466                    match r {
467                        Ok((io_hub, bus_sender, endpoint_id)) => {
468                            let (sender, receiver) = mpsc::channel::<EncodedMessage>();
469
470                            let im = io_hub.io_multiplexing();
471
472                            let bus_controller = BusController::new(
473                                endpoint_id,
474                                options.label,
475                                options.token,
476                                sender,
477                                io_hub,
478                            );
479                            bus_controller.run();
480
481                            let rule = Rule::Server {
482                                endpoint_id,
483                                bus_sender: Mutex::new(bus_sender),
484                                receiver: Some(Mutex::new(receiver)),
485                                im,
486                            };
487                            break rule;
488                        }
489                        Err(Error::IdentifierInUse) => {}
490                        Err(Error::PermissonDenied) => {
491                            permission_denied_count += 1;
492                            if permission_denied_count > 5 {
493                                return Err(JoinError::PermissonDenied);
494                            }
495                            wait!();
496                        }
497                        Err(err) => {
498                            log::error!("register: {:?}", err);
499                            wait!();
500                        }
501                    }
502                }
503                Err(Error::VersionMismatch(v, _)) => {
504                    return Err(JoinError::VersionMismatch(v));
505                }
506                Err(Error::TokenMismatch) => {
507                    return Err(JoinError::TokenMismatch);
508                }
509                Err(Error::PermissonDenied) => {
510                    permission_denied_count += 1;
511                    if permission_denied_count > 5 {
512                        return Err(JoinError::PermissonDenied);
513                    }
514                    wait!();
515                }
516                Err(Error::Timeout) => {
517                    timeout_count += 1;
518
519                    if timeout_count > 5 {
520                        return Err(JoinError::VersionMismatch(Version((0, 0, 0))));
521                    }
522
523                    wait!();
524                }
525                Err(err) => {
526                    log::error!("look_up: {:?}", err);
527                    wait!();
528                }
529            }
530        };
531
532        Ok(rule)
533    }
534}
535
536impl Rule {
537    fn reader_close(&mut self) {
538        match self {
539            Rule::Client {
540                io_hub,
541                reader_closed,
542                ..
543            } => {
544                let _ = io_hub.take();
545                *reader_closed = true;
546            }
547            Rule::Server { receiver, .. } => {
548                let _ = receiver.take();
549            }
550        }
551    }
552}
553
554// Serialize bug with multiple field
555#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
556pub struct Version((u8, u8, u8));
557
558impl Version {
559    fn compatible(&self, rhs: Self) -> bool {
560        if self.major() == 0 && rhs.major() == 0 {
561            self.minor() == rhs.minor()
562        } else {
563            self.major() == rhs.major()
564        }
565    }
566
567    pub fn major(&self) -> u8 {
568        self.0 .0
569    }
570
571    pub fn minor(&self) -> u8 {
572        self.0 .1
573    }
574
575    pub fn patch(&self) -> u8 {
576        self.0 .2
577    }
578}
579
580impl Display for Version {
581    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
582        write!(f, "{}.{}.{}", self.0 .0, self.0 .1, self.0 .2)
583    }
584}
585
586static VERSION: Lazy<Version> = Lazy::new(|| {
587    let v_major = env!("CARGO_PKG_VERSION_MAJOR");
588    let v_minor = env!("CARGO_PKG_VERSION_MINOR");
589    let v_patch = env!("CARGO_PKG_VERSION_PATCH");
590    Version((
591        v_major.parse().unwrap(),
592        v_minor.parse().unwrap(),
593        v_patch.parse().unwrap(),
594    ))
595});
596static VERSION_PRE: Lazy<&'static str> = Lazy::new(|| env!("CARGO_PKG_VERSION_PRE"));
597
598pub fn version() -> Version {
599    *VERSION
600}
601
602pub fn version_pre() -> String {
603    VERSION_PRE.to_owned()
604}