nbd_async/
device.rs

1use async_trait::async_trait;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use futures_util::stream::{Stream, StreamExt};
5use std::io;
6use std::os::unix::io::AsRawFd;
7use std::path::Path;
8use std::thread::JoinHandle;
9use tokio::{
10    fs::OpenOptions, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt,
11    io::ReadBuf, net::UnixStream,
12};
13
14use crate::{nbd, sys};
15
16/// A block device.
17#[async_trait(?Send)]
18pub trait BlockDevice {
19    /// Read a block from offset.
20    async fn read(&mut self, offset: u64, buf: &mut [u8]) -> io::Result<()>;
21    /// Write a block of data at offset.
22    async fn write(&mut self, _offset: u64, _buf: &[u8]) -> io::Result<()> {
23        Err(io::ErrorKind::InvalidInput.into())
24    }
25    /// Flushes write buffers to the underlying storage medium
26    async fn flush(&mut self) -> io::Result<()> {
27        Ok(())
28    }
29    /// Marks blocks as unused
30    async fn trim(&mut self, _offset: u64, _size: usize) -> io::Result<()> {
31        Ok(())
32    }
33}
34
35pub struct Server {
36    do_it_thread: Option<JoinHandle<io::Result<()>>>,
37    file: tokio::fs::File,
38}
39
40impl Drop for Server {
41    fn drop(&mut self) {
42        let _ = sys::disconnect(&self.file);
43        if let Some(do_it_thread) = self.do_it_thread.take() {
44            do_it_thread.join().expect("join thread").unwrap();
45        }
46    }
47}
48
49/// Attach a socket to a NBD device file.
50pub async fn attach_device<P, S>(
51    path: P,
52    socket: S,
53    block_size: u32,
54    block_count: u64,
55    read_only: bool,
56) -> io::Result<Server>
57where
58    P: AsRef<Path>,
59    S: AsRawFd + Send + 'static,
60{
61    let file = OpenOptions::new()
62        .read(true)
63        .write(true)
64        .open(path.as_ref())
65        .await?;
66
67    sys::set_block_size(&file, block_size)?;
68    sys::set_size_blocks(&file, block_count)?;
69    sys::set_timeout(&file, 10)?;
70    sys::clear_sock(&file)?;
71
72    let inner_file = file.try_clone().await?;
73    let do_it_thread = Some(std::thread::spawn(move || -> io::Result<()> {
74        sys::set_sock(&inner_file, socket.as_raw_fd())?;
75        if read_only {
76            sys::set_flags(&inner_file, sys::HAS_FLAGS | sys::READ_ONLY)?;
77        } else {
78            sys::set_flags(&inner_file, 0)?;
79        }
80        // The do_it ioctl will block until device is disconnected, hence
81        // the separate thread.
82        sys::do_it(&inner_file)?;
83        let _ = sys::clear_sock(&inner_file);
84        let _ = sys::clear_queue(&inner_file);
85        Ok(())
86    }));
87    Ok(Server { do_it_thread, file })
88}
89
90/// Serve a local block device through a NBD dev file.
91pub async fn serve_local_nbd<P, B>(
92    path: P,
93    block_size: u32,
94    block_count: u64,
95    read_only: bool,
96    block_device: B,
97) -> io::Result<()>
98where
99    P: AsRef<Path>,
100    B: Unpin + BlockDevice,
101{
102    let (sock, kernel_sock) = UnixStream::pair()?;
103    let _server = attach_device(path, kernel_sock, block_size, block_count, read_only).await?;
104    serve_nbd(block_device, sock).await?;
105    Ok(())
106}
107
108struct RequestStream<C> {
109    client: Option<C>,
110    read_buf: [u8; nbd::SIZE_OF_REQUEST],
111}
112
113/// Serve a block device using a read/write client.
114pub async fn serve_nbd<B, C>(mut block_device: B, client: C) -> io::Result<()>
115where
116    B: Unpin + BlockDevice,
117    C: AsyncRead + AsyncWrite + Unpin,
118{
119    let mut stream = RequestStream {
120        client: Some(client),
121        read_buf: [0; nbd::SIZE_OF_REQUEST],
122    };
123
124    let mut reply_buf = vec![];
125    let mut write_buf = vec![];
126    while let Some(result) = stream.next().await {
127        let request = result?;
128        let request_handler = match stream.client {
129            Some(ref mut sock) => sock,
130            None => break,
131        };
132        let mut reply = nbd::Reply::from_request(&request);
133        match request.command {
134            nbd::Command::Read => {
135                reply_buf.resize(nbd::SIZE_OF_REPLY + request.len, 0);
136                if let Err(err) = block_device
137                    .read(request.from, &mut reply_buf[nbd::SIZE_OF_REPLY..])
138                    .await
139                {
140                    // On error we shall reply with error code but no payload.
141                    reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
142                    reply_buf.resize(nbd::SIZE_OF_REPLY, 0);
143                }
144                reply.write_to_slice(&mut reply_buf[..])?;
145            }
146            nbd::Command::Write => {
147                write_buf.resize(request.len, 0);
148                request_handler.read_exact(&mut write_buf).await?;
149                if let Err(err) = block_device.write(request.from, &write_buf[..]).await {
150                    reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
151                }
152                reply.append_to_vec(&mut reply_buf)?;
153            }
154            nbd::Command::Flush => {
155                if let Err(err) = block_device.flush().await {
156                    reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
157                }
158                reply.append_to_vec(&mut reply_buf)?;
159            }
160            nbd::Command::Trim => {
161                if let Err(err) = block_device.trim(request.from, request.len).await {
162                    reply.error = err.raw_os_error().unwrap_or(nix::errno::Errno::EIO as i32);
163                }
164                reply.append_to_vec(&mut reply_buf)?;
165            }
166            nbd::Command::Disc => unimplemented!(),
167            nbd::Command::WriteZeroes => unimplemented!(),
168        }
169        request_handler.write_all(&reply_buf).await?;
170        reply_buf.clear();
171    }
172    Ok(())
173}
174
175impl<C> RequestStream<C>
176where
177    C: AsyncRead + AsyncWrite + Unpin,
178{
179    fn read_next(&mut self, cx: &mut Context) -> Poll<Option<io::Result<nbd::Request>>> {
180        let client = match self.client {
181            Some(ref mut client) => client,
182            None => return Poll::Ready(None),
183        };
184        let mut read_buf = ReadBuf::new(&mut self.read_buf);
185        let rc = Pin::new(client).poll_read(cx, &mut read_buf);
186        match rc {
187            Poll::Ready(Ok(())) => {
188                if read_buf.filled().is_empty() {
189                    return Poll::Ready(None);
190                }
191                if read_buf.filled().len() != nbd::SIZE_OF_REQUEST {
192                    return Poll::Ready(Some(Err(io::Error::from(io::ErrorKind::UnexpectedEof))));
193                }
194            }
195            Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
196            Poll::Pending => {
197                return Poll::Pending;
198            }
199        };
200        Poll::Ready(Some(match nbd::Request::try_from_bytes(&self.read_buf) {
201            Ok(req) => Ok(req),
202            Err(err) => Err(err),
203        }))
204    }
205}
206
207impl<C> Stream for RequestStream<C>
208where
209    C: AsyncRead + AsyncWrite + Unpin,
210{
211    type Item = io::Result<nbd::Request>;
212    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
213        self.read_next(cx)
214    }
215}