rspamd_client/backend/
async_client.rs1use std::str::FromStr;
2use std::time::Duration;
3use std::collections::HashMap;
4use bytes::{Bytes, BytesMut};
5use reqwest::Client;
6use reqwest::header::{HeaderName, HeaderValue};
7use url::Url;
8use zstd::zstd_safe::WriteBuf;
9use crate::backend::traits::*;
10use crate::config::{Config, EnvelopeData};
11use crate::error::RspamdError;
12use crate::protocol::commands::{RspamdCommand, RspamdEndpoint};
13use crate::protocol::RspamdScanReply;
14use crate::protocol::encryption::{httpcrypt_encrypt, httpcrypt_decrypt, make_key_header};
15
16pub struct AsyncClient<'a> {
17 config: &'a Config,
18 inner: Client,
19}
20
21#[cfg(feature = "async")]
22pub fn async_client(options: &Config) -> Result<AsyncClient, RspamdError> {
23 let client = Client::builder()
24 .timeout(Duration::from_secs_f64(options.timeout));
25
26 let client = if let Some(ref proxy) = options.proxy_config {
27 let proxy = reqwest::Proxy::all(proxy.proxy_url.clone()).map_err(|e| RspamdError::HttpError(e.to_string()))?;
28 client.proxy(proxy)
29 } else {
30 client
31 };
32 let client = if let Some(ref tls) = options.tls_settings {
33 if let Some(ca_path) = tls.ca_path.as_ref() {
34 client.add_root_certificate(reqwest::Certificate::from_pem(
35 &std::fs::read(std::fs::canonicalize(ca_path.as_str()).unwrap())
36 .map_err(|e| RspamdError::ConfigError(e.to_string()))?)
37 .map_err(|e| RspamdError::HttpError(e.to_string()))?)
38 }
39 else {
40 client
41 }
42 } else {
43 client
44 };
45
46
47 Ok(AsyncClient{
48 inner: client.build()
49 .map_err(|e| RspamdError::HttpError(e.to_string()))?,
50 config: options,
51 })
52}
53
54pub struct ReqwestRequest<'a> {
56 endpoint: RspamdEndpoint<'a>,
57 client: AsyncClient<'a>,
58 body: Bytes,
59 envelope_data: Option<EnvelopeData>,
60}
61
62#[maybe_async::maybe_async]
63impl<'a> Request for ReqwestRequest<'a> {
64 type Body = Bytes;
65 type HeaderMap = reqwest::header::HeaderMap;
66
67 async fn response(mut self) -> Result<(Self::HeaderMap, Self::Body), RspamdError> {
68 let mut retry_cnt = self.client.config.retries;
69 let mut maybe_sk = Default::default();
70 let extra_hdrs : HashMap<String, String> = HashMap::from_iter(self.envelope_data.take().unwrap().into_iter());
71
72 let response = loop {
73 let has_file_header = extra_hdrs.contains_key("File");
75 let need_body = self.endpoint.need_body && !has_file_header;
76 let method = if need_body { reqwest::Method::POST } else { reqwest::Method::GET };
77
78 let mut url = Url::from_str(self.client.config.base_url.as_str())
79 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
80 url.set_path(self.endpoint.url);
81 let mut req = self.client.inner.request(method, url.clone());
82
83 if let Some(ref password) = self.client.config.password {
84 req = req.header("Password", password);
85 }
86
87 if self.client.config.zstd && need_body {
88 req = req.header("Content-Encoding", "zstd");
89 req = req.header("Compression", "zstd");
90 }
91
92 for (k, v) in extra_hdrs.iter() {
93 req = req.header(k, v);
94 }
95
96 if let Some(ref encryption_key) = self.client.config.encryption_key {
97 let inner_req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
98 let body = if need_body {
99 if self.client.config.zstd {
100 zstd::encode_all(self.body.as_ref(), 0)?
101 }
102 else {
103 self.body.to_vec()
104 }
105 } else {
106 Vec::new()
107 };
108 let encrypted = httpcrypt_encrypt(
109 url.path(),
110 body.as_slice(),
111 inner_req.headers(),
112 encryption_key.as_bytes(),
113 )?;
114 req = self.client.inner.request(reqwest::Method::POST, url);
115 let key_header = make_key_header(encryption_key.as_str(), encrypted.peer_key.as_str())?;
116 req = req.header("Key", key_header);
117 req = req.body(encrypted.body);
118 maybe_sk = Some(encrypted.shared_key);
119 }
120 else if need_body {
121 req = if self.client.config.zstd {
122 req.body(reqwest::Body::from(zstd::encode_all(self.body.as_ref(), 0)?))
123 }
124 else {
125 req.body(self.body.clone())
126 };
127 }
128
129 let req = req.timeout(Duration::from_secs_f64(self.client.config.timeout));
130 let req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
131
132 match self.client.inner.execute(req).await {
133 Ok(v) => break Ok(v),
134 Err(e) => {
135 if (retry_cnt - 1) == 0 {
136 break Err(e);
137 }
138 retry_cnt -= 1;
139 let delay = Duration::from_secs_f64(self.client.config.timeout);
140 tokio::time::sleep(delay).await;
141 continue;
142 }
143 };
144 }.map_err(|e| RspamdError::HttpError(e.to_string()))?;
145
146 if !response.status().is_success() {
147 return Err(RspamdError::HttpError(format!(
148 "Status: {}",
149 response.status()
150 )));
151 }
152
153 if let Some(sk) = maybe_sk {
154 let mut body = BytesMut::from(response.bytes().await.map_err(|e| RspamdError::HttpError(e.to_string()))?);
155 let decrypted_offset = httpcrypt_decrypt(body.as_mut(), sk)?;
156 let mut hdrs = [httparse::EMPTY_HEADER; 64];
157 let mut parsed = httparse::Response::new(&mut hdrs);
158
159 let body_offset = parsed.parse(&body.as_slice()[decrypted_offset..]).map_err(|s| RspamdError::HttpError(s.to_string()))?;
160 let mut output_hdrs = reqwest::header::HeaderMap::with_capacity(parsed.headers.len());
161 for hdr in parsed.headers.iter_mut() {
162 output_hdrs.insert(HeaderName::from_str(hdr.name)?, HeaderValue::from_str(std::str::from_utf8(hdr.value)?)?);
163 }
164 let body = if output_hdrs.get("Compression").map_or(false,
165 |hv| hv == "zstd") {
166 zstd::decode_all(&body.as_slice()[body_offset.unwrap() + decrypted_offset..])?
167 } else {
168 body.as_slice()[body_offset.unwrap() + decrypted_offset..].to_vec()
169 };
170 Ok((output_hdrs, body.into()))
171 }
172 else {
173 Ok((response.headers().clone(), response.bytes().await?))
174 }
175 }
176}
177
178#[maybe_async::maybe_async]
179impl<'a> ReqwestRequest<'a> {
180 pub async fn new<T: Into<Bytes>>(
181 client: AsyncClient<'a>,
182 body: T,
183 command: RspamdCommand,
184 envelope_data: EnvelopeData,
185 ) -> Result<ReqwestRequest<'a>, RspamdError> {
186 Ok(Self {
187 endpoint: RspamdEndpoint::from_command(command),
188 client,
189 body: body.into(),
190 envelope_data: Some(envelope_data),
191 })
192 }
193}
194
195#[maybe_async::maybe_async]
216pub async fn scan_async<T: Into<Bytes>>(options: &Config, body: T, envelope_data: EnvelopeData) -> Result<RspamdScanReply, RspamdError> {
217 let client = async_client(options)?;
218 let request = ReqwestRequest::new(client, body, RspamdCommand::Scan, envelope_data).await?;
219 let (_, body) = request.response().await.map_err(|e| RspamdError::HttpError(e.to_string()))?;
220 let response = serde_json::from_slice::<RspamdScanReply>(body.as_ref())?;
221 Ok(response)
222}