mfio_netfs/net/server.rs
1use core::cell::RefCell;
2use core::num::NonZeroU16;
3use core::pin::Pin;
4use std::collections::BTreeMap;
5use std::net::SocketAddr;
6use std::path::Path;
7
8use super::{FsRequest, FsResponse, HeaderRouter, Request, Response};
9
10use async_mutex::Mutex as AsyncMutex;
11use cglue::result::IntError;
12use log::*;
13use mfio::backend::IoBackendExt;
14use mfio::error::Error;
15use mfio::io::{NoPos, OwnedPacket, PacketIoExt, PacketView, PacketVtblRef, Read};
16use mfio::stdeq::Seekable;
17use mfio::tarc::BaseArc;
18use mfio_rt::{
19 native::{NativeRtDir, ReadDir},
20 DirHandle, Fs, NativeFile, NativeRt, TcpListenerHandle,
21};
22use parking_lot::Mutex;
23use slab::Slab;
24use tracing::instrument::Instrument;
25
26use debug_ignore::DebugIgnore;
27use futures::{
28 future::FutureExt,
29 stream::{Fuse, FusedStream, FuturesUnordered, StreamExt},
30};
31use mfio_rt::{
32 native::{NativeTcpListener, NativeTcpStream},
33 Tcp,
34};
35
36#[derive(Debug)]
37enum Operation {
38 Read {
39 file_id: u32,
40 packet_id: u32,
41 pos: u64,
42 len: u64,
43 },
44 Write {
45 file_id: u32,
46 packet_id: u32,
47 pos: u64,
48 buf: DebugIgnore<Vec<u8>>,
49 },
50 FileClose {
51 file_id: u32,
52 },
53 ReadDir {
54 stream_id: u16,
55 count: u16,
56 },
57 Fs {
58 req_id: u32,
59 dir_id: u16,
60 req: FsRequest,
61 },
62}
63
64#[repr(C)]
65struct ReadPacket {
66 hdr: Packet<Write>,
67 len: u64,
68 shards: Mutex<BTreeMap<u64, BaseArc<Packet<Read>>>>,
69}
70
71impl ReadPacket {
72 pub fn new(capacity: u64) -> Self {
73 unsafe extern "C" fn len(pkt: &Packet<Write>) -> u64 {
74 unsafe {
75 let this = &*(pkt as *const Packet<Write> as *const ReadPacket);
76 this.len
77 }
78 }
79
80 unsafe extern "C" fn get_mut(
81 _: &mut ManuallyDrop<BoundPacketView<Write>>,
82 _: usize,
83 _: &mut MaybeUninit<WritePacketObj>,
84 ) -> bool {
85 false
86 }
87
88 unsafe extern "C" fn transfer_data(obj: &mut PacketView<'_, Write>, src: *const ()) {
89 let this = &*(obj.pkt() as *const Packet<Write> as *const ReadPacket);
90 let len = obj.len();
91 let idx = obj.start();
92 let buf = src as *const u8;
93 let buf = core::slice::from_raw_parts(buf, len as usize);
94 let pkt = Packet::<Read>::copy_from_slice(buf);
95 this.shards.lock().insert(idx, pkt);
96 }
97
98 Self {
99 hdr: unsafe {
100 Packet::new_hdr(PacketVtblRef {
101 vtbl: &Write {
102 len,
103 get_mut,
104 transfer_data,
105 },
106 })
107 },
108 len: capacity,
109 shards: Default::default(),
110 }
111 }
112}
113
114impl AsRef<Packet<Write>> for ReadPacket {
115 fn as_ref(&self) -> &Packet<Write> {
116 &self.hdr
117 }
118}
119
120use core::mem::{ManuallyDrop, MaybeUninit};
121use mfio::io::{BoundPacketView, Packet, Write, WritePacketObj};
122
123async fn run_server(stream: NativeTcpStream, fs: &NativeRt) {
124 let stream_raw = &stream;
125 let mut file_handles: Slab<BaseArc<Seekable<NativeFile, u64>>> = Default::default();
126 let file_handles = RefCell::new(&mut file_handles);
127 let mut dir_handles: Slab<BaseArc<NativeRtDir>> = Default::default();
128 let dir_handles = RefCell::new(&mut dir_handles);
129 let mut read_dir_streams: Slab<Pin<BaseArc<AsyncMutex<Fuse<ReadDir>>>>> = Default::default();
130 let read_dir_streams = RefCell::new(&mut read_dir_streams);
131
132 let router = BaseArc::new(HeaderRouter::new(stream_raw));
133
134 let mut futures = FuturesUnordered::new();
135
136 let (tx, rx) = flume::bounded(512);
137
138 let ingress_loop = async {
139 use mfio::traits::IoRead;
140 while let Ok(v) = {
141 let header_span = tracing::span!(tracing::Level::TRACE, "server read Request header");
142 trace!("Queue req read");
143 stream_raw
144 .read::<Request>(NoPos::new())
145 .instrument(header_span)
146 .await
147 } {
148 let end_span = tracing::span!(tracing::Level::TRACE, "server read Request");
149 let op = async {
150 // Verify that the tag is proper, since otherwise we may jump to the wrong place of
151 // code. TODO: use proper deserialization techniques
152 // SAFETY: memunsafe made safe
153 // while adding this check saves us from memory safety bugs, this will probably
154 // still lead to arbitrarily large allocations that make us crash.
155 let tag = unsafe { *(&v as *const _ as *const u8) };
156 assert!(tag < 5, "incoming data tag is invalid {tag}");
157
158 trace!("Receive req: {v:?}");
159
160 match v {
161 Request::Read {
162 file_id,
163 packet_id,
164 pos,
165 len,
166 } => Operation::Read {
167 file_id,
168 packet_id,
169 pos,
170 len,
171 },
172 Request::Write {
173 file_id,
174 packet_id,
175 pos,
176 len,
177 } => {
178 let mut buf = vec![0; len as usize];
179 stream_raw
180 .read_all(NoPos::new(), &mut buf[..])
181 .await
182 .unwrap();
183 Operation::Write {
184 file_id,
185 packet_id,
186 pos,
187 buf: buf.into(),
188 }
189 }
190 Request::Fs {
191 req_id,
192 dir_id,
193 req_len,
194 } => {
195 let mut buf = vec![0; req_len as usize];
196 stream_raw
197 .read_all(NoPos::new(), &mut buf[..])
198 .await
199 .unwrap();
200 let req: FsRequest = postcard::from_bytes(&buf).unwrap();
201 Operation::Fs {
202 req_id,
203 dir_id,
204 req,
205 }
206 }
207 Request::ReadDir { stream_id, count } => {
208 Operation::ReadDir { stream_id, count }
209 }
210 Request::FileClose { file_id } => Operation::FileClose { file_id },
211 }
212 }
213 .instrument(end_span)
214 .await;
215
216 if tx.send_async(op).await.is_err() {
217 break;
218 }
219 }
220
221 core::mem::drop(tx);
222
223 trace!("Ingress loop end");
224 }
225 .instrument(tracing::span!(tracing::Level::TRACE, "server ingress_loop"));
226
227 let process_loop = async {
228 loop {
229 match futures::select! {
230 res = rx.recv_async() => {
231 Ok(res)
232 }
233 res = futures.next() => {
234 Err(res)
235 }
236 complete => break,
237 } {
238 Ok(Ok(op)) => {
239 trace!("Io thread op {op:?}");
240 let fut = async {
241 trace!("Start process {op:?}");
242 match op {
243 Operation::Read {
244 file_id,
245 packet_id,
246 pos,
247 len,
248 } => {
249 let req_span = tracing::span!(
250 tracing::Level::TRACE,
251 "read request",
252 file_id,
253 packet_id,
254 pos,
255 len
256 );
257 async {
258 /*router.send_bytes(
259 |_| Response::Read {
260 packet_id,
261 idx: 0,
262 len,
263 err: None,
264 },
265 vec![0; len].into_boxed_slice(),
266 );*/
267 let packet = BaseArc::new(ReadPacket::new(len));
268
269 let fh = file_handles.borrow().get(file_id as usize).cloned();
270
271 if let Some(fh) = fh {
272 log::trace!("Read raw {len}");
273
274 fh.io_to_fn(
275 pos,
276 packet.clone().transpose().into_base().unwrap(),
277 {
278 let router = router.clone();
279 move |pkt, err| {
280 if let Some(err) = err.map(Error::into_int_err)
281 {
282 log::trace!("Err {err:?}");
283 router.send_hdr(|_| Response::Read {
284 packet_id,
285 idx: pkt.start(),
286 len: pkt.len(),
287 err: Some(err),
288 });
289 } else {
290 log::trace!("Resp read {}", pkt.len());
291
292 let buf = {
293 let mut shards = packet.shards.lock();
294 shards.remove(&pkt.start()).expect(
295 "Successful packet, but no written shard",
296 )
297 };
298
299 log::trace!(
300 "Send bytes {:?}",
301 buf.simple_slice().map(|v| v.len())
302 );
303
304 router.send_bytes(
305 |_| Response::Read {
306 packet_id,
307 idx: pkt.start(),
308 len: pkt.len(),
309 err: None,
310 },
311 buf,
312 );
313 }
314 }
315 },
316 )
317 .await;
318 }
319 }
320 .instrument(req_span)
321 .await
322 }
323 Operation::Write {
324 file_id,
325 packet_id,
326 pos,
327 buf,
328 } => {
329 let req_span = tracing::span!(
330 tracing::Level::TRACE,
331 "write request",
332 file_id,
333 packet_id,
334 pos,
335 len = buf.len()
336 );
337 async {
338 let fh = file_handles.borrow().get(file_id as usize).cloned();
339
340 if let Some(fh) = fh {
341 fh.io_to_fn(
342 pos,
343 OwnedPacket::<Read>::from(buf.0.into_boxed_slice()),
344 {
345 let router = router.clone();
346 move |pkt, err| {
347 router.send_hdr(|_| Response::Write {
348 packet_id,
349 idx: pkt.start(),
350 len: pkt.len(),
351 err: err.map(Error::into_int_err),
352 })
353 }
354 },
355 )
356 .await;
357 }
358 }
359 .instrument(req_span)
360 .await
361 }
362 Operation::ReadDir { stream_id, count } => {
363 let read_dir_span = tracing::span!(
364 tracing::Level::TRACE,
365 "read dir",
366 stream_id,
367 count,
368 );
369 async {
370 let stream = read_dir_streams.borrow().get(stream_id as usize).cloned();
371 let rm_stream = if let Some(stream) = stream
372 {
373 let stream_buf = &mut *stream.lock().await;
374 // SAFETY: we already ensure the stream is pinned at the
375 // storage level.
376 let stream =
377 unsafe { Pin::new_unchecked(&mut *stream_buf) };
378
379 let res = stream
380 .take(count as usize)
381 .map(|v| v.map_err(Error::into_int_err))
382 .collect::<Vec<_>>()
383 .await;
384
385 let buf = postcard::to_allocvec(&res).unwrap();
386
387 // TODO: we need to somehow ensure that this condition
388 // never occurs.
389 assert_eq!(buf.len() & (1 << 31), 0);
390 assert!(buf.len() < u32::MAX as usize);
391
392 let mut len = buf.len() as u32;
393
394 // SAFETY: we already ensure the stream is pinned at its
395 // storage level.
396 let stream = unsafe { Pin::new_unchecked(stream_buf) };
397
398 if stream.is_terminated() {
399 len |= 1 << 31;
400 }
401
402 router.send_bytes(
403 |_| Response::ReadDir { stream_id, len },
404 OwnedPacket::from(buf.into_boxed_slice()),
405 );
406
407 stream.is_terminated()
408 } else {
409 router
410 .send_hdr(|_| Response::ReadDir { stream_id, len: 0 });
411 false
412 };
413
414 if rm_stream {
415 read_dir_streams.borrow_mut().remove(stream_id as usize);
416 }
417 }
418 .instrument(read_dir_span)
419 .await
420 }
421 Operation::Fs {
422 req_id,
423 dir_id,
424 req,
425 } => {
426 let req_span = tracing::span!(
427 tracing::Level::TRACE,
428 "fs request",
429 req_id,
430 dir_id,
431 );
432 async {
433 let dh = if dir_id > 0 {
434 let ret = Some(
435 dir_handles.borrow().get(dir_id as usize - 1).cloned(),
436 );
437 ret
438 } else {
439 None
440 };
441
442 let dh = dh.as_ref().map(|v| v.as_deref());
443 let dh = dh.unwrap_or_else(|| Some(fs.current_dir()));
444
445 if let Some(dh) = dh {
446 let resp = match req {
447 FsRequest::Path => {
448 trace!("Get path");
449 let path = dh
450 .path()
451 .await
452 .map(|p| p.to_string_lossy().into())
453 .map_err(Error::into_int_err);
454 FsResponse::Path { path }
455 }
456 FsRequest::OpenFile { path, options } => {
457 trace!("Open file {path}");
458 let file_id = match dh
459 .open_file(Path::new(&path), options)
460 .await
461 {
462 Ok(file) => {
463 let file_id = file_handles
464 .borrow_mut()
465 .insert(BaseArc::new(file));
466
467 assert!(file_id <= u32::MAX as usize);
468
469 Ok(file_id as u32)
470 }
471 Err(err) => Err(err.into_int_err()),
472 };
473 trace!("Opened file {file_id:?}");
474
475 FsResponse::OpenFile { file_id }
476 }
477 FsRequest::OpenDir { path } => {
478 trace!("Open dir {path}");
479 let dir_id = match dh.open_dir(&path).await {
480 Ok(dir) => {
481 let dir_id = dir_handles
482 .borrow_mut()
483 .insert(BaseArc::new(dir))
484 + 1;
485
486 assert!(dir_id <= u16::MAX as usize);
487
488 Ok(NonZeroU16::new(dir_id as u16).unwrap())
489 }
490 Err(err) => Err(err.into_int_err()),
491 };
492 trace!("Opened dir {dir_id:?}");
493
494 FsResponse::OpenDir { dir_id }
495 }
496 FsRequest::ReadDir => {
497 trace!("Read dir");
498 let stream_id = match dh.read_dir().await {
499 Ok(stream) => {
500 let stream_id =
501 read_dir_streams.borrow_mut().insert(
502 BaseArc::pin(stream.fuse().into()),
503 );
504
505 assert!(stream_id <= u16::MAX as usize);
506
507 Ok(stream_id as u16)
508 }
509 Err(err) => Err(err.into_int_err()),
510 };
511 trace!("Opened read handle {stream_id:?}");
512
513 FsResponse::ReadDir { stream_id }
514 }
515 FsRequest::Metadata { path } => {
516 trace!("Metadata {path}");
517 let metadata = dh
518 .metadata(&path)
519 .await
520 .map_err(Error::into_int_err);
521 FsResponse::Metadata { metadata }
522 }
523 FsRequest::DirOp(op) => {
524 trace!("Do dir op");
525 FsResponse::DirOp(
526 dh.do_op(op.as_path())
527 .await
528 .map_err(Error::into_int_err)
529 .err(),
530 )
531 }
532 };
533
534 let resp = postcard::to_allocvec(&resp).unwrap();
535
536 assert!(resp.len() <= u16::MAX as usize);
537
538 let resp_len = resp.len() as u16;
539
540 router.send_bytes(
541 |_| Response::Fs { req_id, resp_len },
542 OwnedPacket::from(resp.into_boxed_slice()),
543 );
544
545 trace!("Written response for {req_id}");
546 } else {
547 router.send_hdr(|_| Response::Fs {
548 req_id,
549 resp_len: 0,
550 })
551 }
552 }
553 .instrument(req_span)
554 .await
555 }
556 Operation::FileClose { file_id } => {
557 trace!("Close {file_id}");
558 file_handles.borrow_mut().remove(file_id as _);
559 }
560 }
561
562 trace!("Finish processing op");
563 };
564 futures.push(fut);
565 }
566 Ok(Err(_)) => break,
567 Err(_) => {}
568 }
569 }
570
571 log::trace!("Process loop done");
572 }
573 .instrument(tracing::span!(tracing::Level::TRACE, "server process_loop"));
574
575 let l1 = async move { futures::join!(process_loop, ingress_loop) }.fuse();
576 l1.await;
577}
578
579fn single_client_server_with(
580 addr: SocketAddr,
581 fs: NativeRt,
582) -> (std::thread::JoinHandle<()>, SocketAddr) {
583 let (tx, rx) = flume::bounded(1);
584
585 let ret = std::thread::spawn(move || {
586 fs.block_on(async {
587 let mut listener = fs.bind(addr).await.unwrap();
588 let _ = tx.send_async(listener.local_addr().unwrap()).await;
589 let (stream, _) = listener.next().await.unwrap();
590 run_server(stream, &fs).await
591 });
592
593 trace!("Server done polling");
594 });
595
596 let addr = rx.recv().unwrap();
597
598 (ret, addr)
599}
600
601pub fn single_client_server(addr: SocketAddr) -> (std::thread::JoinHandle<()>, SocketAddr) {
602 single_client_server_with(addr, NativeRt::default())
603}
604
605pub async fn server_bind(fs: &NativeRt, bind_addr: SocketAddr) {
606 let listener = fs.bind(bind_addr).await.unwrap();
607 server(fs, listener).await
608}
609
610pub async fn server(fs: &NativeRt, listener: NativeTcpListener) {
611 let clients = listener.fuse();
612 futures::pin_mut!(clients);
613
614 // TODO: load balance clients with multiple FS instances per-thread.
615 let mut futures = FuturesUnordered::new();
616
617 loop {
618 match futures::select! {
619 res = clients.next() => {
620 Ok(res)
621 }
622 res = futures.next() => {
623 Err(res)
624 }
625 complete => break,
626 } {
627 Ok(Some((stream, peer))) => futures.push(async move {
628 run_server(stream, fs).await;
629 trace!("{peer:?} finished");
630 peer
631 }),
632 Err(peer) => debug!("{peer:?} finished"),
633 _ => break,
634 }
635 }
636}
637
638// TODO: test on miri
639#[cfg(all(test, not(miri)))]
640mod tests {
641 use super::super::client::NetworkFs;
642 use super::*;
643 use mfio::traits::IoRead;
644 use mfio_rt::{Fs, OpenOptions};
645 use std::path::Path;
646
647 #[test]
648 fn fs_test() {
649 let _ = ::env_logger::builder().is_test(true).try_init();
650 let addr: SocketAddr = "127.0.0.1:54321".parse().unwrap();
651
652 let (server, addr) = single_client_server(addr);
653
654 let fs = mfio_rt::NativeRt::builder().thread(true).build().unwrap();
655 let fs = NetworkFs::with_fs(addr, fs.into(), true).unwrap();
656
657 fs.block_on(async {
658 println!("Conned");
659 let fh = fs
660 .open(Path::new("./Cargo.toml"), OpenOptions::new().read(true))
661 .await
662 .unwrap();
663 println!("Got fh");
664 let mut out = vec![];
665 fh.read_to_end(0, &mut out).await.unwrap();
666 println!("Read to end");
667 println!("{}", String::from_utf8(out).unwrap());
668 });
669
670 println!("Drop fs");
671
672 core::mem::drop(fs);
673
674 println!("Dropped fs");
675
676 server.join().unwrap();
677 }
678
679 mfio_rt::test_suite!(tests, |test_name, closure| {
680 let _ = ::env_logger::builder().is_test(true).try_init();
681 use super::{single_client_server_with, NetworkFs, SocketAddr};
682 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
683
684 let mut rt = mfio_rt::NativeRt::default();
685 let dir = TempDir::new(test_name).unwrap();
686 rt.set_cwd(dir.path().to_path_buf());
687 let (server, addr) = single_client_server_with(addr, rt);
688
689 let rt = mfio_rt::NativeRt::default();
690
691 let mut rt = NetworkFs::with_fs(addr, rt.into(), true).unwrap();
692
693 let fs = staticify(&mut rt);
694
695 pub fn run<'a, Func: FnOnce(&'a NetworkFs) -> F, F: Future>(
696 fs: &'a mut NetworkFs,
697 func: Func,
698 ) -> F::Output {
699 fs.block_on(func(fs))
700 }
701
702 run(fs, move |rt| {
703 let run = TestRun::new(rt, dir);
704 closure(run)
705 });
706
707 core::mem::drop(rt);
708
709 log::trace!("Joining thread");
710
711 server.join().unwrap();
712 });
713}