rspamd_client/backend/
async_client.rs

1use std::str::FromStr;
2use std::time::Duration;
3use std::collections::HashMap;
4use bytes::{Bytes, BytesMut};
5use reqwest::Client;
6use reqwest::header::{HeaderName, HeaderValue};
7use url::Url;
8use zstd::zstd_safe::WriteBuf;
9use crate::backend::traits::*;
10use crate::config::{Config, EnvelopeData};
11use crate::error::RspamdError;
12use crate::protocol::commands::{RspamdCommand, RspamdEndpoint};
13use crate::protocol::RspamdScanReply;
14use crate::protocol::encryption::{httpcrypt_encrypt, httpcrypt_decrypt, make_key_header};
15
16pub struct AsyncClient<'a> {
17	config: &'a Config,
18	inner: Client,
19}
20
21#[cfg(feature = "async")]
22pub fn async_client(options: &Config) -> Result<AsyncClient, RspamdError> {
23	let client = Client::builder()
24		.timeout(Duration::from_secs_f64(options.timeout));
25
26	let client = if let Some(ref proxy) = options.proxy_config {
27		let proxy = reqwest::Proxy::all(proxy.proxy_url.clone()).map_err(|e| RspamdError::HttpError(e.to_string()))?;
28		client.proxy(proxy)
29	} else {
30		client
31	};
32	let client = if let Some(ref tls) = options.tls_settings {
33		if let Some(ca_path) = tls.ca_path.as_ref() {
34			client.add_root_certificate(reqwest::Certificate::from_pem(
35				&std::fs::read(std::fs::canonicalize(ca_path.as_str()).unwrap())
36				.map_err(|e| RspamdError::ConfigError(e.to_string()))?)
37				.map_err(|e| RspamdError::HttpError(e.to_string()))?)
38		}
39		else {
40			client
41		}
42	} else {
43		client
44	};
45
46
47	Ok(AsyncClient{
48		inner: client.build()
49			.map_err(|e| RspamdError::HttpError(e.to_string()))?,
50		config: options,
51	})
52}
53
54// Temporary structure for making a request
55pub struct ReqwestRequest<'a> {
56	endpoint: RspamdEndpoint<'a>,
57	client: AsyncClient<'a>,
58	body: Bytes,
59	envelope_data: Option<EnvelopeData>,
60}
61
62#[maybe_async::maybe_async]
63impl<'a> Request for ReqwestRequest<'a> {
64	type Body = Bytes;
65	type HeaderMap = reqwest::header::HeaderMap;
66
67	async fn response(mut self) -> Result<(Self::HeaderMap, Self::Body), RspamdError> {
68		let mut retry_cnt = self.client.config.retries;
69		let mut maybe_sk = Default::default();
70		let extra_hdrs :  HashMap<String, String> = HashMap::from_iter(self.envelope_data.take().unwrap().into_iter());
71
72		let response = loop {
73			// Check if File header is present - if so, we don't need to send the body
74			let has_file_header = extra_hdrs.contains_key("File");
75			let need_body = self.endpoint.need_body && !has_file_header;
76			let method = if need_body {  reqwest::Method::POST } else { reqwest::Method::GET };
77
78			let mut url = Url::from_str(self.client.config.base_url.as_str())
79				.map_err(|e| RspamdError::HttpError(e.to_string()))?;
80			url.set_path(self.endpoint.url);
81			let mut req = self.client.inner.request(method, url.clone());
82
83			if let Some(ref password) = self.client.config.password {
84				req = req.header("Password", password);
85			}
86
87			if self.client.config.zstd && need_body {
88				req = req.header("Content-Encoding", "zstd");
89				req = req.header("Compression", "zstd");
90			}
91
92			for (k, v) in extra_hdrs.iter() {
93				req = req.header(k, v);
94			}
95
96			if let Some(ref encryption_key) = self.client.config.encryption_key {
97				let inner_req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
98				let body = if need_body {
99					if self.client.config.zstd {
100						zstd::encode_all(self.body.as_ref(), 0)?
101					}
102					else {
103						self.body.to_vec()
104					}
105				} else {
106					Vec::new()
107				};
108				let encrypted = httpcrypt_encrypt(
109					url.path(),
110					body.as_slice(),
111					inner_req.headers(),
112					encryption_key.as_bytes(),
113				)?;
114				req = self.client.inner.request(reqwest::Method::POST, url);
115				let key_header = make_key_header(encryption_key.as_str(), encrypted.peer_key.as_str())?;
116				req = req.header("Key", key_header);
117				req = req.body(encrypted.body);
118				maybe_sk = Some(encrypted.shared_key);
119			}
120			else if need_body {
121				req = if self.client.config.zstd {
122					req.body(reqwest::Body::from(zstd::encode_all(self.body.as_ref(), 0)?))
123				}
124				else {
125					req.body(self.body.clone())
126				};
127			}
128
129			let req = req.timeout(Duration::from_secs_f64(self.client.config.timeout));
130			let req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
131
132			match self.client.inner.execute(req).await {
133				Ok(v) => break Ok(v),
134				Err(e) => {
135					if (retry_cnt - 1) == 0 {
136						break Err(e);
137					}
138					retry_cnt -= 1;
139					let delay = Duration::from_secs_f64(self.client.config.timeout);
140					tokio::time::sleep(delay).await;
141					continue;
142				}
143			};
144		}.map_err(|e| RspamdError::HttpError(e.to_string()))?;
145
146		if !response.status().is_success() {
147			return Err(RspamdError::HttpError(format!(
148				"Status: {}",
149				response.status()
150			)));
151		}
152
153		if let Some(sk) = maybe_sk {
154			let mut body = BytesMut::from(response.bytes().await.map_err(|e| RspamdError::HttpError(e.to_string()))?);
155			let decrypted_offset = httpcrypt_decrypt(body.as_mut(), sk)?;
156			let mut hdrs = [httparse::EMPTY_HEADER; 64];
157			let mut parsed = httparse::Response::new(&mut hdrs);
158
159			let body_offset = parsed.parse(&body.as_slice()[decrypted_offset..]).map_err(|s| RspamdError::HttpError(s.to_string()))?;
160			let mut output_hdrs = reqwest::header::HeaderMap::with_capacity(parsed.headers.len());
161			for hdr in parsed.headers.iter_mut() {
162				output_hdrs.insert(HeaderName::from_str(hdr.name)?, HeaderValue::from_str(std::str::from_utf8(hdr.value)?)?);
163			}
164			let body = if output_hdrs.get("Compression").map_or(false,
165																		|hv| hv == "zstd") {
166				zstd::decode_all(&body.as_slice()[body_offset.unwrap() + decrypted_offset..])?
167			} else {
168				body.as_slice()[body_offset.unwrap() + decrypted_offset..].to_vec()
169			};
170			Ok((output_hdrs, body.into()))
171		}
172		else {
173			Ok((response.headers().clone(), response.bytes().await?))
174		}
175	}
176}
177
178#[maybe_async::maybe_async]
179impl<'a> ReqwestRequest<'a> {
180	pub async fn new<T: Into<Bytes>>(
181		client: AsyncClient<'a>,
182		body: T,
183		command: RspamdCommand,
184		envelope_data: EnvelopeData,
185	) -> Result<ReqwestRequest<'a>, RspamdError> {
186		Ok(Self {
187			endpoint: RspamdEndpoint::from_command(command),
188			client,
189			body: body.into(),
190			envelope_data: Some(envelope_data),
191		})
192	}
193}
194
195/// Scan an email asynchronously, returning the parsed reply or error.
196/// Example:
197/// ```rust
198/// use rspamd_client::config::Config;
199/// use rspamd_client::scan_async;
200/// use rspamd_client::error::RspamdError;
201/// use bytes::Bytes;
202/// use std::str::FromStr;
203///
204///	#[tokio::main]
205/// async fn main() -> Result<(), RspamdError> {
206/// 	let config = Config::builder()
207/// 		.base_url("http://localhost:11333".to_string())
208/// 		.build();
209/// 	let envelope = Default::default();
210/// 	let email = "...";
211/// 	let response = scan_async(&config, email, envelope).await?;
212/// 	Ok(())
213/// }
214/// ```
215#[maybe_async::maybe_async]
216pub async fn scan_async<T: Into<Bytes>>(options: &Config, body: T, envelope_data: EnvelopeData) -> Result<RspamdScanReply, RspamdError> {
217	let client = async_client(options)?;
218	let request = ReqwestRequest::new(client, body, RspamdCommand::Scan, envelope_data).await?;
219	let (_, body) = request.response().await.map_err(|e| RspamdError::HttpError(e.to_string()))?;
220	let response = serde_json::from_slice::<RspamdScanReply>(body.as_ref())?;
221	Ok(response)
222}