1use actix_codec::Framed;
3use awc::{
4 error::{PayloadError, SendRequestError},
5 http::header::{HeaderMap, HeaderName, HeaderValue},
6 http::{header, Method, StatusCode},
7 ws::Codec,
8 BoxedSocket, ClientRequest, ClientResponse, SendClientRequest,
9};
10use bytes::{Bytes, BytesMut};
11use futures::stream::Peekable;
12use futures::{Stream, StreamExt, TryStreamExt};
13use heck::ToLowerCamelCase;
14use serde::{de::DeserializeOwned, Serialize};
15use serde_qs;
16use std::cmp::max;
17use std::convert::TryFrom;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use std::{env, rc::Rc, str::FromStr, time::Duration};
21use url::{form_urlencoded, Url};
22
23use crate::model::ErrorMessage;
24use crate::{Error, Result};
25
26pub const YAGNA_API_URL_ENV_VAR: &str = "YAGNA_API_URL";
27pub const DEFAULT_YAGNA_API_URL: &str = "http://127.0.0.1:7465";
28const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
29
30pub fn rest_api_url() -> Url {
31 let api_url = env::var(YAGNA_API_URL_ENV_VAR).unwrap_or(DEFAULT_YAGNA_API_URL.into());
32 api_url
33 .parse()
34 .unwrap_or_else(|_| panic!("invalid API URL: {}", api_url))
35}
36
37#[derive(Clone, Debug)]
38pub enum WebAuth {
39 Bearer(String),
40}
41
42#[derive(Clone)]
45pub struct WebClient {
46 base_url: Rc<Url>,
47 awc: awc::Client,
48}
49
50pub trait WebInterface {
51 const API_URL_ENV_VAR: &'static str;
52 const API_SUFFIX: &'static str;
53
54 fn rebase_service_url(base_url: Rc<Url>) -> Result<Rc<Url>> {
55 if let Ok(url) = std::env::var(Self::API_URL_ENV_VAR) {
56 return Ok(Url::from_str(&url)?.into());
57 }
58 let suffix = if Self::API_SUFFIX.starts_with('/') {
59 Self::API_SUFFIX[1..].to_string()
60 } else {
61 Self::API_SUFFIX.to_string()
62 };
63 let with_trailing = format!("{}/", suffix);
64 let u = base_url.join(&with_trailing);
65 Ok(u?.into())
66 }
67
68 fn from_client(client: WebClient) -> Self;
69}
70
71#[derive(Clone)]
72pub struct WebRequestMeta {
73 method: Method,
74 url: String,
75}
76
77impl WebRequestMeta {
78 fn new(method: Method, url: String) -> Self {
79 WebRequestMeta { method, url }
80 }
81
82 fn as_request_err(&self, err: SendRequestError) -> Error {
83 Error::from_request(err, self.method.clone(), self.url.clone())
84 }
85
86 fn as_response_err(&self, code: StatusCode, msg: String) -> Error {
87 Error::from_response(code, msg, self.method.clone(), self.url.clone())
88 }
89}
90
91pub struct WebRequest<T> {
92 inner_request: T,
93 meta: WebRequestMeta,
94}
95
96impl WebClient {
97 pub fn builder() -> WebClientBuilder {
98 WebClientBuilder::default()
99 }
100
101 pub fn with_token(token: &str) -> WebClient {
102 WebClientBuilder::default().auth_token(token).build()
103 }
104
105 fn url<T: AsRef<str>>(&self, suffix: T) -> Result<url::Url> {
109 Ok(self.base_url.join(suffix.as_ref())?)
110 }
111
112 pub fn request(&self, method: Method, url: &str) -> WebRequest<ClientRequest> {
113 let url = self.url(url).unwrap().to_string();
114 log::debug!("doing {} on {}", method, url);
115 WebRequest {
116 inner_request: self.awc.request(method.clone(), &url),
117 meta: WebRequestMeta::new(method, url),
118 }
119 }
120
121 pub async fn event_stream(&self, url: &str) -> Result<impl Stream<Item = Result<Event>>> {
122 let url = self.url(url).unwrap().to_string();
123 log::debug!("event stream at {}", url);
124 let method = Method::GET;
125 let request = self
126 .awc
127 .request(method.clone(), &url)
128 .insert_header((header::ACCEPT, mime::TEXT_EVENT_STREAM));
129 let stream = request
130 .send()
131 .await
132 .map_err(|e| Error::from_request(e, method, url))?
133 .into_stream()
134 .map_err(Error::from)
135 .event_stream();
136 Ok(stream)
137 }
138
139 pub async fn ws(&self, url: &str) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>)> {
140 let mut url = self.base_url.join(url).unwrap();
141 url.set_scheme("ws")
142 .map_err(|_| Error::InternalError(format!("Invalid URL: {}", url)))?;
143 Ok(self.awc.ws(url.to_string()).connect().await?)
144 }
145
146 pub fn get(&self, url: &str) -> WebRequest<ClientRequest> {
147 self.request(Method::GET, url)
148 }
149
150 pub fn post(&self, url: &str) -> WebRequest<ClientRequest> {
151 self.request(Method::POST, url)
152 }
153
154 pub fn put(&self, url: &str) -> WebRequest<ClientRequest> {
155 self.request(Method::PUT, url)
156 }
157
158 pub fn delete(&self, url: &str) -> WebRequest<ClientRequest> {
159 self.request(Method::DELETE, url)
160 }
161
162 pub fn interface<T: WebInterface>(&self) -> Result<T> {
163 self.interface_at(None)
164 }
165
166 pub fn interface_at<T: WebInterface>(&self, base_url: impl Into<Option<Url>>) -> Result<T> {
167 let base_url = match base_url.into() {
168 Some(url) => url.into(),
169 None => T::rebase_service_url(self.base_url.clone())?,
170 };
171
172 let awc = self.awc.clone();
173 Ok(T::from_client(WebClient { base_url, awc }))
174 }
175}
176
177impl WebRequest<ClientRequest> {
178 pub fn send_json<T: Serialize + std::fmt::Debug>(
179 self,
180 value: &T,
181 ) -> WebRequest<SendClientRequest> {
182 log::trace!("sending payload: {:?}", value);
183 WebRequest {
184 inner_request: self.inner_request.send_json(value),
185 meta: self.meta,
186 }
187 }
188
189 pub fn send_bytes(self, bytes: Vec<u8>) -> WebRequest<SendClientRequest> {
190 let inner_request = self
191 .inner_request
192 .content_type("application/octet-stream")
193 .send_body(bytes);
194 WebRequest {
195 inner_request,
196 meta: self.meta,
197 }
198 }
199
200 pub fn add_header(mut self, name: &str, value: &str) -> Self {
201 self.inner_request = self.inner_request.append_header((name, value));
202 self
203 }
204
205 pub fn send(self) -> WebRequest<SendClientRequest> {
206 WebRequest {
207 inner_request: self.inner_request.send(),
208 meta: self.meta,
209 }
210 }
211}
212
213impl WebRequest<SendClientRequest> {
214 async fn request(
215 self,
216 ) -> Result<ClientResponse<impl Stream<Item = std::result::Result<Bytes, PayloadError>>>> {
217 let meta = self.meta.clone();
218 let mut response = self
219 .inner_request
220 .await
221 .map_err(|e| meta.as_request_err(e))?;
222
223 log::trace!("{:?}", response.headers());
224 if response.status().is_success() {
225 Ok(response)
226 } else {
227 let msg = if response
228 .headers()
229 .get(header::CONTENT_TYPE)
230 .map(|v| v.as_bytes() == b"application/json")
231 .unwrap_or_default()
232 {
233 let err_msg = response.json().await;
234 err_msg
235 .map(|e: ErrorMessage| e.message.unwrap_or_default())
236 .unwrap_or_else(|e| format!("error parsing error msg: {}", e))
237 } else {
238 match response.body().limit(MAX_BODY_SIZE).await {
239 Ok(ref bytes) => String::from_utf8_lossy(bytes).to_string(),
240 Err(e) => e.to_string(),
241 }
242 };
243 Err(meta.as_response_err(response.status(), msg))
244 }
245 }
246
247 pub async fn bytes(self) -> Result<Vec<u8>> {
248 Ok(self.request().await?.body().await?.to_vec())
249 }
250
251 pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
252 let meta = self.meta.clone();
253 let mut response = self.request().await?;
254
255 if StatusCode::NO_CONTENT == response.status()
257 || Some("0")
258 == response
259 .headers()
260 .get(header::CONTENT_LENGTH)
261 .and_then(|h| h.to_str().ok())
262 {
263 return Ok(serde_json::from_value(serde_json::json!(()))?);
264 }
265 let raw_body = response.body().limit(MAX_BODY_SIZE).await?;
266 let body = std::str::from_utf8(&raw_body)?;
267 log::debug!(
268 "WebRequest.json(). method={} url={}, resp='{}'",
269 meta.method,
270 meta.url,
271 body.split_at(512.min(body.len())).0
272 );
273 Ok(serde_json::from_str(body)?)
274 }
275}
276
277pub(crate) fn default_on_timeout<T: Default>(err: Error) -> Result<T> {
280 match err {
281 Error::TimeoutError { msg, url, .. } => {
282 log::trace!("timeout getting url {}: {}", url, msg);
283 Ok(Default::default())
284 }
285 _ => Err(err),
286 }
287}
288
289#[derive(Clone, Debug)]
290pub struct WebClientBuilder {
291 pub(crate) api_url: Option<Url>,
292 pub(crate) auth: Option<WebAuth>,
293 pub(crate) headers: HeaderMap,
294 pub(crate) timeout: Option<Duration>,
295}
296
297impl WebClientBuilder {
298 pub fn auth_token(mut self, token: &str) -> Self {
299 self.auth = Some(WebAuth::Bearer(token.to_string()));
300 self
301 }
302
303 pub fn api_url(mut self, url: Url) -> Self {
304 self.api_url = Some(url);
305 self
306 }
307
308 pub fn timeout(mut self, timeout: Duration) -> Self {
309 self.timeout = Some(timeout);
310 self
311 }
312
313 pub fn header(mut self, name: String, value: String) -> Result<Self> {
314 let name = HeaderName::from_str(name.as_str())?;
315 let value = HeaderValue::from_str(value.as_str())?;
316
317 self.headers.insert(name, value);
318 Ok(self)
319 }
320
321 pub fn build(self) -> WebClient {
322 let mut builder = awc::ClientBuilder::new();
323
324 if let Some(timeout) = self.timeout {
325 builder = builder.timeout(timeout);
326 } else {
327 builder = builder.disable_timeout();
328 }
329 if let Some(auth) = &self.auth {
330 builder = match auth {
331 WebAuth::Bearer(token) => builder.bearer_auth(token),
332 }
333 }
334 for (key, value) in self.headers.iter() {
335 builder = builder.add_default_header((key.clone(), value.clone()));
336 }
337
338 WebClient {
339 base_url: Rc::new(self.api_url.unwrap_or_else(rest_api_url)),
340 awc: builder.finish(),
341 }
342 }
343}
344
345impl Default for WebClientBuilder {
346 fn default() -> Self {
347 WebClientBuilder {
348 api_url: None,
349 auth: None,
350 headers: HeaderMap::new(),
351 timeout: None,
352 }
353 }
354}
355
356pub struct QueryParamsBuilder<'a> {
358 serializer: form_urlencoded::Serializer<'a, String>,
359}
360
361impl<'a> Default for QueryParamsBuilder<'a> {
362 fn default() -> Self {
363 let serializer = form_urlencoded::Serializer::new("".into());
364 QueryParamsBuilder { serializer }
365 }
366}
367
368impl<'a> QueryParamsBuilder<'a> {
369 pub fn put<N: ToString, V: ToString>(mut self, name: N, value: Option<V>) -> Self {
370 if let Some(v) = value {
371 self.serializer
372 .append_pair(&name.to_string().to_lower_camel_case(), &v.to_string());
373 };
374 self
375 }
376
377 pub fn build(mut self) -> String {
378 self.serializer.finish()
379 }
380}
381
382#[derive(Debug)]
383pub struct Event {
384 pub id: Option<u64>,
385 pub event: String,
386 pub data: String,
387}
388
389impl TryFrom<String> for Event {
390 type Error = Error;
391
392 fn try_from(string: String) -> Result<Self> {
393 let mut id = None;
394 let mut event = String::new();
395 let mut data = Vec::<String>::new();
396
397 for line in string.split('\n') {
398 let split = line.splitn(2, ':').collect::<Vec<_>>();
399 if split.len() < 2 {
400 continue;
401 }
402
403 let value = split[1].trim_start();
404 match split[0] {
405 "event" => event = value.into(),
406 "data" => data.push(value.into()),
407 "id" => {
408 id = match value.parse::<u64>() {
409 Ok(id) => Some(id),
410 _ => None,
411 }
412 }
413 _ => (),
414 }
415 }
416 if event.is_empty() {
417 return Err(Error::EventStreamError("Missing event entry".into()));
418 }
419 let data = data.join("\n");
420 Ok(Event { id, event, data })
421 }
422}
423
424pub trait EventStreamExt<S, E>
425where
426 S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
427 E: Into<Error>,
428{
429 fn event_stream(self) -> EventStream<S, E>;
430}
431
432impl<S, E> EventStreamExt<S, E> for S
433where
434 S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
435 E: Into<Error>,
436{
437 fn event_stream(self) -> EventStream<S, E> {
438 EventStream::new(self)
439 }
440}
441
442pub struct EventStream<S, E>
443where
444 S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
445{
446 inner: Peekable<S>,
447 buffer: BytesMut,
448}
449
450impl<S, E> EventStream<S, E>
451where
452 S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
453 E: Into<Error>,
454{
455 pub fn new(stream: S) -> Self {
456 EventStream {
457 inner: stream.peekable(),
458 buffer: BytesMut::new(),
459 }
460 }
461
462 fn next_event(&mut self, start_idx: usize) -> Option<Result<Event>> {
463 let idx = max(0, start_idx as i64 - 1) as usize;
464 if let Some(idx) = Self::find(&self.buffer, b"\n\n", idx) {
465 let bytes = self.buffer.split_to(idx);
466 return String::from_utf8(bytes.to_vec())
467 .map(Event::try_from)
468 .map_err(Error::from)
469 .ok();
470 }
471 None
472 }
473
474 fn find(source: &[u8], find: &[u8], start_idx: usize) -> Option<usize> {
475 let mut find_idx = 0;
476 for (i, b) in source.iter().enumerate().skip(start_idx) {
477 if *b == find[find_idx] {
478 find_idx += 1;
479 if find_idx == find.len() {
480 return Some(i);
481 }
482 } else {
483 find_idx = 0;
484 }
485 }
486 None
487 }
488}
489
490impl<S, E> Stream for EventStream<S, E>
491where
492 S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
493 E: Into<Error>,
494{
495 type Item = std::result::Result<Event, Error>;
496
497 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
498 let this = self.get_mut();
499 if let Some(result) = this.next_event(0) {
500 return Poll::Ready(Some(result));
501 }
502
503 match Pin::new(&mut this.inner).poll_next(cx) {
504 Poll::Ready(Some(Ok(bytes))) => {
505 let idx = this.buffer.len();
506 this.buffer.extend(bytes);
507
508 if let Some(result) = this.next_event(idx) {
509 Poll::Ready(Some(result))
510 } else {
511 if Pin::new(&mut this.inner).poll_peek(cx).is_ready() {
512 cx.waker().wake_by_ref();
513 }
514 Poll::Pending
515 }
516 }
517 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
518 Poll::Ready(None) => Poll::Ready(None),
519 Poll::Pending => Poll::Pending,
520 }
521 }
522}
523
524macro_rules! url_format {
535 {
536 $path:expr $(,$var:ident)* $(,#[query] $varq:ident)* $(,)?
537 } => {{
538 let mut url = format!( $path $(, $var)* );
539 let query = crate::web::QueryParamsBuilder::default()
540 $( .put( stringify!($varq), $varq ) )*
541 .build();
542 if query.len() > 1 {
543 url = format!("{}?{}", url, query)
544 }
545 url
546 }};
547}
548
549pub fn url_format_obj<T>(base: &str, params: &T) -> String
550where
551 T: Serialize,
552{
553 let qs = serde_qs::to_string(params).unwrap_or("".to_string());
554 if !qs.is_empty() {
555 format!("{}?{}", base, qs)
556 } else {
557 base.to_string()
558 }
559}
560
561#[cfg(test)]
562#[rustfmt::skip]
563mod tests {
564 use bytes::Bytes;
565 use crate::web::EventStream;
566 use futures::{StreamExt, FutureExt, Stream};
567 use crate::Error;
568
569 #[test]
570 fn static_url() {
571 assert_eq!(url_format!("foo"), "foo");
572 }
573
574 #[test]
575 fn single_placeholder_url() {
576 let bar = "qux";
577 assert_eq!(url_format!("foo/{}", bar), "foo/qux");
578 }
579
580 #[test]
581 fn single_var_url() {
582 let bar = "qux";
583 assert_eq!(url_format!("foo/{bar}"), "foo/qux");
584 }
585
586 #[test]
594 fn multi_var_url() {
595 let bar = "qux";
596 let baz = "quz";
597 assert_eq!(
598 url_format!("foo/{}/fuu/{baz}", bar),
599 "foo/qux/fuu/quz"
600 );
601 }
602
603 #[test]
604 fn empty_query_url() {
605 let bar = Option::<String>::None;
606 assert_eq!(url_format!("foo", #[query] bar), "foo");
607 }
608
609 #[test]
610 #[rustfmt::skip]
611 fn single_query_url() {
612 let bar= Some("qux");
613 assert_eq!(url_format!("foo", #[query] bar), "foo?bar=qux");
614 }
615
616 #[test]
617 fn mix_query_url() {
618 let bar = Option::<String>::None;
619 let baz = Some("quz");
620 assert_eq!(url_format!("foo", #[query] bar, #[query] baz), "foo?baz=quz");
621 }
622
623 #[test]
624 fn multi_query_url() {
625 let bar = Some("qux");
626 let baz = Some("quz");
627 assert_eq!(url_format!("foo", #[query] bar, #[query] baz), "foo?bar=qux&baz=quz");
628 }
629
630 #[test]
631 fn multi_var_and_query_url() {
632 let bar = "baara";
633 let baz = 0;
634 let qar = Some(true);
635 let qaz = Some(3);
636 assert_eq!(
637 url_format!(
638 "foo/{bar}/fuu/{baz}",
639 #[query] qar,
640 #[query] qaz
641 ),
642 "foo/baara/fuu/0?qar=true&qaz=3"
643 );
644 }
645
646 async fn verify_stream<S, F>(f: F) -> anyhow::Result<()>
647 where
648 S: Stream<Item = std::result::Result<Bytes, Error>> + Unpin + 'static,
649 F: Fn(&'static str) -> EventStream<S, Error>,
650 {
651 let src = r#"
652:ping
653event: stdout
654data: some
655data: output
656id: 1
657
658:ping
659
660event: stderr
661data:
662id: 2
663
664event: stdout
665data: 0
666id
667
668"#;
669 let stream = f(src);
670 let events = stream.collect::<Vec<_>>().await;
671
672 assert_eq!(events.len(), 4);
673 let mut iter = events.into_iter();
674
675 let event = iter.next().unwrap()?;
676 assert_eq!(event.event, "stdout".to_string());
677 assert_eq!(event.data, "some\noutput".to_string());
678 assert_eq!(event.id, Some(1));
679
680 assert!(iter.next().unwrap().is_err());
681
682 let event = iter.next().unwrap()?;
683 assert_eq!(event.event, "stderr".to_string());
684 assert_eq!(event.data, "".to_string());
685 assert_eq!(event.id, Some(2));
686
687 let event = iter.next().unwrap()?;
688 assert_eq!(event.event, "stdout".to_string());
689 assert_eq!(event.data, "0".to_string());
690 assert_eq!(event.id, None);
691
692 Ok(())
693 }
694
695 #[actix_rt::test]
696 async fn event_stream() {
697 verify_stream(|s| {
698 let stream = futures::stream::once(async move { Ok::<_, Error>(Bytes::from(s.to_string().into_bytes()))}.boxed_local());
699 EventStream::new(stream)
700 }).await.unwrap();
701
702 verify_stream(|s| {
703 let stream = futures::stream::iter(s.as_bytes()).chunks(5).map(|v| {
704 Ok::<_, Error>(Bytes::from(v.iter().map(|b| **b).collect::<Vec<_>>()))
705 });
706 EventStream::new(stream)
707 }).await.unwrap();
708
709 verify_stream(|s| {
710 let stream = futures::stream::iter(s.as_bytes()).chunks(1).map(|v| {
711 Ok::<_, Error>(Bytes::from(v.iter().map(|b| **b).collect::<Vec<_>>()))
712 });
713 EventStream::new(stream)
714 }).await.unwrap();
715 }
716}