rspamd_client/backend/
async_client.rs1use crate::backend::traits::*;
2use crate::config::{Config, EnvelopeData};
3use crate::error::RspamdError;
4use crate::protocol::commands::{RspamdCommand, RspamdEndpoint};
5use crate::protocol::encryption::{httpcrypt_decrypt, httpcrypt_encrypt, make_key_header};
6use crate::protocol::RspamdScanReply;
7use bytes::{Bytes, BytesMut};
8use reqwest::header::{HeaderName, HeaderValue};
9use reqwest::Client;
10use std::collections::HashMap;
11use std::str::FromStr;
12use std::time::Duration;
13use url::Url;
14use zstd::zstd_safe::WriteBuf;
15
16pub struct AsyncClient<'a> {
17 config: &'a Config,
18 inner: Client,
19}
20
21#[cfg(feature = "async")]
22pub fn async_client(options: &Config) -> Result<AsyncClient<'_>, RspamdError> {
23 let client = Client::builder().timeout(Duration::from_secs_f64(options.timeout));
24
25 let client = if let Some(ref proxy) = options.proxy_config {
26 let proxy = reqwest::Proxy::all(proxy.proxy_url.clone())
27 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
28 client.proxy(proxy)
29 } else {
30 client
31 };
32 let client = if let Some(ref tls) = options.tls_settings {
33 if let Some(ca_path) = tls.ca_path.as_ref() {
34 client.add_root_certificate(
35 reqwest::Certificate::from_pem(
36 &std::fs::read(std::fs::canonicalize(ca_path.as_str()).unwrap())
37 .map_err(|e| RspamdError::ConfigError(e.to_string()))?,
38 )
39 .map_err(|e| RspamdError::HttpError(e.to_string()))?,
40 )
41 } else {
42 client
43 }
44 } else {
45 client
46 };
47
48 Ok(AsyncClient {
49 inner: client
50 .build()
51 .map_err(|e| RspamdError::HttpError(e.to_string()))?,
52 config: options,
53 })
54}
55
56pub struct ReqwestRequest<'a, B> {
58 endpoint: RspamdEndpoint<'a>,
59 client: AsyncClient<'a>,
60 body: B,
61 envelope_data: Option<EnvelopeData>,
62}
63
64#[maybe_async::maybe_async]
65impl<'a, B: AsRef<[u8]> + Send> Request for ReqwestRequest<'a, B> {
66 type Body = Bytes;
67 type HeaderMap = reqwest::header::HeaderMap;
68
69 async fn response(mut self) -> Result<(Self::HeaderMap, Self::Body), RspamdError> {
70 let mut retry_cnt = self.client.config.retries;
71 let mut maybe_sk = Default::default();
72 let extra_hdrs: HashMap<String, String> =
73 HashMap::from_iter(self.envelope_data.take().unwrap());
74
75 let response = loop {
76 let has_file_header = extra_hdrs.contains_key("File");
78 let need_body = self.endpoint.need_body && !has_file_header;
79 let method = if need_body {
80 reqwest::Method::POST
81 } else {
82 reqwest::Method::GET
83 };
84
85 let mut url = Url::from_str(self.client.config.base_url.as_str())
86 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
87 url.set_path(self.endpoint.url);
88 let mut req = self.client.inner.request(method, url.clone());
89
90 if let Some(ref password) = self.client.config.password {
91 req = req.header("Password", password);
92 }
93
94 if self.client.config.zstd && need_body {
95 req = req.header("Content-Encoding", "zstd");
96 req = req.header("Compression", "zstd");
97 }
98
99 for (k, v) in extra_hdrs.iter() {
100 req = req.header(k, v);
101 }
102
103 if let Some(ref encryption_key) = self.client.config.encryption_key {
104 let inner_req = req
105 .build()
106 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
107 let body = if need_body {
108 if self.client.config.zstd {
109 zstd::encode_all(self.body.as_ref(), 0)?
110 } else {
111 self.body.as_ref().to_vec()
112 }
113 } else {
114 Vec::new()
115 };
116 let encrypted = httpcrypt_encrypt(
117 url.path(),
118 body.as_slice(),
119 inner_req.headers(),
120 encryption_key.as_bytes(),
121 )?;
122 req = self.client.inner.request(reqwest::Method::POST, url);
123 let key_header =
124 make_key_header(encryption_key.as_str(), encrypted.peer_key.as_str())?;
125 req = req.header("Key", key_header);
126 req = req.body(encrypted.body);
127 maybe_sk = Some(encrypted.shared_key);
128 } else if need_body {
129 req = if self.client.config.zstd {
130 req.body(reqwest::Body::from(zstd::encode_all(
131 self.body.as_ref(),
132 0,
133 )?))
134 } else {
135 req.body(Bytes::copy_from_slice(self.body.as_ref()))
136 };
137 }
138
139 let req = req.timeout(Duration::from_secs_f64(self.client.config.timeout));
140 let req = req
141 .build()
142 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
143
144 match self.client.inner.execute(req).await {
145 Ok(v) => break Ok(v),
146 Err(e) => {
147 if (retry_cnt - 1) == 0 {
148 break Err(e);
149 }
150 retry_cnt -= 1;
151 let delay = Duration::from_secs_f64(self.client.config.timeout);
152 tokio::time::sleep(delay).await;
153 continue;
154 }
155 };
156 }
157 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
158
159 if !response.status().is_success() {
160 return Err(RspamdError::HttpError(format!(
161 "Status: {}",
162 response.status()
163 )));
164 }
165
166 if let Some(sk) = maybe_sk {
167 let mut body = BytesMut::from(
168 response
169 .bytes()
170 .await
171 .map_err(|e| RspamdError::HttpError(e.to_string()))?,
172 );
173 let decrypted_offset = httpcrypt_decrypt(body.as_mut(), sk)?;
174 let mut hdrs = [httparse::EMPTY_HEADER; 64];
175 let mut parsed = httparse::Response::new(&mut hdrs);
176
177 let body_offset = parsed
178 .parse(&body.as_slice()[decrypted_offset..])
179 .map_err(|s| RspamdError::HttpError(s.to_string()))?;
180 let mut output_hdrs = reqwest::header::HeaderMap::with_capacity(parsed.headers.len());
181 for hdr in parsed.headers.iter_mut() {
182 output_hdrs.insert(
183 HeaderName::from_str(hdr.name)?,
184 HeaderValue::from_str(std::str::from_utf8(hdr.value)?)?,
185 );
186 }
187 let body = if output_hdrs
188 .get("Compression")
189 .is_some_and(|hv| hv == "zstd")
190 {
191 zstd::decode_all(&body.as_slice()[body_offset.unwrap() + decrypted_offset..])?
192 } else {
193 body.as_slice()[body_offset.unwrap() + decrypted_offset..].to_vec()
194 };
195 Ok((output_hdrs, body.into()))
196 } else {
197 Ok((response.headers().clone(), response.bytes().await?))
198 }
199 }
200}
201
202#[maybe_async::maybe_async]
203impl<'a, B: AsRef<[u8]> + Send> ReqwestRequest<'a, B> {
204 pub async fn new(
205 client: AsyncClient<'a>,
206 body: B,
207 command: RspamdCommand,
208 envelope_data: EnvelopeData,
209 ) -> Result<ReqwestRequest<'a, B>, RspamdError> {
210 Ok(Self {
211 endpoint: RspamdEndpoint::from_command(command),
212 client,
213 body,
214 envelope_data: Some(envelope_data),
215 })
216 }
217}
218
219#[maybe_async::maybe_async]
240pub async fn scan_async<B: AsRef<[u8]> + Send>(
241 options: &Config,
242 body: B,
243 envelope_data: EnvelopeData,
244) -> Result<RspamdScanReply, RspamdError> {
245 let client = async_client(options)?;
246 let request = ReqwestRequest::new(client, body, RspamdCommand::Scan, envelope_data).await?;
247 let (headers, body) = request
248 .response()
249 .await
250 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
251
252 let response = if let Some(offset_header) = headers.get("Message-Offset") {
254 let offset = offset_header
255 .to_str()
256 .map_err(|e| RspamdError::HttpError(format!("Invalid Message-Offset header: {}", e)))?
257 .parse::<usize>()
258 .map_err(|e| RspamdError::HttpError(format!("Invalid Message-Offset value: {}", e)))?;
259
260 if offset < body.len() {
261 let json_part = &body[..offset];
263 let body_part = &body[offset..];
264
265 let mut response = serde_json::from_slice::<RspamdScanReply>(json_part)?;
266 response.rewritten_body = Some(body_part.to_vec());
267 response
268 } else {
269 serde_json::from_slice::<RspamdScanReply>(body.as_ref())?
271 }
272 } else {
273 serde_json::from_slice::<RspamdScanReply>(body.as_ref())?
275 };
276
277 Ok(response)
278}