Skip to main content

conjure_codegen/
clients.rs

1// Copyright 2025 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro2::TokenStream;
16use quote::quote;
17use syn::Ident;
18
19use crate::{
20    context::{BaseModule, Context},
21    types::objects::{
22        ArgumentDefinition, AuthType, EndpointDefinition, ParameterType, ServiceDefinition, Type,
23    },
24};
25
26#[derive(Copy, Clone)]
27enum Style {
28    Async,
29    Sync,
30    Local,
31}
32
33pub fn generate(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
34    let sync_trait = generate_trait(ctx, def, Style::Sync);
35    let async_trait = generate_trait(ctx, def, Style::Async);
36    let local_trait = generate_trait(ctx, def, Style::Local);
37
38    quote! {
39        use conjure_http::endpoint;
40
41        #sync_trait
42        #async_trait
43        #local_trait
44    }
45}
46
47fn generate_trait(ctx: &Context, def: &ServiceDefinition, style: Style) -> TokenStream {
48    let docs = ctx.docs(def.docs());
49    let service_name = def.service_name().name();
50    let name = trait_name(ctx, def, style);
51    let version = match ctx.version() {
52        Some(version) => {
53            let some = ctx.some_ident(def.service_name());
54            quote!(, version = #some(#version))
55        }
56        None => quote!(),
57    };
58    let local = match style {
59        Style::Local => quote!(, local),
60        Style::Async | Style::Sync => quote!(),
61    };
62    let params = params(ctx, def, style);
63
64    let endpoints = def
65        .endpoints()
66        .iter()
67        .map(|e| generate_trait_endpoint(ctx, def, e, style));
68
69    quote! {
70        #docs
71        #[conjure_http::conjure_client(name = #service_name #version #local)]
72        pub trait #name #params {
73            #(#endpoints)*
74        }
75    }
76}
77
78fn trait_name(ctx: &Context, def: &ServiceDefinition, style: Style) -> Ident {
79    match style {
80        Style::Async => ctx.type_name(&format!("Async{}", def.service_name().name())),
81        Style::Local => ctx.type_name(&format!("LocalAsync{}", def.service_name().name())),
82        Style::Sync => ctx.type_name(def.service_name().name()),
83    }
84}
85
86fn params(ctx: &Context, def: &ServiceDefinition, style: Style) -> TokenStream {
87    let mut params = vec![];
88    if service_has_binary_request_body(ctx, def) {
89        params.push(quote! {
90            #[request_writer]
91            O
92        })
93    }
94
95    if !def.endpoints().is_empty() {
96        let result = ctx.result_ident(def.service_name());
97        let trait_ = match style {
98            Style::Async | Style::Local => quote!(conjure_http::private::Stream),
99            Style::Sync => {
100                let iterator = ctx.iterator_ident(def.service_name());
101                quote!(#iterator)
102            }
103        };
104        params.push(quote! {
105            #[response_body]
106            I: #trait_<Item = #result<conjure_http::private::Bytes, conjure_http::private::Error>>
107        });
108    }
109
110    if params.is_empty() {
111        quote!()
112    } else {
113        quote!(<#(#params),*>)
114    }
115}
116
117fn service_has_binary_request_body(ctx: &Context, def: &ServiceDefinition) -> bool {
118    def.endpoints()
119        .iter()
120        .any(|e| endpoint_has_binary_request_body(ctx, e))
121}
122
123fn endpoint_has_binary_request_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
124    endpoint.args().iter().any(|a| match a.param_type() {
125        ParameterType::Body(_) => ctx.is_binary(a.type_()),
126        _ => false,
127    })
128}
129
130fn generate_trait_endpoint(
131    ctx: &Context,
132    def: &ServiceDefinition,
133    endpoint: &EndpointDefinition,
134    style: Style,
135) -> TokenStream {
136    let docs = ctx.docs(endpoint.docs());
137    let method = endpoint
138        .http_method()
139        .as_str()
140        .parse::<TokenStream>()
141        .unwrap();
142    let path = path(endpoint);
143    let endpoint_name = &**endpoint.endpoint_name();
144    let async_ = match style {
145        Style::Async | Style::Local => quote!(async),
146        Style::Sync => quote!(),
147    };
148    let name = ctx.field_name(endpoint.endpoint_name());
149    let accept = accept(ctx, endpoint);
150
151    let auth_arg = auth_arg(endpoint);
152    let args = endpoint.args().iter().map(|a| arg(ctx, def, a, style));
153
154    let result = ctx.result_ident(def.service_name());
155
156    let ret_ty = rust_return_type(ctx, def, endpoint);
157    let ret_ty = quote!(#result<#ret_ty, conjure_http::private::Error>);
158
159    quote! {
160        #docs
161        #[endpoint(method = #method, path = #path, name = #endpoint_name, accept = #accept)]
162        #async_ fn #name(&self #auth_arg #(, #args)*) -> #ret_ty;
163    }
164}
165
166/// We need to strip the legacy regexes off of path params:
167///
168/// /foo/{bar:.*} -> /foo/{bar}
169fn path(endpoint: &EndpointDefinition) -> String {
170    endpoint
171        .http_path()
172        .split('/')
173        .map(
174            |segment| match segment.strip_prefix('{').and_then(|s| s.strip_suffix('}')) {
175                Some(segment) => format!("{{{}}}", segment.split(':').next().unwrap()),
176                None => segment.to_string(),
177            },
178        )
179        .collect::<Vec<_>>()
180        .join("/")
181}
182
183fn accept(ctx: &Context, endpoint: &EndpointDefinition) -> TokenStream {
184    match return_type(ctx, endpoint) {
185        ReturnType::None => quote!(conjure_http::client::conjure::EmptyResponseDeserializer),
186        ReturnType::Json(ty) => {
187            if ctx.is_iterable(ty) {
188                quote!(conjure_http::client::conjure::CollectionResponseDeserializer)
189            } else {
190                quote!(conjure_http::client::StdResponseDeserializer)
191            }
192        }
193        ReturnType::Binary => quote!(conjure_http::client::conjure::BinaryResponseDeserializer),
194        ReturnType::OptionalBinary => {
195            quote!(conjure_http::client::conjure::OptionalBinaryResponseDeserializer)
196        }
197    }
198}
199
200fn auth_arg(endpoint: &EndpointDefinition) -> TokenStream {
201    match endpoint.auth() {
202        Some(auth) => {
203            let params = match auth {
204                AuthType::Header(_) => quote!(),
205                AuthType::Cookie(cookie) => {
206                    let name = &cookie.cookie_name();
207                    quote!((cookie_name = #name))
208                }
209            };
210            quote!(, #[auth #params] auth_: &conjure_object::BearerToken)
211        }
212        None => quote!(),
213    }
214}
215
216fn arg(
217    ctx: &Context,
218    def: &ServiceDefinition,
219    arg: &ArgumentDefinition,
220    style: Style,
221) -> TokenStream {
222    let name = ctx.field_name(arg.arg_name());
223
224    let attr = match arg.param_type() {
225        ParameterType::Body(_) => {
226            let serializer = if ctx.is_binary(arg.type_()) {
227                quote!(conjure_http::client::conjure::BinaryRequestSerializer)
228            } else {
229                quote!(conjure_http::client::StdRequestSerializer)
230            };
231            quote!(#[body(serializer = #serializer)])
232        }
233        ParameterType::Header(header) => {
234            let name = &**header.param_id();
235            let mut encoder = if ctx.is_optional(arg.type_()).is_some() {
236                quote!(conjure_http::client::conjure::PlainSeqEncoder)
237            } else {
238                quote!(conjure_http::client::conjure::PlainEncoder)
239            };
240            if ctx.is_aliased(arg.type_()) {
241                let dealiased = ctx.dealiased_type(arg.type_());
242                let dealiased = ctx.rust_type(BaseModule::Clients, def.service_name(), dealiased);
243                encoder = quote!(conjure_http::client::AsRefEncoder<#encoder, #dealiased>)
244            }
245            quote!(#[header(name = #name, encoder = #encoder)])
246        }
247        ParameterType::Path(_) => {
248            let name = &**arg.arg_name();
249            quote!(#[path(name = #name, encoder = conjure_http::client::conjure::PlainEncoder)])
250        }
251        ParameterType::Query(query) => {
252            let name = &**query.param_id();
253            let mut encoder = if ctx.is_iterable(arg.type_()) {
254                quote!(conjure_http::client::conjure::PlainSeqEncoder)
255            } else {
256                quote!(conjure_http::client::conjure::PlainEncoder)
257            };
258            if ctx.is_aliased(arg.type_()) {
259                let dealiased = ctx.dealiased_type(arg.type_());
260                let dealiased = ctx.rust_type(BaseModule::Clients, def.service_name(), dealiased);
261                encoder = quote!(conjure_http::client::AsRefEncoder<#encoder, #dealiased>)
262            }
263            quote!(#[query(name = #name, encoder = #encoder)])
264        }
265    };
266
267    let ty = if ctx.is_binary(arg.type_()) {
268        match style {
269            Style::Async => {
270                let sync = ctx.sync_ident(def.service_name());
271                let send = ctx.send_ident(def.service_name());
272                quote!(impl conjure_http::client::AsyncWriteBody<O> + #sync + #send)
273            }
274            Style::Local => quote!(impl conjure_http::client::LocalAsyncWriteBody<O>),
275            Style::Sync => quote!(impl conjure_http::client::WriteBody<O>),
276        }
277    } else {
278        ctx.borrowed_rust_type(BaseModule::Clients, def.service_name(), arg.type_())
279    };
280    quote!(#attr #name: #ty)
281}
282
283fn rust_return_type(
284    ctx: &Context,
285    def: &ServiceDefinition,
286    endpoint: &EndpointDefinition,
287) -> TokenStream {
288    match return_type(ctx, endpoint) {
289        ReturnType::None => quote!(()),
290        ReturnType::Json(ty) => ctx.rust_type(BaseModule::Clients, def.service_name(), ty),
291        ReturnType::Binary => quote!(I),
292        ReturnType::OptionalBinary => {
293            let option = ctx.option_ident(def.service_name());
294            quote!(#option<I>)
295        }
296    }
297}
298
299fn return_type<'a>(ctx: &Context, endpoint: &'a EndpointDefinition) -> ReturnType<'a> {
300    match endpoint.returns() {
301        Some(ty) => match ctx.is_optional(ty) {
302            Some(inner) if ctx.is_binary(inner) => ReturnType::OptionalBinary,
303            _ if ctx.is_binary(ty) => ReturnType::Binary,
304            _ => ReturnType::Json(ty),
305        },
306        None => ReturnType::None,
307    }
308}
309
310enum ReturnType<'a> {
311    None,
312    Json(&'a Type),
313    Binary,
314    OptionalBinary,
315}