rspamd_client/backend/
async_client.rs

1use std::str::FromStr;
2use std::time::Duration;
3use std::collections::HashMap;
4use bytes::{Bytes, BytesMut};
5use reqwest::Client;
6use reqwest::header::{HeaderName, HeaderValue};
7use url::Url;
8use zstd::zstd_safe::WriteBuf;
9use crate::backend::traits::*;
10use crate::config::{Config, EnvelopeData};
11use crate::error::RspamdError;
12use crate::protocol::commands::{RspamdCommand, RspamdEndpoint};
13use crate::protocol::RspamdScanReply;
14use crate::protocol::encryption::{httpcrypt_encrypt, httpcrypt_decrypt, make_key_header};
15
16pub struct AsyncClient<'a> {
17	config: &'a Config,
18	inner: Client,
19}
20
21#[cfg(feature = "async")]
22pub fn async_client(options: &Config) -> Result<AsyncClient, RspamdError> {
23	let client = Client::builder()
24		.timeout(Duration::from_secs_f64(options.timeout));
25
26	let client = if let Some(ref proxy) = options.proxy_config {
27		let proxy = reqwest::Proxy::all(proxy.proxy_url.clone()).map_err(|e| RspamdError::HttpError(e.to_string()))?;
28		client.proxy(proxy)
29	} else {
30		client
31	};
32	let client = if let Some(ref tls) = options.tls_settings {
33		if let Some(ca_path) = tls.ca_path.as_ref() {
34			client.add_root_certificate(reqwest::Certificate::from_pem(
35				&std::fs::read(std::fs::canonicalize(ca_path.as_str()).unwrap())
36				.map_err(|e| RspamdError::ConfigError(e.to_string()))?)
37				.map_err(|e| RspamdError::HttpError(e.to_string()))?)
38		}
39		else {
40			client
41		}
42	} else {
43		client
44	};
45
46
47	Ok(AsyncClient{
48		inner: client.build()
49			.map_err(|e| RspamdError::HttpError(e.to_string()))?,
50		config: options,
51	})
52}
53
54// Temporary structure for making a request
55pub struct ReqwestRequest<'a> {
56	endpoint: RspamdEndpoint<'a>,
57	client: AsyncClient<'a>,
58	body: Bytes,
59	envelope_data: Option<EnvelopeData>,
60}
61
62#[maybe_async::maybe_async]
63impl<'a> Request for ReqwestRequest<'a> {
64	type Body = Bytes;
65	type HeaderMap = reqwest::header::HeaderMap;
66
67	async fn response(mut self) -> Result<(Self::HeaderMap, Self::Body), RspamdError> {
68		let mut retry_cnt = self.client.config.retries;
69		let mut maybe_sk = Default::default();
70		let extra_hdrs :  HashMap<String, String> = HashMap::from_iter(self.envelope_data.take().unwrap().into_iter());
71
72		let response = loop {
73			let method = if self.endpoint.need_body {  reqwest::Method::POST } else { reqwest::Method::GET };
74
75			let mut url = Url::from_str(self.client.config.base_url.as_str())
76				.map_err(|e| RspamdError::HttpError(e.to_string()))?;
77			url.set_path(self.endpoint.url);
78			let mut req = self.client.inner.request(method, url.clone());
79
80			if let Some(ref password) = self.client.config.password {
81				req = req.header("Password", password);
82			}
83
84			if self.client.config.zstd {
85				req = req.header("Content-Encoding", "zstd");
86				req = req.header("Compression", "zstd");
87			}
88
89			for (k, v) in extra_hdrs.iter() {
90				req = req.header(k, v);
91			}
92
93			if let Some(ref encryption_key) = self.client.config.encryption_key {
94				let inner_req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
95				let body = if self.client.config.zstd {
96					zstd::encode_all(self.body.as_ref(), 0)?
97				}
98				else {
99					self.body.to_vec()
100				};
101				let encrypted = httpcrypt_encrypt(
102					url.path(),
103					body.as_slice(),
104					inner_req.headers(),
105					encryption_key.as_bytes(),
106				)?;
107				req = self.client.inner.request(reqwest::Method::POST, url);
108				let key_header = make_key_header(encryption_key.as_str(), encrypted.peer_key.as_str())?;
109				req = req.header("Key", key_header);
110				req = req.body(encrypted.body);
111				maybe_sk = Some(encrypted.shared_key);
112			}
113			else if self.endpoint.need_body {
114				req = if self.client.config.zstd {
115					req.body(reqwest::Body::from(zstd::encode_all(self.body.as_ref(), 0)?))
116				}
117				else {
118					req.body(self.body.clone())
119				};
120			}
121
122			let req = req.timeout(Duration::from_secs_f64(self.client.config.timeout));
123			let req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
124
125			match self.client.inner.execute(req).await {
126				Ok(v) => break Ok(v),
127				Err(e) => {
128					if (retry_cnt - 1) == 0 {
129						break Err(e);
130					}
131					retry_cnt -= 1;
132					let delay = Duration::from_secs_f64(self.client.config.timeout);
133					tokio::time::sleep(delay).await;
134					continue;
135				}
136			};
137		}.map_err(|e| RspamdError::HttpError(e.to_string()))?;
138
139		if !response.status().is_success() {
140			return Err(RspamdError::HttpError(format!(
141				"Status: {}",
142				response.status()
143			)));
144		}
145
146		if let Some(sk) = maybe_sk {
147			let mut body = BytesMut::from(response.bytes().await.map_err(|e| RspamdError::HttpError(e.to_string()))?);
148			let decrypted_offset = httpcrypt_decrypt(body.as_mut(), sk)?;
149			let mut hdrs = [httparse::EMPTY_HEADER; 64];
150			let mut parsed = httparse::Response::new(&mut hdrs);
151
152			let body_offset = parsed.parse(&body.as_slice()[decrypted_offset..]).map_err(|s| RspamdError::HttpError(s.to_string()))?;
153			let mut output_hdrs = reqwest::header::HeaderMap::with_capacity(parsed.headers.len());
154			for hdr in parsed.headers.iter_mut() {
155				output_hdrs.insert(HeaderName::from_str(hdr.name)?, HeaderValue::from_str(std::str::from_utf8(hdr.value)?)?);
156			}
157			let body = if output_hdrs.get("Compression").map_or(false,
158																		|hv| hv == "zstd") {
159				zstd::decode_all(&body.as_slice()[body_offset.unwrap() + decrypted_offset..])?
160			} else {
161				body.as_slice()[body_offset.unwrap() + decrypted_offset..].to_vec()
162			};
163			Ok((output_hdrs, body.into()))
164		}
165		else {
166			Ok((response.headers().clone(), response.bytes().await?))
167		}
168	}
169}
170
171#[maybe_async::maybe_async]
172impl<'a> ReqwestRequest<'a> {
173	pub async fn new<T: Into<Bytes>>(
174		client: AsyncClient<'a>,
175		body: T,
176		command: RspamdCommand,
177		envelope_data: EnvelopeData,
178	) -> Result<ReqwestRequest<'a>, RspamdError> {
179		Ok(Self {
180			endpoint: RspamdEndpoint::from_command(command),
181			client,
182			body: body.into(),
183			envelope_data: Some(envelope_data),
184		})
185	}
186}
187
188/// Scan an email asynchronously, returning the parsed reply or error.
189/// Example:
190/// ```rust
191/// use rspamd_client::config::Config;
192/// use rspamd_client::scan_async;
193/// use rspamd_client::error::RspamdError;
194/// use bytes::Bytes;
195/// use std::str::FromStr;
196///
197///	#[tokio::main]
198/// async fn main() -> Result<(), RspamdError> {
199/// 	let config = Config::builder()
200/// 		.base_url("http://localhost:11333".to_string())
201/// 		.build();
202/// 	let envelope = Default::default();
203/// 	let email = "...";
204/// 	let response = scan_async(&config, email, envelope).await?;
205/// 	Ok(())
206/// }
207/// ```
208#[maybe_async::maybe_async]
209pub async fn scan_async<T: Into<Bytes>>(options: &Config, body: T, envelope_data: EnvelopeData) -> Result<RspamdScanReply, RspamdError> {
210	let client = async_client(options)?;
211	let request = ReqwestRequest::new(client, body, RspamdCommand::Scan, envelope_data).await?;
212	let (_, body) = request.response().await.map_err(|e| RspamdError::HttpError(e.to_string()))?;
213	let response = serde_json::from_slice::<RspamdScanReply>(body.as_ref())?;
214	Ok(response)
215}