moto_ipc/
sync.rs

1use alloc::borrow::ToOwned;
2use alloc::boxed::Box;
3use alloc::collections::BTreeMap;
4use alloc::string::String;
5use alloc::vec::Vec;
6use core::any::Any;
7use core::slice;
8
9use moto_sys::ErrorCode;
10use moto_sys::{syscalls::*, url_encode};
11
12// ChannelSize: Small: 4K; Mid: 2M.
13#[derive(Clone, Copy)]
14pub enum ChannelSize {
15    Small,
16    Mid,
17}
18
19impl ChannelSize {
20    pub fn size(&self) -> usize {
21        match self {
22            ChannelSize::Small => SysMem::PAGE_SIZE_SMALL as usize,
23            ChannelSize::Mid => SysMem::PAGE_SIZE_MID as usize,
24        }
25    }
26}
27
28// Rust's borrow checker inferferes with direct memory access to the shared mem
29// while holding references to connections; exposing RawChannel goes around
30// this problem.
31pub struct RawChannel {
32    addr: usize,
33    size: usize,
34}
35
36impl RawChannel {
37    pub fn size(&self) -> usize {
38        self.size
39    }
40
41    pub unsafe fn get_mut<T: Sized>(&self) -> &mut T {
42        assert!(core::mem::size_of::<T>() <= self.size);
43        (self.addr as *mut T).as_mut().unwrap_unchecked()
44    }
45
46    pub unsafe fn get<T: Sized>(&self) -> &T {
47        assert!(core::mem::size_of::<T>() <= self.size);
48        (self.addr as *const T).as_ref().unwrap_unchecked()
49    }
50
51    pub unsafe fn get_at_mut<T: Sized>(
52        &self,
53        buf: &mut [T; 0],
54        size: usize,
55    ) -> Result<&mut [T], ErrorCode> {
56        let start = buf.as_mut_ptr();
57        let start_addr = start as usize;
58        if (start_addr < self.addr)
59            || ((start_addr + core::mem::size_of::<T>() * size) > (self.addr + self.size))
60        {
61            return Err(ErrorCode::InvalidArgument);
62        }
63
64        Ok(core::slice::from_raw_parts_mut(start, size))
65    }
66
67    pub unsafe fn get_at<T: Sized>(&self, buf: &[T; 0], size: usize) -> Result<&[T], ErrorCode> {
68        let start = buf.as_ptr();
69        let start_addr = start as usize;
70        if (start_addr < self.addr)
71            || ((start_addr + core::mem::size_of::<T>() * size) > (self.addr + self.size))
72        {
73            return Err(ErrorCode::InvalidArgument);
74        }
75
76        Ok(core::slice::from_raw_parts(start, size))
77    }
78
79    pub unsafe fn get_bytes(&self, buf: &[u8; 0], size: usize) -> Result<&[u8], ErrorCode> {
80        let start = buf.as_ptr();
81        let start_addr = start as usize;
82        if (start_addr < self.addr) || ((start_addr + size) > (self.addr + self.size)) {
83            return Err(ErrorCode::InvalidArgument);
84        }
85
86        Ok(core::slice::from_raw_parts(start, size))
87    }
88
89    pub unsafe fn get_bytes_mut(
90        &self,
91        buf: &mut [u8; 0],
92        size: usize,
93    ) -> Result<&mut [u8], ErrorCode> {
94        let start = buf.as_mut_ptr();
95        let start_addr = start as usize;
96        if (start_addr < self.addr) || ((start_addr + size) > (self.addr + self.size)) {
97            return Err(ErrorCode::InvalidArgument);
98        }
99
100        Ok(core::slice::from_raw_parts_mut(start, size))
101    }
102
103    pub unsafe fn put_bytes(&self, src: &[u8], dst: &mut [u8; 0]) -> Result<(), ErrorCode> {
104        let start = dst.as_mut_ptr();
105        let start_addr = start as usize;
106        if (start_addr < self.addr) || ((start_addr + src.len()) > (self.addr + self.size)) {
107            return Err(ErrorCode::InvalidArgument);
108        }
109
110        core::intrinsics::copy_nonoverlapping(src.as_ptr(), start, src.len());
111        Ok(())
112    }
113}
114
115#[derive(Debug, PartialEq, Eq)]
116enum ClientConnectionStatus {
117    CONNECTED,
118    ERROR,
119    NONE,
120}
121
122pub struct ClientConnection {
123    status: ClientConnectionStatus,
124    handle: SysHandle,
125    smem_addr: u64,
126    channel_size: ChannelSize,
127}
128
129impl Drop for ClientConnection {
130    fn drop(&mut self) {
131        if self.handle != SysHandle::NONE {
132            SysCtl::put(self.handle).unwrap();
133        }
134
135        if self.smem_addr == 0 {
136            return;
137        }
138        match self.channel_size {
139            ChannelSize::Small => {
140                SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
141            }
142            ChannelSize::Mid => {
143                SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
144            }
145        }
146    }
147}
148
149impl ClientConnection {
150    pub fn new(channel_size: ChannelSize) -> Result<Self, ErrorCode> {
151        let addr = match channel_size {
152            ChannelSize::Small => SysMem::map(
153                SysHandle::SELF,
154                SysMem::F_READABLE | SysMem::F_WRITABLE,
155                u64::MAX,
156                u64::MAX,
157                SysMem::PAGE_SIZE_SMALL,
158                1,
159            )?,
160            ChannelSize::Mid => SysMem::map(
161                SysHandle::SELF,
162                SysMem::F_READABLE | SysMem::F_WRITABLE,
163                u64::MAX,
164                u64::MAX,
165                SysMem::PAGE_SIZE_MID,
166                1,
167            )?,
168        };
169
170        Ok(Self {
171            status: ClientConnectionStatus::NONE,
172            handle: SysHandle::NONE,
173            smem_addr: addr,
174            channel_size,
175        })
176    }
177
178    pub fn connect(&mut self, url: &str) -> Result<(), ErrorCode> {
179        assert_eq!(self.status, ClientConnectionStatus::NONE);
180        assert_eq!(self.handle, SysHandle::NONE);
181
182        let full_url = alloc::format!(
183            "shared:url={};address={};page_type={};page_num=1",
184            url_encode(url),
185            self.smem_addr,
186            match self.channel_size {
187                ChannelSize::Small => "small",
188                ChannelSize::Mid => "mid",
189            }
190        );
191        self.handle = SysCtl::get(SysHandle::SELF, 0, &full_url)?;
192        self.status = ClientConnectionStatus::CONNECTED;
193        Ok(())
194    }
195
196    pub fn disconnect(&mut self) {
197        if self.handle != SysHandle::NONE {
198            SysCtl::put(self.handle).unwrap();
199            self.handle = SysHandle::NONE;
200            self.status = ClientConnectionStatus::NONE;
201        }
202    }
203
204    pub fn connected(&self) -> bool {
205        self.status == ClientConnectionStatus::CONNECTED
206    }
207
208    pub fn data(&self) -> &[u8] {
209        unsafe {
210            slice::from_raw_parts(
211                self.smem_addr as usize as *const u8,
212                self.channel_size.size(),
213            )
214        }
215    }
216
217    pub fn data_mut(&mut self) -> &mut [u8] {
218        unsafe {
219            slice::from_raw_parts_mut(self.smem_addr as usize as *mut u8, self.channel_size.size())
220        }
221    }
222
223    pub fn do_rpc(
224        &mut self,
225        timeout: Option<moto_sys::time::Instant>,
226    ) -> Result<(), ErrorCode> {
227        if self.connected() {
228            core::sync::atomic::fence(core::sync::atomic::Ordering::Release);
229            let mut handles = [self.handle];
230            let res = SysCpu::wait(&mut handles, self.handle, SysHandle::NONE, timeout);
231
232            if res.is_ok() {
233                core::sync::atomic::fence(core::sync::atomic::Ordering::Acquire);
234            } else if let Err(ErrorCode::BadHandle) = res {
235                assert_eq!(handles[0], self.handle);
236                self.status = ClientConnectionStatus::ERROR;
237            }
238            res
239        } else {
240            Err(ErrorCode::InvalidArgument)
241        }
242    }
243
244    pub fn req<T: Sized>(&mut self) -> &mut T {
245        assert!(core::mem::size_of::<T>() <= self.channel_size.size());
246        unsafe {
247            (self.data_mut().as_mut_ptr() as *mut T)
248                .as_mut()
249                .unwrap_unchecked()
250        }
251    }
252
253    pub fn resp<T: Sized>(&self) -> &T {
254        assert!(core::mem::size_of::<T>() <= self.channel_size.size());
255        unsafe {
256            (self.data().as_ptr() as *const T)
257                .as_ref()
258                .unwrap_unchecked()
259        }
260    }
261
262    pub fn raw_channel(&self) -> RawChannel {
263        RawChannel {
264            addr: self.smem_addr as usize,
265            size: self.channel_size.size(),
266        }
267    }
268}
269
270#[derive(Eq, PartialEq, Debug)]
271enum LocalServerConnectionStatus {
272    LISTENING,
273    CONNECTED,
274    NONE,
275}
276
277pub struct LocalServerConnection {
278    status: LocalServerConnectionStatus,
279    handle: SysHandle,
280    smem_addr: u64,
281    channel_size: ChannelSize,
282    extension: Box<dyn Any>,
283}
284
285impl Drop for LocalServerConnection {
286    fn drop(&mut self) {
287        if self.handle != SysHandle::NONE {
288            SysCtl::put(self.handle).unwrap();
289        }
290
291        if self.smem_addr == 0 {
292            return;
293        }
294        match self.channel_size {
295            ChannelSize::Small => {
296                SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
297            }
298            ChannelSize::Mid => {
299                SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
300            }
301        }
302    }
303}
304
305impl LocalServerConnection {
306    pub fn new(channel_size: ChannelSize) -> Result<Self, ErrorCode> {
307        let addr = match channel_size {
308            ChannelSize::Small => SysMem::map(
309                SysHandle::SELF,
310                0, // Not mapped to a physical frame.
311                u64::MAX,
312                u64::MAX,
313                SysMem::PAGE_SIZE_SMALL,
314                1,
315            )?,
316            ChannelSize::Mid => SysMem::map(
317                SysHandle::SELF,
318                0, // Not mapped to a physical frame.
319                u64::MAX,
320                u64::MAX,
321                SysMem::PAGE_SIZE_MID,
322                1,
323            )?,
324        };
325
326        Ok(Self {
327            status: LocalServerConnectionStatus::NONE,
328            handle: SysHandle::NONE,
329            smem_addr: addr,
330            channel_size,
331            extension: Box::new(()),
332        })
333    }
334
335    fn start_listening(&mut self, url: &str) -> Result<(), ErrorCode> {
336        assert_eq!(self.status, LocalServerConnectionStatus::NONE);
337        assert_eq!(self.handle, SysHandle::NONE);
338
339        let full_url = alloc::format!(
340            "shared:url={};address={};page_type={};page_num=1",
341            url_encode(url),
342            self.smem_addr,
343            match self.channel_size {
344                ChannelSize::Small => "small",
345                ChannelSize::Mid => "mid",
346            }
347        );
348        self.handle = SysCtl::create(SysHandle::SELF, 0, &full_url)?;
349        self.status = LocalServerConnectionStatus::LISTENING;
350
351        Ok(())
352    }
353
354    pub fn channel_size(&self) -> usize {
355        match self.channel_size {
356            ChannelSize::Small => SysMem::PAGE_SIZE_SMALL as usize,
357            ChannelSize::Mid => SysMem::PAGE_SIZE_MID as usize,
358        }
359    }
360
361    pub fn data(&self) -> &[u8] {
362        unsafe {
363            slice::from_raw_parts(
364                self.smem_addr as usize as *const u8,
365                self.channel_size.size(),
366            )
367        }
368    }
369
370    pub fn data_mut(&mut self) -> &mut [u8] {
371        unsafe {
372            slice::from_raw_parts_mut(self.smem_addr as usize as *mut u8, self.channel_size.size())
373        }
374    }
375
376    pub fn raw_channel(&self) -> RawChannel {
377        RawChannel {
378            addr: self.smem_addr as usize,
379            size: self.channel_size.size(),
380        }
381    }
382
383    pub fn extension<'a, T: 'static>(&'a self) -> Option<&'a T> {
384        self.extension.downcast_ref::<T>()
385    }
386
387    pub fn extension_mut<'a, T: 'static>(&'a mut self) -> Option<&'a mut T> {
388        self.extension.downcast_mut::<T>()
389    }
390
391    pub fn set_extension<T: Any>(&mut self, ext: Box<T>) {
392        self.extension = ext;
393    }
394
395    pub fn connected(&self) -> bool {
396        self.status == LocalServerConnectionStatus::CONNECTED
397    }
398
399    pub fn disconnect(&mut self) {
400        match self.status {
401            LocalServerConnectionStatus::LISTENING | LocalServerConnectionStatus::CONNECTED => {
402                SysCtl::put(self.handle).unwrap();
403                self.handle = SysHandle::NONE;
404                self.status = LocalServerConnectionStatus::NONE;
405            }
406            LocalServerConnectionStatus::NONE => {}
407        }
408    }
409
410    pub fn finish_rpc(&mut self) -> Result<(), ErrorCode> {
411        core::sync::atomic::fence(core::sync::atomic::Ordering::Release);
412        if self.connected() {
413            SysCpu::wake(self.handle).map_err(|err| {
414                assert_eq!(err, ErrorCode::BadHandle);
415                self.disconnect();
416                err
417            })
418        } else {
419            Err(ErrorCode::InvalidArgument)
420        }
421    }
422
423    pub fn req<T: Sized>(&self) -> &T {
424        assert!(core::mem::size_of::<T>() <= self.channel_size.size());
425        unsafe {
426            (self.data().as_ptr() as *const T)
427                .as_ref()
428                .unwrap_unchecked()
429        }
430    }
431
432    pub fn resp<T: Sized>(&mut self) -> &mut T {
433        assert!(core::mem::size_of::<T>() <= self.channel_size.size());
434        unsafe {
435            (self.data_mut().as_mut_ptr() as *mut T)
436                .as_mut()
437                .unwrap_unchecked()
438        }
439    }
440
441    pub fn handle(&self) -> SysHandle {
442        self.handle
443    }
444}
445
446// LocalServer: not Send/Sync.
447pub struct LocalServer {
448    max_connections: u64,
449    max_listeners: u64,
450    channel_size: ChannelSize,
451
452    url: String,
453
454    listeners: BTreeMap<SysHandle, LocalServerConnection>,
455    active_conns: BTreeMap<SysHandle, LocalServerConnection>,
456}
457
458impl LocalServer {
459    pub fn new(
460        url: &str,
461        channel_size: ChannelSize,
462        max_connections: u64,
463        max_listeners: u64,
464    ) -> Result<LocalServer, ErrorCode> {
465        assert!(max_connections >= max_listeners);
466
467        let mut self_ = Self {
468            max_connections,
469            max_listeners,
470            channel_size,
471            url: url.to_owned(),
472            listeners: BTreeMap::new(),
473            active_conns: BTreeMap::new(),
474        };
475
476        for _i in 0..self_.max_listeners {
477            self_.add_listener()?;
478        }
479
480        Ok(self_)
481    }
482
483    fn add_listener(&mut self) -> Result<(), ErrorCode> {
484        let mut listener = LocalServerConnection::new(self.channel_size)?;
485        listener.start_listening(self.url.as_str())?;
486        self.listeners.insert(listener.handle.clone(), listener);
487        Ok(())
488    }
489
490    pub fn wait(
491        &mut self,
492        swap_target: SysHandle,
493        extra_waiters: &[SysHandle],
494    ) -> Result<Vec<SysHandle>, Vec<SysHandle>> {
495        while self.listeners.len() < (self.max_listeners as usize)
496            && (self.listeners.len() + self.active_conns.len() < (self.max_connections as usize))
497        {
498            self.add_listener().unwrap();
499        }
500
501        let mut waiters = Vec::with_capacity(
502            self.listeners.len() + self.active_conns.len() + extra_waiters.len(),
503        );
504
505        for k in self.listeners.keys() {
506            waiters.push(k.clone());
507        }
508
509        let mut bad_connections = Vec::new();
510        for k in self.active_conns.keys() {
511            let conn = self.active_conns.get(k).unwrap();
512            if !conn.connected() {
513                bad_connections.push(k.clone());
514            } else {
515                waiters.push(k.clone());
516            }
517        }
518        for k in bad_connections {
519            self.active_conns.remove(&k);
520        }
521
522        for k in extra_waiters {
523            waiters.push(k.clone());
524        }
525
526        core::sync::atomic::fence(core::sync::atomic::Ordering::Release);
527        SysCpu::wait(&mut waiters[..], swap_target, SysHandle::NONE, None).map_err(|err| {
528            assert_eq!(err, ErrorCode::BadHandle);
529            let mut bad_extras = Vec::new();
530            for waiter in &waiters {
531                if *waiter == SysHandle::NONE {
532                    continue;
533                }
534                if let Some(mut conn) = self.active_conns.remove(&waiter) {
535                    assert!(conn.connected());
536                    conn.disconnect();
537                } else if let Some(mut listener) = self.listeners.remove(&waiter) {
538                    // A remote process can connect to the listener and then drop.
539                    listener.disconnect();
540                } else {
541                    bad_extras.push(*waiter);
542                }
543            }
544            bad_extras
545        })?;
546
547        let mut wakers = Vec::with_capacity(waiters.len());
548        for h in &waiters {
549            if *h == SysHandle::NONE {
550                break;
551            }
552            let handle = h.clone();
553            if let Some(mut conn) = self.listeners.remove(&handle) {
554                assert_eq!(conn.status, LocalServerConnectionStatus::LISTENING);
555                conn.status = LocalServerConnectionStatus::CONNECTED;
556                let prev = self.active_conns.insert(handle.clone(), conn);
557                assert!(prev.is_none());
558            }
559            wakers.push(handle);
560        }
561
562        core::sync::atomic::fence(core::sync::atomic::Ordering::Acquire);
563        Ok(wakers)
564    }
565
566    pub fn get_connection(&mut self, handle: SysHandle) -> Option<&mut LocalServerConnection> {
567        self.active_conns.get_mut(&handle)
568    }
569}