rspamd_client/backend/
async_client.rs

1use crate::backend::traits::*;
2use crate::config::{Config, EnvelopeData};
3use crate::error::RspamdError;
4use crate::protocol::commands::{RspamdCommand, RspamdEndpoint};
5use crate::protocol::encryption::{httpcrypt_decrypt, httpcrypt_encrypt, make_key_header};
6use crate::protocol::RspamdScanReply;
7use bytes::{Bytes, BytesMut};
8use reqwest::header::{HeaderName, HeaderValue};
9use reqwest::Client;
10use std::collections::HashMap;
11use std::str::FromStr;
12use std::time::Duration;
13use url::Url;
14use zstd::zstd_safe::WriteBuf;
15
16pub struct AsyncClient<'a> {
17    config: &'a Config,
18    inner: Client,
19}
20
21#[cfg(feature = "async")]
22pub fn async_client(options: &Config) -> Result<AsyncClient<'_>, RspamdError> {
23    let client = Client::builder().timeout(Duration::from_secs_f64(options.timeout));
24
25    let client = if let Some(ref proxy) = options.proxy_config {
26        let proxy = reqwest::Proxy::all(proxy.proxy_url.clone())
27            .map_err(|e| RspamdError::HttpError(e.to_string()))?;
28        client.proxy(proxy)
29    } else {
30        client
31    };
32    let client = if let Some(ref tls) = options.tls_settings {
33        if let Some(ca_path) = tls.ca_path.as_ref() {
34            client.add_root_certificate(
35                reqwest::Certificate::from_pem(
36                    &std::fs::read(std::fs::canonicalize(ca_path.as_str()).unwrap())
37                        .map_err(|e| RspamdError::ConfigError(e.to_string()))?,
38                )
39                .map_err(|e| RspamdError::HttpError(e.to_string()))?,
40            )
41        } else {
42            client
43        }
44    } else {
45        client
46    };
47
48    Ok(AsyncClient {
49        inner: client
50            .build()
51            .map_err(|e| RspamdError::HttpError(e.to_string()))?,
52        config: options,
53    })
54}
55
56// Temporary structure for making a request
57pub struct ReqwestRequest<'a, B> {
58    endpoint: RspamdEndpoint<'a>,
59    client: AsyncClient<'a>,
60    body: B,
61    envelope_data: Option<EnvelopeData>,
62}
63
64#[maybe_async::maybe_async]
65impl<'a, B: AsRef<[u8]> + Send> Request for ReqwestRequest<'a, B> {
66    type Body = Bytes;
67    type HeaderMap = reqwest::header::HeaderMap;
68
69    async fn response(mut self) -> Result<(Self::HeaderMap, Self::Body), RspamdError> {
70        let mut retry_cnt = self.client.config.retries;
71        let mut maybe_sk = Default::default();
72        let extra_hdrs: HashMap<String, String> =
73            HashMap::from_iter(self.envelope_data.take().unwrap());
74
75        let response = loop {
76            // Check if File header is present - if so, we don't need to send the body
77            let has_file_header = extra_hdrs.contains_key("File");
78            let need_body = self.endpoint.need_body && !has_file_header;
79            let method = if need_body {
80                reqwest::Method::POST
81            } else {
82                reqwest::Method::GET
83            };
84
85            let mut url = Url::from_str(self.client.config.base_url.as_str())
86                .map_err(|e| RspamdError::HttpError(e.to_string()))?;
87            url.set_path(self.endpoint.url);
88            let mut req = self.client.inner.request(method, url.clone());
89
90            if let Some(ref password) = self.client.config.password {
91                req = req.header("Password", password);
92            }
93
94            if self.client.config.zstd && need_body {
95                req = req.header("Content-Encoding", "zstd");
96                req = req.header("Compression", "zstd");
97            }
98
99            for (k, v) in extra_hdrs.iter() {
100                req = req.header(k, v);
101            }
102
103            if let Some(ref encryption_key) = self.client.config.encryption_key {
104                let inner_req = req
105                    .build()
106                    .map_err(|e| RspamdError::HttpError(e.to_string()))?;
107                let body = if need_body {
108                    if self.client.config.zstd {
109                        zstd::encode_all(self.body.as_ref(), 0)?
110                    } else {
111                        self.body.as_ref().to_vec()
112                    }
113                } else {
114                    Vec::new()
115                };
116                let encrypted = httpcrypt_encrypt(
117                    url.path(),
118                    body.as_slice(),
119                    inner_req.headers(),
120                    encryption_key.as_bytes(),
121                )?;
122                req = self.client.inner.request(reqwest::Method::POST, url);
123                let key_header =
124                    make_key_header(encryption_key.as_str(), encrypted.peer_key.as_str())?;
125                req = req.header("Key", key_header);
126                req = req.body(encrypted.body);
127                maybe_sk = Some(encrypted.shared_key);
128            } else if need_body {
129                req = if self.client.config.zstd {
130                    req.body(reqwest::Body::from(zstd::encode_all(
131                        self.body.as_ref(),
132                        0,
133                    )?))
134                } else {
135                    req.body(Bytes::copy_from_slice(self.body.as_ref()))
136                };
137            }
138
139            let req = req.timeout(Duration::from_secs_f64(self.client.config.timeout));
140            let req = req
141                .build()
142                .map_err(|e| RspamdError::HttpError(e.to_string()))?;
143
144            match self.client.inner.execute(req).await {
145                Ok(v) => break Ok(v),
146                Err(e) => {
147                    if (retry_cnt - 1) == 0 {
148                        break Err(e);
149                    }
150                    retry_cnt -= 1;
151                    let delay = Duration::from_secs_f64(self.client.config.timeout);
152                    tokio::time::sleep(delay).await;
153                    continue;
154                }
155            };
156        }
157        .map_err(|e| RspamdError::HttpError(e.to_string()))?;
158
159        if !response.status().is_success() {
160            return Err(RspamdError::HttpError(format!(
161                "Status: {}",
162                response.status()
163            )));
164        }
165
166        if let Some(sk) = maybe_sk {
167            let mut body = BytesMut::from(
168                response
169                    .bytes()
170                    .await
171                    .map_err(|e| RspamdError::HttpError(e.to_string()))?,
172            );
173            let decrypted_offset = httpcrypt_decrypt(body.as_mut(), sk)?;
174            let mut hdrs = [httparse::EMPTY_HEADER; 64];
175            let mut parsed = httparse::Response::new(&mut hdrs);
176
177            let body_offset = parsed
178                .parse(&body.as_slice()[decrypted_offset..])
179                .map_err(|s| RspamdError::HttpError(s.to_string()))?;
180            let mut output_hdrs = reqwest::header::HeaderMap::with_capacity(parsed.headers.len());
181            for hdr in parsed.headers.iter_mut() {
182                output_hdrs.insert(
183                    HeaderName::from_str(hdr.name)?,
184                    HeaderValue::from_str(std::str::from_utf8(hdr.value)?)?,
185                );
186            }
187            let body = if output_hdrs
188                .get("Compression")
189                .is_some_and(|hv| hv == "zstd")
190            {
191                zstd::decode_all(&body.as_slice()[body_offset.unwrap() + decrypted_offset..])?
192            } else {
193                body.as_slice()[body_offset.unwrap() + decrypted_offset..].to_vec()
194            };
195            Ok((output_hdrs, body.into()))
196        } else {
197            Ok((response.headers().clone(), response.bytes().await?))
198        }
199    }
200}
201
202#[maybe_async::maybe_async]
203impl<'a, B: AsRef<[u8]> + Send> ReqwestRequest<'a, B> {
204    pub async fn new(
205        client: AsyncClient<'a>,
206        body: B,
207        command: RspamdCommand,
208        envelope_data: EnvelopeData,
209    ) -> Result<ReqwestRequest<'a, B>, RspamdError> {
210        Ok(Self {
211            endpoint: RspamdEndpoint::from_command(command),
212            client,
213            body,
214            envelope_data: Some(envelope_data),
215        })
216    }
217}
218
219/// Scan an email asynchronously, returning the parsed reply or error.
220/// Example:
221/// ```rust
222/// use rspamd_client::config::Config;
223/// use rspamd_client::scan_async;
224/// use rspamd_client::error::RspamdError;
225/// use bytes::Bytes;
226/// use std::str::FromStr;
227///
228///	#[tokio::main]
229/// async fn main() -> Result<(), RspamdError> {
230/// 	let config = Config::builder()
231/// 		.base_url("http://localhost:11333".to_string())
232/// 		.build();
233/// 	let envelope = Default::default();
234/// 	let email = "...";
235/// 	let response = scan_async(&config, email, envelope).await?;
236/// 	Ok(())
237/// }
238/// ```
239#[maybe_async::maybe_async]
240pub async fn scan_async<B: AsRef<[u8]> + Send>(
241    options: &Config,
242    body: B,
243    envelope_data: EnvelopeData,
244) -> Result<RspamdScanReply, RspamdError> {
245    let client = async_client(options)?;
246    let request = ReqwestRequest::new(client, body, RspamdCommand::Scan, envelope_data).await?;
247    let (headers, body) = request
248        .response()
249        .await
250        .map_err(|e| RspamdError::HttpError(e.to_string()))?;
251
252    // Check for Message-Offset header to handle body_block feature
253    let response = if let Some(offset_header) = headers.get("Message-Offset") {
254        let offset = offset_header
255            .to_str()
256            .map_err(|e| RspamdError::HttpError(format!("Invalid Message-Offset header: {}", e)))?
257            .parse::<usize>()
258            .map_err(|e| RspamdError::HttpError(format!("Invalid Message-Offset value: {}", e)))?;
259
260        if offset < body.len() {
261            // Split body into JSON part and rewritten body part
262            let json_part = &body[..offset];
263            let body_part = &body[offset..];
264
265            let mut response = serde_json::from_slice::<RspamdScanReply>(json_part)?;
266            response.rewritten_body = Some(body_part.to_vec());
267            response
268        } else {
269            // Offset is out of bounds, parse entire body as JSON
270            serde_json::from_slice::<RspamdScanReply>(body.as_ref())?
271        }
272    } else {
273        // No Message-Offset header, parse entire body as JSON
274        serde_json::from_slice::<RspamdScanReply>(body.as_ref())?
275    };
276
277    Ok(response)
278}