rocket_cgi/
lib.rs

1#![deny(missing_docs)]
2//! Implement CGI directory handler for Rocket
3
4mod os;
5
6use std::{
7    borrow::Cow,
8    collections::HashMap,
9    future::{ready, Future},
10    io::{self, Error, ErrorKind, Read},
11    path::{Path, PathBuf},
12    pin::Pin,
13    process::Stdio,
14    task::Poll,
15};
16
17use bitfield::bitfield;
18use os::{allowed, has_dot_file, has_setuid, is_writable};
19use rocket::{
20    data::ToByteUnit,
21    http::{uncased, ContentType, Method},
22    log::*,
23    request::Request,
24    response::{Redirect, Responder},
25    route::Outcome,
26};
27use rocket::{http::Status, Data};
28use rocket::{response::Response, route::Handler, Route};
29use tokio::{
30    io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader},
31    process::{Child, Command},
32};
33
34const PATH_DEF: &str = "/<path..>?<..>";
35
36bitfield! {
37    #[derive(Clone, Copy, PartialEq, Eq, Hash)]
38    struct CGISettings(u64);
39    impl Debug;
40    unencoded_equals, set_unencoded_equals: 0;
41    dot_files, set_dot_files: 1;
42    hidden_files, set_hidden_files: 2;
43    setuid, set_setuid: 3;
44    direct_executable, set_direct_executable: 4;
45    writable_files, set_writable_files: 5;
46    allow_post, set_allow_post: 6;
47    allow_get, set_allow_get: 7;
48    ensure_newline, set_ensure_newline: 8;
49}
50
51/// Custom handler to execute CGIScripts
52///
53/// This handler will execute any script within the directory provided.
54/// See examples/cgi.rs for a full usage example
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct CGIDir {
57    path: PathBuf,
58    settings: CGISettings,
59    file_types: HashMap<Cow<'static, str>, Cow<'static, Path>>,
60}
61
62impl CGIDir {
63    /// Generate a CGI script from the associated path
64    ///
65    /// ```rust
66    /// # use rocket::build;
67    /// # use rocket_cgi::CGIDir;
68    /// build().mount("/", CGIDir::new("examples"))
69    /// # ;
70    /// ```
71    pub fn new(path: impl AsRef<Path>) -> Self {
72        let mut settings = CGISettings(0);
73        settings.set_unencoded_equals(false);
74        settings.set_dot_files(false);
75        settings.set_hidden_files(false);
76        settings.set_setuid(false);
77        settings.set_direct_executable(true);
78        settings.set_writable_files(true);
79        settings.set_allow_get(true);
80        settings.set_allow_post(true);
81        settings.set_ensure_newline(false);
82        Self {
83            path: std::fs::canonicalize(path).expect("Path does not exist"),
84            settings,
85            file_types: [("pl", "perl"), ("py", "python"), ("sh", "sh")]
86                .iter()
87                .map(|&(a, b)| (a.into(), Path::new(b).into()))
88                .collect(),
89        }
90    }
91
92    /// Clear file type associations, and disables directly running executables
93    ///
94    /// ```rust
95    /// # use rocket::{build, http::Status};
96    /// # use rocket::local::blocking::Client;
97    /// # use rocket_cgi::CGIDir;
98    /// let rocket = build().mount("/", CGIDir::new("test").clear_file_types());
99    /// let client = Client::tracked(rocket).unwrap();
100    /// let res = client.get("/simple.sh").dispatch();
101    /// assert_eq!(res.status(), Status::InternalServerError);
102    /// // Since the file could not be executed, a 500 error is returned
103    /// ```
104    pub fn clear_file_types(mut self) -> Self {
105        self.file_types.clear();
106        self.settings.set_direct_executable(false);
107        self
108    }
109
110    /// Add a file type association for executing a file. Overrides an existing file type
111    /// association if one exists.
112    ///
113    /// ```rust
114    /// # use rocket::{build, http::Status};
115    /// # use rocket::local::blocking::Client;
116    /// # use rocket_cgi::CGIDir;
117    /// # use std::path::Path;
118    /// let rocket = build().mount("/",
119    ///     CGIDir::new("test")
120    ///         .clear_file_types()// Clear file types
121    ///         .set_file_type("sh", Path::new("sh"))// manually insert `sh`
122    ///     );
123    /// let client = Client::tracked(rocket).unwrap();
124    /// let res = client.get("/simple.sh").dispatch();
125    /// assert_eq!(res.status(), Status::Ok);
126    /// ```
127    pub fn set_file_type(
128        mut self,
129        extension: impl Into<Cow<'static, str>>,
130        executable: impl Into<Cow<'static, Path>>,
131    ) -> Self {
132        self.file_types.insert(extension.into(), executable.into());
133        self
134    }
135
136    /// Only allow executing perl scripts. Disables all filetypes except `.pl`
137    ///
138    /// ```rust
139    /// # use rocket::{build, http::Status};
140    /// # use rocket::local::blocking::Client;
141    /// # use rocket_cgi::CGIDir;
142    /// let rocket = build().mount("/", CGIDir::new("test").only_perl());
143    /// let client = Client::tracked(rocket).unwrap();
144    /// let res = client.get("/simple.pl").dispatch();
145    /// assert_eq!(res.status(), Status::Ok);
146    /// let res = client.get("/simple.sh").dispatch();
147    /// assert_eq!(res.status(), Status::InternalServerError);
148    /// ```
149    pub fn only_perl(mut self) -> Self {
150        self.file_types.retain(|s, _| s == "pl");
151        self.settings.set_direct_executable(false);
152        self
153    }
154
155    /// Only allow executing python scripts. Disables all filetypes except `.py`
156    ///
157    /// ```rust
158    /// # use rocket::{build, http::Status};
159    /// # use rocket::local::blocking::Client;
160    /// # use rocket_cgi::CGIDir;
161    /// let rocket = build().mount("/", CGIDir::new("test").only_python().detect_python3());
162    /// let client = Client::tracked(rocket).unwrap();
163    /// let res = client.get("/simple.py").dispatch();
164    /// assert_eq!(res.status(), Status::Ok);
165    /// let res = client.get("/simple.sh").dispatch();
166    /// assert_eq!(res.status(), Status::InternalServerError);
167    /// ```
168    pub fn only_python(mut self) -> Self {
169        self.file_types.retain(|s, _| s == "py");
170        self.settings.set_direct_executable(false);
171        self
172    }
173
174    /// Automatically detect python executables. This should allow either `python` or `python3` to
175    /// be present on the system
176    ///
177    /// ```rust
178    /// # use rocket::{build, http::Status};
179    /// # use rocket::local::blocking::Client;
180    /// # use rocket_cgi::CGIDir;
181    /// let rocket = build().mount("/", CGIDir::new("test").detect_python3());
182    /// let client = Client::tracked(rocket).unwrap();
183    /// let res = client.get("/simple.py").dispatch();
184    /// assert_eq!(res.status(), Status::Ok);
185    /// ```
186    ///
187    /// # Panics
188    ///
189    /// If python cannot be found on the current Path. If a version of python not on the path is
190    /// desired, it is recommended to explicitly set the path e.g.
191    /// `.set_file_type("py", Path::new("/opt/py/bin/python"))`
192    pub fn detect_python3(self) -> Self {
193        use std::process::Command;
194        match Command::new("python3").arg("-V").spawn() {
195            Ok(_) => return self.set_file_type("py", Path::new("python3")),
196            _ => (),
197        }
198        match Command::new("python").arg("-V").spawn() {
199            Ok(c) => {
200                let mut s = String::new();
201                let _ = c.stdout.unwrap().read_to_string(&mut s);
202                if s.starts_with("Python 3") {
203                    return self.set_file_type("py", Path::new("python"));
204                }
205            }
206            _ => (),
207        }
208        panic!("Python 3 not found")
209    }
210
211    /// Automatically detect python executables. This should allow either `python` or `python2` to
212    /// be present on the system
213    ///
214    /// ```rust
215    /// # use rocket::{build, http::Status};
216    /// # use rocket::local::blocking::Client;
217    /// # use rocket_cgi::CGIDir;
218    /// let rocket = build().mount("/", CGIDir::new("test").detect_python2());
219    /// let client = Client::tracked(rocket).unwrap();
220    /// let res = client.get("/simple.py").dispatch();
221    /// assert_eq!(res.status(), Status::Ok);
222    /// ```
223    ///
224    /// # Panics
225    ///
226    /// If python cannot be found on the current Path. If a version of python not on the path is
227    /// desired, it is recommended to explicitly set the path e.g.
228    /// `.set_file_type("py", Path::new("/opt/py/bin/python"))`
229    pub fn detect_python2(self) -> Self {
230        use std::process::Command;
231        match Command::new("python2").arg("-V").spawn() {
232            Ok(_) => return self.set_file_type("py", Path::new("python3")),
233            _ => (),
234        }
235        match Command::new("python").arg("-V").spawn() {
236            Ok(c) => {
237                let mut s = String::new();
238                let _ = c.stdout.unwrap().read_to_string(&mut s);
239                if s.starts_with("Python 2") {
240                    return self.set_file_type("py", Path::new("python"));
241                }
242            }
243            _ => (),
244        }
245        panic!("Python 2 not found")
246    }
247
248    /// Only allow executing python scripts. Disables all filetypes except `.sh`
249    ///
250    /// ```rust
251    /// # use rocket::{build, http::Status};
252    /// # use rocket::local::blocking::Client;
253    /// # use rocket_cgi::CGIDir;
254    /// let rocket = build().mount("/", CGIDir::new("test").only_sh());
255    /// let client = Client::tracked(rocket).unwrap();
256    /// let res = client.get("/simple.sh").dispatch();
257    /// assert_eq!(res.status(), Status::Ok);
258    /// let res = client.get("/simple.py").dispatch();
259    /// assert_eq!(res.status(), Status::InternalServerError);
260    /// ```
261    pub fn only_sh(mut self) -> Self {
262        self.file_types.retain(|s, _| s == "sh");
263        self.settings.set_direct_executable(false);
264        self
265    }
266
267    /// Sets the shell interpreter. Implicitly enables `.sh` files if they are currently disabled
268    ///
269    /// Default is `sh`
270    ///
271    /// ```rust
272    /// # use rocket::{build, http::Status};
273    /// # use rocket::local::blocking::Client;
274    /// # use rocket_cgi::CGIDir;
275    /// # use std::path::Path;
276    /// let rocket = build().mount("/", CGIDir::new("test").shell_interpreter(Path::new("bash")));
277    /// let client = Client::tracked(rocket).unwrap();
278    /// let res = client.get("/shell.sh").dispatch();
279    /// assert_eq!(res.status(), Status::Ok);
280    /// assert_eq!(res.into_string().unwrap(), "bash\n");
281    /// ```
282    pub fn shell_interpreter(mut self, executable: impl Into<Cow<'static, Path>>) -> Self {
283        self.file_types.insert("sh".into(), executable.into());
284        self
285    }
286
287    /// Adds default Windows Shell Script types:
288    /// - cmd.exe: .cmd, .bat
289    /// - powershell.exe: .ps1
290    /// - cscript.exe: .wsf, .vbs, .js
291    ///
292    /// ```rust
293    /// # use rocket::{build, http::Status};
294    /// # use rocket::local::blocking::Client;
295    /// # use rocket_cgi::CGIDir;
296    /// # use std::path::Path;
297    /// # #[cfg(windows)]
298    /// # fn main() {
299    /// let rocket = build().mount("/", CGIDir::new("test").add_windows_scripts());
300    /// let client = Client::tracked(rocket).unwrap();
301    /// let res = client.get("/simple.cmd").dispatch();
302    /// assert_eq!(res.status(), Status::Ok);
303    /// # }
304    /// # // Empty main to allow testing on non-windows platforms
305    /// # #[cfg(not(windows))] fn main() {}
306    /// ```
307    pub fn add_windows_scripts(mut self) -> Self {
308        [
309            ("cmd", "cmd.exe"),
310            ("bat", "cmd.exe"),
311            ("ps1", "powershell.exe"),
312            ("wsf", "csript.exe"),
313            ("vbs", "csript.exe"),
314            ("js", "csript.exe"), // ?
315        ]
316        .iter()
317        .for_each(|&(a, b)| {
318            self.file_types.insert(a.into(), Path::new(b).into());
319        });
320        self
321    }
322
323    /// Whether to allow directly executable files. This may allow scripts with execute
324    /// permissions and a shebang (`#!`) to be executed, on some systems.
325    ///
326    /// Defaults to true
327    pub fn direct_executables(mut self, allow: bool) -> Self {
328        self.settings.set_direct_executable(allow);
329        self
330    }
331
332    /// Whether to pass parameters that contain unencoded `=`
333    ///
334    /// The CGI spec requires this to be false, which is the default
335    pub fn unencoded_equals(mut self, allow: bool) -> Self {
336        self.settings.set_unencoded_equals(allow);
337        self
338    }
339
340    /// Whether to allow serving unix hidden files (files starting with a `.`)
341    ///
342    /// Defaults to false
343    ///
344    /// ```rust
345    /// # use rocket::{build, http::Status};
346    /// # use rocket::local::blocking::Client;
347    /// # use rocket_cgi::CGIDir;
348    /// let rocket = build().mount("/", CGIDir::new("test").dot_files(false));
349    /// let client = Client::tracked(rocket).unwrap();
350    /// let res = client.get("/.simple.sh").dispatch();
351    /// assert_eq!(res.status(), Status::NotFound);
352    /// ```
353    pub fn dot_files(mut self, allow: bool) -> Self {
354        self.settings.set_dot_files(allow);
355        self
356    }
357
358    /// Whether to allow serving hidden files
359    ///
360    /// Defaults to false, only applies to Windows
361    pub fn hidden_files(mut self, allow: bool) -> Self {
362        self.settings.set_hidden_files(allow);
363        self
364    }
365
366    // This is commented out, since the ability to detect write permissions are not currently
367    // complete. This should likely be checked by attempting to open the file for writing (although
368    // this may not be sufficient).
369    // /// Whether to allow serving writable files
370    // ///
371    // /// Defaults to true
372    // pub fn writable_files(mut self, allow: bool) -> Self {
373    //     self.settings.set_writable_files(allow);
374    //     self
375    // }
376
377    /// Whether to allow serving files with setuid & setgid bits set
378    ///
379    /// Defaults to false, only has an effect on Unix systems. Note this does not prevent a script
380    /// from executing a setuid bit binary, but rather only protects against Rocket starting a
381    /// setuid binary
382    pub fn setuid_files(mut self, allow: bool) -> Self {
383        self.settings.set_setuid(allow);
384        self
385    }
386
387    /// Whether to serve GET & HEAD requests
388    ///
389    /// Defaults to true
390    ///
391    /// ```rust
392    /// # use rocket::{build, http::Status};
393    /// # use rocket::local::blocking::Client;
394    /// # use rocket_cgi::CGIDir;
395    /// let rocket = build().mount("/", CGIDir::new("test").serve_get(false));
396    /// let client = Client::tracked(rocket).unwrap();
397    /// let res = client.get("/.simple.sh").dispatch();
398    /// assert_eq!(res.status(), Status::NotFound);
399    /// ```
400    pub fn serve_get(mut self, allow: bool) -> Self {
401        self.settings.set_allow_get(allow);
402        self
403    }
404
405    /// Whether to serve POST requests
406    ///
407    /// Defaults to true
408    ///
409    /// ```rust
410    /// # use rocket::{build, http::Status};
411    /// # use rocket::local::blocking::Client;
412    /// # use rocket_cgi::CGIDir;
413    /// let rocket = build().mount("/", CGIDir::new("test").serve_post(false));
414    /// let client = Client::tracked(rocket).unwrap();
415    /// let res = client.post("/.simple.sh").dispatch();
416    /// assert_eq!(res.status(), Status::NotFound);
417    /// ```
418    pub fn serve_post(mut self, allow: bool) -> Self {
419        self.settings.set_allow_post(allow);
420        self
421    }
422
423    async fn locate_file<'r>(&self, r: &'r Request<'_>) -> io::Result<Child> {
424        let mut path = self.path.to_path_buf();
425        let prefix = r.route().unwrap().uri.as_str().trim_end_matches(PATH_DEF);
426        let uri_path = r.uri().path();
427        let decoded = uri_path
428            .strip_prefix(prefix)
429            .unwrap_or(&uri_path) // This shouldn't happen, since the URL matched
430            .percent_decode()
431            .map_err(|_| Error::new(ErrorKind::InvalidInput, "URL is not valid UTF-8"))?;
432        let trailing = decoded.trim_start_matches(|c| c == '/' || c == '\\');
433        // Final trim allows repeated `/`s in the file while
434        let trailing_path = Path::new(trailing);
435
436        if !self.settings.dot_files() && has_dot_file(trailing_path) {
437            return Err(io::Error::new(
438                ErrorKind::NotFound,
439                "Hidden files not permitted",
440            ));
441        }
442
443        if !trailing_path.is_relative() {
444            return Err(Error::new(
445                ErrorKind::InvalidInput,
446                "Absolute paths not permitted",
447            ));
448        }
449        path.push(trailing_path);
450        // Sadly this allocates, but I don't think there's a way arount it
451        let path = tokio::fs::canonicalize(path).await?;
452        if !path.starts_with(&self.path) {
453            // error_!("Path: {}", path.display());
454            return Err(Error::new(
455                ErrorKind::NotFound,
456                "Files outside directory not permitted",
457            ));
458        }
459
460        debug_!("Path: {}", path.display());
461        let meta = tokio::fs::metadata(&path).await?;
462        debug_!("meta: {:?}", meta);
463        if !self.settings.setuid() && has_setuid(&meta) {
464            return Err(io::Error::new(
465                ErrorKind::NotFound,
466                "Setuid files not permitted",
467            ));
468        }
469
470        if !self.settings.writable_files() && is_writable(&meta) {
471            return Err(io::Error::new(
472                ErrorKind::Other,
473                "Writable files not permitted",
474            ));
475        }
476
477        if !allowed(&meta) {
478            return Err(io::Error::new(ErrorKind::Other, "File not permitted"));
479        }
480
481        if meta.is_dir() {
482            // path.push("index.pl"); // ?
483            return Err(io::Error::new(
484                ErrorKind::Other,
485                "Directories not supported",
486            ));
487        }
488
489        self.build_process(path, trailing, r)
490    }
491
492    fn build_process(&self, path: PathBuf, name: &str, r: &Request<'_>) -> io::Result<Child> {
493        let mut builder = if let Some(exe) = path
494            .extension()
495            .and_then(|e| e.to_str())
496            .and_then(|e| self.file_types.get(e))
497        {
498            let mut ret = Command::new(exe.as_os_str());
499            ret.arg(path);
500            ret
501        } else if self.settings.direct_executable() {
502            Command::new(path.as_os_str())
503        } else {
504            return Err(io::Error::new(
505                ErrorKind::Other,
506                "Direct executables not permitted",
507            ));
508        };
509        builder.env_clear();
510
511        if let Some(query) = r.uri().query() {
512            builder.env("QUERY_STRING", query.as_str());
513            if self.settings.unencoded_equals() || !query.as_str().contains('=') {
514                for part in query.split('+') {
515                    if let Ok(decoded) = part.url_decode() {
516                        builder.arg(decoded.as_ref());
517                    }
518                }
519            }
520        }
521        builder.env("AUTH_TYPE", "");
522        // We allow this to be empty (e.g. Transfer-Encoding: chunked), and don't set it if we
523        // don't know. The Spec technically requires it to be set, but we ignore that
524        if let Some(len) = r.headers().get_one("Content-Length") {
525            builder.env("CONTENT_LENGTH", len);
526        }
527        if let Some(ty) = r.content_type() {
528            builder.env("CONTENT_TYPE", ty.to_string());
529        }
530        builder.env("GATEWAY_INTERFACE", "CGI/1.1");
531
532        // We don't support sub-resources
533        // builder.env("PATH_INFO", "");
534        // builder.env("PATH_TRANSLATED", "");
535
536        if let Some(ip) = r.remote() {
537            builder.env("REMOTE_ADDR", format!("{ip}"));
538        }
539        if let Some(host) = r.host() {
540            builder.env("REMOTE_HOST", format!("{host}"));
541        }
542        builder.env("REQUEST_METHOD", r.method().as_str());
543        builder.env("SCRIPT_NAME", name);
544        builder.env("SERVER_NAME", r.rocket().config().address.to_string());
545        builder.env("SERVER_PORT", r.rocket().config().port.to_string());
546        builder.env("SERVER_PROTOCOL", "HTTP/1.1");
547        builder.env("SERVER_SOFTWARE", r.rocket().config().ident.to_string());
548
549        builder.stdin(Stdio::piped());
550        builder.stdout(Stdio::piped());
551        builder.kill_on_drop(true);
552
553        info_!("Command: {:?}", builder);
554        builder.spawn()
555    }
556
557    async fn generate_response<'r>(
558        mut process: Child,
559        request: &'r Request<'_>,
560    ) -> (Outcome<'r>, impl AsyncRead + 'static, Child) {
561        let mut stdout = BufReader::new(process.stdout.take().unwrap());
562        let mut res = Response::new();
563        let mut buf = String::new();
564
565        res.set_status(Status::Ok);
566
567        let mut has_body = true;
568        loop {
569            match stdout.read_line(&mut buf).await {
570                Ok(_) => (),
571                Err(_) => {
572                    return (
573                        Outcome::Failure(Status::InternalServerError),
574                        stdout,
575                        process,
576                    )
577                }
578            }
579            let line = buf.trim();
580            if line == "" {
581                break;
582            }
583            if let Some((key, value)) = line.split_once(':') {
584                let key = key.trim();
585                let value = value.trim();
586                if uncased::eq(key, "Content-Type") {
587                    if let Some(content_type) = ContentType::parse_flexible(value) {
588                        res.set_header(content_type);
589                    }
590                } else if uncased::eq(key, "Location") {
591                    if value.starts_with("/") {
592                        error_!("`local-Location` is not supported");
593                        return (
594                            Outcome::Failure(Status::InternalServerError),
595                            stdout,
596                            process,
597                        );
598                    } else {
599                        has_body = false;
600                        match Redirect::to(value.to_owned()).respond_to(request) {
601                            Ok(r) => res.merge(r),
602                            Err(e) => {
603                                res.set_status(e);
604                                let _ = process.kill().await;
605                                return (Outcome::Success(res), stdout, process);
606                            }
607                        }
608                    }
609                } else if uncased::eq(key, "Status") {
610                    if let Ok(code) = value
611                        .split_once(char::is_whitespace)
612                        .map_or(value, |(n, _)| n)
613                        .parse::<u16>()
614                    {
615                        res.set_status(Status { code });
616                    }
617                } else {
618                    error_!("Extension header `{key}` is not supported");
619                    // Unknown headers are ignored, and not sent to the client
620                }
621            }
622            buf.clear();
623        }
624        if !has_body {
625            let _ = process.kill().await;
626        }
627        // res.set_streamed_body(stdout);
628        return (Outcome::Success(res), stdout, process);
629    }
630}
631
632struct Paired<R, F> {
633    child: Option<Child>,
634    reader: R,
635    future: Option<F>,
636}
637
638impl<R, F> Paired<R, F> {
639    fn new(child: Child, reader: R, future: F) -> Self {
640        Self {
641            child: Some(child),
642            reader,
643            future: Some(future),
644        }
645    }
646}
647
648impl<R: AsyncRead, F: Future> AsyncRead for Paired<R, F> {
649    fn poll_read(
650        self: std::pin::Pin<&mut Self>,
651        cx: &mut std::task::Context<'_>,
652        buf: &mut tokio::io::ReadBuf<'_>,
653    ) -> std::task::Poll<io::Result<()>> {
654        // SAFETY: We immediately repin interior references, except self.child
655        let this = unsafe { self.get_unchecked_mut() };
656        if let Some(future) = this.future.as_mut() {
657            // This is a repin, and therefore meets the requirements of pinning
658            match unsafe { Pin::new_unchecked(future) }.poll(cx) {
659                Poll::Pending => (),
660                Poll::Ready(_) => this.future = None,
661            }
662        }
663        // This is a repin, and therefore meets the requirements of pinning
664        match unsafe { Pin::new_unchecked(&mut this.reader) }.poll_read(cx, buf) {
665            Poll::Pending => return Poll::Pending,
666            Poll::Ready(res) => {
667                // Pinning the child is not structural
668                if let Some(mut child) = this.child.take() {
669                    tokio::spawn(async move { child.kill().await });
670                }
671                return Poll::Ready(res);
672            }
673        }
674    }
675}
676
677#[rocket::async_trait]
678impl Handler for CGIDir {
679    async fn handle<'r>(&self, request: &'r Request<'_>, data: Data<'r>) -> Outcome<'r> {
680        let mut process = match self.locate_file(request).await {
681            Ok(p) => p,
682            Err(e) if e.kind() == ErrorKind::NotFound => return Outcome::Forward(data),
683            Err(e) => {
684                error_!("Error: {e}");
685                return Outcome::Failure(Status::InternalServerError);
686            }
687        };
688        let mut body = process.stdin.take().unwrap();
689
690        let limit = request
691            .rocket()
692            .config()
693            .limits
694            .find(["cgi"])
695            .unwrap_or(1.mebibytes());
696
697        // Not ideal to box this, but we need to move it later, so...
698        let generate_response = Self::generate_response(process, request);
699
700        if request.method() == Method::Head {
701            drop(body);
702            let (res, _, mut process) = generate_response.await;
703            let _ = process.kill().await;
704            res
705        } else if request.method() == Method::Get {
706            drop(body);
707            let (res, stdout, process) = generate_response.await;
708            res.map(|mut res| {
709                res.set_streamed_body(Paired::new(process, stdout, std::future::ready(())));
710                res
711            })
712        } else if request.method() == Method::Post {
713            let ensure_newline = self.settings.ensure_newline();
714            let mut write_post_data = Box::pin(async move {
715                let _ = tokio::io::copy(&mut data.open(limit), &mut body).await;
716                if ensure_newline {
717                    let _ = body.write_all(b"\n").await;
718                }
719            });
720            tokio::pin!(generate_response);
721            tokio::select! {
722                biased;// Ideally we want to take the first path, so we should always try the first one
723                _ = &mut write_post_data => {
724                    let (res, stdout, process) = generate_response.await;
725                    res.map(|mut res| { res.set_streamed_body(Paired::new(process, stdout, ready(()))); res })
726                },
727                res = &mut generate_response => {
728                    let (res, stdout, process) = res;
729                    res.map(|mut res| { res.set_streamed_body(Paired::new(process, stdout, write_post_data)); res })
730                }
731            }
732        } else {
733            unreachable!("Only Get, Head & Post supported")
734        }
735    }
736}
737
738impl Into<Vec<Route>> for CGIDir {
739    fn into(self) -> Vec<Route> {
740        let mut ret = Vec::with_capacity(3);
741        if self.settings.allow_get() {
742            ret.push(Route::ranked(9, Method::Get, PATH_DEF, self.clone()));
743            ret.push(Route::ranked(9, Method::Head, PATH_DEF, self.clone()));
744        }
745        if self.settings.allow_post() {
746            ret.push(Route::ranked(9, Method::Post, PATH_DEF, self.clone()));
747        }
748        ret
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use rocket::local::asynchronous::Client;
755
756    use super::*;
757
758    async fn generate_client() -> Client {
759        let dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
760        let rocket = rocket::build().mount("/", CGIDir::new(format!("{dir}/test")));
761        Client::tracked(rocket).await.unwrap()
762    }
763
764    #[rocket::async_test]
765    async fn simple_script() {
766        let client = generate_client().await;
767        let res = client.get("/simple.sh").dispatch().await;
768        assert_eq!(res.status(), Status::Ok);
769        assert_eq!(res.content_type(), Some(ContentType::Text));
770        assert_eq!(res.into_string().await.unwrap(), "simple output\n");
771    }
772
773    #[rocket::async_test]
774    async fn redirect() {
775        let client = generate_client().await;
776        let res = client.get("/redirect.sh").dispatch().await;
777        assert_eq!(res.status(), Status::SeeOther);
778        assert_eq!(
779            res.headers().get_one("Location").unwrap(),
780            "http://localhost:8000/simple.sh"
781        );
782    }
783
784    #[rocket::async_test]
785    async fn params() {
786        let client = generate_client().await;
787        let res = client.get("/params.sh?world").dispatch().await;
788        assert_eq!(res.status(), Status::Ok);
789        assert_eq!(res.content_type(), Some(ContentType::Text));
790        assert_eq!(res.into_string().await.unwrap(), "Hello 'world'!\n");
791
792        // Unencoded equals
793        let res = client.get("/params.sh?world=hello").dispatch().await;
794        assert_eq!(res.status(), Status::Ok);
795        assert_eq!(res.content_type(), Some(ContentType::Text));
796        assert_eq!(res.into_string().await.unwrap(), "Hello ''!\n");
797
798        // Encoded equals
799        let res = client.get("/params.sh?world%3dhello").dispatch().await;
800        assert_eq!(res.status(), Status::Ok);
801        assert_eq!(res.content_type(), Some(ContentType::Text));
802        assert_eq!(res.into_string().await.unwrap(), "Hello 'world=hello'!\n");
803    }
804
805    #[rocket::async_test]
806    async fn env_vars() {
807        let client = generate_client().await;
808        macro_rules! var {
809            ($var:literal, $val:literal) => {{
810                let res = client.get(concat!("/env_vars.sh?", $var)).dispatch().await;
811                assert_eq!(res.status(), Status::Ok);
812                assert_eq!(res.content_type(), Some(ContentType::Text));
813                assert_eq!(res.into_string().await.unwrap().trim(), $val);
814            }};
815        }
816
817        var!("AUTH_TYPE", "");
818        var!("CONTENT_LENGTH", "");
819        var!("CONTENT_TYPE", "");
820        var!("GATEWAY_INTERFACE", "CGI/1.1");
821        var!("PATH_INFO", "");
822        var!("PATH_TRANSLATED", "");
823        var!("QUERY_STRING", "QUERY_STRING");
824        var!("REMOTE_IDENT", "");
825        var!("REMOTE_USER", "");
826        var!("REQUEST_METHOD", "GET");
827        var!("SCRIPT_NAME", "env_vars.sh");
828        var!("SERVER_NAME", "127.0.0.1");
829        var!("SERVER_PORT", "8000");
830        var!("SERVER_PROTOCOL", "HTTP/1.1");
831        var!("SERVER_SOFTWARE", "Rocket");
832    }
833
834    #[rocket::async_test]
835    async fn post_body() {
836        let client = generate_client().await;
837        let res = client.post("/post.sh").body("something").dispatch().await;
838        assert_eq!(res.status(), Status::Ok);
839        assert_eq!(res.content_type(), Some(ContentType::Text));
840        assert_eq!(res.into_string().await.unwrap(), "val: something\n");
841    }
842}