progenitor_client/
progenitor_client.rs

1// Copyright 2025 Oxide Computer Company
2
3#![allow(dead_code)]
4
5//! Support code for generated clients.
6
7use std::ops::{Deref, DerefMut};
8
9use bytes::Bytes;
10use futures_core::Stream;
11use reqwest::RequestBuilder;
12use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize};
13
14#[cfg(not(target_arch = "wasm32"))]
15type InnerByteStream = std::pin::Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send + Sync>>;
16
17#[cfg(target_arch = "wasm32")]
18type InnerByteStream = std::pin::Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>;
19
20/// Untyped byte stream used for both success and error responses.
21pub struct ByteStream(InnerByteStream);
22
23impl ByteStream {
24    /// Creates a new ByteStream
25    ///
26    /// Useful for generating test fixtures.
27    pub fn new(inner: InnerByteStream) -> Self {
28        Self(inner)
29    }
30
31    /// Consumes the [`ByteStream`] and return its inner [`Stream`].
32    pub fn into_inner(self) -> InnerByteStream {
33        self.0
34    }
35}
36
37impl Deref for ByteStream {
38    type Target = InnerByteStream;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45impl DerefMut for ByteStream {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        &mut self.0
48    }
49}
50
51/// Interface for which an implementation is generated for all clients.
52pub trait ClientInfo<Inner> {
53    /// Get the version of this API.
54    ///
55    /// This string is pulled directly from the source OpenAPI document and may
56    /// be in any format the API selects.
57    fn api_version() -> &'static str;
58
59    /// Get the base URL to which requests are made.
60    fn baseurl(&self) -> &str;
61
62    /// Get the internal `reqwest::Client` used to make requests.
63    fn client(&self) -> &reqwest::Client;
64
65    /// Get the inner value of type `T` if one is specified.
66    fn inner(&self) -> &Inner;
67}
68
69impl<T, Inner> ClientInfo<Inner> for &T
70where
71    T: ClientInfo<Inner>,
72{
73    fn api_version() -> &'static str {
74        T::api_version()
75    }
76
77    fn baseurl(&self) -> &str {
78        (*self).baseurl()
79    }
80
81    fn client(&self) -> &reqwest::Client {
82        (*self).client()
83    }
84
85    fn inner(&self) -> &Inner {
86        (*self).inner()
87    }
88}
89
90/// Information about an operation, consumed by hook implementations.
91pub struct OperationInfo {
92    /// The corresponding operationId from the source OpenAPI document.
93    pub operation_id: &'static str,
94}
95
96/// Interface for changing the behavior of generated clients. All clients
97/// implement this for `&Client`; to override the default behavior, implement
98/// some or all of the interfaces for the `Client` type (without the
99/// reference). This mechanism relies on so-called "auto-ref specialization".
100#[allow(async_fn_in_trait, unused)]
101pub trait ClientHooks<Inner = ()>
102where
103    Self: ClientInfo<Inner>,
104{
105    /// Runs prior to the execution of the request. This may be used to modify
106    /// the request before it is transmitted.
107    async fn pre<E>(
108        &self,
109        request: &mut reqwest::Request,
110        info: &OperationInfo,
111    ) -> std::result::Result<(), Error<E>> {
112        Ok(())
113    }
114
115    /// Runs after completion of the request.
116    async fn post<E>(
117        &self,
118        result: &reqwest::Result<reqwest::Response>,
119        info: &OperationInfo,
120    ) -> std::result::Result<(), Error<E>> {
121        Ok(())
122    }
123
124    /// Execute the request. Note that for almost any reasonable implementation
125    /// this will include code equivalent to this:
126    /// ```
127    /// # use progenitor_client::{ClientHooks, ClientInfo, OperationInfo};
128    /// # struct X;
129    /// # impl ClientInfo<()> for X {
130    /// #   fn api_version() -> &'static str { panic!() }
131    /// #   fn baseurl(&self) -> &str { panic!() }
132    /// #   fn client(&self) -> &reqwest::Client { panic!() }
133    /// #   fn inner(&self) -> &() { panic!() }
134    /// # }
135    /// # impl ClientHooks for X {
136    /// #   async fn exec(
137    /// #       &self,
138    /// #       request: reqwest::Request,
139    /// #       info: &OperationInfo,
140    /// #   ) -> reqwest::Result<reqwest::Response> {
141    ///         self.client().execute(request).await
142    /// #   }
143    /// # }
144    /// ```
145    async fn exec(
146        &self,
147        request: reqwest::Request,
148        info: &OperationInfo,
149    ) -> reqwest::Result<reqwest::Response> {
150        self.client().execute(request).await
151    }
152}
153
154/// Typed value returned by generated client methods.
155///
156/// This is used for successful responses and may appear in error responses
157/// generated from the server (see [`Error::ErrorResponse`])
158pub struct ResponseValue<T> {
159    inner: T,
160    status: reqwest::StatusCode,
161    headers: reqwest::header::HeaderMap,
162    // TODO cookies?
163}
164
165impl<T: DeserializeOwned> ResponseValue<T> {
166    #[doc(hidden)]
167    pub async fn from_response<E>(response: reqwest::Response) -> Result<Self, Error<E>> {
168        let status = response.status();
169        let headers = response.headers().clone();
170        let full = response.bytes().await.map_err(Error::ResponseBodyError)?;
171        let inner =
172            serde_json::from_slice(&full).map_err(|e| Error::InvalidResponsePayload(full, e))?;
173
174        Ok(Self {
175            inner,
176            status,
177            headers,
178        })
179    }
180}
181
182#[cfg(not(target_arch = "wasm32"))]
183impl ResponseValue<reqwest::Upgraded> {
184    #[doc(hidden)]
185    pub async fn upgrade<E: std::fmt::Debug>(
186        response: reqwest::Response,
187    ) -> Result<Self, Error<E>> {
188        let status = response.status();
189        let headers = response.headers().clone();
190        if status == reqwest::StatusCode::SWITCHING_PROTOCOLS {
191            let inner = response.upgrade().await.map_err(Error::InvalidUpgrade)?;
192
193            Ok(Self {
194                inner,
195                status,
196                headers,
197            })
198        } else {
199            Err(Error::UnexpectedResponse(response))
200        }
201    }
202}
203
204impl ResponseValue<ByteStream> {
205    #[doc(hidden)]
206    pub fn stream(response: reqwest::Response) -> Self {
207        let status = response.status();
208        let headers = response.headers().clone();
209        Self {
210            inner: ByteStream(Box::pin(response.bytes_stream())),
211            status,
212            headers,
213        }
214    }
215}
216
217impl ResponseValue<()> {
218    #[doc(hidden)]
219    pub fn empty(response: reqwest::Response) -> Self {
220        let status = response.status();
221        let headers = response.headers().clone();
222        // TODO is there anything we want to do to confirm that there is no
223        // content?
224        Self {
225            inner: (),
226            status,
227            headers,
228        }
229    }
230}
231
232impl<T> ResponseValue<T> {
233    /// Creates a [`ResponseValue`] from the inner type, status, and headers.
234    ///
235    /// Useful for generating test fixtures.
236    pub fn new(inner: T, status: reqwest::StatusCode, headers: reqwest::header::HeaderMap) -> Self {
237        Self {
238            inner,
239            status,
240            headers,
241        }
242    }
243
244    /// Consumes the ResponseValue, returning the wrapped value.
245    pub fn into_inner(self) -> T {
246        self.inner
247    }
248
249    /// Gets the status from this response.
250    pub fn status(&self) -> reqwest::StatusCode {
251        self.status
252    }
253
254    /// Gets the headers from this response.
255    pub fn headers(&self) -> &reqwest::header::HeaderMap {
256        &self.headers
257    }
258
259    /// Gets the parsed value of the Content-Length header, if present and
260    /// valid.
261    pub fn content_length(&self) -> Option<u64> {
262        self.headers
263            .get(reqwest::header::CONTENT_LENGTH)?
264            .to_str()
265            .ok()?
266            .parse::<u64>()
267            .ok()
268    }
269
270    #[doc(hidden)]
271    pub fn map<U: std::fmt::Debug, F, E>(self, f: F) -> Result<ResponseValue<U>, E>
272    where
273        F: FnOnce(T) -> U,
274    {
275        let Self {
276            inner,
277            status,
278            headers,
279        } = self;
280
281        Ok(ResponseValue {
282            inner: f(inner),
283            status,
284            headers,
285        })
286    }
287}
288
289impl ResponseValue<ByteStream> {
290    /// Consumes the `ResponseValue`, returning the wrapped [`Stream`].
291    pub fn into_inner_stream(self) -> InnerByteStream {
292        self.into_inner().into_inner()
293    }
294}
295
296impl<T> Deref for ResponseValue<T> {
297    type Target = T;
298
299    fn deref(&self) -> &Self::Target {
300        &self.inner
301    }
302}
303
304impl<T> DerefMut for ResponseValue<T> {
305    fn deref_mut(&mut self) -> &mut Self::Target {
306        &mut self.inner
307    }
308}
309
310impl<T> AsRef<T> for ResponseValue<T> {
311    fn as_ref(&self) -> &T {
312        &self.inner
313    }
314}
315
316impl<T: std::fmt::Debug> std::fmt::Debug for ResponseValue<T> {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        self.inner.fmt(f)
319    }
320}
321
322/// Error produced by generated client methods.
323///
324/// The type parameter may be a struct if there's a single expected error type
325/// or an enum if there are multiple valid error types. It can be the unit type
326/// if there are no structured returns expected.
327pub enum Error<E = ()> {
328    /// The request did not conform to API requirements.
329    InvalidRequest(String),
330
331    /// A server error either due to the data, or with the connection.
332    CommunicationError(reqwest::Error),
333
334    /// An expected response when upgrading connection.
335    InvalidUpgrade(reqwest::Error),
336
337    /// A documented, expected error response.
338    ErrorResponse(ResponseValue<E>),
339
340    /// Encountered an error reading the body for an expected response.
341    ResponseBodyError(reqwest::Error),
342
343    /// An expected response code whose deserialization failed.
344    InvalidResponsePayload(Bytes, serde_json::Error),
345
346    /// A response not listed in the API description. This may represent a
347    /// success or failure response; check `status().is_success()`.
348    UnexpectedResponse(reqwest::Response),
349
350    /// A custom error from a consumer-defined hook.
351    Custom(String),
352}
353
354impl<E> Error<E> {
355    /// Returns the status code, if the error was generated from a response.
356    pub fn status(&self) -> Option<reqwest::StatusCode> {
357        match self {
358            Error::InvalidRequest(_) => None,
359            Error::Custom(_) => None,
360            Error::CommunicationError(e) => e.status(),
361            Error::ErrorResponse(rv) => Some(rv.status()),
362            Error::InvalidUpgrade(e) => e.status(),
363            Error::ResponseBodyError(e) => e.status(),
364            Error::InvalidResponsePayload(_, _) => None,
365            Error::UnexpectedResponse(r) => Some(r.status()),
366        }
367    }
368
369    /// Converts this error into one without a typed body.
370    ///
371    /// This is useful for unified error handling with APIs that distinguish
372    /// various error response bodies.
373    pub fn into_untyped(self) -> Error {
374        match self {
375            Error::InvalidRequest(s) => Error::InvalidRequest(s),
376            Error::Custom(s) => Error::Custom(s),
377            Error::CommunicationError(e) => Error::CommunicationError(e),
378            Error::ErrorResponse(ResponseValue {
379                inner: _,
380                status,
381                headers,
382            }) => Error::ErrorResponse(ResponseValue {
383                inner: (),
384                status,
385                headers,
386            }),
387            Error::InvalidUpgrade(e) => Error::InvalidUpgrade(e),
388            Error::ResponseBodyError(e) => Error::ResponseBodyError(e),
389            Error::InvalidResponsePayload(b, e) => Error::InvalidResponsePayload(b, e),
390            Error::UnexpectedResponse(r) => Error::UnexpectedResponse(r),
391        }
392    }
393}
394
395impl<E> From<std::convert::Infallible> for Error<E> {
396    fn from(x: std::convert::Infallible) -> Self {
397        match x {}
398    }
399}
400
401impl<E> From<reqwest::Error> for Error<E> {
402    fn from(e: reqwest::Error) -> Self {
403        Self::CommunicationError(e)
404    }
405}
406
407impl<E> From<reqwest::header::InvalidHeaderValue> for Error<E> {
408    fn from(e: reqwest::header::InvalidHeaderValue) -> Self {
409        Self::InvalidRequest(e.to_string())
410    }
411}
412
413impl<E> std::fmt::Display for Error<E>
414where
415    ResponseValue<E>: ErrorFormat,
416{
417    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418        match self {
419            Error::InvalidRequest(s) => {
420                write!(f, "Invalid Request: {}", s)?;
421            }
422            Error::CommunicationError(e) => {
423                write!(f, "Communication Error: {}", e)?;
424            }
425            Error::ErrorResponse(rve) => {
426                write!(f, "Error Response: ")?;
427                rve.fmt_info(f)?;
428            }
429            Error::InvalidUpgrade(e) => {
430                write!(f, "Invalid Response Upgrade: {}", e)?;
431            }
432            Error::ResponseBodyError(e) => {
433                write!(f, "Invalid Response Body Bytes: {}", e)?;
434            }
435            Error::InvalidResponsePayload(b, e) => {
436                write!(f, "Invalid Response Payload ({:?}): {}", b, e)?;
437            }
438            Error::UnexpectedResponse(r) => {
439                write!(f, "Unexpected Response: {:?}", r)?;
440            }
441            Error::Custom(s) => {
442                write!(f, "Error: {}", s)?;
443            }
444        }
445
446        if f.alternate() {
447            use std::error::Error as _;
448
449            let mut src = self.source().and_then(|e| e.source());
450            while let Some(s) = src {
451                write!(f, ": {s}")?;
452                src = s.source();
453            }
454        }
455        Ok(())
456    }
457}
458
459trait ErrorFormat {
460    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
461}
462
463impl<E> ErrorFormat for ResponseValue<E>
464where
465    E: std::fmt::Debug,
466{
467    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        write!(
469            f,
470            "status: {}; headers: {:?}; value: {:?}",
471            self.status, self.headers, self.inner,
472        )
473    }
474}
475
476impl ErrorFormat for ResponseValue<ByteStream> {
477    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
478        write!(
479            f,
480            "status: {}; headers: {:?}; value: <stream>",
481            self.status, self.headers,
482        )
483    }
484}
485
486impl<E> std::fmt::Debug for Error<E>
487where
488    ResponseValue<E>: ErrorFormat,
489{
490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        std::fmt::Display::fmt(self, f)
492    }
493}
494impl<E> std::error::Error for Error<E>
495where
496    ResponseValue<E>: ErrorFormat,
497{
498    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
499        match self {
500            Error::CommunicationError(e) => Some(e),
501            Error::InvalidUpgrade(e) => Some(e),
502            Error::ResponseBodyError(e) => Some(e),
503            Error::InvalidResponsePayload(_b, e) => Some(e),
504            _ => None,
505        }
506    }
507}
508
509// See https://url.spec.whatwg.org/#url-path-segment-string
510const PATH_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
511    .add(b' ')
512    .add(b'"')
513    .add(b'#')
514    .add(b'<')
515    .add(b'>')
516    .add(b'?')
517    .add(b'`')
518    .add(b'{')
519    .add(b'}')
520    .add(b'/')
521    .add(b'%');
522
523#[doc(hidden)]
524/// Percent encode input string.
525pub fn encode_path(pc: &str) -> String {
526    percent_encoding::utf8_percent_encode(pc, PATH_SET).to_string()
527}
528
529#[doc(hidden)]
530pub trait RequestBuilderExt<E> {
531    fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<RequestBuilder, Error<E>>;
532}
533
534impl<E> RequestBuilderExt<E> for RequestBuilder {
535    fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<Self, Error<E>> {
536        Ok(self
537            .header(
538                reqwest::header::CONTENT_TYPE,
539                reqwest::header::HeaderValue::from_static("application/x-www-form-urlencoded"),
540            )
541            .body(
542                serde_urlencoded::to_string(body)
543                    .map_err(|_| Error::InvalidRequest("failed to serialize body".to_string()))?,
544            ))
545    }
546}
547
548#[doc(hidden)]
549pub struct QueryParam<'a, T> {
550    name: &'a str,
551    value: &'a T,
552}
553
554impl<'a, T> QueryParam<'a, T> {
555    #[doc(hidden)]
556    pub fn new(name: &'a str, value: &'a T) -> Self {
557        Self { name, value }
558    }
559}
560impl<T> Serialize for QueryParam<'_, T>
561where
562    T: Serialize,
563{
564    fn serialize<S>(&self, inner: S) -> Result<S::Ok, S::Error>
565    where
566        S: serde::Serializer,
567    {
568        let serializer = QuerySerializer {
569            inner,
570            name: self.name,
571        };
572        self.value.serialize(serializer)
573    }
574}
575
576pub(crate) struct QuerySerializer<'a, S> {
577    inner: S,
578    name: &'a str,
579}
580
581macro_rules! serialize_scalar {
582    ($f:ident, $t:ty) => {
583        fn $f(self, v: $t) -> Result<Self::Ok, Self::Error> {
584            [(self.name, v)].serialize(self.inner)
585        }
586    };
587}
588
589impl<'a, S> serde::Serializer for QuerySerializer<'a, S>
590where
591    S: serde::Serializer,
592{
593    type Ok = S::Ok;
594    type Error = S::Error;
595    type SerializeSeq = QuerySeq<'a, S::SerializeSeq>;
596    type SerializeTuple = S::SerializeTuple;
597    type SerializeTupleStruct = S::SerializeTupleStruct;
598    type SerializeTupleVariant = S::SerializeTupleVariant;
599    type SerializeMap = S::SerializeMap;
600    type SerializeStruct = S::SerializeStruct;
601    type SerializeStructVariant = S::SerializeStructVariant;
602
603    serialize_scalar!(serialize_bool, bool);
604    serialize_scalar!(serialize_i8, i8);
605    serialize_scalar!(serialize_i16, i16);
606    serialize_scalar!(serialize_i32, i32);
607    serialize_scalar!(serialize_i64, i64);
608    serialize_scalar!(serialize_u8, u8);
609    serialize_scalar!(serialize_u16, u16);
610    serialize_scalar!(serialize_u32, u32);
611    serialize_scalar!(serialize_u64, u64);
612    serialize_scalar!(serialize_f32, f32);
613    serialize_scalar!(serialize_f64, f64);
614    serialize_scalar!(serialize_char, char);
615    serialize_scalar!(serialize_str, &str);
616
617    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
618        self.inner.serialize_bytes(v)
619    }
620
621    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
622        self.inner.serialize_none()
623    }
624
625    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
626    where
627        T: ?Sized + Serialize,
628    {
629        // Serialize the value through self which will proxy into the inner
630        // Serializer as appropriate.
631        value.serialize(self)
632    }
633
634    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
635        self.inner.serialize_unit()
636    }
637
638    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
639        self.inner.serialize_unit_struct(name)
640    }
641
642    fn serialize_unit_variant(
643        self,
644        _name: &'static str,
645        _variant_index: u32,
646        variant: &'static str,
647    ) -> Result<Self::Ok, Self::Error> {
648        // A query parameter with a list of enumerated values will produce an
649        // enum with unit variants. We treat these as scalar values, ignoring
650        // the unit variant wrapper.
651        variant.serialize(self)
652    }
653
654    fn serialize_newtype_struct<T>(
655        self,
656        name: &'static str,
657        value: &T,
658    ) -> Result<Self::Ok, Self::Error>
659    where
660        T: ?Sized + Serialize,
661    {
662        self.inner.serialize_newtype_struct(name, value)
663    }
664
665    fn serialize_newtype_variant<T>(
666        self,
667        name: &'static str,
668        _variant_index: u32,
669        variant: &'static str,
670        value: &T,
671    ) -> Result<Self::Ok, Self::Error>
672    where
673        T: ?Sized + Serialize,
674    {
675        // As with serde_json, we treat a newtype variant like a struct with a
676        // single field. This may seem a little weird, but if an OpenAPI
677        // document were to specify a query parameter whose schema was a oneOf
678        // whose elements were objects with a single field, the user would end
679        // up with an enum like this as a parameter.
680        let mut map = self.inner.serialize_struct(name, 1)?;
681        map.serialize_field(variant, value)?;
682        map.end()
683    }
684
685    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
686        let Self { inner, name, .. } = self;
687        Ok(QuerySeq {
688            inner: inner.serialize_seq(len)?,
689            name,
690        })
691    }
692
693    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
694        self.inner.serialize_tuple(len)
695    }
696
697    fn serialize_tuple_struct(
698        self,
699        name: &'static str,
700        len: usize,
701    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
702        self.inner.serialize_tuple_struct(name, len)
703    }
704
705    fn serialize_tuple_variant(
706        self,
707        name: &'static str,
708        variant_index: u32,
709        variant: &'static str,
710        len: usize,
711    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
712        self.inner
713            .serialize_tuple_variant(name, variant_index, variant, len)
714    }
715
716    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
717        self.inner.serialize_map(len)
718    }
719
720    fn serialize_struct(
721        self,
722        name: &'static str,
723        len: usize,
724    ) -> Result<Self::SerializeStruct, Self::Error> {
725        self.inner.serialize_struct(name, len)
726    }
727
728    fn serialize_struct_variant(
729        self,
730        name: &'static str,
731        variant_index: u32,
732        variant: &'static str,
733        len: usize,
734    ) -> Result<Self::SerializeStructVariant, Self::Error> {
735        self.inner
736            .serialize_struct_variant(name, variant_index, variant, len)
737    }
738}
739
740#[doc(hidden)]
741pub struct QuerySeq<'a, S> {
742    inner: S,
743    name: &'a str,
744}
745
746impl<S> serde::ser::SerializeSeq for QuerySeq<'_, S>
747where
748    S: serde::ser::SerializeSeq,
749{
750    type Ok = S::Ok;
751
752    type Error = S::Error;
753
754    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
755    where
756        T: ?Sized + Serialize,
757    {
758        let v = (self.name, value);
759        self.inner.serialize_element(&v)
760    }
761
762    fn end(self) -> Result<Self::Ok, Self::Error> {
763        self.inner.end()
764    }
765}