1use alloc::borrow::ToOwned;
2use alloc::boxed::Box;
3use alloc::collections::BTreeMap;
4use alloc::string::String;
5use alloc::vec::Vec;
6use core::any::Any;
7use core::slice;
8
9use moto_sys::ErrorCode;
10use moto_sys::{syscalls::*, url_encode};
11
12#[derive(Clone, Copy)]
14pub enum ChannelSize {
15 Small,
16 Mid,
17}
18
19impl ChannelSize {
20 pub fn size(&self) -> usize {
21 match self {
22 ChannelSize::Small => SysMem::PAGE_SIZE_SMALL as usize,
23 ChannelSize::Mid => SysMem::PAGE_SIZE_MID as usize,
24 }
25 }
26}
27
28pub struct RawChannel {
32 addr: usize,
33 size: usize,
34}
35
36impl RawChannel {
37 pub fn size(&self) -> usize {
38 self.size
39 }
40
41 pub unsafe fn get_mut<T: Sized>(&self) -> &mut T {
42 assert!(core::mem::size_of::<T>() <= self.size);
43 (self.addr as *mut T).as_mut().unwrap_unchecked()
44 }
45
46 pub unsafe fn get<T: Sized>(&self) -> &T {
47 assert!(core::mem::size_of::<T>() <= self.size);
48 (self.addr as *const T).as_ref().unwrap_unchecked()
49 }
50
51 pub unsafe fn get_at_mut<T: Sized>(
52 &self,
53 buf: &mut [T; 0],
54 size: usize,
55 ) -> Result<&mut [T], ErrorCode> {
56 let start = buf.as_mut_ptr();
57 let start_addr = start as usize;
58 if (start_addr < self.addr)
59 || ((start_addr + core::mem::size_of::<T>() * size) > (self.addr + self.size))
60 {
61 return Err(ErrorCode::InvalidArgument);
62 }
63
64 Ok(core::slice::from_raw_parts_mut(start, size))
65 }
66
67 pub unsafe fn get_at<T: Sized>(&self, buf: &[T; 0], size: usize) -> Result<&[T], ErrorCode> {
68 let start = buf.as_ptr();
69 let start_addr = start as usize;
70 if (start_addr < self.addr)
71 || ((start_addr + core::mem::size_of::<T>() * size) > (self.addr + self.size))
72 {
73 return Err(ErrorCode::InvalidArgument);
74 }
75
76 Ok(core::slice::from_raw_parts(start, size))
77 }
78
79 pub unsafe fn get_bytes(&self, buf: &[u8; 0], size: usize) -> Result<&[u8], ErrorCode> {
80 let start = buf.as_ptr();
81 let start_addr = start as usize;
82 if (start_addr < self.addr) || ((start_addr + size) > (self.addr + self.size)) {
83 return Err(ErrorCode::InvalidArgument);
84 }
85
86 Ok(core::slice::from_raw_parts(start, size))
87 }
88
89 pub unsafe fn get_bytes_mut(
90 &self,
91 buf: &mut [u8; 0],
92 size: usize,
93 ) -> Result<&mut [u8], ErrorCode> {
94 let start = buf.as_mut_ptr();
95 let start_addr = start as usize;
96 if (start_addr < self.addr) || ((start_addr + size) > (self.addr + self.size)) {
97 return Err(ErrorCode::InvalidArgument);
98 }
99
100 Ok(core::slice::from_raw_parts_mut(start, size))
101 }
102
103 pub unsafe fn put_bytes(&self, src: &[u8], dst: &mut [u8; 0]) -> Result<(), ErrorCode> {
104 let start = dst.as_mut_ptr();
105 let start_addr = start as usize;
106 if (start_addr < self.addr) || ((start_addr + src.len()) > (self.addr + self.size)) {
107 return Err(ErrorCode::InvalidArgument);
108 }
109
110 core::intrinsics::copy_nonoverlapping(src.as_ptr(), start, src.len());
111 Ok(())
112 }
113}
114
115#[derive(Debug, PartialEq, Eq)]
116enum ClientConnectionStatus {
117 CONNECTED,
118 ERROR,
119 NONE,
120}
121
122pub struct ClientConnection {
123 status: ClientConnectionStatus,
124 handle: SysHandle,
125 smem_addr: u64,
126 channel_size: ChannelSize,
127}
128
129impl Drop for ClientConnection {
130 fn drop(&mut self) {
131 if self.handle != SysHandle::NONE {
132 SysCtl::put(self.handle).unwrap();
133 }
134
135 if self.smem_addr == 0 {
136 return;
137 }
138 match self.channel_size {
139 ChannelSize::Small => {
140 SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
141 }
142 ChannelSize::Mid => {
143 SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
144 }
145 }
146 }
147}
148
149impl ClientConnection {
150 pub fn new(channel_size: ChannelSize) -> Result<Self, ErrorCode> {
151 let addr = match channel_size {
152 ChannelSize::Small => SysMem::map(
153 SysHandle::SELF,
154 SysMem::F_READABLE | SysMem::F_WRITABLE,
155 u64::MAX,
156 u64::MAX,
157 SysMem::PAGE_SIZE_SMALL,
158 1,
159 )?,
160 ChannelSize::Mid => SysMem::map(
161 SysHandle::SELF,
162 SysMem::F_READABLE | SysMem::F_WRITABLE,
163 u64::MAX,
164 u64::MAX,
165 SysMem::PAGE_SIZE_MID,
166 1,
167 )?,
168 };
169
170 Ok(Self {
171 status: ClientConnectionStatus::NONE,
172 handle: SysHandle::NONE,
173 smem_addr: addr,
174 channel_size,
175 })
176 }
177
178 pub fn connect(&mut self, url: &str) -> Result<(), ErrorCode> {
179 assert_eq!(self.status, ClientConnectionStatus::NONE);
180 assert_eq!(self.handle, SysHandle::NONE);
181
182 let full_url = alloc::format!(
183 "shared:url={};address={};page_type={};page_num=1",
184 url_encode(url),
185 self.smem_addr,
186 match self.channel_size {
187 ChannelSize::Small => "small",
188 ChannelSize::Mid => "mid",
189 }
190 );
191 self.handle = SysCtl::get(SysHandle::SELF, 0, &full_url)?;
192 self.status = ClientConnectionStatus::CONNECTED;
193 Ok(())
194 }
195
196 pub fn disconnect(&mut self) {
197 if self.handle != SysHandle::NONE {
198 SysCtl::put(self.handle).unwrap();
199 self.handle = SysHandle::NONE;
200 self.status = ClientConnectionStatus::NONE;
201 }
202 }
203
204 pub fn connected(&self) -> bool {
205 self.status == ClientConnectionStatus::CONNECTED
206 }
207
208 pub fn data(&self) -> &[u8] {
209 unsafe {
210 slice::from_raw_parts(
211 self.smem_addr as usize as *const u8,
212 self.channel_size.size(),
213 )
214 }
215 }
216
217 pub fn data_mut(&mut self) -> &mut [u8] {
218 unsafe {
219 slice::from_raw_parts_mut(self.smem_addr as usize as *mut u8, self.channel_size.size())
220 }
221 }
222
223 pub fn do_rpc(
224 &mut self,
225 timeout: Option<moto_sys::time::Instant>,
226 ) -> Result<(), ErrorCode> {
227 if self.connected() {
228 core::sync::atomic::fence(core::sync::atomic::Ordering::Release);
229 let mut handles = [self.handle];
230 let res = SysCpu::wait(&mut handles, self.handle, SysHandle::NONE, timeout);
231
232 if res.is_ok() {
233 core::sync::atomic::fence(core::sync::atomic::Ordering::Acquire);
234 } else if let Err(ErrorCode::BadHandle) = res {
235 assert_eq!(handles[0], self.handle);
236 self.status = ClientConnectionStatus::ERROR;
237 }
238 res
239 } else {
240 Err(ErrorCode::InvalidArgument)
241 }
242 }
243
244 pub fn req<T: Sized>(&mut self) -> &mut T {
245 assert!(core::mem::size_of::<T>() <= self.channel_size.size());
246 unsafe {
247 (self.data_mut().as_mut_ptr() as *mut T)
248 .as_mut()
249 .unwrap_unchecked()
250 }
251 }
252
253 pub fn resp<T: Sized>(&self) -> &T {
254 assert!(core::mem::size_of::<T>() <= self.channel_size.size());
255 unsafe {
256 (self.data().as_ptr() as *const T)
257 .as_ref()
258 .unwrap_unchecked()
259 }
260 }
261
262 pub fn raw_channel(&self) -> RawChannel {
263 RawChannel {
264 addr: self.smem_addr as usize,
265 size: self.channel_size.size(),
266 }
267 }
268}
269
270#[derive(Eq, PartialEq, Debug)]
271enum LocalServerConnectionStatus {
272 LISTENING,
273 CONNECTED,
274 NONE,
275}
276
277pub struct LocalServerConnection {
278 status: LocalServerConnectionStatus,
279 handle: SysHandle,
280 smem_addr: u64,
281 channel_size: ChannelSize,
282 extension: Box<dyn Any>,
283}
284
285impl Drop for LocalServerConnection {
286 fn drop(&mut self) {
287 if self.handle != SysHandle::NONE {
288 SysCtl::put(self.handle).unwrap();
289 }
290
291 if self.smem_addr == 0 {
292 return;
293 }
294 match self.channel_size {
295 ChannelSize::Small => {
296 SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
297 }
298 ChannelSize::Mid => {
299 SysMem::unmap(SysHandle::SELF, 0, u64::MAX, self.smem_addr).unwrap();
300 }
301 }
302 }
303}
304
305impl LocalServerConnection {
306 pub fn new(channel_size: ChannelSize) -> Result<Self, ErrorCode> {
307 let addr = match channel_size {
308 ChannelSize::Small => SysMem::map(
309 SysHandle::SELF,
310 0, u64::MAX,
312 u64::MAX,
313 SysMem::PAGE_SIZE_SMALL,
314 1,
315 )?,
316 ChannelSize::Mid => SysMem::map(
317 SysHandle::SELF,
318 0, u64::MAX,
320 u64::MAX,
321 SysMem::PAGE_SIZE_MID,
322 1,
323 )?,
324 };
325
326 Ok(Self {
327 status: LocalServerConnectionStatus::NONE,
328 handle: SysHandle::NONE,
329 smem_addr: addr,
330 channel_size,
331 extension: Box::new(()),
332 })
333 }
334
335 fn start_listening(&mut self, url: &str) -> Result<(), ErrorCode> {
336 assert_eq!(self.status, LocalServerConnectionStatus::NONE);
337 assert_eq!(self.handle, SysHandle::NONE);
338
339 let full_url = alloc::format!(
340 "shared:url={};address={};page_type={};page_num=1",
341 url_encode(url),
342 self.smem_addr,
343 match self.channel_size {
344 ChannelSize::Small => "small",
345 ChannelSize::Mid => "mid",
346 }
347 );
348 self.handle = SysCtl::create(SysHandle::SELF, 0, &full_url)?;
349 self.status = LocalServerConnectionStatus::LISTENING;
350
351 Ok(())
352 }
353
354 pub fn channel_size(&self) -> usize {
355 match self.channel_size {
356 ChannelSize::Small => SysMem::PAGE_SIZE_SMALL as usize,
357 ChannelSize::Mid => SysMem::PAGE_SIZE_MID as usize,
358 }
359 }
360
361 pub fn data(&self) -> &[u8] {
362 unsafe {
363 slice::from_raw_parts(
364 self.smem_addr as usize as *const u8,
365 self.channel_size.size(),
366 )
367 }
368 }
369
370 pub fn data_mut(&mut self) -> &mut [u8] {
371 unsafe {
372 slice::from_raw_parts_mut(self.smem_addr as usize as *mut u8, self.channel_size.size())
373 }
374 }
375
376 pub fn raw_channel(&self) -> RawChannel {
377 RawChannel {
378 addr: self.smem_addr as usize,
379 size: self.channel_size.size(),
380 }
381 }
382
383 pub fn extension<'a, T: 'static>(&'a self) -> Option<&'a T> {
384 self.extension.downcast_ref::<T>()
385 }
386
387 pub fn extension_mut<'a, T: 'static>(&'a mut self) -> Option<&'a mut T> {
388 self.extension.downcast_mut::<T>()
389 }
390
391 pub fn set_extension<T: Any>(&mut self, ext: Box<T>) {
392 self.extension = ext;
393 }
394
395 pub fn connected(&self) -> bool {
396 self.status == LocalServerConnectionStatus::CONNECTED
397 }
398
399 pub fn disconnect(&mut self) {
400 match self.status {
401 LocalServerConnectionStatus::LISTENING | LocalServerConnectionStatus::CONNECTED => {
402 SysCtl::put(self.handle).unwrap();
403 self.handle = SysHandle::NONE;
404 self.status = LocalServerConnectionStatus::NONE;
405 }
406 LocalServerConnectionStatus::NONE => {}
407 }
408 }
409
410 pub fn finish_rpc(&mut self) -> Result<(), ErrorCode> {
411 core::sync::atomic::fence(core::sync::atomic::Ordering::Release);
412 if self.connected() {
413 SysCpu::wake(self.handle).map_err(|err| {
414 assert_eq!(err, ErrorCode::BadHandle);
415 self.disconnect();
416 err
417 })
418 } else {
419 Err(ErrorCode::InvalidArgument)
420 }
421 }
422
423 pub fn req<T: Sized>(&self) -> &T {
424 assert!(core::mem::size_of::<T>() <= self.channel_size.size());
425 unsafe {
426 (self.data().as_ptr() as *const T)
427 .as_ref()
428 .unwrap_unchecked()
429 }
430 }
431
432 pub fn resp<T: Sized>(&mut self) -> &mut T {
433 assert!(core::mem::size_of::<T>() <= self.channel_size.size());
434 unsafe {
435 (self.data_mut().as_mut_ptr() as *mut T)
436 .as_mut()
437 .unwrap_unchecked()
438 }
439 }
440
441 pub fn handle(&self) -> SysHandle {
442 self.handle
443 }
444}
445
446pub struct LocalServer {
448 max_connections: u64,
449 max_listeners: u64,
450 channel_size: ChannelSize,
451
452 url: String,
453
454 listeners: BTreeMap<SysHandle, LocalServerConnection>,
455 active_conns: BTreeMap<SysHandle, LocalServerConnection>,
456}
457
458impl LocalServer {
459 pub fn new(
460 url: &str,
461 channel_size: ChannelSize,
462 max_connections: u64,
463 max_listeners: u64,
464 ) -> Result<LocalServer, ErrorCode> {
465 assert!(max_connections >= max_listeners);
466
467 let mut self_ = Self {
468 max_connections,
469 max_listeners,
470 channel_size,
471 url: url.to_owned(),
472 listeners: BTreeMap::new(),
473 active_conns: BTreeMap::new(),
474 };
475
476 for _i in 0..self_.max_listeners {
477 self_.add_listener()?;
478 }
479
480 Ok(self_)
481 }
482
483 fn add_listener(&mut self) -> Result<(), ErrorCode> {
484 let mut listener = LocalServerConnection::new(self.channel_size)?;
485 listener.start_listening(self.url.as_str())?;
486 self.listeners.insert(listener.handle.clone(), listener);
487 Ok(())
488 }
489
490 pub fn wait(
491 &mut self,
492 swap_target: SysHandle,
493 extra_waiters: &[SysHandle],
494 ) -> Result<Vec<SysHandle>, Vec<SysHandle>> {
495 while self.listeners.len() < (self.max_listeners as usize)
496 && (self.listeners.len() + self.active_conns.len() < (self.max_connections as usize))
497 {
498 self.add_listener().unwrap();
499 }
500
501 let mut waiters = Vec::with_capacity(
502 self.listeners.len() + self.active_conns.len() + extra_waiters.len(),
503 );
504
505 for k in self.listeners.keys() {
506 waiters.push(k.clone());
507 }
508
509 let mut bad_connections = Vec::new();
510 for k in self.active_conns.keys() {
511 let conn = self.active_conns.get(k).unwrap();
512 if !conn.connected() {
513 bad_connections.push(k.clone());
514 } else {
515 waiters.push(k.clone());
516 }
517 }
518 for k in bad_connections {
519 self.active_conns.remove(&k);
520 }
521
522 for k in extra_waiters {
523 waiters.push(k.clone());
524 }
525
526 core::sync::atomic::fence(core::sync::atomic::Ordering::Release);
527 SysCpu::wait(&mut waiters[..], swap_target, SysHandle::NONE, None).map_err(|err| {
528 assert_eq!(err, ErrorCode::BadHandle);
529 let mut bad_extras = Vec::new();
530 for waiter in &waiters {
531 if *waiter == SysHandle::NONE {
532 continue;
533 }
534 if let Some(mut conn) = self.active_conns.remove(&waiter) {
535 assert!(conn.connected());
536 conn.disconnect();
537 } else if let Some(mut listener) = self.listeners.remove(&waiter) {
538 listener.disconnect();
540 } else {
541 bad_extras.push(*waiter);
542 }
543 }
544 bad_extras
545 })?;
546
547 let mut wakers = Vec::with_capacity(waiters.len());
548 for h in &waiters {
549 if *h == SysHandle::NONE {
550 break;
551 }
552 let handle = h.clone();
553 if let Some(mut conn) = self.listeners.remove(&handle) {
554 assert_eq!(conn.status, LocalServerConnectionStatus::LISTENING);
555 conn.status = LocalServerConnectionStatus::CONNECTED;
556 let prev = self.active_conns.insert(handle.clone(), conn);
557 assert!(prev.is_none());
558 }
559 wakers.push(handle);
560 }
561
562 core::sync::atomic::fence(core::sync::atomic::Ordering::Acquire);
563 Ok(wakers)
564 }
565
566 pub fn get_connection(&mut self, handle: SysHandle) -> Option<&mut LocalServerConnection> {
567 self.active_conns.get_mut(&handle)
568 }
569}