conjure_macros/
lib.rs

1// Copyright 2022 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Macros exposed by conjure-http.
15//!
16//! Do not consume directly.
17#![warn(missing_docs)]
18
19use proc_macro::TokenStream;
20use syn::{Error, ItemTrait, TraitItem};
21
22mod client;
23mod endpoints;
24mod path;
25
26/// Creates a Conjure client type implementing the annotated trait.
27///
28/// For a trait named `MyService`, the macro will create a type named `MyServiceClient` which
29/// implements the Conjure `Service`/`AsyncService` and `MyService` traits.
30///
31/// The attribute has several parameters:
32///
33/// * `name` - The value of the `service` field in the `Endpoint` extension. Defaults to the trait's
34///   name.
35/// * `version` - The value of the `version` field in the `Endpoint` extension. Defaults to
36///   `Some(env!("CARGO_PKG_VERSION"))`.
37/// * `local` - For async clients, causes the generated struct to use the `LocalAsyncClient` APIs
38///   that don't have a `Send` bound.
39///
40/// # Parameters
41///
42/// The trait can optionally be declared generic over the request body and response writer types by
43/// using the `#[request_writer]` and `#[response_body]` annotations on the type parameters.
44///
45/// # Endpoints
46///
47/// Each method corresponds to a separate HTTP endpoint, and is expected to take `&self` and return
48/// `Result<T, Error>`. Each must be annotated with `#[endpoint]`, which has several
49/// parameters:
50///
51/// * `method` - The HTTP method (e.g. `GET`). Required.
52/// * `path` - The HTTP path template. Path parameters should be identified by `{name}` and must
53///   make up an entire path component. Required.
54/// * `name` - The value of the `name` field in the `Endpoint` extension. Defaults to the method's
55///   name.
56/// * `accept` - A type implementing `DeserializeResponse` which will be used to create the return
57///   value. Defaults to returning `()`.
58///
59/// Each method argument must have an annotation describing the type of parameter. One of:
60///
61/// * `#[path]` - A path parameter.
62///
63///     Parameters:
64///     * `name` - The name of the path template parameter. Defaults to the argument name.
65///     * `encoder` - A type implementing `EncodeParam` which will be used to encode the value into
66///       a string. Defaults to `DisplayParamEncoder`.
67/// * `#[query]` - A query parameter.
68///
69///     Parameters:
70///     * `name` - The string used as the key in the encoded URI. Required.
71///     * `encoder` - A type implementing `EncodeParam` which will be used to encode the value into
72///       a string. Defaults to `DisplayParamEncoder`.
73/// * `#[auth]` - A `BearerToken` used to authenticate the request. A method may only have at most
74///   one auth parameter.
75///
76///     Parameters:
77///     * `cookie_name` - The name of the cookie used if the token is to be passed via a `Cookie`
78///       header. If unset, it will be passed via an `Authorization` header instead.
79/// * `#[header]` - A header.
80///
81///     Parameters:
82///     * `name` - The header name. Required.
83///     * `encoder` - A type implementing `EncodeHeader` which will be used to encode the value
84///       into a header. Defaults to `DisplayHeaderEncoder`.
85/// * `#[body]` - The request body. A method may only have at most one body parameter.
86///
87///     Parameters:
88///     * `serializer` - A type implementing `SerializeRequest` which will be used to serialize the
89///       value into a body. Defaults to `StdRequestSerializer`.
90/// # Async
91///
92/// Both blocking and async clients are supported. For technical reasons, async method definitions
93/// will be rewritten by the macro to require the returned future be `Send` unless the `local` flag
94/// is set in the attribute.
95///
96/// # Examples
97///
98/// ```rust,ignore
99/// use conjure_error::Error;
100/// use conjure_http::{conjure_client, endpoint};
101/// use conjure_http::client::{
102///     AsyncClient, AsyncService, Client, ConjureRuntime, StdResponseDeserializer,
103///     DeserializeResponse, DisplaySeqEncoder, RequestBody, SerializeRequest, Service, WriteBody,
104/// };
105/// use conjure_object::BearerToken;
106/// use http::Response;
107/// use http::header::HeaderValue;
108/// use std::io::Write;
109/// use std::sync::Arc;
110///
111/// #[conjure_client]
112/// trait MyService {
113///     #[endpoint(method = GET, path = "/yaks/{yak_id}", accept = StdResponseDeserializer)]
114///     fn get_yak(&self, #[auth] auth: &BearerToken, #[path] yak_id: i32) -> Result<String, Error>;
115///
116///     #[endpoint(method = POST, path = "/yaks")]
117///     fn create_yak(
118///         &self,
119///         #[auth] auth_token: &BearerToken,
120///         #[query(name = "parentName", encoder = DisplaySeqEncoder)] parent_id: Option<&str>,
121///         #[body] yak: &str,
122///     ) -> Result<(), Error>;
123/// }
124///
125/// fn do_work(client: impl Client, runtime: &Arc<ConjureRuntime>, auth: &BearerToken) -> Result<(), Error> {
126///     let client = MyServiceClient::new(client, runtime);
127///     client.create_yak(auth, None, "my cool yak")?;
128///
129///     Ok(())
130/// }
131///
132/// #[conjure_client]
133/// trait MyServiceAsync {
134///     #[endpoint(method = GET, path = "/yaks/{yak_id}", accept = StdResponseDeserializer)]
135///     async fn get_yak(
136///         &self,
137///         #[auth] auth: &BearerToken,
138///         #[path] yak_id: i32,
139///     ) -> Result<String, Error>;
140///
141///     #[endpoint(method = POST, path = "/yaks")]
142///     async fn create_yak(
143///         &self,
144///         #[auth] auth_token: &BearerToken,
145///         #[query(name = "parentName", encoder = DisplaySeqEncoder)] parent_id: Option<&str>,
146///         #[body] yak: &str,
147///     ) -> Result<(), Error>;
148/// }
149///
150/// async fn do_work_async<C>(client: C, runtime: &Arc<ConjureRuntime>, auth: &BearerToken) -> Result<(), Error>
151/// where
152///     C: AsyncClient + Sync + Send,
153///     C::ResponseBody: 'static + Send,
154/// {
155///     let client = MyServiceAsyncClient::new(client, runtime);
156///     client.create_yak(auth, None, "my cool yak").await?;
157///
158///     Ok(())
159/// }
160///
161/// #[conjure_client]
162/// trait MyStreamingService<#[response_body] I, #[request_writer] O>
163/// where
164///     O: Write,
165/// {
166///     #[endpoint(method = POST, path = "/streamData")]
167///     fn upload_stream(
168///         &self,
169///         #[body(serializer = StreamingRequestSerializer)] body: StreamingRequest,
170///     ) -> Result<(), Error>;
171///
172///     #[endpoint(method = GET, path = "/streamData", accept = StreamingResponseDeserializer)]
173///     fn download_stream(&self) -> Result<I, Error>;
174/// }
175///
176/// struct StreamingRequest;
177///
178/// impl<W> WriteBody<W> for StreamingRequest
179/// where
180///     W: Write,
181/// {
182///     fn write_body(&mut self, w: &mut W) -> Result<(), Error> {
183///         // ...
184///         Ok(())
185///     }
186///
187///     fn reset(&mut self) -> bool {
188///         true
189///     }
190/// }
191///
192/// enum StreamingRequestSerializer {}
193///
194/// impl<W> SerializeRequest<'static, StreamingRequest, W> for StreamingRequestSerializer
195/// where
196///     W: Write,
197/// {
198///     fn content_type(_: &ConjureRuntime, _: &StreamingRequest) -> HeaderValue {
199///         HeaderValue::from_static("text/plain")
200///     }
201///
202///     fn serialize(_: &ConjureRuntime, value: StreamingRequest) -> Result<RequestBody<'static, W>, Error> {
203///         Ok(RequestBody::Streaming(Box::new(value)))
204///     }
205/// }
206///
207/// enum StreamingResponseDeserializer {}
208///
209/// impl<R> DeserializeResponse<R, R> for StreamingResponseDeserializer {
210///     fn accept(_: &ConjureRuntime) -> Option<HeaderValue> {
211///         None
212///     }
213///
214///     fn deserialize(_: &ConjureRuntime, response: Response<R>) -> Result<R, Error> {
215///         Ok(response.into_body())
216///     }
217/// }
218/// ```
219#[proc_macro_attribute]
220pub fn conjure_client(attr: TokenStream, item: TokenStream) -> TokenStream {
221    client::generate(attr, item)
222}
223
224/// Creates a Conjure service type wrapping types implementing the annotated trait.
225///
226/// For a trait named `MyService`, the macro will create a type named `MyServiceEndpoints` which
227/// implements the conjure `Service` trait.
228///
229/// The attribute has a parameter:
230///
231/// * `name` - The value returned from the `EndpointMetadata::service_name` method. Defaults to the
232///   trait name.
233/// * `use_legacy_error_serialization` - If set, parameters of service errors will be serialized in
234///   old stringified format.
235///
236/// # Parameters
237///
238/// The trait can optionally be declared generic over the request body and response writer types by
239/// using the `#[request_body]` and `#[response_writer]` annotations on the type parameters.
240///
241/// # Endpoints
242///
243/// Each method corresponds to a separate HTTP endpoint, and is expected to take `&self` and return
244/// `Result<T, Error>`. Each must be annotated with `#[endpoint]`, which has several parameters:
245///
246/// * `method` - The HTTP method (e.g. `GET`). Required.
247/// * `path` - The HTTP path template. Path parameters should be identified by `{name}` and must
248///   make up an entire path component. Required.
249/// * `name` - The value returned from the `EndpointMetadata::name` method. Defaults to the method
250///   name.
251/// * `produces` - A type implementing `SerializeResponse` which will be used to convert the value
252///   returned by the method into a response. Defaults to `EmptyResponseSerializer`.
253///
254/// Each method argument must have an annotation describing the type of parameter. One of:
255///
256/// * `#[path]` - A path parameter.
257///
258///     Parameters:
259///     * `name` - The name of the path template parameter. Defaults to the argument name.
260///     * `decoder` - A type implementing `DecodeParam` which will be used to decode the value.
261///       Defaults to `FromStrDecoder`.
262///     * `safe` - If set, the parameter will be added to the `SafeParams` response extension.
263///     * `log_as` - The name of the parameter used in request logging and error reporting. Defaults
264///       to the argument name.
265/// * `#[query]` - A query parameter.
266///
267///     Parameters:
268///     * `name` - The string used as the key in the encoded URI. Required.
269///     * `decoder` - A type implementing `DecodeParam` which will be used to decode the value.
270///       Defaults to `FromStrDecoder`.
271///     * `safe` - If set, the parameter will be added to the `SafeParams` response extension.
272///     * `log_as` - The name of the parameter used in request logging and error reporting. Defaults
273///       to the argument name.
274/// * `#[auth]` - A `BearerToken` used to authenticate the request.
275///
276///     Parameters:
277///     * `cookie_name` - The name of the cookie if the token is to be parsed from a `Cookie`
278///       header. If unset, it will be parsed from an `Authorization` header instead.
279/// * `#[header]` - A header parameter.
280///
281///     Parameters:
282///     * `name` - The header name. Required.
283///     * `decoder` - A type implementing `DecodeHeader` which will be used to decode the value.
284///       Defaults to `FromStrDecoder`.
285///     * `safe` - If set, the parameter will be added to the `SafeParams` response extension.
286///     * `log_as` - The name of the parameter used in request logging and error reporting. Defaults
287///       to the argument name.
288/// * `#[body]` - The request body.
289///
290///     Parameters:
291///     * `deserializer` - A type implementing `DeserializeRequest` which will be used to
292///       deserialize the request body into a value. Defaults to `StdRequestDeserializer`.
293///     * `safe` - If set, the parameter will be added to the `SafeParams` response extension.
294///     * `log_as` - The name of the parameter used in request logging and error reporting. Defaults
295///       to the argument name.
296/// * `#[context]` - A `RequestContext` which provides lower level access to the request.
297///
298/// # Async
299///
300/// Both blocking and async services are supported. For technical reasons, async method definitions
301/// will be rewritten by the macro to require the returned future be `Send`.
302///
303/// # Examples
304///
305/// ```rust,ignore
306/// use conjure_error::Error;
307/// use conjure_http::{conjure_endpoints, endpoint};
308/// use conjure_http::server::{
309///     ConjureRuntime, DeserializeRequest, FromStrOptionDecoder, ResponseBody, SerializeResponse,
310///     StdResponseSerializer, WriteBody,
311/// };
312/// use conjure_object::BearerToken;
313/// use http::Response;
314/// use http::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
315/// use std::io::Write;
316///
317/// #[conjure_endpoints]
318/// trait MyService {
319///     #[endpoint(method = GET, path = "/yaks/{yak_id}", produces = StdResponseSerializer)]
320///     fn get_yak(
321///         &self,
322///         #[auth] auth: BearerToken,
323///         #[path(safe)] yak_id: i32,
324///     ) -> Result<String, Error>;
325///
326///     #[endpoint(method = POST, path = "/yaks")]
327///     fn create_yak(
328///         &self,
329///         #[auth] auth: BearerToken,
330///         #[query(name = "parentName", decoder = FromStrOptionDecoder)] parent_id: Option<String>,
331///         #[body] yak: String,
332///     ) -> Result<(), Error>;
333/// }
334///
335/// #[conjure_endpoints]
336/// trait AsyncMyService {
337///     #[endpoint(method = GET, path = "/yaks/{yak_id}", produces = StdResponseSerializer)]
338///     async fn get_yak(
339///         &self,
340///         #[auth] auth: BearerToken,
341///         #[path(safe)] yak_id: i32,
342///     ) -> Result<String, Error>;
343///
344///     #[endpoint(method = POST, path = "/yaks")]
345///     async fn create_yak(
346///         &self,
347///         #[auth] auth: BearerToken,
348///         #[query(name = "parentName", decoder = FromStrOptionDecoder)] parent_id: Option<String>,
349///         #[body] yak: String,
350///     ) -> Result<(), Error>;
351/// }
352///
353/// #[conjure_endpoints]
354/// trait MyStreamingService<#[request_body] I, #[response_writer] O>
355/// where
356///     O: Write,
357/// {
358///     #[endpoint(method = POST, path = "/streamData")]
359///     fn receive_stream(
360///         &self,
361///         #[body(deserializer = StreamingRequestDeserializer)] body: I,
362///     )  -> Result<(), Error>;
363///
364///     #[endpoint(method = GET, path = "/streamData", produces = StreamingResponseSerializer)]
365///     fn stream_response(&self) -> Result<StreamingResponse, Error>;
366/// }
367///
368/// struct StreamingRequestDeserializer;
369///
370/// impl<I> DeserializeRequest<I, I> for StreamingRequestDeserializer {
371///     fn deserialize(
372///         _runtime: &ConjureRuntime,
373///         _headers: &HeaderMap,
374///         body: I,
375///     ) -> Result<I, Error> {
376///         Ok(body)
377///     }
378/// }
379///
380/// struct StreamingResponse;
381///
382/// impl<O> WriteBody<O> for StreamingResponse
383/// where
384///     O: Write,
385/// {
386///     fn write_body(self: Box<Self>, w: &mut O) -> Result<(), Error> {
387///         // ...
388///         Ok(())
389///     }
390/// }
391///
392/// struct StreamingResponseSerializer;
393///
394/// impl<O> SerializeResponse<StreamingResponse, O> for StreamingResponseSerializer
395/// where
396///     O: Write,
397/// {
398///     fn serialize(
399///         _runtime: &ConjureRuntime,
400///         _request_headers: &HeaderMap,
401///         body: StreamingResponse,
402///     ) -> Result<Response<ResponseBody<O>>, Error> {
403///         let mut response = Response::new(ResponseBody::Streaming(Box::new(body)));
404///         response.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
405///         Ok(response)
406///     }
407/// }
408/// ```
409#[proc_macro_attribute]
410pub fn conjure_endpoints(attr: TokenStream, item: TokenStream) -> TokenStream {
411    endpoints::generate(attr, item)
412}
413
414/// A no-op attribute macro required due to technical limitations of Rust's macro system.
415#[proc_macro_attribute]
416pub fn endpoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
417    item
418}
419
420struct Errors(Vec<Error>);
421
422impl Errors {
423    fn new() -> Self {
424        Errors(vec![])
425    }
426
427    fn push(&mut self, error: Error) {
428        self.0.push(error);
429    }
430
431    fn build(mut self) -> Result<(), Error> {
432        let Some(mut error) = self.0.pop() else {
433            return Ok(());
434        };
435        for other in self.0 {
436            error.combine(other);
437        }
438        Err(error)
439    }
440}
441
442#[derive(Copy, Clone)]
443enum Asyncness {
444    Sync,
445    Async,
446    LocalAsync,
447}
448
449impl Asyncness {
450    fn resolve(trait_: &ItemTrait, local: bool) -> Result<Self, Error> {
451        let mut it = trait_.items.iter().filter_map(|t| match t {
452            TraitItem::Fn(f) => Some(f),
453            _ => None,
454        });
455
456        let Some(first) = it.next() else {
457            return Ok(Asyncness::Sync);
458        };
459
460        let is_async = first.sig.asyncness.is_some();
461
462        let mut errors = Errors::new();
463
464        for f in it {
465            if f.sig.asyncness.is_some() != is_async {
466                errors.push(Error::new_spanned(
467                    f,
468                    "all methods must either be sync or async",
469                ));
470            }
471        }
472
473        errors.build()?;
474        let asyncness = if is_async {
475            if local {
476                Asyncness::LocalAsync
477            } else {
478                Asyncness::Async
479            }
480        } else {
481            Asyncness::Sync
482        };
483        Ok(asyncness)
484    }
485}