progenitor_client/
progenitor_client.rs

1// Copyright 2025 Oxide Computer Company
2
3#![allow(dead_code)]
4
5//! Support code for generated clients.
6
7use std::ops::{Deref, DerefMut};
8
9use bytes::Bytes;
10use futures_core::Stream;
11use reqwest::RequestBuilder;
12use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize};
13
14#[cfg(not(target_arch = "wasm32"))]
15type InnerByteStream = std::pin::Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send + Sync>>;
16
17#[cfg(target_arch = "wasm32")]
18type InnerByteStream = std::pin::Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>;
19
20/// Untyped byte stream used for both success and error responses.
21pub struct ByteStream(InnerByteStream);
22
23impl ByteStream {
24    /// Creates a new ByteStream
25    ///
26    /// Useful for generating test fixtures.
27    pub fn new(inner: InnerByteStream) -> Self {
28        Self(inner)
29    }
30
31    /// Consumes the [`ByteStream`] and return its inner [`Stream`].
32    pub fn into_inner(self) -> InnerByteStream {
33        self.0
34    }
35}
36
37impl Deref for ByteStream {
38    type Target = InnerByteStream;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45impl DerefMut for ByteStream {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        &mut self.0
48    }
49}
50
51/// Typed value returned by generated client methods.
52///
53/// This is used for successful responses and may appear in error responses
54/// generated from the server (see [`Error::ErrorResponse`])
55pub struct ResponseValue<T> {
56    inner: T,
57    status: reqwest::StatusCode,
58    headers: reqwest::header::HeaderMap,
59    // TODO cookies?
60}
61
62impl<T: DeserializeOwned> ResponseValue<T> {
63    #[doc(hidden)]
64    pub async fn from_response<E>(response: reqwest::Response) -> Result<Self, Error<E>> {
65        let status = response.status();
66        let headers = response.headers().clone();
67        let full = response.bytes().await.map_err(Error::ResponseBodyError)?;
68        let inner =
69            serde_json::from_slice(&full).map_err(|e| Error::InvalidResponsePayload(full, e))?;
70
71        Ok(Self {
72            inner,
73            status,
74            headers,
75        })
76    }
77}
78
79#[cfg(not(target_arch = "wasm32"))]
80impl ResponseValue<reqwest::Upgraded> {
81    #[doc(hidden)]
82    pub async fn upgrade<E: std::fmt::Debug>(
83        response: reqwest::Response,
84    ) -> Result<Self, Error<E>> {
85        let status = response.status();
86        let headers = response.headers().clone();
87        if status == reqwest::StatusCode::SWITCHING_PROTOCOLS {
88            let inner = response.upgrade().await.map_err(Error::InvalidUpgrade)?;
89
90            Ok(Self {
91                inner,
92                status,
93                headers,
94            })
95        } else {
96            Err(Error::UnexpectedResponse(response))
97        }
98    }
99}
100
101impl ResponseValue<ByteStream> {
102    #[doc(hidden)]
103    pub fn stream(response: reqwest::Response) -> Self {
104        let status = response.status();
105        let headers = response.headers().clone();
106        Self {
107            inner: ByteStream(Box::pin(response.bytes_stream())),
108            status,
109            headers,
110        }
111    }
112}
113
114impl ResponseValue<()> {
115    #[doc(hidden)]
116    pub fn empty(response: reqwest::Response) -> Self {
117        let status = response.status();
118        let headers = response.headers().clone();
119        // TODO is there anything we want to do to confirm that there is no
120        // content?
121        Self {
122            inner: (),
123            status,
124            headers,
125        }
126    }
127}
128
129impl<T> ResponseValue<T> {
130    /// Creates a [`ResponseValue`] from the inner type, status, and headers.
131    ///
132    /// Useful for generating test fixtures.
133    pub fn new(inner: T, status: reqwest::StatusCode, headers: reqwest::header::HeaderMap) -> Self {
134        Self {
135            inner,
136            status,
137            headers,
138        }
139    }
140
141    /// Consumes the ResponseValue, returning the wrapped value.
142    pub fn into_inner(self) -> T {
143        self.inner
144    }
145
146    /// Gets the status from this response.
147    pub fn status(&self) -> reqwest::StatusCode {
148        self.status
149    }
150
151    /// Gets the headers from this response.
152    pub fn headers(&self) -> &reqwest::header::HeaderMap {
153        &self.headers
154    }
155
156    /// Gets the parsed value of the Content-Length header, if present and
157    /// valid.
158    pub fn content_length(&self) -> Option<u64> {
159        self.headers
160            .get(reqwest::header::CONTENT_LENGTH)?
161            .to_str()
162            .ok()?
163            .parse::<u64>()
164            .ok()
165    }
166
167    #[doc(hidden)]
168    pub fn map<U: std::fmt::Debug, F, E>(self, f: F) -> Result<ResponseValue<U>, E>
169    where
170        F: FnOnce(T) -> U,
171    {
172        let Self {
173            inner,
174            status,
175            headers,
176        } = self;
177
178        Ok(ResponseValue {
179            inner: f(inner),
180            status,
181            headers,
182        })
183    }
184}
185
186impl ResponseValue<ByteStream> {
187    /// Consumes the `ResponseValue`, returning the wrapped [`Stream`].
188    pub fn into_inner_stream(self) -> InnerByteStream {
189        self.into_inner().into_inner()
190    }
191}
192
193impl<T> Deref for ResponseValue<T> {
194    type Target = T;
195
196    fn deref(&self) -> &Self::Target {
197        &self.inner
198    }
199}
200
201impl<T> DerefMut for ResponseValue<T> {
202    fn deref_mut(&mut self) -> &mut Self::Target {
203        &mut self.inner
204    }
205}
206
207impl<T> AsRef<T> for ResponseValue<T> {
208    fn as_ref(&self) -> &T {
209        &self.inner
210    }
211}
212
213impl<T: std::fmt::Debug> std::fmt::Debug for ResponseValue<T> {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        self.inner.fmt(f)
216    }
217}
218
219/// Error produced by generated client methods.
220///
221/// The type parameter may be a struct if there's a single expected error type
222/// or an enum if there are multiple valid error types. It can be the unit type
223/// if there are no structured returns expected.
224pub enum Error<E = ()> {
225    /// The request did not conform to API requirements.
226    InvalidRequest(String),
227
228    /// A server error either due to the data, or with the connection.
229    CommunicationError(reqwest::Error),
230
231    /// An expected response when upgrading connection.
232    InvalidUpgrade(reqwest::Error),
233
234    /// A documented, expected error response.
235    ErrorResponse(ResponseValue<E>),
236
237    /// Encountered an error reading the body for an expected response.
238    ResponseBodyError(reqwest::Error),
239
240    /// An expected response code whose deserialization failed.
241    InvalidResponsePayload(Bytes, serde_json::Error),
242
243    /// A response not listed in the API description. This may represent a
244    /// success or failure response; check `status().is_success()`.
245    UnexpectedResponse(reqwest::Response),
246
247    /// An error occurred in the processing of a request pre-hook.
248    PreHookError(String),
249
250    /// An error occurred in the processing of a request post-hook.
251    PostHookError(String),
252}
253
254impl<E> Error<E> {
255    /// Returns the status code, if the error was generated from a response.
256    pub fn status(&self) -> Option<reqwest::StatusCode> {
257        match self {
258            Error::InvalidRequest(_) => None,
259            Error::PreHookError(_) => None,
260            Error::PostHookError(_) => None,
261            Error::CommunicationError(e) => e.status(),
262            Error::ErrorResponse(rv) => Some(rv.status()),
263            Error::InvalidUpgrade(e) => e.status(),
264            Error::ResponseBodyError(e) => e.status(),
265            Error::InvalidResponsePayload(_, _) => None,
266            Error::UnexpectedResponse(r) => Some(r.status()),
267        }
268    }
269
270    /// Converts this error into one without a typed body.
271    ///
272    /// This is useful for unified error handling with APIs that distinguish
273    /// various error response bodies.
274    pub fn into_untyped(self) -> Error {
275        match self {
276            Error::InvalidRequest(s) => Error::InvalidRequest(s),
277            Error::PreHookError(s) => Error::PreHookError(s),
278            Error::PostHookError(s) => Error::PostHookError(s),
279            Error::CommunicationError(e) => Error::CommunicationError(e),
280            Error::ErrorResponse(ResponseValue {
281                inner: _,
282                status,
283                headers,
284            }) => Error::ErrorResponse(ResponseValue {
285                inner: (),
286                status,
287                headers,
288            }),
289            Error::InvalidUpgrade(e) => Error::InvalidUpgrade(e),
290            Error::ResponseBodyError(e) => Error::ResponseBodyError(e),
291            Error::InvalidResponsePayload(b, e) => Error::InvalidResponsePayload(b, e),
292            Error::UnexpectedResponse(r) => Error::UnexpectedResponse(r),
293        }
294    }
295}
296
297impl<E> From<reqwest::Error> for Error<E> {
298    fn from(e: reqwest::Error) -> Self {
299        Self::CommunicationError(e)
300    }
301}
302
303impl<E> From<reqwest::header::InvalidHeaderValue> for Error<E> {
304    fn from(e: reqwest::header::InvalidHeaderValue) -> Self {
305        Self::InvalidRequest(e.to_string())
306    }
307}
308
309impl<E> std::fmt::Display for Error<E>
310where
311    ResponseValue<E>: ErrorFormat,
312{
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        match self {
315            Error::InvalidRequest(s) => {
316                write!(f, "Invalid Request: {}", s)?;
317            }
318            Error::CommunicationError(e) => {
319                write!(f, "Communication Error: {}", e)?;
320            }
321            Error::ErrorResponse(rve) => {
322                write!(f, "Error Response: ")?;
323                rve.fmt_info(f)?;
324            }
325            Error::InvalidUpgrade(e) => {
326                write!(f, "Invalid Response Upgrade: {}", e)?;
327            }
328            Error::ResponseBodyError(e) => {
329                write!(f, "Invalid Response Body Bytes: {}", e)?;
330            }
331            Error::InvalidResponsePayload(b, e) => {
332                write!(f, "Invalid Response Payload ({:?}): {}", b, e)?;
333            }
334            Error::UnexpectedResponse(r) => {
335                write!(f, "Unexpected Response: {:?}", r)?;
336            }
337            Error::PreHookError(s) => {
338                write!(f, "Pre-hook Error: {}", s)?;
339            }
340            Error::PostHookError(s) => {
341                write!(f, "Post-hook Error: {}", s)?;
342            }
343        }
344
345        if f.alternate() {
346            use std::error::Error as _;
347
348            let mut src = self.source().and_then(|e| e.source());
349            while let Some(s) = src {
350                write!(f, ": {s}")?;
351                src = s.source();
352            }
353        }
354        Ok(())
355    }
356}
357
358trait ErrorFormat {
359    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
360}
361
362impl<E> ErrorFormat for ResponseValue<E>
363where
364    E: std::fmt::Debug,
365{
366    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        write!(
368            f,
369            "status: {}; headers: {:?}; value: {:?}",
370            self.status, self.headers, self.inner,
371        )
372    }
373}
374
375impl ErrorFormat for ResponseValue<ByteStream> {
376    fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        write!(
378            f,
379            "status: {}; headers: {:?}; value: <stream>",
380            self.status, self.headers,
381        )
382    }
383}
384
385impl<E> std::fmt::Debug for Error<E>
386where
387    ResponseValue<E>: ErrorFormat,
388{
389    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        std::fmt::Display::fmt(self, f)
391    }
392}
393impl<E> std::error::Error for Error<E>
394where
395    ResponseValue<E>: ErrorFormat,
396{
397    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
398        match self {
399            Error::CommunicationError(e) => Some(e),
400            Error::InvalidUpgrade(e) => Some(e),
401            Error::ResponseBodyError(e) => Some(e),
402            Error::InvalidResponsePayload(_b, e) => Some(e),
403            _ => None,
404        }
405    }
406}
407
408// See https://url.spec.whatwg.org/#url-path-segment-string
409const PATH_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
410    .add(b' ')
411    .add(b'"')
412    .add(b'#')
413    .add(b'<')
414    .add(b'>')
415    .add(b'?')
416    .add(b'`')
417    .add(b'{')
418    .add(b'}')
419    .add(b'/')
420    .add(b'%');
421
422#[doc(hidden)]
423/// Percent encode input string.
424pub fn encode_path(pc: &str) -> String {
425    percent_encoding::utf8_percent_encode(pc, PATH_SET).to_string()
426}
427
428#[doc(hidden)]
429pub trait RequestBuilderExt<E> {
430    fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<RequestBuilder, Error<E>>;
431}
432
433impl<E> RequestBuilderExt<E> for RequestBuilder {
434    fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<Self, Error<E>> {
435        Ok(self
436            .header(
437                reqwest::header::CONTENT_TYPE,
438                reqwest::header::HeaderValue::from_static("application/x-www-form-urlencoded"),
439            )
440            .body(
441                serde_urlencoded::to_string(body)
442                    .map_err(|_| Error::InvalidRequest("failed to serialize body".to_string()))?,
443            ))
444    }
445}
446
447#[doc(hidden)]
448pub struct QueryParam<'a, T> {
449    name: &'a str,
450    value: &'a T,
451}
452
453impl<'a, T> QueryParam<'a, T> {
454    #[doc(hidden)]
455    pub fn new(name: &'a str, value: &'a T) -> Self {
456        Self { name, value }
457    }
458}
459impl<T> Serialize for QueryParam<'_, T>
460where
461    T: Serialize,
462{
463    fn serialize<S>(&self, inner: S) -> Result<S::Ok, S::Error>
464    where
465        S: serde::Serializer,
466    {
467        let serializer = QuerySerializer {
468            inner,
469            name: self.name,
470        };
471        self.value.serialize(serializer)
472    }
473}
474
475pub(crate) struct QuerySerializer<'a, S> {
476    inner: S,
477    name: &'a str,
478}
479
480macro_rules! serialize_scalar {
481    ($f:ident, $t:ty) => {
482        fn $f(self, v: $t) -> Result<Self::Ok, Self::Error> {
483            [(self.name, v)].serialize(self.inner)
484        }
485    };
486}
487
488impl<'a, S> serde::Serializer for QuerySerializer<'a, S>
489where
490    S: serde::Serializer,
491{
492    type Ok = S::Ok;
493    type Error = S::Error;
494    type SerializeSeq = QuerySeq<'a, S::SerializeSeq>;
495    type SerializeTuple = S::SerializeTuple;
496    type SerializeTupleStruct = S::SerializeTupleStruct;
497    type SerializeTupleVariant = S::SerializeTupleVariant;
498    type SerializeMap = S::SerializeMap;
499    type SerializeStruct = S::SerializeStruct;
500    type SerializeStructVariant = S::SerializeStructVariant;
501
502    serialize_scalar!(serialize_bool, bool);
503    serialize_scalar!(serialize_i8, i8);
504    serialize_scalar!(serialize_i16, i16);
505    serialize_scalar!(serialize_i32, i32);
506    serialize_scalar!(serialize_i64, i64);
507    serialize_scalar!(serialize_u8, u8);
508    serialize_scalar!(serialize_u16, u16);
509    serialize_scalar!(serialize_u32, u32);
510    serialize_scalar!(serialize_u64, u64);
511    serialize_scalar!(serialize_f32, f32);
512    serialize_scalar!(serialize_f64, f64);
513    serialize_scalar!(serialize_char, char);
514    serialize_scalar!(serialize_str, &str);
515
516    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
517        self.inner.serialize_bytes(v)
518    }
519
520    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
521        self.inner.serialize_none()
522    }
523
524    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
525    where
526        T: ?Sized + Serialize,
527    {
528        // Serialize the value through self which will proxy into the inner
529        // Serializer as appropriate.
530        value.serialize(self)
531    }
532
533    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
534        self.inner.serialize_unit()
535    }
536
537    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
538        self.inner.serialize_unit_struct(name)
539    }
540
541    fn serialize_unit_variant(
542        self,
543        _name: &'static str,
544        _variant_index: u32,
545        variant: &'static str,
546    ) -> Result<Self::Ok, Self::Error> {
547        // A query parameter with a list of enumerated values will produce an
548        // enum with unit variants. We treat these as scalar values, ignoring
549        // the unit variant wrapper.
550        variant.serialize(self)
551    }
552
553    fn serialize_newtype_struct<T>(
554        self,
555        name: &'static str,
556        value: &T,
557    ) -> Result<Self::Ok, Self::Error>
558    where
559        T: ?Sized + Serialize,
560    {
561        self.inner.serialize_newtype_struct(name, value)
562    }
563
564    fn serialize_newtype_variant<T>(
565        self,
566        name: &'static str,
567        _variant_index: u32,
568        variant: &'static str,
569        value: &T,
570    ) -> Result<Self::Ok, Self::Error>
571    where
572        T: ?Sized + Serialize,
573    {
574        // As with serde_json, we treat a newtype variant like a struct with a
575        // single field. This may seem a little weird, but if an OpenAPI
576        // document were to specify a query parameter whose schema was a oneOf
577        // whose elements were objects with a single field, the user would end
578        // up with an enum like this as a parameter.
579        let mut map = self.inner.serialize_struct(name, 1)?;
580        map.serialize_field(variant, value)?;
581        map.end()
582    }
583
584    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
585        let Self { inner, name, .. } = self;
586        Ok(QuerySeq {
587            inner: inner.serialize_seq(len)?,
588            name,
589        })
590    }
591
592    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
593        self.inner.serialize_tuple(len)
594    }
595
596    fn serialize_tuple_struct(
597        self,
598        name: &'static str,
599        len: usize,
600    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
601        self.inner.serialize_tuple_struct(name, len)
602    }
603
604    fn serialize_tuple_variant(
605        self,
606        name: &'static str,
607        variant_index: u32,
608        variant: &'static str,
609        len: usize,
610    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
611        self.inner
612            .serialize_tuple_variant(name, variant_index, variant, len)
613    }
614
615    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
616        self.inner.serialize_map(len)
617    }
618
619    fn serialize_struct(
620        self,
621        name: &'static str,
622        len: usize,
623    ) -> Result<Self::SerializeStruct, Self::Error> {
624        self.inner.serialize_struct(name, len)
625    }
626
627    fn serialize_struct_variant(
628        self,
629        name: &'static str,
630        variant_index: u32,
631        variant: &'static str,
632        len: usize,
633    ) -> Result<Self::SerializeStructVariant, Self::Error> {
634        self.inner
635            .serialize_struct_variant(name, variant_index, variant, len)
636    }
637}
638
639#[doc(hidden)]
640pub struct QuerySeq<'a, S> {
641    inner: S,
642    name: &'a str,
643}
644
645impl<S> serde::ser::SerializeSeq for QuerySeq<'_, S>
646where
647    S: serde::ser::SerializeSeq,
648{
649    type Ok = S::Ok;
650
651    type Error = S::Error;
652
653    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
654    where
655        T: ?Sized + Serialize,
656    {
657        let v = (self.name, value);
658        self.inner.serialize_element(&v)
659    }
660
661    fn end(self) -> Result<Self::Ok, Self::Error> {
662        self.inner.end()
663    }
664}