http_adapter_ureq/
lib.rs

1//! # HTTP adapter implementation for [`ureq`](https://crates.io/crates/ureq)
2//!
3//! For more details refer to [`http-adapter`](https://crates.io/crates/http-adapter)
4
5use std::error::Error as StdError;
6use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
7use std::future::Future;
8use std::io::Read;
9use std::pin::Pin;
10use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError};
11use std::task::{Context, Poll, Waker};
12use std::thread;
13use std::thread::JoinHandle;
14
15pub use ureq;
16
17use http_adapter::async_trait::async_trait;
18use http_adapter::http::{HeaderValue, StatusCode, Version};
19use http_adapter::{http, HttpClientAdapter};
20use http_adapter::{Request, Response};
21
22#[derive(Clone, Debug)]
23pub struct UreqAdapter {
24	agent: ureq::Agent,
25}
26
27impl UreqAdapter {
28	pub fn new(agent: ureq::Agent) -> Self {
29		Self { agent }
30	}
31}
32
33impl Default for UreqAdapter {
34	#[inline]
35	fn default() -> Self {
36		Self {
37			agent: ureq::Agent::new(),
38		}
39	}
40}
41
42#[derive(Debug)]
43pub enum Error {
44	Http(http::Error),
45	Ureq(Box<ureq::Error>),
46	InvalidHeaderValue(HeaderValue),
47	InvalidHttpVersion(String),
48	InvalidStatusCode(u16),
49	InternalCommunicationError(String),
50}
51
52impl Display for Error {
53	fn fmt(&self, f: &mut Formatter) -> FmtResult {
54		match self {
55			Error::Http(e) => Display::fmt(e, f),
56			Error::Ureq(e) => Display::fmt(e, f),
57			Error::InvalidHeaderValue(header_value) => {
58				write!(f, "Invalid header value: {header_value:?}")
59			}
60			Error::InvalidHttpVersion(version) => {
61				write!(f, "Invalid HTTP version: {version}")
62			}
63			Error::InvalidStatusCode(code) => {
64				write!(f, "Invalid status code: {code}")
65			}
66			Error::InternalCommunicationError(e) => {
67				write!(f, "Internal communication error: {e}")
68			}
69		}
70	}
71}
72
73impl StdError for Error {}
74
75#[inline]
76fn from_request<B>(client: &ureq::Agent, request: &Request<B>) -> Result<ureq::Request, Error> {
77	let mut out = client.request(request.method().as_str(), &request.uri().to_string());
78	for (name, value) in request.headers() {
79		out = out.set(
80			name.as_str(),
81			value.to_str().map_err(|_| Error::InvalidHeaderValue(value.clone()))?,
82		);
83	}
84	Ok(out)
85}
86
87#[inline]
88fn to_response(res: ureq::Response) -> Result<Response<Vec<u8>>, Error> {
89	let version = match res.http_version() {
90		"HTTP/0.9" => Version::HTTP_09,
91		"HTTP/1.0" => Version::HTTP_10,
92		"HTTP/1.1" => Version::HTTP_11,
93		"HTTP/2.0" => Version::HTTP_2,
94		"HTTP/3.0" => Version::HTTP_3,
95		_ => return Err(Error::InvalidHttpVersion(res.http_version().to_string())),
96	};
97
98	let status = StatusCode::from_u16(res.status()).map_err(|_| Error::InvalidStatusCode(res.status()))?;
99
100	let mut response = Response::builder().status(status).version(version);
101
102	for header_name in res.headers_names() {
103		if let Some(header_value) = res.header(&header_name) {
104			response = response.header(header_name, header_value);
105		}
106	}
107	let mut body = vec![];
108	res.into_reader()
109		.read_to_end(&mut body)
110		.map_err(|e| Error::Ureq(Box::new(e.into())))?;
111	response.body(body).map_err(Error::Http)
112}
113
114#[async_trait(?Send)]
115impl HttpClientAdapter for UreqAdapter {
116	type Error = Error;
117
118	async fn execute(&self, request: Request<Vec<u8>>) -> Result<Response<Vec<u8>>, Self::Error> {
119		let req = from_request(&self.agent, &request)?;
120		let res = ThreadFuture::new(|send_result, recv_waker| {
121			move || {
122				let waker = recv_waker
123					.recv()
124					.map_err(|_| Error::InternalCommunicationError("Waker receive channel is closed".to_string()))?;
125				match req.send_bytes(request.body()).map_err(|e| Error::Ureq(Box::new(e))) {
126					Ok(res) => send_result
127						.send(to_response(res))
128						.map_err(|_| Error::InternalCommunicationError("Result send channel is closed for Ok result".to_string()))?,
129					Err(e) => send_result
130						.send(Err(e))
131						.map_err(|_| Error::InternalCommunicationError("Result send channel is closed for Err result".to_string()))?,
132				}
133				waker.wake();
134				Ok(())
135			}
136		})
137		.await;
138		match res {
139			FutureResult::CommunicationError(e) => Err(Error::InternalCommunicationError(e)),
140			FutureResult::Result(r) => r,
141		}
142	}
143}
144
145struct ThreadFuture<Res> {
146	thread: Option<JoinHandle<Result<(), Error>>>,
147	recv_result: Receiver<Res>,
148	send_waker: Sender<Waker>,
149	waker_sent: bool,
150}
151
152impl<Res: Send + 'static> ThreadFuture<Res> {
153	pub fn new<Factory, Body>(factory: Factory) -> ThreadFuture<Res>
154	where
155		Factory: FnOnce(Sender<Res>, Receiver<Waker>) -> Body,
156		Body: FnOnce() -> Result<(), Error> + Send + 'static,
157	{
158		let (send_result, recv_result) = channel();
159		let (send_waker, recv_waker) = channel();
160		let body = factory(send_result, recv_waker);
161		let thread = thread::spawn(body);
162		ThreadFuture {
163			thread: Some(thread),
164			recv_result,
165			send_waker,
166			waker_sent: false,
167		}
168	}
169}
170
171impl<Res> Drop for ThreadFuture<Res> {
172	fn drop(&mut self) {
173		if let Some(thread) = self.thread.take() {
174			let _ = thread.join().expect("Can't join thread");
175		}
176	}
177}
178
179impl<Res> Future for ThreadFuture<Res> {
180	type Output = FutureResult<Res>;
181
182	fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
183		if !self.waker_sent {
184			if self.send_waker.send(cx.waker().clone()).is_err() {
185				return Poll::Ready(FutureResult::CommunicationError("Waker send channel is closed".to_string()));
186			}
187			self.waker_sent = true;
188		}
189		match self.recv_result.try_recv() {
190			Ok(res) => Poll::Ready(FutureResult::Result(res)),
191			Err(TryRecvError::Disconnected) => Poll::Ready(FutureResult::CommunicationError(
192				"Result receive channel is closed".to_string(),
193			)),
194			Err(TryRecvError::Empty) => Poll::Pending,
195		}
196	}
197}
198
199enum FutureResult<Res> {
200	CommunicationError(String),
201	Result(Res),
202}