pdk_classy/
client.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! HTTP client to make requests with.
6
7use http::uri::Scheme;
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::fmt::{Debug, Display, Formatter};
11use std::str::FromStr;
12use std::{
13    convert::Infallible,
14    fmt,
15    future::Future,
16    marker::PhantomData,
17    rc::Rc,
18    task::{Poll, Waker},
19    time::Duration,
20};
21
22use crate::proxy_wasm::types::{Bytes, Status};
23use serde::de::StdError;
24
25use crate::http_constants::{
26    DEFAULT_TIMEOUT, HEADER_AUTHORITY, HEADER_METHOD, HEADER_PATH, HEADER_SCHEME, HEADER_STATUS,
27    METHOD_DELETE, METHOD_GET, METHOD_OPTIONS, METHOD_POST, METHOD_PUT, USER_AGENT_HEADER,
28};
29use crate::user_agent::UserAgent;
30use crate::{
31    extract::{Extract, FromContext},
32    host::Host,
33    reactor::root::{BoxedExtractor, RootReactor},
34    types::{Cid, RequestId},
35};
36
37#[derive(Clone, Debug, PartialEq, Eq)]
38/// The response of an HTTP call.
39pub struct HttpCallResponse {
40    pub request_id: RequestId,
41    pub num_headers: usize,
42    pub body_size: usize,
43    pub num_trailers: usize,
44}
45
46/// An asynchronous HTTP client to make Requests with.
47pub struct HttpClient {
48    reactor: Rc<RootReactor>,
49    host: Rc<dyn Host>,
50    user_agent: Rc<UserAgent>,
51}
52
53/// The Errors that may occur when processing a Request.
54#[derive(thiserror::Error, Debug, Clone)]
55pub enum HttpClientError {
56    /// Proxy status problem.
57    #[error("Proxy status problem: {0:?}")]
58    Status(Status),
59
60    /// Request awaited on create context event.
61    #[error("Request awaited on create context event")]
62    AwaitedOnCreateContext,
63}
64
65impl HttpClient {
66    pub(crate) fn new(
67        reactor: Rc<RootReactor>,
68        host: Rc<dyn Host>,
69        user_agent: Rc<UserAgent>,
70    ) -> Self {
71        Self {
72            reactor,
73            host,
74            user_agent,
75        }
76    }
77
78    /// Creates a request that will forward the whole response to the caller.
79    ///
80    /// Note: If you want to avoid reading body buffers in certain situations see the `extract_with` method.
81    pub fn request<'a>(
82        &'a self,
83        service: &'a Service,
84    ) -> RequestBuilder<'a, DefaultResponseExtractor> {
85        RequestBuilder::new(self, service, DefaultResponseExtractor)
86    }
87}
88
89impl<C> FromContext<C> for HttpClient
90where
91    Rc<dyn Host>: FromContext<C, Error = Infallible>,
92    Rc<RootReactor>: FromContext<C, Error = Infallible>,
93{
94    type Error = Infallible;
95
96    fn from_context(context: &C) -> Result<Self, Self::Error> {
97        let reactor = context.extract()?;
98        let host = context.extract()?;
99        let agent = context.extract()?;
100        Ok(Self::new(reactor, host, agent))
101    }
102}
103
104/// The request to be sent to the backend.
105pub struct Request<T> {
106    reactor: Rc<RootReactor>,
107    request_id: RequestId,
108    cid_and_waker: Option<(Cid, Waker)>,
109    error: Option<HttpClientError>,
110    _response_type: PhantomData<T>,
111}
112
113/// An accessor for response parts.
114pub trait ResponseBuffers {
115    /// Returns the status code from a Response.
116    fn status_code(&self) -> u32;
117
118    /// Returns a header value by name if exists.
119    /// Known Limitations: The header value will be converted to an utf-8 String
120    /// If the bytes correspond to a non utf-8 string they will be parsed as an iso_8859_1 encoding.
121    fn header(&self, name: &str) -> Option<String>;
122
123    /// Returns a [`Vec`] containing all pairs of header names and their values.
124    /// Known Limitations: The header values will be converted to utf-8 Strings
125    /// If the bytes correspond to a non utf-8 string they will be parsed as an iso_8859_1 encoding.
126    fn headers(&self) -> Vec<(String, String)>;
127
128    /// Returns the body in binary format.
129    fn body(&self, start: usize, max_size: usize) -> Option<Bytes>;
130
131    /// Returns a [`Vec`] containing all pairs of trailer names and their values.
132    fn trailers(&self) -> Vec<(String, String)>;
133}
134
135impl ResponseBuffers for Rc<dyn Host> {
136    fn status_code(&self) -> u32 {
137        self.header(HEADER_STATUS)
138            .and_then(|status| status.parse::<u32>().ok())
139            .unwrap_or_default()
140    }
141
142    fn header(&self, name: &str) -> Option<String> {
143        self.get_http_call_response_header(name)
144    }
145
146    fn headers(&self) -> Vec<(String, String)> {
147        self.get_http_call_response_headers()
148    }
149
150    fn body(&self, start: usize, max_size: usize) -> Option<Bytes> {
151        self.get_http_call_response_body(start, max_size)
152    }
153
154    fn trailers(&self) -> Vec<(String, String)> {
155        self.get_http_call_response_trailers()
156    }
157}
158
159/// A low-level trait for extracting a Response and convert it to a
160/// [`ResponseExtractor::Output`] type.
161pub trait ResponseExtractor {
162    /// The output type
163    type Output;
164
165    /// Extracts the Response from their low-level components
166    ///
167    fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output;
168}
169
170/// A function to extract only the necessary data from Response. For example, using this you could ignore
171/// body if you don't need them.
172pub struct FnResponseExtractor<F> {
173    function: F,
174}
175
176impl<F, T> ResponseExtractor for FnResponseExtractor<F>
177where
178    F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
179{
180    type Output = T;
181
182    fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output {
183        (self.function)(event, buffers)
184    }
185}
186
187impl<F, T> FnResponseExtractor<F>
188where
189    F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
190{
191    pub fn from_fn(function: F) -> FnResponseExtractor<F>
192    where
193        F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
194    {
195        FnResponseExtractor { function }
196    }
197}
198
199/// A builder for a request to be sent to the upstream.
200pub struct RequestBuilder<'a, E> {
201    client: &'a HttpClient,
202    extractor: E,
203    service: &'a Service,
204    path: Option<&'a str>,
205    headers: Option<Vec<(&'a str, &'a str)>>,
206    body: Option<&'a [u8]>,
207    trailers: Option<Vec<(&'a str, &'a str)>>,
208    timeout: Option<Duration>,
209}
210
211impl<'a, E> RequestBuilder<'a, E>
212where
213    E: ResponseExtractor + 'static,
214    E::Output: 'static,
215{
216    fn new(client: &'a HttpClient, service: &'a Service, extractor: E) -> Self {
217        Self {
218            client,
219            extractor,
220            service,
221            path: None,
222            headers: None,
223            body: None,
224            trailers: None,
225            timeout: None,
226        }
227    }
228
229    /// Sets the extractor to be used to extract only the necessary data from Response.
230    pub fn extractor<T>(self, extractor: T) -> RequestBuilder<'a, T>
231    where
232        T: ResponseExtractor,
233    {
234        RequestBuilder {
235            client: self.client,
236            extractor,
237            service: self.service,
238            path: self.path,
239            headers: self.headers,
240            body: self.body,
241            trailers: self.trailers,
242            timeout: self.timeout,
243        }
244    }
245
246    /// Sets the extractor to be used to extract only the necessary data from Response.
247    pub fn extract_with<F, T>(self, function: F) -> RequestBuilder<'a, FnResponseExtractor<F>>
248    where
249        F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
250    {
251        self.extractor(FnResponseExtractor::from_fn(function))
252    }
253
254    /// Sets the path to be used in the request.
255    pub fn path(mut self, path: &'a str) -> Self {
256        self.path = Some(path);
257        self
258    }
259
260    /// Sets the headers to be used in the request.
261    pub fn headers(mut self, headers: Vec<(&'a str, &'a str)>) -> Self {
262        self.headers = Some(headers);
263        self
264    }
265
266    /// Sets the body to be used in the request.
267    pub fn body(mut self, body: &'a [u8]) -> Self {
268        self.body = Some(body);
269        self
270    }
271
272    /// Sets the trailers to be used in the request.
273    pub fn trailers(mut self, trailers: Vec<(&'a str, &'a str)>) -> Self {
274        self.trailers = Some(trailers);
275        self
276    }
277
278    /// Sets the timeout for the request.
279    pub fn timeout(mut self, timeout: Duration) -> Self {
280        self.timeout = Some(timeout);
281        self
282    }
283
284    /// Executes the request using the POST method.
285    pub fn post(self) -> Request<E::Output> {
286        self.send(METHOD_POST)
287    }
288    /// Executes the request using the PUT method.
289    pub fn put(self) -> Request<E::Output> {
290        self.send(METHOD_PUT)
291    }
292    /// Executes the request using the GET method.
293    pub fn get(self) -> Request<E::Output> {
294        self.send(METHOD_GET)
295    }
296    /// Executes the request using the OPTIONS method.
297    pub fn options(self) -> Request<E::Output> {
298        self.send(METHOD_OPTIONS)
299    }
300    /// Executes the request using the DELETE method.
301    pub fn delete(self) -> Request<E::Output> {
302        self.send(METHOD_DELETE)
303    }
304
305    #[must_use]
306    /// Executes the request using the provided method.
307    pub fn send(mut self, method: &str) -> Request<E::Output> {
308        let mut headers = self.headers.take().unwrap_or_default();
309
310        headers.push((HEADER_PATH, self.path.unwrap_or(self.service.uri().path())));
311        headers.push((HEADER_AUTHORITY, self.service.uri().authority()));
312        headers.push((HEADER_METHOD, method));
313        headers.push((USER_AGENT_HEADER, self.client.user_agent.value()));
314        headers.push((HEADER_SCHEME, self.service.uri().scheme()));
315
316        let body = self.body.take();
317        let trailers = self.trailers.take().unwrap_or_default();
318        let timeout = self.timeout.take().unwrap_or(DEFAULT_TIMEOUT);
319
320        match self.client.host.dispatch_http_call(
321            self.service.cluster_name(),
322            headers,
323            body,
324            trailers,
325            timeout,
326        ) {
327            Ok(request_id) => {
328                let request_id: RequestId = request_id.into();
329                let extractor = boxed_extractor(self.client.host.clone(), self.extractor);
330                self.client.reactor.insert_extractor(request_id, extractor);
331                Request::new(self.client.reactor.clone(), request_id)
332            }
333            Err(err) => Request::error(self.client.reactor.clone(), HttpClientError::Status(err)),
334        }
335    }
336}
337
338impl<E: ResponseExtractor> ResponseExtractor for RequestBuilder<'_, E> {
339    type Output = E::Output;
340
341    fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output {
342        self.extractor.extract(event, buffers)
343    }
344}
345
346fn boxed_extractor<E>(buffers: Rc<dyn Host>, extractor: E) -> BoxedExtractor
347where
348    E: ResponseExtractor + 'static,
349    E::Output: 'static,
350{
351    Box::new(move |event| Box::new(extractor.extract(event, &buffers)))
352}
353
354/// A default implementation of [`ResponseExtractor`] which ignores the response.
355pub struct EmptyResponseExtractor;
356
357impl ResponseExtractor for EmptyResponseExtractor {
358    type Output = ();
359
360    fn extract(self, _event: &HttpCallResponse, _buffers: &dyn ResponseBuffers) -> Self::Output {}
361}
362
363impl<T> Request<T> {
364    fn new(reactor: Rc<RootReactor>, request_id: RequestId) -> Self {
365        Request {
366            reactor,
367            request_id,
368            error: None,
369            cid_and_waker: None,
370            _response_type: PhantomData,
371        }
372    }
373
374    fn error(reactor: Rc<RootReactor>, error: HttpClientError) -> Self {
375        Request {
376            reactor,
377            request_id: RequestId::from(0),
378            error: Some(error),
379            cid_and_waker: None,
380            _response_type: PhantomData,
381        }
382    }
383
384    pub fn id(&self) -> RequestId {
385        self.request_id
386    }
387}
388
389impl<T> Drop for Request<T> {
390    fn drop(&mut self) {
391        if self.error.is_none() {
392            let reactor = self.reactor.as_ref();
393
394            // Ensure that all related objects were removed
395            reactor.remove_extractor(self.request_id);
396            reactor.remove_response(self.request_id);
397            reactor.remove_client(self.request_id);
398        }
399    }
400}
401
402impl<T: Unpin + 'static> Future for Request<T> {
403    type Output = Result<T, HttpClientError>;
404
405    fn poll(
406        mut self: std::pin::Pin<&mut Self>,
407        cx: &mut std::task::Context<'_>,
408    ) -> Poll<Self::Output> {
409        if let Some(error) = self.error.clone() {
410            return Poll::Ready(Err(error));
411        }
412
413        if let Some((_event, content)) = self.reactor.remove_response(self.request_id) {
414            // It should be safe to unwrap here
415            let content = content.expect("response content should have been extracted");
416
417            // It should be safe to unwrap here
418            let content = content.downcast().expect("downcasting");
419
420            Poll::Ready(Ok(*content))
421        } else {
422            let this = &mut *self.as_mut();
423            match this.cid_and_waker.as_ref() {
424                None => {
425                    let cid = this.reactor.active_cid();
426
427                    // Register the waker in the reactor.
428                    this.reactor
429                        .insert_client(this.request_id, cx.waker().clone());
430                    this.reactor.set_paused(cid, true);
431                    this.cid_and_waker = Some((cid, cx.waker().clone()));
432                }
433                Some((cid, waker)) if !waker.will_wake(cx.waker()) => {
434                    // Deregister the waker from the reactor to remove the old waker.
435                    let _ = this
436                        .reactor
437                        .remove_client(this.request_id)
438                        // It should be safe to unwrap here
439                        .expect("stored extractor");
440
441                    // Register the waker in the reactor with the new waker.
442                    this.reactor
443                        .insert_client(this.request_id, cx.waker().clone());
444                    this.cid_and_waker = Some((*cid, cx.waker().clone()));
445                }
446                Some(_) => {}
447            }
448            Poll::Pending
449        }
450    }
451}
452
453/// A default implementation of [`ResponseExtractor`].
454pub struct DefaultResponseExtractor;
455
456impl ResponseExtractor for DefaultResponseExtractor {
457    type Output = HttpClientResponse;
458
459    fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output {
460        let mut map = HashMap::new();
461        for (k, v) in buffers.headers().into_iter() {
462            match map.entry(k) {
463                Entry::Vacant(e) => {
464                    e.insert(v);
465                }
466                Entry::Occupied(mut e) => {
467                    e.insert(format!("{},{}", e.get(), v));
468                }
469            }
470        }
471
472        let body = buffers.body(0, event.body_size).unwrap_or_default();
473
474        HttpClientResponse::new(map, body)
475    }
476}
477
478/// The response for the default [`DefaultResponseExtractor`].
479#[derive(Debug)]
480pub struct HttpClientResponse {
481    headers: HashMap<String, String>,
482    body: Bytes,
483}
484
485impl HttpClientResponse {
486    pub fn new(headers: HashMap<String, String>, body: Bytes) -> Self {
487        Self { headers, body }
488    }
489
490    /// Returns the status code.
491    pub fn status_code(&self) -> u32 {
492        self.header(HEADER_STATUS)
493            .and_then(|status| status.parse::<u32>().ok())
494            .unwrap_or_default()
495    }
496
497    /// Returns a list of headers grouped by name and value.
498    pub fn headers(&self) -> &HashMap<String, String> {
499        &self.headers
500    }
501
502    /// Returns a header by name if exists.
503    pub fn header(&self, header: &str) -> Option<&String> {
504        self.headers.get(header)
505    }
506
507    /// Returns the body in binary format.
508    pub fn body(&self) -> &[u8] {
509        self.body.as_slice()
510    }
511
512    /// Returns the body in [`String`] format.
513    pub fn as_utf8_lossy(&self) -> String {
514        String::from_utf8_lossy(&self.body).to_string()
515    }
516}
517
518/// Represents an invalid URI error.
519pub struct InvalidUri(InvalidUriKind);
520
521enum InvalidUriKind {
522    Delegate(http::uri::InvalidUri),
523    MissingAuthority,
524    InvalidSchema,
525}
526
527impl Display for InvalidUri {
528    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
529        match &self.0 {
530            InvalidUriKind::Delegate(d) => Display::fmt(d, f),
531            InvalidUriKind::MissingAuthority => Display::fmt("authority missing", f),
532            InvalidUriKind::InvalidSchema => Display::fmt("scheme not supported", f),
533        }
534    }
535}
536
537impl Debug for InvalidUri {
538    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
539        Display::fmt(&self, f)
540    }
541}
542
543impl StdError for InvalidUri {}
544
545/// Represents a URI.
546#[derive(Clone, Debug)]
547pub struct Uri {
548    delegate: http::Uri,
549}
550
551impl Uri {
552    /// Returns the URI path.
553    pub fn path(&self) -> &str {
554        self.delegate
555            .path_and_query()
556            .map(|path_and_query| path_and_query.as_str())
557            .unwrap_or_else(|| self.delegate.path())
558    }
559
560    /// Returns the URI schema.
561    pub fn scheme(&self) -> &str {
562        // The unwrap should never take effect since we don't allow construction without scheme
563        self.delegate.scheme_str().unwrap_or_default()
564    }
565
566    /// Returns the URI authority.
567    pub fn authority(&self) -> &str {
568        // The unwrap should never take effect since we don't allow construction without authority
569        self.delegate
570            .authority()
571            .map(|authority| authority.as_str())
572            .unwrap_or_default()
573    }
574}
575
576impl Display for Uri {
577    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
578        f.write_fmt(format_args!("{}", self.delegate.to_string().as_str()))
579    }
580}
581
582impl FromStr for Uri {
583    type Err = InvalidUri;
584
585    fn from_str(s: &str) -> Result<Self, Self::Err> {
586        match s.parse::<http::Uri>() {
587            Ok(delegate) => {
588                if delegate.authority().is_none() {
589                    return Err(InvalidUri(InvalidUriKind::MissingAuthority));
590                }
591
592                if delegate
593                    .scheme()
594                    .map(|s| {
595                        !s.eq(&Scheme::HTTP)
596                            && !s.eq(&Scheme::HTTPS)
597                            && !s.as_str().eq_ignore_ascii_case("h2")
598                    })
599                    .unwrap_or(true)
600                {
601                    return Err(InvalidUri(InvalidUriKind::InvalidSchema));
602                }
603
604                Ok(Self { delegate })
605            }
606            Err(e) => Err(InvalidUri(InvalidUriKind::Delegate(e))),
607        }
608    }
609}
610
611#[derive(Clone, Debug)]
612/// Represents the upstream to be called.
613pub struct Service {
614    cluster_name: String,
615    uri: Uri,
616}
617
618impl Service {
619    pub fn from<'a>(name: &'a str, namespace: &'a str, uri: Uri) -> Service {
620        let cluster_name = format!("{name}.{namespace}.svc");
621        Service { cluster_name, uri }
622    }
623
624    pub fn new(cluster_name: &str, uri: Uri) -> Service {
625        Service {
626            cluster_name: cluster_name.to_string(),
627            uri,
628        }
629    }
630
631    /// The name of the cluster where the request will be forwarded to
632    pub fn cluster_name(&self) -> &str {
633        self.cluster_name.as_str()
634    }
635
636    /// The URI of the upstream to be called.
637    pub fn uri(&self) -> &Uri {
638        &self.uri
639    }
640}
641
642#[cfg(test)]
643mod test {
644    use super::Uri;
645
646    #[test]
647    fn successfully_parse_http() {
648        assert!("http://some.com/foo?some=val".parse::<Uri>().is_ok());
649    }
650
651    #[test]
652    fn successfully_parse_https() {
653        assert!("https://some.com/foo".parse::<Uri>().is_ok());
654    }
655
656    #[test]
657    fn successfully_parse_h2() {
658        assert!("h2://some.com/foo".parse::<Uri>().is_ok());
659    }
660
661    #[test]
662    fn error_invalid_scheme() {
663        assert!("ftp://some.com/foo".parse::<Uri>().is_err());
664    }
665
666    #[test]
667    fn error_on_missing_scheme() {
668        assert!("some.com/foo".parse::<Uri>().is_err());
669    }
670
671    #[test]
672    fn error_on_missing_host() {
673        assert!("/foo".parse::<Uri>().is_err());
674    }
675}