mfio_netfs/net/
server.rs

1use core::cell::RefCell;
2use core::num::NonZeroU16;
3use core::pin::Pin;
4use std::collections::BTreeMap;
5use std::net::SocketAddr;
6use std::path::Path;
7
8use super::{FsRequest, FsResponse, HeaderRouter, Request, Response};
9
10use async_mutex::Mutex as AsyncMutex;
11use cglue::result::IntError;
12use log::*;
13use mfio::backend::IoBackendExt;
14use mfio::error::Error;
15use mfio::io::{NoPos, OwnedPacket, PacketIoExt, PacketView, PacketVtblRef, Read};
16use mfio::stdeq::Seekable;
17use mfio::tarc::BaseArc;
18use mfio_rt::{
19    native::{NativeRtDir, ReadDir},
20    DirHandle, Fs, NativeFile, NativeRt, TcpListenerHandle,
21};
22use parking_lot::Mutex;
23use slab::Slab;
24use tracing::instrument::Instrument;
25
26use debug_ignore::DebugIgnore;
27use futures::{
28    future::FutureExt,
29    stream::{Fuse, FusedStream, FuturesUnordered, StreamExt},
30};
31use mfio_rt::{
32    native::{NativeTcpListener, NativeTcpStream},
33    Tcp,
34};
35
36#[derive(Debug)]
37enum Operation {
38    Read {
39        file_id: u32,
40        packet_id: u32,
41        pos: u64,
42        len: u64,
43    },
44    Write {
45        file_id: u32,
46        packet_id: u32,
47        pos: u64,
48        buf: DebugIgnore<Vec<u8>>,
49    },
50    FileClose {
51        file_id: u32,
52    },
53    ReadDir {
54        stream_id: u16,
55        count: u16,
56    },
57    Fs {
58        req_id: u32,
59        dir_id: u16,
60        req: FsRequest,
61    },
62}
63
64#[repr(C)]
65struct ReadPacket {
66    hdr: Packet<Write>,
67    len: u64,
68    shards: Mutex<BTreeMap<u64, BaseArc<Packet<Read>>>>,
69}
70
71impl ReadPacket {
72    pub fn new(capacity: u64) -> Self {
73        unsafe extern "C" fn len(pkt: &Packet<Write>) -> u64 {
74            unsafe {
75                let this = &*(pkt as *const Packet<Write> as *const ReadPacket);
76                this.len
77            }
78        }
79
80        unsafe extern "C" fn get_mut(
81            _: &mut ManuallyDrop<BoundPacketView<Write>>,
82            _: usize,
83            _: &mut MaybeUninit<WritePacketObj>,
84        ) -> bool {
85            false
86        }
87
88        unsafe extern "C" fn transfer_data(obj: &mut PacketView<'_, Write>, src: *const ()) {
89            let this = &*(obj.pkt() as *const Packet<Write> as *const ReadPacket);
90            let len = obj.len();
91            let idx = obj.start();
92            let buf = src as *const u8;
93            let buf = core::slice::from_raw_parts(buf, len as usize);
94            let pkt = Packet::<Read>::copy_from_slice(buf);
95            this.shards.lock().insert(idx, pkt);
96        }
97
98        Self {
99            hdr: unsafe {
100                Packet::new_hdr(PacketVtblRef {
101                    vtbl: &Write {
102                        len,
103                        get_mut,
104                        transfer_data,
105                    },
106                })
107            },
108            len: capacity,
109            shards: Default::default(),
110        }
111    }
112}
113
114impl AsRef<Packet<Write>> for ReadPacket {
115    fn as_ref(&self) -> &Packet<Write> {
116        &self.hdr
117    }
118}
119
120use core::mem::{ManuallyDrop, MaybeUninit};
121use mfio::io::{BoundPacketView, Packet, Write, WritePacketObj};
122
123async fn run_server(stream: NativeTcpStream, fs: &NativeRt) {
124    let stream_raw = &stream;
125    let mut file_handles: Slab<BaseArc<Seekable<NativeFile, u64>>> = Default::default();
126    let file_handles = RefCell::new(&mut file_handles);
127    let mut dir_handles: Slab<BaseArc<NativeRtDir>> = Default::default();
128    let dir_handles = RefCell::new(&mut dir_handles);
129    let mut read_dir_streams: Slab<Pin<BaseArc<AsyncMutex<Fuse<ReadDir>>>>> = Default::default();
130    let read_dir_streams = RefCell::new(&mut read_dir_streams);
131
132    let router = BaseArc::new(HeaderRouter::new(stream_raw));
133
134    let mut futures = FuturesUnordered::new();
135
136    let (tx, rx) = flume::bounded(512);
137
138    let ingress_loop = async {
139        use mfio::traits::IoRead;
140        while let Ok(v) = {
141            let header_span = tracing::span!(tracing::Level::TRACE, "server read Request header");
142            trace!("Queue req read");
143            stream_raw
144                .read::<Request>(NoPos::new())
145                .instrument(header_span)
146                .await
147        } {
148            let end_span = tracing::span!(tracing::Level::TRACE, "server read Request");
149            let op = async {
150                // Verify that the tag is proper, since otherwise we may jump to the wrong place of
151                // code. TODO: use proper deserialization techniques
152                // SAFETY: memunsafe made safe
153                // while adding this check saves us from memory safety bugs, this will probably
154                // still lead to arbitrarily large allocations that make us crash.
155                let tag = unsafe { *(&v as *const _ as *const u8) };
156                assert!(tag < 5, "incoming data tag is invalid {tag}");
157
158                trace!("Receive req: {v:?}");
159
160                match v {
161                    Request::Read {
162                        file_id,
163                        packet_id,
164                        pos,
165                        len,
166                    } => Operation::Read {
167                        file_id,
168                        packet_id,
169                        pos,
170                        len,
171                    },
172                    Request::Write {
173                        file_id,
174                        packet_id,
175                        pos,
176                        len,
177                    } => {
178                        let mut buf = vec![0; len as usize];
179                        stream_raw
180                            .read_all(NoPos::new(), &mut buf[..])
181                            .await
182                            .unwrap();
183                        Operation::Write {
184                            file_id,
185                            packet_id,
186                            pos,
187                            buf: buf.into(),
188                        }
189                    }
190                    Request::Fs {
191                        req_id,
192                        dir_id,
193                        req_len,
194                    } => {
195                        let mut buf = vec![0; req_len as usize];
196                        stream_raw
197                            .read_all(NoPos::new(), &mut buf[..])
198                            .await
199                            .unwrap();
200                        let req: FsRequest = postcard::from_bytes(&buf).unwrap();
201                        Operation::Fs {
202                            req_id,
203                            dir_id,
204                            req,
205                        }
206                    }
207                    Request::ReadDir { stream_id, count } => {
208                        Operation::ReadDir { stream_id, count }
209                    }
210                    Request::FileClose { file_id } => Operation::FileClose { file_id },
211                }
212            }
213            .instrument(end_span)
214            .await;
215
216            if tx.send_async(op).await.is_err() {
217                break;
218            }
219        }
220
221        core::mem::drop(tx);
222
223        trace!("Ingress loop end");
224    }
225    .instrument(tracing::span!(tracing::Level::TRACE, "server ingress_loop"));
226
227    let process_loop = async {
228        loop {
229            match futures::select! {
230                res = rx.recv_async() => {
231                    Ok(res)
232                }
233                res = futures.next() => {
234                    Err(res)
235                }
236                complete => break,
237            } {
238                Ok(Ok(op)) => {
239                    trace!("Io thread op {op:?}");
240                    let fut = async {
241                        trace!("Start process {op:?}");
242                        match op {
243                            Operation::Read {
244                                file_id,
245                                packet_id,
246                                pos,
247                                len,
248                            } => {
249                                let req_span = tracing::span!(
250                                    tracing::Level::TRACE,
251                                    "read request",
252                                    file_id,
253                                    packet_id,
254                                    pos,
255                                    len
256                                );
257                                async {
258                                    /*router.send_bytes(
259                                        |_| Response::Read {
260                                            packet_id,
261                                            idx: 0,
262                                            len,
263                                            err: None,
264                                        },
265                                        vec![0; len].into_boxed_slice(),
266                                    );*/
267                                    let packet = BaseArc::new(ReadPacket::new(len));
268
269                                    let fh = file_handles.borrow().get(file_id as usize).cloned();
270
271                                    if let Some(fh) = fh {
272                                        log::trace!("Read raw {len}");
273
274                                        fh.io_to_fn(
275                                            pos,
276                                            packet.clone().transpose().into_base().unwrap(),
277                                            {
278                                                let router = router.clone();
279                                                move |pkt, err| {
280                                                    if let Some(err) = err.map(Error::into_int_err)
281                                                    {
282                                                        log::trace!("Err {err:?}");
283                                                        router.send_hdr(|_| Response::Read {
284                                                            packet_id,
285                                                            idx: pkt.start(),
286                                                            len: pkt.len(),
287                                                            err: Some(err),
288                                                        });
289                                                    } else {
290                                                        log::trace!("Resp read {}", pkt.len());
291
292                                                        let buf = {
293                                                            let mut shards = packet.shards.lock();
294                                                            shards.remove(&pkt.start()).expect(
295                                                                "Successful packet, but no written shard",
296                                                            )
297                                                        };
298
299                                                        log::trace!(
300                                                            "Send bytes {:?}",
301                                                            buf.simple_slice().map(|v| v.len())
302                                                        );
303
304                                                        router.send_bytes(
305                                                            |_| Response::Read {
306                                                                packet_id,
307                                                                idx: pkt.start(),
308                                                                len: pkt.len(),
309                                                                err: None,
310                                                            },
311                                                            buf,
312                                                        );
313                                                    }
314                                                }
315                                            },
316                                        )
317                                        .await;
318                                    }
319                                }
320                                .instrument(req_span)
321                                .await
322                            }
323                            Operation::Write {
324                                file_id,
325                                packet_id,
326                                pos,
327                                buf,
328                            } => {
329                                let req_span = tracing::span!(
330                                    tracing::Level::TRACE,
331                                    "write request",
332                                    file_id,
333                                    packet_id,
334                                    pos,
335                                    len = buf.len()
336                                );
337                                async {
338                                    let fh = file_handles.borrow().get(file_id as usize).cloned();
339
340                                    if let Some(fh) = fh {
341                                        fh.io_to_fn(
342                                            pos,
343                                            OwnedPacket::<Read>::from(buf.0.into_boxed_slice()),
344                                            {
345                                                let router = router.clone();
346                                                move |pkt, err| {
347                                                    router.send_hdr(|_| Response::Write {
348                                                        packet_id,
349                                                        idx: pkt.start(),
350                                                        len: pkt.len(),
351                                                        err: err.map(Error::into_int_err),
352                                                    })
353                                                }
354                                            },
355                                        )
356                                        .await;
357                                    }
358                                }
359                                .instrument(req_span)
360                                .await
361                            }
362                            Operation::ReadDir { stream_id, count } => {
363                                let read_dir_span = tracing::span!(
364                                    tracing::Level::TRACE,
365                                    "read dir",
366                                    stream_id,
367                                    count,
368                                );
369                                async {
370                                    let stream = read_dir_streams.borrow().get(stream_id as usize).cloned();
371                                    let rm_stream = if let Some(stream) = stream
372                                    {
373                                        let stream_buf = &mut *stream.lock().await;
374                                        // SAFETY: we already ensure the stream is pinned at the
375                                        // storage level.
376                                        let stream =
377                                            unsafe { Pin::new_unchecked(&mut *stream_buf) };
378
379                                        let res = stream
380                                            .take(count as usize)
381                                            .map(|v| v.map_err(Error::into_int_err))
382                                            .collect::<Vec<_>>()
383                                            .await;
384
385                                        let buf = postcard::to_allocvec(&res).unwrap();
386
387                                        // TODO: we need to somehow ensure that this condition
388                                        // never occurs.
389                                        assert_eq!(buf.len() & (1 << 31), 0);
390                                        assert!(buf.len() < u32::MAX as usize);
391
392                                        let mut len = buf.len() as u32;
393
394                                        // SAFETY: we already ensure the stream is pinned at its
395                                        // storage level.
396                                        let stream = unsafe { Pin::new_unchecked(stream_buf) };
397
398                                        if stream.is_terminated() {
399                                            len |= 1 << 31;
400                                        }
401
402                                        router.send_bytes(
403                                            |_| Response::ReadDir { stream_id, len },
404                                            OwnedPacket::from(buf.into_boxed_slice()),
405                                        );
406
407                                        stream.is_terminated()
408                                    } else {
409                                        router
410                                            .send_hdr(|_| Response::ReadDir { stream_id, len: 0 });
411                                        false
412                                    };
413
414                                    if rm_stream {
415                                        read_dir_streams.borrow_mut().remove(stream_id as usize);
416                                    }
417                                }
418                                .instrument(read_dir_span)
419                                .await
420                            }
421                            Operation::Fs {
422                                req_id,
423                                dir_id,
424                                req,
425                            } => {
426                                let req_span = tracing::span!(
427                                    tracing::Level::TRACE,
428                                    "fs request",
429                                    req_id,
430                                    dir_id,
431                                );
432                                async {
433                                    let dh = if dir_id > 0 {
434                                        let ret = Some(
435                                            dir_handles.borrow().get(dir_id as usize - 1).cloned(),
436                                        );
437                                        ret
438                                    } else {
439                                        None
440                                    };
441
442                                    let dh = dh.as_ref().map(|v| v.as_deref());
443                                    let dh = dh.unwrap_or_else(|| Some(fs.current_dir()));
444
445                                    if let Some(dh) = dh {
446                                        let resp = match req {
447                                            FsRequest::Path => {
448                                                trace!("Get path");
449                                                let path = dh
450                                                    .path()
451                                                    .await
452                                                    .map(|p| p.to_string_lossy().into())
453                                                    .map_err(Error::into_int_err);
454                                                FsResponse::Path { path }
455                                            }
456                                            FsRequest::OpenFile { path, options } => {
457                                                trace!("Open file {path}");
458                                                let file_id = match dh
459                                                    .open_file(Path::new(&path), options)
460                                                    .await
461                                                {
462                                                    Ok(file) => {
463                                                        let file_id = file_handles
464                                                            .borrow_mut()
465                                                            .insert(BaseArc::new(file));
466
467                                                        assert!(file_id <= u32::MAX as usize);
468
469                                                        Ok(file_id as u32)
470                                                    }
471                                                    Err(err) => Err(err.into_int_err()),
472                                                };
473                                                trace!("Opened file {file_id:?}");
474
475                                                FsResponse::OpenFile { file_id }
476                                            }
477                                            FsRequest::OpenDir { path } => {
478                                                trace!("Open dir {path}");
479                                                let dir_id = match dh.open_dir(&path).await {
480                                                    Ok(dir) => {
481                                                        let dir_id = dir_handles
482                                                            .borrow_mut()
483                                                            .insert(BaseArc::new(dir))
484                                                            + 1;
485
486                                                        assert!(dir_id <= u16::MAX as usize);
487
488                                                        Ok(NonZeroU16::new(dir_id as u16).unwrap())
489                                                    }
490                                                    Err(err) => Err(err.into_int_err()),
491                                                };
492                                                trace!("Opened dir {dir_id:?}");
493
494                                                FsResponse::OpenDir { dir_id }
495                                            }
496                                            FsRequest::ReadDir => {
497                                                trace!("Read dir");
498                                                let stream_id = match dh.read_dir().await {
499                                                    Ok(stream) => {
500                                                        let stream_id =
501                                                            read_dir_streams.borrow_mut().insert(
502                                                                BaseArc::pin(stream.fuse().into()),
503                                                            );
504
505                                                        assert!(stream_id <= u16::MAX as usize);
506
507                                                        Ok(stream_id as u16)
508                                                    }
509                                                    Err(err) => Err(err.into_int_err()),
510                                                };
511                                                trace!("Opened read handle {stream_id:?}");
512
513                                                FsResponse::ReadDir { stream_id }
514                                            }
515                                            FsRequest::Metadata { path } => {
516                                                trace!("Metadata {path}");
517                                                let metadata = dh
518                                                    .metadata(&path)
519                                                    .await
520                                                    .map_err(Error::into_int_err);
521                                                FsResponse::Metadata { metadata }
522                                            }
523                                            FsRequest::DirOp(op) => {
524                                                trace!("Do dir op");
525                                                FsResponse::DirOp(
526                                                    dh.do_op(op.as_path())
527                                                        .await
528                                                        .map_err(Error::into_int_err)
529                                                        .err(),
530                                                )
531                                            }
532                                        };
533
534                                        let resp = postcard::to_allocvec(&resp).unwrap();
535
536                                        assert!(resp.len() <= u16::MAX as usize);
537
538                                        let resp_len = resp.len() as u16;
539
540                                        router.send_bytes(
541                                            |_| Response::Fs { req_id, resp_len },
542                                            OwnedPacket::from(resp.into_boxed_slice()),
543                                        );
544
545                                        trace!("Written response for {req_id}");
546                                    } else {
547                                        router.send_hdr(|_| Response::Fs {
548                                            req_id,
549                                            resp_len: 0,
550                                        })
551                                    }
552                                }
553                                .instrument(req_span)
554                                .await
555                            }
556                            Operation::FileClose { file_id } => {
557                                trace!("Close {file_id}");
558                                file_handles.borrow_mut().remove(file_id as _);
559                            }
560                        }
561
562                        trace!("Finish processing op");
563                    };
564                    futures.push(fut);
565                }
566                Ok(Err(_)) => break,
567                Err(_) => {}
568            }
569        }
570
571        log::trace!("Process loop done");
572    }
573    .instrument(tracing::span!(tracing::Level::TRACE, "server process_loop"));
574
575    let l1 = async move { futures::join!(process_loop, ingress_loop) }.fuse();
576    l1.await;
577}
578
579fn single_client_server_with(
580    addr: SocketAddr,
581    fs: NativeRt,
582) -> (std::thread::JoinHandle<()>, SocketAddr) {
583    let (tx, rx) = flume::bounded(1);
584
585    let ret = std::thread::spawn(move || {
586        fs.block_on(async {
587            let mut listener = fs.bind(addr).await.unwrap();
588            let _ = tx.send_async(listener.local_addr().unwrap()).await;
589            let (stream, _) = listener.next().await.unwrap();
590            run_server(stream, &fs).await
591        });
592
593        trace!("Server done polling");
594    });
595
596    let addr = rx.recv().unwrap();
597
598    (ret, addr)
599}
600
601pub fn single_client_server(addr: SocketAddr) -> (std::thread::JoinHandle<()>, SocketAddr) {
602    single_client_server_with(addr, NativeRt::default())
603}
604
605pub async fn server_bind(fs: &NativeRt, bind_addr: SocketAddr) {
606    let listener = fs.bind(bind_addr).await.unwrap();
607    server(fs, listener).await
608}
609
610pub async fn server(fs: &NativeRt, listener: NativeTcpListener) {
611    let clients = listener.fuse();
612    futures::pin_mut!(clients);
613
614    // TODO: load balance clients with multiple FS instances per-thread.
615    let mut futures = FuturesUnordered::new();
616
617    loop {
618        match futures::select! {
619            res = clients.next() => {
620                Ok(res)
621            }
622            res = futures.next() => {
623                Err(res)
624            }
625            complete => break,
626        } {
627            Ok(Some((stream, peer))) => futures.push(async move {
628                run_server(stream, fs).await;
629                trace!("{peer:?} finished");
630                peer
631            }),
632            Err(peer) => debug!("{peer:?} finished"),
633            _ => break,
634        }
635    }
636}
637
638// TODO: test on miri
639#[cfg(all(test, not(miri)))]
640mod tests {
641    use super::super::client::NetworkFs;
642    use super::*;
643    use mfio::traits::IoRead;
644    use mfio_rt::{Fs, OpenOptions};
645    use std::path::Path;
646
647    #[test]
648    fn fs_test() {
649        let _ = ::env_logger::builder().is_test(true).try_init();
650        let addr: SocketAddr = "127.0.0.1:54321".parse().unwrap();
651
652        let (server, addr) = single_client_server(addr);
653
654        let fs = mfio_rt::NativeRt::builder().thread(true).build().unwrap();
655        let fs = NetworkFs::with_fs(addr, fs.into(), true).unwrap();
656
657        fs.block_on(async {
658            println!("Conned");
659            let fh = fs
660                .open(Path::new("./Cargo.toml"), OpenOptions::new().read(true))
661                .await
662                .unwrap();
663            println!("Got fh");
664            let mut out = vec![];
665            fh.read_to_end(0, &mut out).await.unwrap();
666            println!("Read to end");
667            println!("{}", String::from_utf8(out).unwrap());
668        });
669
670        println!("Drop fs");
671
672        core::mem::drop(fs);
673
674        println!("Dropped fs");
675
676        server.join().unwrap();
677    }
678
679    mfio_rt::test_suite!(tests, |test_name, closure| {
680        let _ = ::env_logger::builder().is_test(true).try_init();
681        use super::{single_client_server_with, NetworkFs, SocketAddr};
682        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
683
684        let mut rt = mfio_rt::NativeRt::default();
685        let dir = TempDir::new(test_name).unwrap();
686        rt.set_cwd(dir.path().to_path_buf());
687        let (server, addr) = single_client_server_with(addr, rt);
688
689        let rt = mfio_rt::NativeRt::default();
690
691        let mut rt = NetworkFs::with_fs(addr, rt.into(), true).unwrap();
692
693        let fs = staticify(&mut rt);
694
695        pub fn run<'a, Func: FnOnce(&'a NetworkFs) -> F, F: Future>(
696            fs: &'a mut NetworkFs,
697            func: Func,
698        ) -> F::Output {
699            fs.block_on(func(fs))
700        }
701
702        run(fs, move |rt| {
703            let run = TestRun::new(rt, dir);
704            closure(run)
705        });
706
707        core::mem::drop(rt);
708
709        log::trace!("Joining thread");
710
711        server.join().unwrap();
712    });
713}