network_reader/
lib.rs

1use std::io::{self, BufReader, BufWriter, SeekFrom, prelude::*};
2use std::net::{ToSocketAddrs, TcpStream, TcpListener};
3use std::convert::TryInto;
4
5#[cfg(test)]
6mod tests;
7
8const OPERATION_READ: u8 = 0xFF;
9const OPERATION_SEEK: u8 = 0xFE;
10
11const SEEK_FROM_START: u8 = 0;
12const SEEK_FROM_END: u8 = 1;
13const SEEK_FROM_CURRENT: u8 = 2;
14
15const RESULT_OK: u8 = 0;
16const RESULT_ERR: u8 = 1;
17
18pub struct Networked<R: Read + Seek> {
19    listener: TcpListener,
20    reader: R,
21}
22
23impl<R: Read + Seek> Networked<R> {
24    pub fn new<S: ToSocketAddrs>(reader: R, socket: S) -> io::Result<Self> {
25        Ok(Self {
26            reader,
27            listener: TcpListener::bind(socket)?
28        })
29    }
30    
31    pub fn new_buffered<S: ToSocketAddrs>(reader: R, socket: S) -> io::Result<Networked<BufReader<R>>> {
32        Ok(Networked {
33            reader: BufReader::new(reader),
34            listener: TcpListener::bind(socket)?
35        })
36    }
37    
38    pub fn listen(mut self) -> io::Result<()> {
39        for connection in self.listener.incoming() {
40            let mut connection = connection?;
41            let mut buf = [0u8];
42            while connection.read_exact(&mut buf).is_ok() {
43                match buf[0] {
44                    OPERATION_SEEK => {
45                        let mut buf = [0u8; 9];
46                        let pos = match connection.read_exact(&mut buf) {
47                            Ok(_) => {
48                                let offset = i64::from_be_bytes(buf[1..].try_into().unwrap());
49                                match buf[0] {
50                                    0 => SeekFrom::Start(offset as u64),
51                                    1 => SeekFrom::End(offset),
52                                    2 => SeekFrom::Current(offset),
53                                    _ => continue
54                                }
55                            },
56                            Err(_) => continue,
57                        };
58
59                        match self.reader.seek(pos) {
60                            Ok(ret) => {
61                                connection.write_all(&[RESULT_OK])?;
62                                connection.write_all(&u64::to_be_bytes(ret))?;
63                            }
64                            Err(_) => {
65                                connection.write_all(&[RESULT_ERR])?;
66                            }
67                        }
68                        connection.flush()?;
69                    }
70                    OPERATION_READ => {
71                        let mut buf = [0u8; 8];
72                        let amount = match connection.read_exact(&mut buf) {
73                            Ok(_) => u64::from_be_bytes(buf),
74                            Err(_) => continue,
75                        };
76                        
77                        let mut writer = BufWriter::new(&mut connection);
78                        let reader = &mut self.reader;
79                        let size = io::copy(&mut reader.take(amount), &mut writer)?;
80
81                        io::copy(&mut io::repeat(0).take(amount - size), &mut writer)?;
82                        writer.write_all(&size.to_be_bytes())?;
83                        writer.flush()?;
84                    }
85                    _ => continue
86                }
87            }
88        }
89        Ok(())
90    }
91}
92
93pub struct NetworkReader(TcpStream);
94
95impl NetworkReader {
96    pub fn new<Addr: ToSocketAddrs>(addr: Addr) -> io::Result<Self> {
97        TcpStream::connect(addr).map(Self)
98    }
99}
100
101impl Seek for NetworkReader {
102    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
103        self.0.write_all(&[OPERATION_SEEK])?;
104        self.0.write_all(&[match pos {
105            SeekFrom::Start(_) => SEEK_FROM_START,
106            SeekFrom::End(_) => SEEK_FROM_END,
107            SeekFrom::Current(_) => SEEK_FROM_CURRENT,
108        }])?;
109        self.0.write_all(&match pos {
110            SeekFrom::Start(offset) => offset.to_be_bytes(),
111            SeekFrom::End(offset) | SeekFrom::Current(offset) => offset.to_be_bytes(),
112        })?;
113        self.0.flush()?;
114
115        let mut result = [0u8];
116        self.0.read_exact(&mut result)?;
117
118        if result == [RESULT_OK] {
119            let mut val = [0u8; 8];
120            self.0.read_exact(&mut val)?;
121            
122            Ok(u64::from_be_bytes(val))
123        } else {
124            Err(io::Error::new(io::ErrorKind::Other, "server returned error"))
125        }
126    }
127}
128
129impl Read for NetworkReader {
130    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
131        self.0.write_all(&[OPERATION_READ])?;
132        self.0.write_all(&(buf.len() as u64).to_be_bytes())?;
133        
134        self.0.read_exact(buf)?;
135
136        let mut buf = [0u8; 8];
137        self.0.read_exact(&mut buf)?;
138
139        Ok(u64::from_be_bytes(buf) as usize)
140    }
141}