1use std::io::{self, BufReader, BufWriter, SeekFrom, prelude::*};
2use std::net::{ToSocketAddrs, TcpStream, TcpListener};
3use std::convert::TryInto;
4
5#[cfg(test)]
6mod tests;
7
8const OPERATION_READ: u8 = 0xFF;
9const OPERATION_SEEK: u8 = 0xFE;
10
11const SEEK_FROM_START: u8 = 0;
12const SEEK_FROM_END: u8 = 1;
13const SEEK_FROM_CURRENT: u8 = 2;
14
15const RESULT_OK: u8 = 0;
16const RESULT_ERR: u8 = 1;
17
18pub struct Networked<R: Read + Seek> {
19 listener: TcpListener,
20 reader: R,
21}
22
23impl<R: Read + Seek> Networked<R> {
24 pub fn new<S: ToSocketAddrs>(reader: R, socket: S) -> io::Result<Self> {
25 Ok(Self {
26 reader,
27 listener: TcpListener::bind(socket)?
28 })
29 }
30
31 pub fn new_buffered<S: ToSocketAddrs>(reader: R, socket: S) -> io::Result<Networked<BufReader<R>>> {
32 Ok(Networked {
33 reader: BufReader::new(reader),
34 listener: TcpListener::bind(socket)?
35 })
36 }
37
38 pub fn listen(mut self) -> io::Result<()> {
39 for connection in self.listener.incoming() {
40 let mut connection = connection?;
41 let mut buf = [0u8];
42 while connection.read_exact(&mut buf).is_ok() {
43 match buf[0] {
44 OPERATION_SEEK => {
45 let mut buf = [0u8; 9];
46 let pos = match connection.read_exact(&mut buf) {
47 Ok(_) => {
48 let offset = i64::from_be_bytes(buf[1..].try_into().unwrap());
49 match buf[0] {
50 0 => SeekFrom::Start(offset as u64),
51 1 => SeekFrom::End(offset),
52 2 => SeekFrom::Current(offset),
53 _ => continue
54 }
55 },
56 Err(_) => continue,
57 };
58
59 match self.reader.seek(pos) {
60 Ok(ret) => {
61 connection.write_all(&[RESULT_OK])?;
62 connection.write_all(&u64::to_be_bytes(ret))?;
63 }
64 Err(_) => {
65 connection.write_all(&[RESULT_ERR])?;
66 }
67 }
68 connection.flush()?;
69 }
70 OPERATION_READ => {
71 let mut buf = [0u8; 8];
72 let amount = match connection.read_exact(&mut buf) {
73 Ok(_) => u64::from_be_bytes(buf),
74 Err(_) => continue,
75 };
76
77 let mut writer = BufWriter::new(&mut connection);
78 let reader = &mut self.reader;
79 let size = io::copy(&mut reader.take(amount), &mut writer)?;
80
81 io::copy(&mut io::repeat(0).take(amount - size), &mut writer)?;
82 writer.write_all(&size.to_be_bytes())?;
83 writer.flush()?;
84 }
85 _ => continue
86 }
87 }
88 }
89 Ok(())
90 }
91}
92
93pub struct NetworkReader(TcpStream);
94
95impl NetworkReader {
96 pub fn new<Addr: ToSocketAddrs>(addr: Addr) -> io::Result<Self> {
97 TcpStream::connect(addr).map(Self)
98 }
99}
100
101impl Seek for NetworkReader {
102 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
103 self.0.write_all(&[OPERATION_SEEK])?;
104 self.0.write_all(&[match pos {
105 SeekFrom::Start(_) => SEEK_FROM_START,
106 SeekFrom::End(_) => SEEK_FROM_END,
107 SeekFrom::Current(_) => SEEK_FROM_CURRENT,
108 }])?;
109 self.0.write_all(&match pos {
110 SeekFrom::Start(offset) => offset.to_be_bytes(),
111 SeekFrom::End(offset) | SeekFrom::Current(offset) => offset.to_be_bytes(),
112 })?;
113 self.0.flush()?;
114
115 let mut result = [0u8];
116 self.0.read_exact(&mut result)?;
117
118 if result == [RESULT_OK] {
119 let mut val = [0u8; 8];
120 self.0.read_exact(&mut val)?;
121
122 Ok(u64::from_be_bytes(val))
123 } else {
124 Err(io::Error::new(io::ErrorKind::Other, "server returned error"))
125 }
126 }
127}
128
129impl Read for NetworkReader {
130 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
131 self.0.write_all(&[OPERATION_READ])?;
132 self.0.write_all(&(buf.len() as u64).to_be_bytes())?;
133
134 self.0.read_exact(buf)?;
135
136 let mut buf = [0u8; 8];
137 self.0.read_exact(&mut buf)?;
138
139 Ok(u64::from_be_bytes(buf) as usize)
140 }
141}