1use typespec::http::RawResponse;
5
6use crate::http::{
7 policies::{Buffer, LoggingPolicy, Policy, TransportPolicy},
8 AsyncRawResponse, ClientOptions, Context, PipelineOptions, Request,
9};
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
31pub struct Pipeline {
32 pipeline: Vec<Arc<dyn Policy>>,
33}
34
35#[derive(Debug, Default)]
37pub struct PipelineSendOptions;
38
39#[derive(Debug, Default)]
41pub struct PipelineStreamOptions;
42
43impl Pipeline {
44 pub fn new(
53 options: ClientOptions,
54 per_call_policies: Vec<Arc<dyn Policy>>,
55 per_try_policies: Vec<Arc<dyn Policy>>,
56 pipeline_options: Option<PipelineOptions>,
57 ) -> Self {
58 const BUILT_IN_LEN: usize = 3;
60 let mut pipeline: Vec<Arc<dyn Policy>> = Vec::with_capacity(
61 per_call_policies.len()
62 + options.per_call_policies.len()
63 + per_try_policies.len()
64 + options.per_try_policies.len()
65 + BUILT_IN_LEN,
66 );
67
68 #[cfg(debug_assertions)]
69 let initial_capacity = pipeline.capacity();
70
71 pipeline.extend_from_slice(&per_call_policies);
72 pipeline.extend_from_slice(&options.per_call_policies);
73
74 let pipeline_options = pipeline_options.unwrap_or_default();
75
76 let retry_policy = options.retry.to_policy(
77 pipeline_options.retry_headers.clone(),
78 &pipeline_options.retry_status_codes,
79 );
80 pipeline.push(retry_policy);
81
82 pipeline.extend_from_slice(&per_try_policies);
83 pipeline.extend_from_slice(&options.per_try_policies);
84
85 pipeline.push(Arc::new(LoggingPolicy::new(options.logging)));
86
87 let transport: Arc<dyn Policy> =
88 Arc::new(TransportPolicy::new(options.transport.unwrap_or_default()));
89 pipeline.push(transport);
90
91 #[cfg(debug_assertions)]
93 debug_assert_eq!(pipeline.len(), initial_capacity);
94
95 Self { pipeline }
96 }
97
98 pub fn policies(&self) -> &[Arc<dyn Policy>] {
100 &self.pipeline
101 }
102
103 pub async fn send(
105 &self,
106 ctx: &Context<'_>,
107 request: &mut Request,
108 _options: Option<PipelineSendOptions>,
109 ) -> crate::Result<RawResponse> {
110 let mut ctx = ctx.to_borrowed();
112 ctx.insert(Buffer);
113
114 self.pipeline[0]
115 .send(&ctx, request, &self.pipeline[1..])
116 .await?
117 .try_into_raw_response()
118 .await
119 }
120
121 pub async fn stream(
123 &self,
124 ctx: &Context<'_>,
125 request: &mut Request,
126 _options: Option<PipelineStreamOptions>,
127 ) -> crate::Result<AsyncRawResponse> {
128 self.pipeline[0]
129 .send(ctx, request, &self.pipeline[1..])
130 .await
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::{
138 error::{Error, ErrorKind},
139 http::{
140 headers::Headers, policies::PolicyResult, AsyncRawResponse, FixedRetryOptions,
141 JsonFormat, Method, Response, RetryOptions, StatusCode, Transport,
142 },
143 stream::BytesStream,
144 Bytes,
145 };
146 use futures::{lock::Mutex, StreamExt, TryStreamExt};
147 use serde::Deserialize;
148 use std::collections::VecDeque;
149 use time::Duration;
150
151 #[derive(Debug, Deserialize)]
152 struct Model {
153 foo: i32,
154 bar: String,
155 }
156
157 #[tokio::test]
158 async fn deserializes_response() {
159 #[derive(Debug)]
160 struct Responder {}
161
162 #[async_trait::async_trait]
163 impl Policy for Responder {
164 async fn send(
165 &self,
166 _ctx: &Context,
167 _request: &mut Request,
168 _next: &[Arc<dyn Policy>],
169 ) -> PolicyResult {
170 let buffer = Bytes::from_static(br#"{"foo":1,"bar":"baz"}"#);
171 let stream: BytesStream = buffer.into();
172 let response =
173 AsyncRawResponse::new(StatusCode::Ok, Headers::new(), Box::pin(stream));
174 Ok(std::future::ready(response).await)
175 }
176 }
177
178 async fn service_method() -> crate::Result<Response<Model, JsonFormat>> {
180 let options = ClientOptions {
181 transport: Some(Transport::with_policy(Arc::new(Responder {}))),
182 ..Default::default()
183 };
184 let pipeline = Pipeline::new(options, Vec::new(), Vec::new(), None);
185 let mut request = Request::new("http://localhost".parse().unwrap(), Method::Get);
186 let raw_response = pipeline
187 .send(&Context::default(), &mut request, None)
188 .await?;
189 Ok(raw_response.into())
190 }
191
192 let model = service_method().await.unwrap().into_model().unwrap();
193
194 assert_eq!(1, model.foo);
195 assert_eq!("baz", &model.bar);
196 }
197
198 #[derive(Debug, Default)]
199 struct Counter {
200 count: Mutex<usize>,
201 }
202
203 impl Counter {
204 async fn count(&self) -> usize {
205 let count = self.count.lock().await;
206 *count
207 }
208 }
209
210 #[async_trait::async_trait]
211 impl Policy for Counter {
212 async fn send(
213 &self,
214 ctx: &Context,
215 request: &mut Request,
216 next: &[Arc<dyn Policy>],
217 ) -> PolicyResult {
218 let result = next[0].send(ctx, request, &next[1..]).await;
219
220 let mut count = self.count.lock().await;
222 *count += 1;
223
224 result
225 }
226 }
227
228 #[tokio::test]
229 async fn send_retries_in_pipeline() {
230 #[derive(Debug)]
231 struct Responder {
232 responses: Mutex<VecDeque<AsyncRawResponse>>,
233 }
234
235 impl Default for Responder {
236 fn default() -> Self {
237 let mut headers = Headers::new();
238 headers.insert("content-type", "application/json");
239 headers.insert("transfer-encoding", "chunked");
240
241 Self {
242 responses: Mutex::new(VecDeque::from_iter([
243 AsyncRawResponse::from_bytes(
244 StatusCode::TooManyRequests,
245 Headers::new(),
246 Vec::new(),
247 ),
248 AsyncRawResponse::new(
249 StatusCode::Ok,
250 headers.clone(),
251 futures::stream::iter([
252 Ok(Bytes::from_static(br#"{"foo":1,"#)),
253 Err(Error::new(ErrorKind::Io, "connection reset")),
255 ])
256 .boxed(),
257 ),
258 AsyncRawResponse::new(
259 StatusCode::Ok,
260 headers,
261 futures::stream::iter([
262 Ok(Bytes::from_static(br#"{"foo":1,"#)),
263 Ok(Bytes::from_static(br#""bar":"baz"}"#)),
264 ])
265 .boxed(),
266 ),
267 ])),
268 }
269 }
270 }
271
272 #[async_trait::async_trait]
273 impl Policy for Responder {
274 async fn send(
275 &self,
276 _ctx: &Context,
277 _request: &mut Request,
278 _next: &[Arc<dyn Policy>],
279 ) -> PolicyResult {
280 let mut responses = self.responses.lock().await;
281 let response = responses.pop_front().expect("expected AsyncRawResponse");
282 Ok(response)
283 }
284 }
285
286 let per_call_count = Arc::new(Counter::default());
287 let per_try_count = Arc::new(Counter::default());
288
289 async fn service_method(
291 per_call_count: Arc<Counter>,
292 per_try_count: Arc<Counter>,
293 ) -> crate::Result<Response<Model, JsonFormat>> {
294 let options = ClientOptions {
295 retry: RetryOptions::fixed(FixedRetryOptions {
296 delay: Duration::milliseconds(1),
297 ..Default::default()
298 }),
299 transport: Some(Transport::with_policy(Arc::new(Responder::default()))),
300 ..Default::default()
301 };
302 let pipeline = Pipeline::new(options, vec![per_call_count], vec![per_try_count], None);
303 let mut request = Request::new("http://localhost".parse().unwrap(), Method::Get);
304 let raw_response = pipeline
305 .send(&Context::default(), &mut request, None)
306 .await?;
307 Ok(raw_response.into())
308 }
309
310 let resp = service_method(per_call_count.clone(), per_try_count.clone())
311 .await
312 .expect("expected Response");
313 assert_eq!(per_try_count.count().await, 3);
314 assert_eq!(per_call_count.count().await, 1);
315
316 let model = resp.into_model().expect("expected Model");
317 assert_eq!(per_try_count.count().await, 3);
318 assert_eq!(per_call_count.count().await, 1);
319
320 assert_eq!(1, model.foo);
321 assert_eq!("baz", &model.bar);
322 }
323
324 #[tokio::test]
325 async fn stream_out_of_pipeline() {
326 #[derive(Debug)]
327 struct Responder {
328 responses: Mutex<VecDeque<AsyncRawResponse>>,
329 }
330
331 impl Default for Responder {
332 fn default() -> Self {
333 let mut headers = Headers::new();
334 headers.insert("content-type", "application/x-octet-stream");
335 headers.insert("transfer-encoding", "chunked");
336
337 Self {
338 responses: Mutex::new(VecDeque::from_iter([
339 AsyncRawResponse::from_bytes(
340 StatusCode::TooManyRequests,
341 Headers::new(),
342 Vec::new(),
343 ),
344 AsyncRawResponse::new(
345 StatusCode::Ok,
346 headers.clone(),
347 futures::stream::iter([
348 Ok(vec![0xde, 0xad].into()),
349 Ok(vec![0xbe, 0xef].into()),
350 Err(Error::new(ErrorKind::Io, "connection reset")),
352 ])
353 .boxed(),
354 ),
355 AsyncRawResponse::from_bytes(
356 StatusCode::ImATeapot,
357 Headers::new(),
358 r#"unexpected"#,
359 ),
360 ])),
361 }
362 }
363 }
364
365 #[async_trait::async_trait]
366 impl Policy for Responder {
367 async fn send(
368 &self,
369 _ctx: &Context,
370 _request: &mut Request,
371 _next: &[Arc<dyn Policy>],
372 ) -> PolicyResult {
373 let mut responses = self.responses.lock().await;
374 let response = responses.pop_front().expect("expected AsyncRawResponse");
375 Ok(response)
376 }
377 }
378
379 let per_call_count = Arc::new(Counter::default());
380 let per_try_count = Arc::new(Counter::default());
381
382 async fn service_method(
384 per_call_count: Arc<Counter>,
385 per_try_count: Arc<Counter>,
386 ) -> crate::Result<AsyncRawResponse> {
387 let options = ClientOptions {
388 retry: RetryOptions::fixed(FixedRetryOptions {
389 delay: Duration::milliseconds(1),
390 ..Default::default()
391 }),
392 transport: Some(Transport::with_policy(Arc::new(Responder::default()))),
393 ..Default::default()
394 };
395 let pipeline = Pipeline::new(options, vec![per_call_count], vec![per_try_count], None);
396 let mut request = Request::new("http://localhost".parse().unwrap(), Method::Get);
397 pipeline
398 .stream(&Context::default(), &mut request, None)
399 .await
400 }
401
402 let resp = service_method(per_call_count.clone(), per_try_count.clone())
403 .await
404 .expect("expected AsyncRawResponse");
405 assert_eq!(per_try_count.count().await, 2);
406 assert_eq!(per_call_count.count().await, 1);
407
408 let mut stream = resp.into_body().into_stream();
409 assert_eq!(
410 stream.try_next().await.expect("first chunk"),
411 Some(vec![0xde, 0xad].into())
412 );
413 assert_eq!(
414 stream.try_next().await.expect("second chunk"),
415 Some(vec![0xbe, 0xef].into())
416 );
417 assert!(matches!(stream.try_next().await, Err(e) if *e.kind() == ErrorKind::Io));
418
419 assert_eq!(per_try_count.count().await, 2);
421 assert_eq!(per_call_count.count().await, 1);
422 }
423}