1use std::error::Error as StdError;
6use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
7use std::future::Future;
8use std::io::Read;
9use std::pin::Pin;
10use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError};
11use std::task::{Context, Poll, Waker};
12use std::thread;
13use std::thread::JoinHandle;
14
15pub use ureq;
16
17use http_adapter::async_trait::async_trait;
18use http_adapter::http::{HeaderValue, StatusCode, Version};
19use http_adapter::{http, HttpClientAdapter};
20use http_adapter::{Request, Response};
21
22#[derive(Clone, Debug)]
23pub struct UreqAdapter {
24 agent: ureq::Agent,
25}
26
27impl UreqAdapter {
28 pub fn new(agent: ureq::Agent) -> Self {
29 Self { agent }
30 }
31}
32
33impl Default for UreqAdapter {
34 #[inline]
35 fn default() -> Self {
36 Self {
37 agent: ureq::Agent::new(),
38 }
39 }
40}
41
42#[derive(Debug)]
43pub enum Error {
44 Http(http::Error),
45 Ureq(Box<ureq::Error>),
46 InvalidHeaderValue(HeaderValue),
47 InvalidHttpVersion(String),
48 InvalidStatusCode(u16),
49 InternalCommunicationError(String),
50}
51
52impl Display for Error {
53 fn fmt(&self, f: &mut Formatter) -> FmtResult {
54 match self {
55 Error::Http(e) => Display::fmt(e, f),
56 Error::Ureq(e) => Display::fmt(e, f),
57 Error::InvalidHeaderValue(header_value) => {
58 write!(f, "Invalid header value: {header_value:?}")
59 }
60 Error::InvalidHttpVersion(version) => {
61 write!(f, "Invalid HTTP version: {version}")
62 }
63 Error::InvalidStatusCode(code) => {
64 write!(f, "Invalid status code: {code}")
65 }
66 Error::InternalCommunicationError(e) => {
67 write!(f, "Internal communication error: {e}")
68 }
69 }
70 }
71}
72
73impl StdError for Error {}
74
75#[inline]
76fn from_request<B>(client: &ureq::Agent, request: &Request<B>) -> Result<ureq::Request, Error> {
77 let mut out = client.request(request.method().as_str(), &request.uri().to_string());
78 for (name, value) in request.headers() {
79 out = out.set(
80 name.as_str(),
81 value.to_str().map_err(|_| Error::InvalidHeaderValue(value.clone()))?,
82 );
83 }
84 Ok(out)
85}
86
87#[inline]
88fn to_response(res: ureq::Response) -> Result<Response<Vec<u8>>, Error> {
89 let version = match res.http_version() {
90 "HTTP/0.9" => Version::HTTP_09,
91 "HTTP/1.0" => Version::HTTP_10,
92 "HTTP/1.1" => Version::HTTP_11,
93 "HTTP/2.0" => Version::HTTP_2,
94 "HTTP/3.0" => Version::HTTP_3,
95 _ => return Err(Error::InvalidHttpVersion(res.http_version().to_string())),
96 };
97
98 let status = StatusCode::from_u16(res.status()).map_err(|_| Error::InvalidStatusCode(res.status()))?;
99
100 let mut response = Response::builder().status(status).version(version);
101
102 for header_name in res.headers_names() {
103 if let Some(header_value) = res.header(&header_name) {
104 response = response.header(header_name, header_value);
105 }
106 }
107 let mut body = vec![];
108 res.into_reader()
109 .read_to_end(&mut body)
110 .map_err(|e| Error::Ureq(Box::new(e.into())))?;
111 response.body(body).map_err(Error::Http)
112}
113
114#[async_trait(?Send)]
115impl HttpClientAdapter for UreqAdapter {
116 type Error = Error;
117
118 async fn execute(&self, request: Request<Vec<u8>>) -> Result<Response<Vec<u8>>, Self::Error> {
119 let req = from_request(&self.agent, &request)?;
120 let res = ThreadFuture::new(|send_result, recv_waker| {
121 move || {
122 let waker = recv_waker
123 .recv()
124 .map_err(|_| Error::InternalCommunicationError("Waker receive channel is closed".to_string()))?;
125 match req.send_bytes(request.body()).map_err(|e| Error::Ureq(Box::new(e))) {
126 Ok(res) => send_result
127 .send(to_response(res))
128 .map_err(|_| Error::InternalCommunicationError("Result send channel is closed for Ok result".to_string()))?,
129 Err(e) => send_result
130 .send(Err(e))
131 .map_err(|_| Error::InternalCommunicationError("Result send channel is closed for Err result".to_string()))?,
132 }
133 waker.wake();
134 Ok(())
135 }
136 })
137 .await;
138 match res {
139 FutureResult::CommunicationError(e) => Err(Error::InternalCommunicationError(e)),
140 FutureResult::Result(r) => r,
141 }
142 }
143}
144
145struct ThreadFuture<Res> {
146 thread: Option<JoinHandle<Result<(), Error>>>,
147 recv_result: Receiver<Res>,
148 send_waker: Sender<Waker>,
149 waker_sent: bool,
150}
151
152impl<Res: Send + 'static> ThreadFuture<Res> {
153 pub fn new<Factory, Body>(factory: Factory) -> ThreadFuture<Res>
154 where
155 Factory: FnOnce(Sender<Res>, Receiver<Waker>) -> Body,
156 Body: FnOnce() -> Result<(), Error> + Send + 'static,
157 {
158 let (send_result, recv_result) = channel();
159 let (send_waker, recv_waker) = channel();
160 let body = factory(send_result, recv_waker);
161 let thread = thread::spawn(body);
162 ThreadFuture {
163 thread: Some(thread),
164 recv_result,
165 send_waker,
166 waker_sent: false,
167 }
168 }
169}
170
171impl<Res> Drop for ThreadFuture<Res> {
172 fn drop(&mut self) {
173 if let Some(thread) = self.thread.take() {
174 let _ = thread.join().expect("Can't join thread");
175 }
176 }
177}
178
179impl<Res> Future for ThreadFuture<Res> {
180 type Output = FutureResult<Res>;
181
182 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
183 if !self.waker_sent {
184 if self.send_waker.send(cx.waker().clone()).is_err() {
185 return Poll::Ready(FutureResult::CommunicationError("Waker send channel is closed".to_string()));
186 }
187 self.waker_sent = true;
188 }
189 match self.recv_result.try_recv() {
190 Ok(res) => Poll::Ready(FutureResult::Result(res)),
191 Err(TryRecvError::Disconnected) => Poll::Ready(FutureResult::CommunicationError(
192 "Result receive channel is closed".to_string(),
193 )),
194 Err(TryRecvError::Empty) => Poll::Pending,
195 }
196 }
197}
198
199enum FutureResult<Res> {
200 CommunicationError(String),
201 Result(Res),
202}