ya_client/
web.rs

1//! Web utils
2use actix_codec::Framed;
3use awc::{
4    error::{PayloadError, SendRequestError},
5    http::header::{HeaderMap, HeaderName, HeaderValue},
6    http::{header, Method, StatusCode},
7    ws::Codec,
8    BoxedSocket, ClientRequest, ClientResponse, SendClientRequest,
9};
10use bytes::{Bytes, BytesMut};
11use futures::stream::Peekable;
12use futures::{Stream, StreamExt, TryStreamExt};
13use heck::ToLowerCamelCase;
14use serde::{de::DeserializeOwned, Serialize};
15use serde_qs;
16use std::cmp::max;
17use std::convert::TryFrom;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use std::{env, rc::Rc, str::FromStr, time::Duration};
21use url::{form_urlencoded, Url};
22
23use crate::model::ErrorMessage;
24use crate::{Error, Result};
25
26pub const YAGNA_API_URL_ENV_VAR: &str = "YAGNA_API_URL";
27pub const DEFAULT_YAGNA_API_URL: &str = "http://127.0.0.1:7465";
28const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
29
30pub fn rest_api_url() -> Url {
31    let api_url = env::var(YAGNA_API_URL_ENV_VAR).unwrap_or(DEFAULT_YAGNA_API_URL.into());
32    api_url
33        .parse()
34        .unwrap_or_else(|_| panic!("invalid API URL: {}", api_url))
35}
36
37#[derive(Clone, Debug)]
38pub enum WebAuth {
39    Bearer(String),
40}
41
42/// Convenient wrapper for the [`awc::Client`](
43/// https://docs.rs/awc/1.0/awc/struct.Client.html) with builder.
44#[derive(Clone)]
45pub struct WebClient {
46    base_url: Rc<Url>,
47    awc: awc::Client,
48}
49
50pub trait WebInterface {
51    const API_URL_ENV_VAR: &'static str;
52    const API_SUFFIX: &'static str;
53
54    fn rebase_service_url(base_url: Rc<Url>) -> Result<Rc<Url>> {
55        if let Ok(url) = std::env::var(Self::API_URL_ENV_VAR) {
56            return Ok(Url::from_str(&url)?.into());
57        }
58        let suffix = if Self::API_SUFFIX.starts_with('/') {
59            Self::API_SUFFIX[1..].to_string()
60        } else {
61            Self::API_SUFFIX.to_string()
62        };
63        let with_trailing = format!("{}/", suffix);
64        let u = base_url.join(&with_trailing);
65        Ok(u?.into())
66    }
67
68    fn from_client(client: WebClient) -> Self;
69}
70
71#[derive(Clone)]
72pub struct WebRequestMeta {
73    method: Method,
74    url: String,
75}
76
77impl WebRequestMeta {
78    fn new(method: Method, url: String) -> Self {
79        WebRequestMeta { method, url }
80    }
81
82    fn as_request_err(&self, err: SendRequestError) -> Error {
83        Error::from_request(err, self.method.clone(), self.url.clone())
84    }
85
86    fn as_response_err(&self, code: StatusCode, msg: String) -> Error {
87        Error::from_response(code, msg, self.method.clone(), self.url.clone())
88    }
89}
90
91pub struct WebRequest<T> {
92    inner_request: T,
93    meta: WebRequestMeta,
94}
95
96impl WebClient {
97    pub fn builder() -> WebClientBuilder {
98        WebClientBuilder::default()
99    }
100
101    pub fn with_token(token: &str) -> WebClient {
102        WebClientBuilder::default().auth_token(token).build()
103    }
104
105    /// constructs endpoint url in form of `<base_url>/<suffix>`.
106    ///
107    /// suffix should not have leading slash ie. `offer` not `/offer`
108    fn url<T: AsRef<str>>(&self, suffix: T) -> Result<url::Url> {
109        Ok(self.base_url.join(suffix.as_ref())?)
110    }
111
112    pub fn request(&self, method: Method, url: &str) -> WebRequest<ClientRequest> {
113        let url = self.url(url).unwrap().to_string();
114        log::debug!("doing {} on {}", method, url);
115        WebRequest {
116            inner_request: self.awc.request(method.clone(), &url),
117            meta: WebRequestMeta::new(method, url),
118        }
119    }
120
121    pub async fn event_stream(&self, url: &str) -> Result<impl Stream<Item = Result<Event>>> {
122        let url = self.url(url).unwrap().to_string();
123        log::debug!("event stream at {}", url);
124        let method = Method::GET;
125        let request = self
126            .awc
127            .request(method.clone(), &url)
128            .insert_header((header::ACCEPT, mime::TEXT_EVENT_STREAM));
129        let stream = request
130            .send()
131            .await
132            .map_err(|e| Error::from_request(e, method, url))?
133            .into_stream()
134            .map_err(Error::from)
135            .event_stream();
136        Ok(stream)
137    }
138
139    pub async fn ws(&self, url: &str) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>)> {
140        let mut url = self.base_url.join(url).unwrap();
141        url.set_scheme("ws")
142            .map_err(|_| Error::InternalError(format!("Invalid URL: {}", url)))?;
143        Ok(self.awc.ws(url.to_string()).connect().await?)
144    }
145
146    pub fn get(&self, url: &str) -> WebRequest<ClientRequest> {
147        self.request(Method::GET, url)
148    }
149
150    pub fn post(&self, url: &str) -> WebRequest<ClientRequest> {
151        self.request(Method::POST, url)
152    }
153
154    pub fn put(&self, url: &str) -> WebRequest<ClientRequest> {
155        self.request(Method::PUT, url)
156    }
157
158    pub fn delete(&self, url: &str) -> WebRequest<ClientRequest> {
159        self.request(Method::DELETE, url)
160    }
161
162    pub fn interface<T: WebInterface>(&self) -> Result<T> {
163        self.interface_at(None)
164    }
165
166    pub fn interface_at<T: WebInterface>(&self, base_url: impl Into<Option<Url>>) -> Result<T> {
167        let base_url = match base_url.into() {
168            Some(url) => url.into(),
169            None => T::rebase_service_url(self.base_url.clone())?,
170        };
171
172        let awc = self.awc.clone();
173        Ok(T::from_client(WebClient { base_url, awc }))
174    }
175}
176
177impl WebRequest<ClientRequest> {
178    pub fn send_json<T: Serialize + std::fmt::Debug>(
179        self,
180        value: &T,
181    ) -> WebRequest<SendClientRequest> {
182        log::trace!("sending payload: {:?}", value);
183        WebRequest {
184            inner_request: self.inner_request.send_json(value),
185            meta: self.meta,
186        }
187    }
188
189    pub fn send_bytes(self, bytes: Vec<u8>) -> WebRequest<SendClientRequest> {
190        let inner_request = self
191            .inner_request
192            .content_type("application/octet-stream")
193            .send_body(bytes);
194        WebRequest {
195            inner_request,
196            meta: self.meta,
197        }
198    }
199
200    pub fn add_header(mut self, name: &str, value: &str) -> Self {
201        self.inner_request = self.inner_request.append_header((name, value));
202        self
203    }
204
205    pub fn send(self) -> WebRequest<SendClientRequest> {
206        WebRequest {
207            inner_request: self.inner_request.send(),
208            meta: self.meta,
209        }
210    }
211}
212
213impl WebRequest<SendClientRequest> {
214    async fn request(
215        self,
216    ) -> Result<ClientResponse<impl Stream<Item = std::result::Result<Bytes, PayloadError>>>> {
217        let meta = self.meta.clone();
218        let mut response = self
219            .inner_request
220            .await
221            .map_err(|e| meta.as_request_err(e))?;
222
223        log::trace!("{:?}", response.headers());
224        if response.status().is_success() {
225            Ok(response)
226        } else {
227            let msg = if response
228                .headers()
229                .get(header::CONTENT_TYPE)
230                .map(|v| v.as_bytes() == b"application/json")
231                .unwrap_or_default()
232            {
233                let err_msg = response.json().await;
234                err_msg
235                    .map(|e: ErrorMessage| e.message.unwrap_or_default())
236                    .unwrap_or_else(|e| format!("error parsing error msg: {}", e))
237            } else {
238                match response.body().limit(MAX_BODY_SIZE).await {
239                    Ok(ref bytes) => String::from_utf8_lossy(bytes).to_string(),
240                    Err(e) => e.to_string(),
241                }
242            };
243            Err(meta.as_response_err(response.status(), msg))
244        }
245    }
246
247    pub async fn bytes(self) -> Result<Vec<u8>> {
248        Ok(self.request().await?.body().await?.to_vec())
249    }
250
251    pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
252        let meta = self.meta.clone();
253        let mut response = self.request().await?;
254
255        // allow empty body and no content (204) to pass smoothly
256        if StatusCode::NO_CONTENT == response.status()
257            || Some("0")
258                == response
259                    .headers()
260                    .get(header::CONTENT_LENGTH)
261                    .and_then(|h| h.to_str().ok())
262        {
263            return Ok(serde_json::from_value(serde_json::json!(()))?);
264        }
265        let raw_body = response.body().limit(MAX_BODY_SIZE).await?;
266        let body = std::str::from_utf8(&raw_body)?;
267        log::debug!(
268            "WebRequest.json(). method={} url={}, resp='{}'",
269            meta.method,
270            meta.url,
271            body.split_at(512.min(body.len())).0
272        );
273        Ok(serde_json::from_str(body)?)
274    }
275}
276
277// this is used internally to translate from HTTP Timeout into default result
278// (empty vec most of the time)
279pub(crate) fn default_on_timeout<T: Default>(err: Error) -> Result<T> {
280    match err {
281        Error::TimeoutError { msg, url, .. } => {
282            log::trace!("timeout getting url {}: {}", url, msg);
283            Ok(Default::default())
284        }
285        _ => Err(err),
286    }
287}
288
289#[derive(Clone, Debug)]
290pub struct WebClientBuilder {
291    pub(crate) api_url: Option<Url>,
292    pub(crate) auth: Option<WebAuth>,
293    pub(crate) headers: HeaderMap,
294    pub(crate) timeout: Option<Duration>,
295}
296
297impl WebClientBuilder {
298    pub fn auth_token(mut self, token: &str) -> Self {
299        self.auth = Some(WebAuth::Bearer(token.to_string()));
300        self
301    }
302
303    pub fn api_url(mut self, url: Url) -> Self {
304        self.api_url = Some(url);
305        self
306    }
307
308    pub fn timeout(mut self, timeout: Duration) -> Self {
309        self.timeout = Some(timeout);
310        self
311    }
312
313    pub fn header(mut self, name: String, value: String) -> Result<Self> {
314        let name = HeaderName::from_str(name.as_str())?;
315        let value = HeaderValue::from_str(value.as_str())?;
316
317        self.headers.insert(name, value);
318        Ok(self)
319    }
320
321    pub fn build(self) -> WebClient {
322        let mut builder = awc::ClientBuilder::new();
323
324        if let Some(timeout) = self.timeout {
325            builder = builder.timeout(timeout);
326        } else {
327            builder = builder.disable_timeout();
328        }
329        if let Some(auth) = &self.auth {
330            builder = match auth {
331                WebAuth::Bearer(token) => builder.bearer_auth(token),
332            }
333        }
334        for (key, value) in self.headers.iter() {
335            builder = builder.add_default_header((key.clone(), value.clone()));
336        }
337
338        WebClient {
339            base_url: Rc::new(self.api_url.unwrap_or_else(rest_api_url)),
340            awc: builder.finish(),
341        }
342    }
343}
344
345impl Default for WebClientBuilder {
346    fn default() -> Self {
347        WebClientBuilder {
348            api_url: None,
349            auth: None,
350            headers: HeaderMap::new(),
351            timeout: None,
352        }
353    }
354}
355
356/// Builder for the query part of the URLs.
357pub struct QueryParamsBuilder<'a> {
358    serializer: form_urlencoded::Serializer<'a, String>,
359}
360
361impl<'a> Default for QueryParamsBuilder<'a> {
362    fn default() -> Self {
363        let serializer = form_urlencoded::Serializer::new("".into());
364        QueryParamsBuilder { serializer }
365    }
366}
367
368impl<'a> QueryParamsBuilder<'a> {
369    pub fn put<N: ToString, V: ToString>(mut self, name: N, value: Option<V>) -> Self {
370        if let Some(v) = value {
371            self.serializer
372                .append_pair(&name.to_string().to_lower_camel_case(), &v.to_string());
373        };
374        self
375    }
376
377    pub fn build(mut self) -> String {
378        self.serializer.finish()
379    }
380}
381
382#[derive(Debug)]
383pub struct Event {
384    pub id: Option<u64>,
385    pub event: String,
386    pub data: String,
387}
388
389impl TryFrom<String> for Event {
390    type Error = Error;
391
392    fn try_from(string: String) -> Result<Self> {
393        let mut id = None;
394        let mut event = String::new();
395        let mut data = Vec::<String>::new();
396
397        for line in string.split('\n') {
398            let split = line.splitn(2, ':').collect::<Vec<_>>();
399            if split.len() < 2 {
400                continue;
401            }
402
403            let value = split[1].trim_start();
404            match split[0] {
405                "event" => event = value.into(),
406                "data" => data.push(value.into()),
407                "id" => {
408                    id = match value.parse::<u64>() {
409                        Ok(id) => Some(id),
410                        _ => None,
411                    }
412                }
413                _ => (),
414            }
415        }
416        if event.is_empty() {
417            return Err(Error::EventStreamError("Missing event entry".into()));
418        }
419        let data = data.join("\n");
420        Ok(Event { id, event, data })
421    }
422}
423
424pub trait EventStreamExt<S, E>
425where
426    S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
427    E: Into<Error>,
428{
429    fn event_stream(self) -> EventStream<S, E>;
430}
431
432impl<S, E> EventStreamExt<S, E> for S
433where
434    S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
435    E: Into<Error>,
436{
437    fn event_stream(self) -> EventStream<S, E> {
438        EventStream::new(self)
439    }
440}
441
442pub struct EventStream<S, E>
443where
444    S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
445{
446    inner: Peekable<S>,
447    buffer: BytesMut,
448}
449
450impl<S, E> EventStream<S, E>
451where
452    S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
453    E: Into<Error>,
454{
455    pub fn new(stream: S) -> Self {
456        EventStream {
457            inner: stream.peekable(),
458            buffer: BytesMut::new(),
459        }
460    }
461
462    fn next_event(&mut self, start_idx: usize) -> Option<Result<Event>> {
463        let idx = max(0, start_idx as i64 - 1) as usize;
464        if let Some(idx) = Self::find(&self.buffer, b"\n\n", idx) {
465            let bytes = self.buffer.split_to(idx);
466            return String::from_utf8(bytes.to_vec())
467                .map(Event::try_from)
468                .map_err(Error::from)
469                .ok();
470        }
471        None
472    }
473
474    fn find(source: &[u8], find: &[u8], start_idx: usize) -> Option<usize> {
475        let mut find_idx = 0;
476        for (i, b) in source.iter().enumerate().skip(start_idx) {
477            if *b == find[find_idx] {
478                find_idx += 1;
479                if find_idx == find.len() {
480                    return Some(i);
481                }
482            } else {
483                find_idx = 0;
484            }
485        }
486        None
487    }
488}
489
490impl<S, E> Stream for EventStream<S, E>
491where
492    S: Stream<Item = std::result::Result<Bytes, E>> + Unpin + 'static,
493    E: Into<Error>,
494{
495    type Item = std::result::Result<Event, Error>;
496
497    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
498        let this = self.get_mut();
499        if let Some(result) = this.next_event(0) {
500            return Poll::Ready(Some(result));
501        }
502
503        match Pin::new(&mut this.inner).poll_next(cx) {
504            Poll::Ready(Some(Ok(bytes))) => {
505                let idx = this.buffer.len();
506                this.buffer.extend(bytes);
507
508                if let Some(result) = this.next_event(idx) {
509                    Poll::Ready(Some(result))
510                } else {
511                    if Pin::new(&mut this.inner).poll_peek(cx).is_ready() {
512                        cx.waker().wake_by_ref();
513                    }
514                    Poll::Pending
515                }
516            }
517            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
518            Poll::Ready(None) => Poll::Ready(None),
519            Poll::Pending => Poll::Pending,
520        }
521    }
522}
523
524/// Macro to facilitate URL formatting for REST API async bindings
525///
526/// Supports query parameters, in addition to working similarly to format!(..).
527/// The only exception being the ident=value syntax, which is not supported.
528///
529/// url_format!("foo") => "foo"
530/// url_format!("foo/{bar}") => "foo" + bar
531/// url_format!("foo/{}", bar) => "foo" + bar
532/// url_format!("foo/{bar}", bar="expr") => not supported
533/// url_format!("foo", #[query] bar) => "foo?bar=" + bar
534macro_rules! url_format {
535    {
536        $path:expr $(,$var:ident)* $(,#[query] $varq:ident)* $(,)?
537    } => {{
538        let mut url = format!( $path $(, $var)* );
539        let query = crate::web::QueryParamsBuilder::default()
540            $( .put( stringify!($varq), $varq ) )*
541            .build();
542        if query.len() > 1 {
543            url = format!("{}?{}", url, query)
544        }
545        url
546    }};
547}
548
549pub fn url_format_obj<T>(base: &str, params: &T) -> String
550where
551    T: Serialize,
552{
553    let qs = serde_qs::to_string(params).unwrap_or("".to_string());
554    if !qs.is_empty() {
555        format!("{}?{}", base, qs)
556    } else {
557        base.to_string()
558    }
559}
560
561#[cfg(test)]
562#[rustfmt::skip]
563mod tests {
564    use bytes::Bytes;
565    use crate::web::EventStream;
566    use futures::{StreamExt, FutureExt, Stream};
567    use crate::Error;
568
569    #[test]
570    fn static_url() {
571        assert_eq!(url_format!("foo"), "foo");
572    }
573
574    #[test]
575    fn single_placeholder_url() {
576        let bar = "qux";
577        assert_eq!(url_format!("foo/{}", bar), "foo/qux");
578    }
579
580    #[test]
581    fn single_var_url() {
582        let bar = "qux";
583        assert_eq!(url_format!("foo/{bar}"), "foo/qux");
584    }
585
586    // compilation error when wrong var name given
587    //    #[test]
588    //    fn wrong_single_var_url() {
589    //        let bar="qux";
590    //        assert_eq!(url_format!("foo/{baz}", bar), "foo/{}");
591    //    }
592
593    #[test]
594    fn multi_var_url() {
595        let bar = "qux";
596        let baz = "quz";
597        assert_eq!(
598            url_format!("foo/{}/fuu/{baz}", bar),
599            "foo/qux/fuu/quz"
600        );
601    }
602
603    #[test]
604    fn empty_query_url() {
605        let bar = Option::<String>::None;
606        assert_eq!(url_format!("foo", #[query] bar), "foo");
607    }
608
609    #[test]
610    #[rustfmt::skip]
611    fn single_query_url() {
612        let bar= Some("qux");
613        assert_eq!(url_format!("foo", #[query] bar), "foo?bar=qux");
614    }
615
616    #[test]
617    fn mix_query_url() {
618        let bar = Option::<String>::None;
619        let baz = Some("quz");
620        assert_eq!(url_format!("foo", #[query] bar, #[query] baz), "foo?baz=quz");
621    }
622
623    #[test]
624    fn multi_query_url() {
625        let bar = Some("qux");
626        let baz = Some("quz");
627        assert_eq!(url_format!("foo", #[query] bar, #[query] baz), "foo?bar=qux&baz=quz");
628    }
629
630    #[test]
631    fn multi_var_and_query_url() {
632        let bar = "baara";
633        let baz = 0;
634        let qar = Some(true);
635        let qaz = Some(3);
636        assert_eq!(
637            url_format!(
638                "foo/{bar}/fuu/{baz}",
639                #[query] qar,
640                #[query] qaz
641            ),
642            "foo/baara/fuu/0?qar=true&qaz=3"
643        );
644    }
645
646    async fn verify_stream<S, F>(f: F) -> anyhow::Result<()>
647    where
648        S: Stream<Item = std::result::Result<Bytes, Error>> + Unpin + 'static,
649        F: Fn(&'static str) -> EventStream<S, Error>,
650    {
651        let src = r#"
652:ping
653event: stdout
654data: some
655data: output
656id: 1
657
658:ping
659
660event: stderr
661data:
662id: 2
663
664event: stdout
665data: 0
666id
667
668"#;
669        let stream = f(src);
670        let events = stream.collect::<Vec<_>>().await;
671
672        assert_eq!(events.len(), 4);
673        let mut iter = events.into_iter();
674
675        let event = iter.next().unwrap()?;
676        assert_eq!(event.event, "stdout".to_string());
677        assert_eq!(event.data, "some\noutput".to_string());
678        assert_eq!(event.id, Some(1));
679
680        assert!(iter.next().unwrap().is_err());
681
682        let event = iter.next().unwrap()?;
683        assert_eq!(event.event, "stderr".to_string());
684        assert_eq!(event.data, "".to_string());
685        assert_eq!(event.id, Some(2));
686
687        let event = iter.next().unwrap()?;
688        assert_eq!(event.event, "stdout".to_string());
689        assert_eq!(event.data, "0".to_string());
690        assert_eq!(event.id, None);
691
692        Ok(())
693    }
694
695    #[actix_rt::test]
696    async fn event_stream() {
697        verify_stream(|s| {
698            let stream = futures::stream::once(async move { Ok::<_, Error>(Bytes::from(s.to_string().into_bytes()))}.boxed_local());
699            EventStream::new(stream)
700        }).await.unwrap();
701
702        verify_stream(|s| {
703            let stream = futures::stream::iter(s.as_bytes()).chunks(5).map(|v| {
704                Ok::<_, Error>(Bytes::from(v.iter().map(|b| **b).collect::<Vec<_>>()))
705            });
706            EventStream::new(stream)
707        }).await.unwrap();
708
709        verify_stream(|s| {
710            let stream = futures::stream::iter(s.as_bytes()).chunks(1).map(|v| {
711                Ok::<_, Error>(Bytes::from(v.iter().map(|b| **b).collect::<Vec<_>>()))
712            });
713            EventStream::new(stream)
714        }).await.unwrap();
715    }
716}