Skip to main content

progenitor_middleware_client/
progenitor_middleware_client.rs

1// Copyright 2025 Oxide Computer Company
2
3#![allow(dead_code)]
4
5//! Support code for generated clients.
6
7use std::ops::{Deref, DerefMut};
8
9use bytes::Bytes;
10use futures_core::Stream;
11use reqwest::RequestBuilder;
12use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize};
13
14#[cfg(not(target_arch = "wasm32"))]
15type InnerByteStream = std::pin::Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send + Sync>>;
16
17#[cfg(target_arch = "wasm32")]
18type InnerByteStream = std::pin::Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>;
19
20/// Untyped byte stream used for both success and error responses.
21pub struct ByteStream(InnerByteStream);
22
23impl ByteStream {
24    /// Creates a new ByteStream
25    ///
26    /// Useful for generating test fixtures.
27    pub fn new(inner: InnerByteStream) -> Self {
28        Self(inner)
29    }
30
31    /// Consumes the [`ByteStream`] and return its inner [`Stream`].
32    pub fn into_inner(self) -> InnerByteStream {
33        self.0
34    }
35}
36
37impl Deref for ByteStream {
38    type Target = InnerByteStream;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45impl DerefMut for ByteStream {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        &mut self.0
48    }
49}
50
51/// Interface for which an implementation is generated for all clients.
52pub trait ClientInfo<Inner> {
53    /// Get the version of this API.
54    ///
55    /// This string is pulled directly from the source OpenAPI document and may
56    /// be in any format the API selects.
57    fn api_version() -> &'static str;
58
59    /// Get the base URL to which requests are made.
60    fn baseurl(&self) -> &str;
61
62    /// Get the internal `reqwest_middleware::ClientWithMiddleware` used to make requests.
63    fn client(&self) -> &reqwest_middleware::ClientWithMiddleware;
64
65    /// Get the inner value of type `T` if one is specified.
66    fn inner(&self) -> &Inner;
67}
68
69impl<T, Inner> ClientInfo<Inner> for &T
70where
71    T: ClientInfo<Inner>,
72{
73    fn api_version() -> &'static str {
74        T::api_version()
75    }
76
77    fn baseurl(&self) -> &str {
78        (*self).baseurl()
79    }
80
81    fn client(&self) -> &reqwest_middleware::ClientWithMiddleware {
82        (*self).client()
83    }
84
85    fn inner(&self) -> &Inner {
86        (*self).inner()
87    }
88}
89
90/// Information about an operation, consumed by hook implementations.
91pub struct OperationInfo {
92    /// The corresponding operationId from the source OpenAPI document.
93    pub operation_id: &'static str,
94}
95
96/// Interface for changing the behavior of generated clients. All clients
97/// implement this for `&Client`; to override the default behavior, implement
98/// some or all of the interfaces for the `Client` type (without the
99/// reference). This mechanism relies on so-called "auto-ref specialization".
100#[allow(async_fn_in_trait, unused)]
101pub trait ClientHooks<Inner = ()>
102where
103    Self: ClientInfo<Inner>,
104{
105    /// Runs prior to the execution of the request. This may be used to modify
106    /// the request before it is transmitted.
107    async fn pre<E>(
108        &self,
109        request: &mut reqwest::Request,
110        info: &OperationInfo,
111    ) -> std::result::Result<(), Error<E>> {
112        Ok(())
113    }
114
115    /// Runs after completion of the request.
116    async fn post<E>(
117        &self,
118        result: &Result<reqwest::Response, reqwest_middleware::Error>,
119        info: &OperationInfo,
120    ) -> std::result::Result<(), Error<E>> {
121        Ok(())
122    }
123
124    /// Execute the request. Note that for almost any reasonable implementation
125    /// this will include code equivalent to this:
126    /// ```
127    /// # use progenitor_middleware_client::{ClientHooks, ClientInfo, OperationInfo};
128    /// # struct X;
129    /// # impl ClientInfo<()> for X {
130    /// #   fn api_version() -> &'static str { panic!() }
131    /// #   fn baseurl(&self) -> &str { panic!() }
132    /// #   fn client(&self) -> &reqwest::Client { panic!() }
133    /// #   fn inner(&self) -> &() { panic!() }
134    /// # }
135    /// # impl ClientHooks for X {
136    /// #   async fn exec(
137    /// #       &self,
138    /// #       request: reqwest::Request,
139    /// #       info: &OperationInfo,
140    /// #   ) -> reqwest::Result<reqwest::Response> {
141    ///         self.client().execute(request).await
142    /// #   }
143    /// # }
144    /// ```
145    async fn exec(
146        &self,
147        request: reqwest::Request,
148        info: &OperationInfo,
149    ) -> Result<reqwest::Response, reqwest_middleware::Error> {
150        self.client().execute(request).await
151    }
152}
153
154/// Typed value returned by generated client methods.
155///
156/// This is used for successful responses and may appear in error responses
157/// generated from the server (see [`Error::ErrorResponse`])
158pub struct ResponseValue<T> {
159    inner: T,
160    status: http::StatusCode,
161    headers: http::HeaderMap,
162    // TODO cookies?
163}
164
165impl<T: DeserializeOwned> ResponseValue<T> {
166    #[doc(hidden)]
167    pub async fn from_response<E>(response: reqwest::Response) -> Result<Self, Error<E>> {
168        let status = response.status();
169        let headers = response.headers().clone();
170        let full = response.bytes().await.map_err(|e| Error::ResponseBodyError(e.into()))?;
171        let inner =
172            serde_json::from_slice(&full).map_err(|e| Error::InvalidResponsePayload(full, e))?;
173
174        Ok(Self {
175            inner,
176            status,
177            headers,
178        })
179    }
180}
181
182#[cfg(not(target_arch = "wasm32"))]
183impl ResponseValue<reqwest::Upgraded> {
184    #[doc(hidden)]
185    pub async fn upgrade<E: std::fmt::Debug>(
186        response: reqwest::Response,
187    ) -> Result<Self, Error<E>> {
188        let status = response.status();
189        let headers = response.headers().clone();
190        if status == http::StatusCode::SWITCHING_PROTOCOLS {
191            let inner = response.upgrade().await.map_err(|e| Error::InvalidUpgrade(e.into()))?;
192
193            Ok(Self {
194                inner,
195                status,
196                headers,
197            })
198        } else {
199            Err(Error::UnexpectedResponse(response))
200        }
201    }
202}
203
204impl ResponseValue<ByteStream> {
205    #[doc(hidden)]
206    pub fn stream(response: reqwest::Response) -> Self {
207        let status = response.status();
208        let headers = response.headers().clone();
209        Self {
210            inner: ByteStream(Box::pin(response.bytes_stream())),
211            status,
212            headers,
213        }
214    }
215}
216
217impl ResponseValue<()> {
218    #[doc(hidden)]
219    pub fn empty(response: reqwest::Response) -> Self {
220        let status = response.status();
221        let headers = response.headers().clone();
222        // TODO is there anything we want to do to confirm that there is no
223        // content?
224        Self {
225            inner: (),
226            status,
227            headers,
228        }
229    }
230}
231
232impl<T> ResponseValue<T> {
233    /// Creates a [`ResponseValue`] from the inner type, status, and headers.
234    ///
235    /// Useful for generating test fixtures.
236    pub fn new(inner: T, status: http::StatusCode, headers: http::HeaderMap) -> Self {
237        Self {
238            inner,
239            status,
240            headers,
241        }
242    }
243
244    /// Consumes the ResponseValue, returning the wrapped value.
245    pub fn into_inner(self) -> T {
246        self.inner
247    }
248
249    /// Gets the status from this response.
250    pub fn status(&self) -> http::StatusCode {
251        self.status
252    }
253
254    /// Gets the headers from this response.
255    pub fn headers(&self) -> &http::HeaderMap {
256        &self.headers
257    }
258
259    /// Gets the parsed value of the Content-Length header, if present and
260    /// valid.
261    pub fn content_length(&self) -> Option<u64> {
262        self.headers
263            .get(http::header::CONTENT_LENGTH)?
264            .to_str()
265            .ok()?
266            .parse::<u64>()
267            .ok()
268    }
269
270    #[doc(hidden)]
271    pub fn map<U: std::fmt::Debug, F, E>(self, f: F) -> Result<ResponseValue<U>, E>
272    where
273        F: FnOnce(T) -> U,
274    {
275        let Self {
276            inner,
277            status,
278            headers,
279        } = self;
280
281        Ok(ResponseValue {
282            inner: f(inner),
283            status,
284            headers,
285        })
286    }
287}
288
289impl ResponseValue<ByteStream> {
290    /// Consumes the `ResponseValue`, returning the wrapped [`Stream`].
291    pub fn into_inner_stream(self) -> InnerByteStream {
292        self.into_inner().into_inner()
293    }
294}
295
296impl<T> Deref for ResponseValue<T> {
297    type Target = T;
298
299    fn deref(&self) -> &Self::Target {
300        &self.inner
301    }
302}
303
304impl<T> DerefMut for ResponseValue<T> {
305    fn deref_mut(&mut self) -> &mut Self::Target {
306        &mut self.inner
307    }
308}
309
310impl<T> AsRef<T> for ResponseValue<T> {
311    fn as_ref(&self) -> &T {
312        &self.inner
313    }
314}
315
316impl<T: std::fmt::Debug> std::fmt::Debug for ResponseValue<T> {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        self.inner.fmt(f)
319    }
320}
321
322/// Error produced by generated client methods.
323///
324/// The type parameter may be a struct if there's a single expected error type
325/// or an enum if there are multiple valid error types. It can be the unit type
326/// if there are no structured returns expected.
327pub enum Error<E = ()> {
328    /// The request did not conform to API requirements.
329    InvalidRequest(String),
330
331    /// A server error either due to the data, or with the connection.
332    CommunicationError(reqwest_middleware::Error),
333
334    /// An expected response when upgrading connection.
335    InvalidUpgrade(reqwest_middleware::Error),
336
337    /// A documented, expected error response.
338    ErrorResponse(ResponseValue<E>),
339
340    /// Encountered an error reading the body for an expected response.
341    ResponseBodyError(reqwest_middleware::Error),
342
343    /// An expected response code whose deserialization failed.
344    InvalidResponsePayload(Bytes, serde_json::Error),
345
346    /// A response not listed in the API description. This may represent a
347    /// success or failure response; check `status().is_success()`.
348    UnexpectedResponse(reqwest::Response),
349
350    /// A custom error from a consumer-defined hook.
351    Custom(String),
352}
353
354impl<E> Error<E> {
355    /// Returns the status code, if the error was generated from a response.
356    pub fn status(&self) -> Option<http::StatusCode> {
357        match self {
358            Error::InvalidRequest(_) => None,
359            Error::Custom(_) => None,
360            Error::CommunicationError(e) => e.status(),
361            Error::ErrorResponse(rv) => Some(rv.status()),
362            Error::InvalidUpgrade(e) => e.status(),
363            Error::ResponseBodyError(e) => e.status(),
364            Error::InvalidResponsePayload(_, _) => None,
365            Error::UnexpectedResponse(r) => Some(r.status()),
366        }
367    }
368
369    /// Converts this error into one without a typed body.
370    ///
371    /// This is useful for unified error handling with APIs that distinguish
372    /// various error response bodies.
373    pub fn into_untyped(self) -> Error {
374        match self {
375            Error::InvalidRequest(s) => Error::InvalidRequest(s),
376            Error::Custom(s) => Error::Custom(s),
377            Error::CommunicationError(e) => Error::CommunicationError(e),
378            Error::ErrorResponse(ResponseValue {
379                inner: _,
380                status,
381                headers,
382            }) => Error::ErrorResponse(ResponseValue {
383                inner: (),
384                status,
385                headers,
386            }),
387            Error::InvalidUpgrade(e) => Error::InvalidUpgrade(e),
388            Error::ResponseBodyError(e) => Error::ResponseBodyError(e),
389            Error::InvalidResponsePayload(b, e) => Error::InvalidResponsePayload(b, e),
390            Error::UnexpectedResponse(r) => Error::UnexpectedResponse(r),
391        }
392    }
393}
394
395impl<E> From<std::convert::Infallible> for Error<E> {
396    fn from(x: std::convert::Infallible) -> Self {
397        match x {}
398    }
399}
400
401impl<E> From<reqwest_middleware::Error> for Error<E> {
402    fn from(e: reqwest_middleware::Error) -> Self {
403        Self::CommunicationError(e)
404    }
405}
406
407impl<E> From<reqwest::Error> for Error<E> {
408    fn from(e: reqwest::Error) -> Self {
409        Self::CommunicationError(e.into())
410    }
411}
412
413impl<E> From<http::header::InvalidHeaderValue> for Error<E> {
414    fn from(e: http::header::InvalidHeaderValue) -> Self {
415        Self::InvalidRequest(e.to_string())
416    }
417}
418
419impl<E> std::fmt::Display for Error<E>
420where
421    ResponseValue<E>: ErrorFormat,
422{
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        match self {
425            Error::InvalidRequest(s) => {
426                write!(f, "Invalid Request: {}", s)?;
427            }
428            Error::CommunicationError(e) => {
429                write!(f, "Communication Error: {}", e)?;
430            }
431            Error::ErrorResponse(rve) => {
432                write!(f, "Error Response: ")?;
433                rve.fmt_info(f)?;
434            }
435            Error::InvalidUpgrade(e) => {
436                write!(f, "Invalid Response Upgrade: {}", e)?;
437            }
438            Error::ResponseBodyError(e) => {
439                write!(f, "Invalid Response Body Bytes: {}", e)?;
440            }
441            Error::InvalidResponsePayload(b, e) => {
442                write!(f, "Invalid Response Payload ({:?}): {}", b, e)?;
443            }
444            Error::UnexpectedResponse(r) => {
445                write!(f, "Unexpected Response: {:?}", r)?;
446            }
447            Error::Custom(s) => {
448                write!(f, "Error: {}", s)?;
449            }
450        }
451
452        if f.alternate() {
453            use std::error::Error as _;
454
455            let mut src = self.source().and_then(|e| e.source());
456            while let Some(s) = src {
457                write!(f, ": {s}")?;
458                src = s.source();
459            }
460        }
461        Ok(())
462    }
463}
464
465trait ErrorFormat {
466    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
467}
468
469impl<E> ErrorFormat for ResponseValue<E>
470where
471    E: std::fmt::Debug,
472{
473    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474        write!(
475            f,
476            "status: {}; headers: {:?}; value: {:?}",
477            self.status, self.headers, self.inner,
478        )
479    }
480}
481
482impl ErrorFormat for ResponseValue<ByteStream> {
483    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484        write!(
485            f,
486            "status: {}; headers: {:?}; value: <stream>",
487            self.status, self.headers,
488        )
489    }
490}
491
492impl<E> std::fmt::Debug for Error<E>
493where
494    ResponseValue<E>: ErrorFormat,
495{
496    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497        std::fmt::Display::fmt(self, f)
498    }
499}
500impl<E> std::error::Error for Error<E>
501where
502    ResponseValue<E>: ErrorFormat,
503{
504    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
505        match self {
506            Error::CommunicationError(e) => Some(e),
507            Error::InvalidUpgrade(e) => Some(e),
508            Error::ResponseBodyError(e) => Some(e),
509            Error::InvalidResponsePayload(_b, e) => Some(e),
510            _ => None,
511        }
512    }
513}
514
515// See https://url.spec.whatwg.org/#url-path-segment-string
516const PATH_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
517    .add(b' ')
518    .add(b'"')
519    .add(b'#')
520    .add(b'<')
521    .add(b'>')
522    .add(b'?')
523    .add(b'`')
524    .add(b'{')
525    .add(b'}')
526    .add(b'/')
527    .add(b'%');
528
529#[doc(hidden)]
530/// Percent encode input string.
531pub fn encode_path(pc: &str) -> String {
532    percent_encoding::utf8_percent_encode(pc, PATH_SET).to_string()
533}
534
535#[doc(hidden)]
536pub trait RequestBuilderExt<E> {
537    fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<RequestBuilder, Error<E>>;
538}
539
540impl<E> RequestBuilderExt<E> for RequestBuilder {
541    fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<Self, Error<E>> {
542        Ok(self
543            .header(
544                http::header::CONTENT_TYPE,
545                http::header::HeaderValue::from_static("application/x-www-form-urlencoded"),
546            )
547            .body(
548                serde_urlencoded::to_string(body)
549                    .map_err(|_| Error::InvalidRequest("failed to serialize body".to_string()))?,
550            ))
551    }
552}
553
554#[doc(hidden)]
555pub struct QueryParam<'a, T> {
556    name: &'a str,
557    value: &'a T,
558}
559
560impl<'a, T> QueryParam<'a, T> {
561    #[doc(hidden)]
562    pub fn new(name: &'a str, value: &'a T) -> Self {
563        Self { name, value }
564    }
565}
566impl<T> Serialize for QueryParam<'_, T>
567where
568    T: Serialize,
569{
570    fn serialize<S>(&self, inner: S) -> Result<S::Ok, S::Error>
571    where
572        S: serde::Serializer,
573    {
574        let serializer = QuerySerializer {
575            inner,
576            name: self.name,
577        };
578        self.value.serialize(serializer)
579    }
580}
581
582pub(crate) struct QuerySerializer<'a, S> {
583    inner: S,
584    name: &'a str,
585}
586
587macro_rules! serialize_scalar {
588    ($f:ident, $t:ty) => {
589        fn $f(self, v: $t) -> Result<Self::Ok, Self::Error> {
590            [(self.name, v)].serialize(self.inner)
591        }
592    };
593}
594
595impl<'a, S> serde::Serializer for QuerySerializer<'a, S>
596where
597    S: serde::Serializer,
598{
599    type Ok = S::Ok;
600    type Error = S::Error;
601    type SerializeSeq = QuerySeq<'a, S::SerializeSeq>;
602    type SerializeTuple = S::SerializeTuple;
603    type SerializeTupleStruct = S::SerializeTupleStruct;
604    type SerializeTupleVariant = S::SerializeTupleVariant;
605    type SerializeMap = S::SerializeMap;
606    type SerializeStruct = S::SerializeStruct;
607    type SerializeStructVariant = S::SerializeStructVariant;
608
609    serialize_scalar!(serialize_bool, bool);
610    serialize_scalar!(serialize_i8, i8);
611    serialize_scalar!(serialize_i16, i16);
612    serialize_scalar!(serialize_i32, i32);
613    serialize_scalar!(serialize_i64, i64);
614    serialize_scalar!(serialize_u8, u8);
615    serialize_scalar!(serialize_u16, u16);
616    serialize_scalar!(serialize_u32, u32);
617    serialize_scalar!(serialize_u64, u64);
618    serialize_scalar!(serialize_f32, f32);
619    serialize_scalar!(serialize_f64, f64);
620    serialize_scalar!(serialize_char, char);
621    serialize_scalar!(serialize_str, &str);
622
623    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
624        self.inner.serialize_bytes(v)
625    }
626
627    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
628        self.inner.serialize_none()
629    }
630
631    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
632    where
633        T: ?Sized + Serialize,
634    {
635        // Serialize the value through self which will proxy into the inner
636        // Serializer as appropriate.
637        value.serialize(self)
638    }
639
640    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
641        self.inner.serialize_unit()
642    }
643
644    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
645        self.inner.serialize_unit_struct(name)
646    }
647
648    fn serialize_unit_variant(
649        self,
650        _name: &'static str,
651        _variant_index: u32,
652        variant: &'static str,
653    ) -> Result<Self::Ok, Self::Error> {
654        // A query parameter with a list of enumerated values will produce an
655        // enum with unit variants. We treat these as scalar values, ignoring
656        // the unit variant wrapper.
657        variant.serialize(self)
658    }
659
660    fn serialize_newtype_struct<T>(
661        self,
662        name: &'static str,
663        value: &T,
664    ) -> Result<Self::Ok, Self::Error>
665    where
666        T: ?Sized + Serialize,
667    {
668        self.inner.serialize_newtype_struct(name, value)
669    }
670
671    fn serialize_newtype_variant<T>(
672        self,
673        name: &'static str,
674        _variant_index: u32,
675        variant: &'static str,
676        value: &T,
677    ) -> Result<Self::Ok, Self::Error>
678    where
679        T: ?Sized + Serialize,
680    {
681        // As with serde_json, we treat a newtype variant like a struct with a
682        // single field. This may seem a little weird, but if an OpenAPI
683        // document were to specify a query parameter whose schema was a oneOf
684        // whose elements were objects with a single field, the user would end
685        // up with an enum like this as a parameter.
686        let mut map = self.inner.serialize_struct(name, 1)?;
687        map.serialize_field(variant, value)?;
688        map.end()
689    }
690
691    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
692        let Self { inner, name, .. } = self;
693        Ok(QuerySeq {
694            inner: inner.serialize_seq(len)?,
695            name,
696        })
697    }
698
699    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
700        self.inner.serialize_tuple(len)
701    }
702
703    fn serialize_tuple_struct(
704        self,
705        name: &'static str,
706        len: usize,
707    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
708        self.inner.serialize_tuple_struct(name, len)
709    }
710
711    fn serialize_tuple_variant(
712        self,
713        name: &'static str,
714        variant_index: u32,
715        variant: &'static str,
716        len: usize,
717    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
718        self.inner
719            .serialize_tuple_variant(name, variant_index, variant, len)
720    }
721
722    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
723        self.inner.serialize_map(len)
724    }
725
726    fn serialize_struct(
727        self,
728        name: &'static str,
729        len: usize,
730    ) -> Result<Self::SerializeStruct, Self::Error> {
731        self.inner.serialize_struct(name, len)
732    }
733
734    fn serialize_struct_variant(
735        self,
736        name: &'static str,
737        variant_index: u32,
738        variant: &'static str,
739        len: usize,
740    ) -> Result<Self::SerializeStructVariant, Self::Error> {
741        self.inner
742            .serialize_struct_variant(name, variant_index, variant, len)
743    }
744}
745
746#[doc(hidden)]
747pub struct QuerySeq<'a, S> {
748    inner: S,
749    name: &'a str,
750}
751
752impl<S> serde::ser::SerializeSeq for QuerySeq<'_, S>
753where
754    S: serde::ser::SerializeSeq,
755{
756    type Ok = S::Ok;
757
758    type Error = S::Error;
759
760    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
761    where
762        T: ?Sized + Serialize,
763    {
764        let v = (self.name, value);
765        self.inner.serialize_element(&v)
766    }
767
768    fn end(self) -> Result<Self::Ok, Self::Error> {
769        self.inner.end()
770    }
771}