1use async_trait::async_trait;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use futures_util::stream::{Stream, StreamExt};
5use std::io;
6use std::os::unix::io::AsRawFd;
7use std::path::Path;
8use std::thread::JoinHandle;
9use tokio::{
10 fs::OpenOptions, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt,
11 io::ReadBuf, net::UnixStream,
12};
13
14use crate::{nbd, sys};
15
16#[async_trait(?Send)]
18pub trait BlockDevice {
19 async fn read(&mut self, offset: u64, buf: &mut [u8]) -> io::Result<()>;
21 async fn write(&mut self, _offset: u64, _buf: &[u8]) -> io::Result<()> {
23 Err(io::ErrorKind::InvalidInput.into())
24 }
25 async fn flush(&mut self) -> io::Result<()> {
27 Ok(())
28 }
29 async fn trim(&mut self, _offset: u64, _size: usize) -> io::Result<()> {
31 Ok(())
32 }
33}
34
35pub struct Server {
36 do_it_thread: Option<JoinHandle<io::Result<()>>>,
37 file: tokio::fs::File,
38}
39
40impl Drop for Server {
41 fn drop(&mut self) {
42 let _ = sys::disconnect(&self.file);
43 if let Some(do_it_thread) = self.do_it_thread.take() {
44 do_it_thread.join().expect("join thread").unwrap();
45 }
46 }
47}
48
49pub async fn attach_device<P, S>(
51 path: P,
52 socket: S,
53 block_size: u32,
54 block_count: u64,
55 read_only: bool,
56) -> io::Result<Server>
57where
58 P: AsRef<Path>,
59 S: AsRawFd + Send + 'static,
60{
61 let file = OpenOptions::new()
62 .read(true)
63 .write(true)
64 .open(path.as_ref())
65 .await?;
66
67 sys::set_block_size(&file, block_size)?;
68 sys::set_size_blocks(&file, block_count)?;
69 sys::set_timeout(&file, 10)?;
70 sys::clear_sock(&file)?;
71
72 let inner_file = file.try_clone().await?;
73 let do_it_thread = Some(std::thread::spawn(move || -> io::Result<()> {
74 sys::set_sock(&inner_file, socket.as_raw_fd())?;
75 if read_only {
76 sys::set_flags(&inner_file, sys::HAS_FLAGS | sys::READ_ONLY)?;
77 } else {
78 sys::set_flags(&inner_file, 0)?;
79 }
80 sys::do_it(&inner_file)?;
83 let _ = sys::clear_sock(&inner_file);
84 let _ = sys::clear_queue(&inner_file);
85 Ok(())
86 }));
87 Ok(Server { do_it_thread, file })
88}
89
90pub async fn serve_local_nbd<P, B>(
92 path: P,
93 block_size: u32,
94 block_count: u64,
95 read_only: bool,
96 block_device: B,
97) -> io::Result<()>
98where
99 P: AsRef<Path>,
100 B: Unpin + BlockDevice,
101{
102 let (sock, kernel_sock) = UnixStream::pair()?;
103 let _server = attach_device(path, kernel_sock, block_size, block_count, read_only).await?;
104 serve_nbd(block_device, sock).await?;
105 Ok(())
106}
107
108struct RequestStream<C> {
109 client: Option<C>,
110 read_buf: [u8; nbd::SIZE_OF_REQUEST],
111}
112
113pub async fn serve_nbd<B, C>(mut block_device: B, client: C) -> io::Result<()>
115where
116 B: Unpin + BlockDevice,
117 C: AsyncRead + AsyncWrite + Unpin,
118{
119 let mut stream = RequestStream {
120 client: Some(client),
121 read_buf: [0; nbd::SIZE_OF_REQUEST],
122 };
123
124 let mut reply_buf = vec![];
125 let mut write_buf = vec![];
126 while let Some(result) = stream.next().await {
127 let request = result?;
128 let request_handler = match stream.client {
129 Some(ref mut sock) => sock,
130 None => break,
131 };
132 let mut reply = nbd::Reply::from_request(&request);
133 match request.command {
134 nbd::Command::Read => {
135 reply_buf.resize(nbd::SIZE_OF_REPLY + request.len, 0);
136 if let Err(err) = block_device
137 .read(request.from, &mut reply_buf[nbd::SIZE_OF_REPLY..])
138 .await
139 {
140 reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
142 reply_buf.resize(nbd::SIZE_OF_REPLY, 0);
143 }
144 reply.write_to_slice(&mut reply_buf[..])?;
145 }
146 nbd::Command::Write => {
147 write_buf.resize(request.len, 0);
148 request_handler.read_exact(&mut write_buf).await?;
149 if let Err(err) = block_device.write(request.from, &write_buf[..]).await {
150 reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
151 }
152 reply.append_to_vec(&mut reply_buf)?;
153 }
154 nbd::Command::Flush => {
155 if let Err(err) = block_device.flush().await {
156 reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
157 }
158 reply.append_to_vec(&mut reply_buf)?;
159 }
160 nbd::Command::Trim => {
161 if let Err(err) = block_device.trim(request.from, request.len).await {
162 reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
163 }
164 reply.append_to_vec(&mut reply_buf)?;
165 }
166 nbd::Command::Disc => unimplemented!(),
167 nbd::Command::WriteZeroes => unimplemented!(),
168 }
169 request_handler.write_all(&reply_buf).await?;
170 reply_buf.clear();
171 }
172 Ok(())
173}
174
175impl<C> RequestStream<C>
176where
177 C: AsyncRead + AsyncWrite + Unpin,
178{
179 fn read_next(&mut self, cx: &mut Context) -> Poll<Option<io::Result<nbd::Request>>> {
180 let client = match self.client {
181 Some(ref mut client) => client,
182 None => return Poll::Ready(None),
183 };
184 let mut read_buf = ReadBuf::new(&mut self.read_buf);
185 let rc = Pin::new(client).poll_read(cx, &mut read_buf);
186 match rc {
187 Poll::Ready(Ok(())) => {
188 if read_buf.filled().is_empty() {
189 return Poll::Ready(None);
190 }
191 if read_buf.filled().len() != nbd::SIZE_OF_REQUEST {
192 return Poll::Ready(Some(Err(io::Error::from(io::ErrorKind::UnexpectedEof))));
193 }
194 }
195 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
196 Poll::Pending => {
197 return Poll::Pending;
198 }
199 };
200 Poll::Ready(Some(match nbd::Request::try_from_bytes(&self.read_buf) {
201 Ok(req) => Ok(req),
202 Err(err) => Err(err),
203 }))
204 }
205}
206
207impl<C> Stream for RequestStream<C>
208where
209 C: AsyncRead + AsyncWrite + Unpin,
210{
211 type Item = io::Result<nbd::Request>;
212 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
213 self.read_next(cx)
214 }
215}