rspamd_client/backend/
async_client.rs1use std::str::FromStr;
2use std::time::Duration;
3use std::collections::HashMap;
4use bytes::{Bytes, BytesMut};
5use reqwest::Client;
6use reqwest::header::{HeaderName, HeaderValue};
7use url::Url;
8use zstd::zstd_safe::WriteBuf;
9use crate::backend::traits::*;
10use crate::config::{Config, EnvelopeData};
11use crate::error::RspamdError;
12use crate::protocol::commands::{RspamdCommand, RspamdEndpoint};
13use crate::protocol::RspamdScanReply;
14use crate::protocol::encryption::{httpcrypt_encrypt, httpcrypt_decrypt, make_key_header};
15
16pub struct AsyncClient<'a> {
17 config: &'a Config,
18 inner: Client,
19}
20
21#[cfg(feature = "async")]
22pub fn async_client(options: &Config) -> Result<AsyncClient, RspamdError> {
23 let client = Client::builder()
24 .timeout(Duration::from_secs_f64(options.timeout));
25
26 let client = if let Some(ref proxy) = options.proxy_config {
27 let proxy = reqwest::Proxy::all(proxy.proxy_url.clone()).map_err(|e| RspamdError::HttpError(e.to_string()))?;
28 client.proxy(proxy)
29 } else {
30 client
31 };
32 let client = if let Some(ref tls) = options.tls_settings {
33 if let Some(ca_path) = tls.ca_path.as_ref() {
34 client.add_root_certificate(reqwest::Certificate::from_pem(
35 &std::fs::read(std::fs::canonicalize(ca_path.as_str()).unwrap())
36 .map_err(|e| RspamdError::ConfigError(e.to_string()))?)
37 .map_err(|e| RspamdError::HttpError(e.to_string()))?)
38 }
39 else {
40 client
41 }
42 } else {
43 client
44 };
45
46
47 Ok(AsyncClient{
48 inner: client.build()
49 .map_err(|e| RspamdError::HttpError(e.to_string()))?,
50 config: options,
51 })
52}
53
54pub struct ReqwestRequest<'a> {
56 endpoint: RspamdEndpoint<'a>,
57 client: AsyncClient<'a>,
58 body: Bytes,
59 envelope_data: Option<EnvelopeData>,
60}
61
62#[maybe_async::maybe_async]
63impl<'a> Request for ReqwestRequest<'a> {
64 type Body = Bytes;
65 type HeaderMap = reqwest::header::HeaderMap;
66
67 async fn response(mut self) -> Result<(Self::HeaderMap, Self::Body), RspamdError> {
68 let mut retry_cnt = self.client.config.retries;
69 let mut maybe_sk = Default::default();
70 let extra_hdrs : HashMap<String, String> = HashMap::from_iter(self.envelope_data.take().unwrap().into_iter());
71
72 let response = loop {
73 let method = if self.endpoint.need_body { reqwest::Method::POST } else { reqwest::Method::GET };
74
75 let mut url = Url::from_str(self.client.config.base_url.as_str())
76 .map_err(|e| RspamdError::HttpError(e.to_string()))?;
77 url.set_path(self.endpoint.url);
78 let mut req = self.client.inner.request(method, url.clone());
79
80 if let Some(ref password) = self.client.config.password {
81 req = req.header("Password", password);
82 }
83
84 if self.client.config.zstd {
85 req = req.header("Content-Encoding", "zstd");
86 req = req.header("Compression", "zstd");
87 }
88
89 for (k, v) in extra_hdrs.iter() {
90 req = req.header(k, v);
91 }
92
93 if let Some(ref encryption_key) = self.client.config.encryption_key {
94 let inner_req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
95 let body = if self.client.config.zstd {
96 zstd::encode_all(self.body.as_ref(), 0)?
97 }
98 else {
99 self.body.to_vec()
100 };
101 let encrypted = httpcrypt_encrypt(
102 url.path(),
103 body.as_slice(),
104 inner_req.headers(),
105 encryption_key.as_bytes(),
106 )?;
107 req = self.client.inner.request(reqwest::Method::POST, url);
108 let key_header = make_key_header(encryption_key.as_str(), encrypted.peer_key.as_str())?;
109 req = req.header("Key", key_header);
110 req = req.body(encrypted.body);
111 maybe_sk = Some(encrypted.shared_key);
112 }
113 else if self.endpoint.need_body {
114 req = if self.client.config.zstd {
115 req.body(reqwest::Body::from(zstd::encode_all(self.body.as_ref(), 0)?))
116 }
117 else {
118 req.body(self.body.clone())
119 };
120 }
121
122 let req = req.timeout(Duration::from_secs_f64(self.client.config.timeout));
123 let req = req.build().map_err(|e| RspamdError::HttpError(e.to_string()))?;
124
125 match self.client.inner.execute(req).await {
126 Ok(v) => break Ok(v),
127 Err(e) => {
128 if (retry_cnt - 1) == 0 {
129 break Err(e);
130 }
131 retry_cnt -= 1;
132 let delay = Duration::from_secs_f64(self.client.config.timeout);
133 tokio::time::sleep(delay).await;
134 continue;
135 }
136 };
137 }.map_err(|e| RspamdError::HttpError(e.to_string()))?;
138
139 if !response.status().is_success() {
140 return Err(RspamdError::HttpError(format!(
141 "Status: {}",
142 response.status()
143 )));
144 }
145
146 if let Some(sk) = maybe_sk {
147 let mut body = BytesMut::from(response.bytes().await.map_err(|e| RspamdError::HttpError(e.to_string()))?);
148 let decrypted_offset = httpcrypt_decrypt(body.as_mut(), sk)?;
149 let mut hdrs = [httparse::EMPTY_HEADER; 64];
150 let mut parsed = httparse::Response::new(&mut hdrs);
151
152 let body_offset = parsed.parse(&body.as_slice()[decrypted_offset..]).map_err(|s| RspamdError::HttpError(s.to_string()))?;
153 let mut output_hdrs = reqwest::header::HeaderMap::with_capacity(parsed.headers.len());
154 for hdr in parsed.headers.iter_mut() {
155 output_hdrs.insert(HeaderName::from_str(hdr.name)?, HeaderValue::from_str(std::str::from_utf8(hdr.value)?)?);
156 }
157 let body = if output_hdrs.get("Compression").map_or(false,
158 |hv| hv == "zstd") {
159 zstd::decode_all(&body.as_slice()[body_offset.unwrap() + decrypted_offset..])?
160 } else {
161 body.as_slice()[body_offset.unwrap() + decrypted_offset..].to_vec()
162 };
163 Ok((output_hdrs, body.into()))
164 }
165 else {
166 Ok((response.headers().clone(), response.bytes().await?))
167 }
168 }
169}
170
171#[maybe_async::maybe_async]
172impl<'a> ReqwestRequest<'a> {
173 pub async fn new<T: Into<Bytes>>(
174 client: AsyncClient<'a>,
175 body: T,
176 command: RspamdCommand,
177 envelope_data: EnvelopeData,
178 ) -> Result<ReqwestRequest<'a>, RspamdError> {
179 Ok(Self {
180 endpoint: RspamdEndpoint::from_command(command),
181 client,
182 body: body.into(),
183 envelope_data: Some(envelope_data),
184 })
185 }
186}
187
188#[maybe_async::maybe_async]
209pub async fn scan_async<T: Into<Bytes>>(options: &Config, body: T, envelope_data: EnvelopeData) -> Result<RspamdScanReply, RspamdError> {
210 let client = async_client(options)?;
211 let request = ReqwestRequest::new(client, body, RspamdCommand::Scan, envelope_data).await?;
212 let (_, body) = request.response().await.map_err(|e| RspamdError::HttpError(e.to_string()))?;
213 let response = serde_json::from_slice::<RspamdScanReply>(body.as_ref())?;
214 Ok(response)
215}