Skip to main content

typespec_client_core/http/
pipeline.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4use typespec::http::RawResponse;
5
6use crate::http::{
7    policies::{Buffer, LoggingPolicy, Policy, TransportPolicy},
8    AsyncRawResponse, ClientOptions, Context, PipelineOptions, Request,
9};
10use std::sync::Arc;
11
12/// Execution pipeline.
13///
14/// A pipeline follows a precise flow:
15///
16/// 1. Client library-specified per-call policies are executed. Per-call policies can fail and bail out of the pipeline
17///    immediately.
18/// 2. User-specified per-call policies in [`ClientOptions::per_call_policies`] are executed.
19/// 3. The retry policy is executed. It allows to re-execute the following policies.
20/// 4. Client library-specified per-retry policies. Per-retry polices are always executed at least once but are
21///    re-executed in case of retries.
22/// 5. User-specified per-retry policies in [`ClientOptions::per_try_policies`] are executed.
23/// 6. The transport policy is executed. Transport policy is always the last policy and is the policy that
24///    actually constructs the [`AsyncRawResponse`] to be passed up the pipeline.
25///
26/// A pipeline is immutable. In other words a policy can either succeed and call the following
27/// policy of fail and return to the calling policy. Arbitrary policy "skip" must be avoided (but
28/// cannot be enforced by code). All policies except Transport policy can assume there is another following policy (so
29/// `self.pipeline[0]` is always valid).
30#[derive(Debug, Clone)]
31pub struct Pipeline {
32    pipeline: Vec<Arc<dyn Policy>>,
33}
34
35/// Options for the [`Pipeline::send`] function.
36#[derive(Debug, Default)]
37pub struct PipelineSendOptions;
38
39/// Options for the [`Pipeline::stream`] function.
40#[derive(Debug, Default)]
41pub struct PipelineStreamOptions;
42
43impl Pipeline {
44    /// Creates a new pipeline with user-specified and client library-specified policies.
45    ///
46    /// # Arguments
47    /// * `options` - The client options.
48    /// * `per_call_policies` - Policies to be executed per call, before the policies in `ClientOptions::per_call_policies`.
49    /// * `per_try_policies` - Policies to be executed per try, before the policies in `ClientOptions::per_try_policies`.
50    /// * `pipeline_options` - Additional options for the pipeline.
51    ///
52    pub fn new(
53        options: ClientOptions,
54        per_call_policies: Vec<Arc<dyn Policy>>,
55        per_try_policies: Vec<Arc<dyn Policy>>,
56        pipeline_options: Option<PipelineOptions>,
57    ) -> Self {
58        // The number of policies we'll push to the pipeline Vec ourselves.
59        const BUILT_IN_LEN: usize = 3;
60        let mut pipeline: Vec<Arc<dyn Policy>> = Vec::with_capacity(
61            per_call_policies.len()
62                + options.per_call_policies.len()
63                + per_try_policies.len()
64                + options.per_try_policies.len()
65                + BUILT_IN_LEN,
66        );
67
68        #[cfg(debug_assertions)]
69        let initial_capacity = pipeline.capacity();
70
71        pipeline.extend_from_slice(&per_call_policies);
72        pipeline.extend_from_slice(&options.per_call_policies);
73
74        let pipeline_options = pipeline_options.unwrap_or_default();
75
76        let retry_policy = options.retry.to_policy(
77            pipeline_options.retry_headers.clone(),
78            &pipeline_options.retry_status_codes,
79        );
80        pipeline.push(retry_policy);
81
82        pipeline.extend_from_slice(&per_try_policies);
83        pipeline.extend_from_slice(&options.per_try_policies);
84
85        pipeline.push(Arc::new(LoggingPolicy::new(options.logging)));
86
87        let transport: Arc<dyn Policy> =
88            Arc::new(TransportPolicy::new(options.transport.unwrap_or_default()));
89        pipeline.push(transport);
90
91        // Make sure we didn't have to resize the Vec.
92        #[cfg(debug_assertions)]
93        debug_assert_eq!(pipeline.len(), initial_capacity);
94
95        Self { pipeline }
96    }
97
98    /// Gets the policies in the order a [`Request`] is processed.
99    pub fn policies(&self) -> &[Arc<dyn Policy>] {
100        &self.pipeline
101    }
102
103    /// Sends a [`Request`] through each configured [`Policy`] and gets a [`RawResponse`] that is processed by each policy in reverse.
104    pub async fn send(
105        &self,
106        ctx: &Context<'_>,
107        request: &mut Request,
108        _options: Option<PipelineSendOptions>,
109    ) -> crate::Result<RawResponse> {
110        // Signal the TransportPolicy to buffer the entire response.
111        let mut ctx = ctx.to_borrowed();
112        ctx.insert(Buffer);
113
114        self.pipeline[0]
115            .send(&ctx, request, &self.pipeline[1..])
116            .await?
117            .try_into_raw_response()
118            .await
119    }
120
121    /// Sends a [`Request`] through each configured [`Policy`] to get a [`AsyncRawResponse`] that is processed by each policy in reverse.
122    pub async fn stream(
123        &self,
124        ctx: &Context<'_>,
125        request: &mut Request,
126        _options: Option<PipelineStreamOptions>,
127    ) -> crate::Result<AsyncRawResponse> {
128        self.pipeline[0]
129            .send(ctx, request, &self.pipeline[1..])
130            .await
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::{
138        error::{Error, ErrorKind},
139        http::{
140            headers::Headers, policies::PolicyResult, AsyncRawResponse, FixedRetryOptions,
141            JsonFormat, Method, Response, RetryOptions, StatusCode, Transport,
142        },
143        stream::BytesStream,
144        Bytes,
145    };
146    use futures::{lock::Mutex, StreamExt, TryStreamExt};
147    use serde::Deserialize;
148    use std::collections::VecDeque;
149    use time::Duration;
150
151    #[derive(Debug, Deserialize)]
152    struct Model {
153        foo: i32,
154        bar: String,
155    }
156
157    #[tokio::test]
158    async fn deserializes_response() {
159        #[derive(Debug)]
160        struct Responder {}
161
162        #[async_trait::async_trait]
163        impl Policy for Responder {
164            async fn send(
165                &self,
166                _ctx: &Context,
167                _request: &mut Request,
168                _next: &[Arc<dyn Policy>],
169            ) -> PolicyResult {
170                let buffer = Bytes::from_static(br#"{"foo":1,"bar":"baz"}"#);
171                let stream: BytesStream = buffer.into();
172                let response =
173                    AsyncRawResponse::new(StatusCode::Ok, Headers::new(), Box::pin(stream));
174                Ok(std::future::ready(response).await)
175            }
176        }
177
178        // Simulated service method
179        async fn service_method() -> crate::Result<Response<Model, JsonFormat>> {
180            let options = ClientOptions {
181                transport: Some(Transport::with_policy(Arc::new(Responder {}))),
182                ..Default::default()
183            };
184            let pipeline = Pipeline::new(options, Vec::new(), Vec::new(), None);
185            let mut request = Request::new("http://localhost".parse().unwrap(), Method::Get);
186            let raw_response = pipeline
187                .send(&Context::default(), &mut request, None)
188                .await?;
189            Ok(raw_response.into())
190        }
191
192        let model = service_method().await.unwrap().into_model().unwrap();
193
194        assert_eq!(1, model.foo);
195        assert_eq!("baz", &model.bar);
196    }
197
198    #[derive(Debug, Default)]
199    struct Counter {
200        count: Mutex<usize>,
201    }
202
203    impl Counter {
204        async fn count(&self) -> usize {
205            let count = self.count.lock().await;
206            *count
207        }
208    }
209
210    #[async_trait::async_trait]
211    impl Policy for Counter {
212        async fn send(
213            &self,
214            ctx: &Context,
215            request: &mut Request,
216            next: &[Arc<dyn Policy>],
217        ) -> PolicyResult {
218            let result = next[0].send(ctx, request, &next[1..]).await;
219
220            // Increment the counter after the response.
221            let mut count = self.count.lock().await;
222            *count += 1;
223
224            result
225        }
226    }
227
228    #[tokio::test]
229    async fn send_retries_in_pipeline() {
230        #[derive(Debug)]
231        struct Responder {
232            responses: Mutex<VecDeque<AsyncRawResponse>>,
233        }
234
235        impl Default for Responder {
236            fn default() -> Self {
237                let mut headers = Headers::new();
238                headers.insert("content-type", "application/json");
239                headers.insert("transfer-encoding", "chunked");
240
241                Self {
242                    responses: Mutex::new(VecDeque::from_iter([
243                        AsyncRawResponse::from_bytes(
244                            StatusCode::TooManyRequests,
245                            Headers::new(),
246                            Vec::new(),
247                        ),
248                        AsyncRawResponse::new(
249                            StatusCode::Ok,
250                            headers.clone(),
251                            futures::stream::iter([
252                                Ok(Bytes::from_static(br#"{"foo":1,"#)),
253                                // Simulate an I/O error from default reqwest::Client.
254                                Err(Error::new(ErrorKind::Io, "connection reset")),
255                            ])
256                            .boxed(),
257                        ),
258                        AsyncRawResponse::new(
259                            StatusCode::Ok,
260                            headers,
261                            futures::stream::iter([
262                                Ok(Bytes::from_static(br#"{"foo":1,"#)),
263                                Ok(Bytes::from_static(br#""bar":"baz"}"#)),
264                            ])
265                            .boxed(),
266                        ),
267                    ])),
268                }
269            }
270        }
271
272        #[async_trait::async_trait]
273        impl Policy for Responder {
274            async fn send(
275                &self,
276                _ctx: &Context,
277                _request: &mut Request,
278                _next: &[Arc<dyn Policy>],
279            ) -> PolicyResult {
280                let mut responses = self.responses.lock().await;
281                let response = responses.pop_front().expect("expected AsyncRawResponse");
282                Ok(response)
283            }
284        }
285
286        let per_call_count = Arc::new(Counter::default());
287        let per_try_count = Arc::new(Counter::default());
288
289        // Simulated service method
290        async fn service_method(
291            per_call_count: Arc<Counter>,
292            per_try_count: Arc<Counter>,
293        ) -> crate::Result<Response<Model, JsonFormat>> {
294            let options = ClientOptions {
295                retry: RetryOptions::fixed(FixedRetryOptions {
296                    delay: Duration::milliseconds(1),
297                    ..Default::default()
298                }),
299                transport: Some(Transport::with_policy(Arc::new(Responder::default()))),
300                ..Default::default()
301            };
302            let pipeline = Pipeline::new(options, vec![per_call_count], vec![per_try_count], None);
303            let mut request = Request::new("http://localhost".parse().unwrap(), Method::Get);
304            let raw_response = pipeline
305                .send(&Context::default(), &mut request, None)
306                .await?;
307            Ok(raw_response.into())
308        }
309
310        let resp = service_method(per_call_count.clone(), per_try_count.clone())
311            .await
312            .expect("expected Response");
313        assert_eq!(per_try_count.count().await, 3);
314        assert_eq!(per_call_count.count().await, 1);
315
316        let model = resp.into_model().expect("expected Model");
317        assert_eq!(per_try_count.count().await, 3);
318        assert_eq!(per_call_count.count().await, 1);
319
320        assert_eq!(1, model.foo);
321        assert_eq!("baz", &model.bar);
322    }
323
324    #[tokio::test]
325    async fn stream_out_of_pipeline() {
326        #[derive(Debug)]
327        struct Responder {
328            responses: Mutex<VecDeque<AsyncRawResponse>>,
329        }
330
331        impl Default for Responder {
332            fn default() -> Self {
333                let mut headers = Headers::new();
334                headers.insert("content-type", "application/x-octet-stream");
335                headers.insert("transfer-encoding", "chunked");
336
337                Self {
338                    responses: Mutex::new(VecDeque::from_iter([
339                        AsyncRawResponse::from_bytes(
340                            StatusCode::TooManyRequests,
341                            Headers::new(),
342                            Vec::new(),
343                        ),
344                        AsyncRawResponse::new(
345                            StatusCode::Ok,
346                            headers.clone(),
347                            futures::stream::iter([
348                                Ok(vec![0xde, 0xad].into()),
349                                Ok(vec![0xbe, 0xef].into()),
350                                // Simulate an I/O error from default reqwest::Client.
351                                Err(Error::new(ErrorKind::Io, "connection reset")),
352                            ])
353                            .boxed(),
354                        ),
355                        AsyncRawResponse::from_bytes(
356                            StatusCode::ImATeapot,
357                            Headers::new(),
358                            r#"unexpected"#,
359                        ),
360                    ])),
361                }
362            }
363        }
364
365        #[async_trait::async_trait]
366        impl Policy for Responder {
367            async fn send(
368                &self,
369                _ctx: &Context,
370                _request: &mut Request,
371                _next: &[Arc<dyn Policy>],
372            ) -> PolicyResult {
373                let mut responses = self.responses.lock().await;
374                let response = responses.pop_front().expect("expected AsyncRawResponse");
375                Ok(response)
376            }
377        }
378
379        let per_call_count = Arc::new(Counter::default());
380        let per_try_count = Arc::new(Counter::default());
381
382        // Simulated service method
383        async fn service_method(
384            per_call_count: Arc<Counter>,
385            per_try_count: Arc<Counter>,
386        ) -> crate::Result<AsyncRawResponse> {
387            let options = ClientOptions {
388                retry: RetryOptions::fixed(FixedRetryOptions {
389                    delay: Duration::milliseconds(1),
390                    ..Default::default()
391                }),
392                transport: Some(Transport::with_policy(Arc::new(Responder::default()))),
393                ..Default::default()
394            };
395            let pipeline = Pipeline::new(options, vec![per_call_count], vec![per_try_count], None);
396            let mut request = Request::new("http://localhost".parse().unwrap(), Method::Get);
397            pipeline
398                .stream(&Context::default(), &mut request, None)
399                .await
400        }
401
402        let resp = service_method(per_call_count.clone(), per_try_count.clone())
403            .await
404            .expect("expected AsyncRawResponse");
405        assert_eq!(per_try_count.count().await, 2);
406        assert_eq!(per_call_count.count().await, 1);
407
408        let mut stream = resp.into_body().into_stream();
409        assert_eq!(
410            stream.try_next().await.expect("first chunk"),
411            Some(vec![0xde, 0xad].into())
412        );
413        assert_eq!(
414            stream.try_next().await.expect("second chunk"),
415            Some(vec![0xbe, 0xef].into())
416        );
417        assert!(matches!(stream.try_next().await, Err(e) if *e.kind() == ErrorKind::Io));
418
419        // Make sure we never went back through the pipeline policies.
420        assert_eq!(per_try_count.count().await, 2);
421        assert_eq!(per_call_count.count().await, 1);
422    }
423}