async_tftp/server/handlers/
dir.rs

1use blocking::{unblock, Unblock};
2use log::trace;
3use std::fs::{self, File};
4use std::io;
5use std::net::SocketAddr;
6use std::path::Component;
7use std::path::{Path, PathBuf};
8
9use crate::error::{Error, Result};
10use crate::packet;
11
12/// Handler that serves read requests for a directory.
13pub struct DirHandler {
14    dir: PathBuf,
15    serve_rrq: bool,
16    serve_wrq: bool,
17}
18
19pub enum DirHandlerMode {
20    /// Serve only read requests.
21    ReadOnly,
22    /// Serve only write requests.
23    WriteOnly,
24    /// Server read and write requests.
25    ReadWrite,
26}
27
28impl DirHandler {
29    /// Create new handler for directory.
30    pub fn new<P>(dir: P, flags: DirHandlerMode) -> Result<Self>
31    where
32        P: AsRef<Path>,
33    {
34        let dir = fs::canonicalize(dir.as_ref())?;
35
36        if !dir.is_dir() {
37            return Err(Error::NotDir(dir));
38        }
39
40        trace!("TFTP directory: {}", dir.display());
41
42        let serve_rrq = match flags {
43            DirHandlerMode::ReadOnly => true,
44            DirHandlerMode::WriteOnly => false,
45            DirHandlerMode::ReadWrite => true,
46        };
47
48        let serve_wrq = match flags {
49            DirHandlerMode::ReadOnly => false,
50            DirHandlerMode::WriteOnly => true,
51            DirHandlerMode::ReadWrite => true,
52        };
53
54        Ok(DirHandler {
55            dir,
56            serve_rrq,
57            serve_wrq,
58        })
59    }
60}
61
62#[crate::async_trait]
63impl crate::server::Handler for DirHandler {
64    type Reader = Unblock<File>;
65    type Writer = Unblock<File>;
66
67    async fn read_req_open(
68        &mut self,
69        _client: &SocketAddr,
70        path: &Path,
71    ) -> Result<(Self::Reader, Option<u64>), packet::Error> {
72        if !self.serve_rrq {
73            return Err(packet::Error::IllegalOperation);
74        }
75
76        let path = secure_path(&self.dir, path)?;
77
78        // Send only regular files
79        if !path.is_file() {
80            return Err(packet::Error::FileNotFound);
81        }
82
83        let path_clone = path.clone();
84        let (file, len) = unblock(move || open_file_ro(path_clone)).await?;
85        let reader = Unblock::new(file);
86
87        trace!("TFTP sending file: {}", path.display());
88
89        Ok((reader, len))
90    }
91
92    async fn write_req_open(
93        &mut self,
94        _client: &SocketAddr,
95        path: &Path,
96        size: Option<u64>,
97    ) -> Result<Self::Writer, packet::Error> {
98        if !self.serve_wrq {
99            return Err(packet::Error::IllegalOperation);
100        }
101
102        let path = secure_path(&self.dir, path)?;
103
104        let path_clone = path.clone();
105        let file = unblock(move || open_file_wo(path_clone, size)).await?;
106        let writer = Unblock::new(file);
107
108        trace!("TFTP receiving file: {}", path.display());
109
110        Ok(writer)
111    }
112}
113
114fn secure_path(
115    restricted_dir: &Path,
116    path: &Path,
117) -> Result<PathBuf, packet::Error> {
118    // Strip `/` and `./` prefixes
119    let path = path
120        .strip_prefix("/")
121        .or_else(|_| path.strip_prefix("./"))
122        .unwrap_or(path);
123
124    // Avoid directory traversal attack by filtering `../`.
125    if path.components().any(|x| x == Component::ParentDir) {
126        return Err(packet::Error::PermissionDenied);
127    }
128
129    // Path should not start from root dir or have any Windows prefixes.
130    // i.e. We accept only normal path components.
131    match path.components().next() {
132        Some(Component::Normal(_)) => {}
133        _ => return Err(packet::Error::PermissionDenied),
134    }
135
136    Ok(restricted_dir.join(path))
137}
138
139fn open_file_ro(path: PathBuf) -> io::Result<(File, Option<u64>)> {
140    let file = File::open(path)?;
141    let len = file.metadata().ok().map(|m| m.len());
142    Ok((file, len))
143}
144
145fn open_file_wo(path: PathBuf, size: Option<u64>) -> io::Result<File> {
146    let file = File::create(path)?;
147
148    if let Some(size) = size {
149        file.set_len(size)?;
150    }
151
152    Ok(file)
153}