1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*
 * This file is part of Actix Form Data.
 *
 * Copyright © 2020 Riley Trautman
 *
 * Actix Form Data is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Actix Form Data is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Actix Form Data.  If not, see <http://www.gnu.org/licenses/>.
 */

use crate::{
    types::{Form, Value},
    upload::handle_multipart,
};
use actix_web::{
    dev::{Payload, Service, ServiceRequest, Transform},
    http::StatusCode,
    FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
};
use futures_util::future::LocalBoxFuture;
use std::{
    future::{ready, Ready},
    task::{Context, Poll},
};
use tokio::sync::oneshot::{channel, Receiver};

#[derive(Debug, thiserror::Error)]
pub enum FromRequestError {
    #[error("Uploaded guard used without Multipart middleware")]
    MissingMiddleware,
    #[error("Impossible Error! Middleware exists, didn't fail, and didn't send value")]
    TxDropped,
}

impl ResponseError for FromRequestError {
    fn status_code(&self) -> StatusCode {
        match self {
            Self::MissingMiddleware | Self::TxDropped => StatusCode::INTERNAL_SERVER_ERROR,
        }
    }

    fn error_response(&self) -> HttpResponse {
        match self {
            Self::MissingMiddleware | Self::TxDropped => {
                HttpResponse::InternalServerError().finish()
            }
        }
    }
}

struct Uploaded<T> {
    rx: Receiver<Value<T>>,
}

pub struct MultipartMiddleware<S, T, E> {
    form: Form<T, E>,
    service: S,
}

impl<T> FromRequest for Value<T>
where
    T: 'static,
{
    type Error = FromRequestError;
    type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;

    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
        let opt = req.extensions_mut().remove::<Uploaded<T>>();
        Box::pin(async move {
            let fut = opt.ok_or(FromRequestError::MissingMiddleware)?;

            fut.rx.await.map_err(|_| FromRequestError::TxDropped)
        })
    }
}

impl<S, T, E> Transform<S, ServiceRequest> for Form<T, E>
where
    S: Service<ServiceRequest, Error = actix_web::Error>,
    S::Future: 'static,
    T: 'static,
    E: ResponseError + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type InitError = ();
    type Transform = MultipartMiddleware<S, T, E>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(MultipartMiddleware {
            form: self.clone(),
            service,
        }))
    }
}

impl<S, T, E> Service<ServiceRequest> for MultipartMiddleware<S, T, E>
where
    S: Service<ServiceRequest, Error = actix_web::Error>,
    S::Future: 'static,
    T: 'static,
    E: ResponseError + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = LocalBoxFuture<'static, Result<S::Response, S::Error>>;

    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(cx)
    }

    fn call(&self, mut req: ServiceRequest) -> Self::Future {
        let (tx, rx) = channel();
        req.extensions_mut().insert(Uploaded { rx });
        let payload = req.take_payload();
        let multipart = actix_multipart::Multipart::new(req.headers(), payload);
        let form = self.form.clone();
        let fut = self.service.call(req);

        Box::pin(async move {
            let uploaded = match handle_multipart(multipart, form.clone()).await {
                Ok(Ok(uploaded)) => uploaded,
                Ok(Err(e)) => return Err(e.into()),
                Err(e) => {
                    if let Some(f) = form.transform_error.clone() {
                        return Err((f)(e));
                    } else {
                        return Err(e.into());
                    }
                }
            };
            let _ = tx.send(uploaded);
            fut.await
        })
    }
}