Skip to main content

openai_core/resources/
common.rs

1//! Shared request-builder primitives used across resource namespaces.
2
3use std::collections::BTreeMap;
4use std::marker::PhantomData;
5use std::time::Duration;
6
7use bytes::Bytes;
8use http::Method;
9use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
10use serde::Serialize;
11use serde_json::Value;
12use tokio_util::sync::CancellationToken;
13
14use crate::Client;
15use crate::config::RequestOptions;
16use crate::error::{Error, Result};
17use crate::files::{MultipartField, UploadSource};
18use crate::json_payload::JsonPayload;
19use crate::pagination::{CursorPage, ListEnvelope};
20use crate::response_meta::ApiResponse;
21use crate::stream::{RawSseStream, SseStream};
22use crate::transport::{RequestSpec, merge_json_body};
23
24/// URL path encoding set used for dynamic path segments.
25const PATH_SEGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
26    .add(b'/')
27    .add(b'?')
28    .add(b'#')
29    .add(b'%')
30    .add(b'&')
31    .add(b'+')
32    .add(b'=');
33
34pub(crate) fn value_from<T>(value: &T) -> Result<Value>
35where
36    T: Serialize,
37{
38    serde_json::to_value(value).map_err(|error| {
39        crate::error::Error::Serialization(crate::SerializationError::new(error.to_string()))
40    })
41}
42
43/// 对单个路径参数做安全编码,避免动态 ID 改写 URL 结构。
44pub(crate) fn encode_path_segment(segment: impl AsRef<str>) -> String {
45    utf8_percent_encode(segment.as_ref(), PATH_SEGMENT_ENCODE_SET).to_string()
46}
47
48pub(crate) fn metadata_is_empty(metadata: &BTreeMap<String, String>) -> bool {
49    metadata.is_empty()
50}
51
52/// Shared state for typed JSON request builders defined in longtail namespaces.
53#[derive(Debug, Clone)]
54pub(crate) struct TypedJsonRequestState<P> {
55    pub(crate) client: Option<Client>,
56    pub(crate) params: P,
57    pub(crate) body_override: Option<Value>,
58    pub(crate) options: RequestOptions,
59    pub(crate) extra_body: BTreeMap<String, Value>,
60    pub(crate) provider_options: BTreeMap<String, Value>,
61}
62
63impl<P> TypedJsonRequestState<P> {
64    pub(crate) fn new(client: Client, params: P) -> Self {
65        Self {
66            client: Some(client),
67            params,
68            body_override: None,
69            options: RequestOptions::default(),
70            extra_body: BTreeMap::new(),
71            provider_options: BTreeMap::new(),
72        }
73    }
74
75    pub(crate) fn body_value(mut self, body: impl Into<JsonPayload>) -> Self {
76        self.body_override = Some(body.into().into_raw());
77        self
78    }
79
80    pub(crate) fn extra_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
81        self.options.insert_header(key, value);
82        self
83    }
84
85    pub(crate) fn extra_query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
86        self.options.insert_query(key, value);
87        self
88    }
89
90    pub(crate) fn extra_body(
91        mut self,
92        key: impl Into<String>,
93        value: impl Into<JsonPayload>,
94    ) -> Self {
95        self.extra_body.insert(key.into(), value.into().into_raw());
96        self
97    }
98
99    pub(crate) fn provider_option(
100        mut self,
101        key: impl Into<String>,
102        value: impl Into<JsonPayload>,
103    ) -> Self {
104        self.provider_options
105            .insert(key.into(), value.into().into_raw());
106        self
107    }
108
109    pub(crate) fn timeout(mut self, timeout: Duration) -> Self {
110        self.options.timeout = Some(timeout);
111        self
112    }
113
114    pub(crate) fn max_retries(mut self, max_retries: u32) -> Self {
115        self.options.max_retries = Some(max_retries);
116        self
117    }
118
119    pub(crate) fn cancellation_token(mut self, token: CancellationToken) -> Self {
120        self.options.cancellation_token = Some(token);
121        self
122    }
123}
124
125impl<P> TypedJsonRequestState<P>
126where
127    P: Serialize,
128{
129    pub(crate) fn build_spec(
130        mut self,
131        endpoint_id: &'static str,
132        path: &'static str,
133    ) -> Result<(Client, RequestSpec)> {
134        let client = self
135            .client
136            .take()
137            .ok_or_else(|| Error::InvalidConfig("请求构建器缺少客户端".into()))?;
138        let provider_key = client.provider().kind().as_key();
139        let body = merge_json_body(
140            Some(
141                self.body_override
142                    .take()
143                    .unwrap_or(value_from(&self.params)?),
144            ),
145            &self.extra_body,
146            provider_key,
147            &self.provider_options,
148        );
149        let mut spec = RequestSpec::new(endpoint_id, Method::POST, path);
150        spec.body = Some(body);
151        spec.options = self.options;
152        Ok((client, spec))
153    }
154}
155
156/// 表示通用 JSON 请求构建器。
157#[derive(Debug, Clone)]
158pub struct JsonRequestBuilder<T> {
159    pub(crate) client: Client,
160    pub(crate) spec: RequestSpec,
161    pub(crate) extra_body: BTreeMap<String, Value>,
162    pub(crate) provider_options: BTreeMap<String, Value>,
163    pub(crate) _marker: PhantomData<T>,
164}
165
166impl<T> JsonRequestBuilder<T> {
167    pub(crate) fn new(
168        client: Client,
169        endpoint_id: &'static str,
170        method: Method,
171        path: impl Into<String>,
172    ) -> Self {
173        Self {
174            client,
175            spec: RequestSpec::new(endpoint_id, method, path),
176            extra_body: BTreeMap::new(),
177            provider_options: BTreeMap::new(),
178            _marker: PhantomData,
179        }
180    }
181
182    /// 设置整个请求体为一个 `serde_json::Value`。
183    pub fn body_value(mut self, body: impl Into<JsonPayload>) -> Self {
184        self.spec.body = Some(body.into().into_raw());
185        self
186    }
187
188    /// 使用任意可序列化对象设置请求体。
189    ///
190    /// # Errors
191    ///
192    /// 当序列化失败时返回错误。
193    pub fn json_body<U>(mut self, body: &U) -> Result<Self>
194    where
195        U: Serialize,
196    {
197        self.spec.body = Some(value_from(body)?);
198        Ok(self)
199    }
200
201    /// 添加一个额外请求头。
202    pub fn extra_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
203        self.spec.options.insert_header(key, value);
204        self
205    }
206
207    /// 删除一个默认请求头。
208    pub fn remove_header(mut self, key: impl Into<String>) -> Self {
209        self.spec.options.remove_header(key);
210        self
211    }
212
213    /// 添加一个额外查询参数。
214    pub fn extra_query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
215        self.spec.options.insert_query(key, value);
216        self
217    }
218
219    /// 在 JSON 根对象中追加字段。
220    pub fn extra_body(mut self, key: impl Into<String>, value: impl Into<JsonPayload>) -> Self {
221        self.extra_body.insert(key.into(), value.into().into_raw());
222        self
223    }
224
225    /// 在 provider 对应的 `provider_options` 下追加字段。
226    pub fn provider_option(
227        mut self,
228        key: impl Into<String>,
229        value: impl Into<JsonPayload>,
230    ) -> Self {
231        self.provider_options
232            .insert(key.into(), value.into().into_raw());
233        self
234    }
235
236    /// 覆盖请求超时时间。
237    pub fn timeout(mut self, timeout: Duration) -> Self {
238        self.spec.options.timeout = Some(timeout);
239        self
240    }
241
242    /// 覆盖最大重试次数。
243    pub fn max_retries(mut self, max_retries: u32) -> Self {
244        self.spec.options.max_retries = Some(max_retries);
245        self
246    }
247
248    /// 设置取消令牌。
249    pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
250        self.spec.options.cancellation_token = Some(token);
251        self
252    }
253
254    /// 添加 Multipart 文本字段。
255    pub fn multipart_text(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
256        let multipart = self.spec.multipart.get_or_insert_default();
257        multipart.fields.push(MultipartField {
258            name: name.into(),
259            value: value.into(),
260        });
261        self
262    }
263
264    /// 添加 Multipart 文件字段。
265    pub fn multipart_file(mut self, name: impl Into<String>, file: UploadSource) -> Self {
266        let multipart = self.spec.multipart.get_or_insert_default();
267        multipart.files.push((name.into(), file));
268        self
269    }
270
271    pub(crate) fn into_spec(mut self) -> RequestSpec {
272        let provider_key = self.client.provider().kind().as_key();
273        self.spec.body = Some(merge_json_body(
274            self.spec.body.take(),
275            &self.extra_body,
276            provider_key,
277            &self.provider_options,
278        ));
279        self.spec
280    }
281}
282
283impl<T> JsonRequestBuilder<T>
284where
285    T: serde::de::DeserializeOwned,
286{
287    /// 发送请求并返回业务对象。
288    ///
289    /// # Errors
290    ///
291    /// 当请求失败或反序列化失败时返回错误。
292    pub async fn send(self) -> Result<T> {
293        Ok(self.send_with_meta().await?.data)
294    }
295
296    /// 发送请求并保留响应元信息。
297    ///
298    /// # Errors
299    ///
300    /// 当请求失败或反序列化失败时返回错误。
301    pub async fn send_with_meta(self) -> Result<ApiResponse<T>> {
302        let client = self.client.clone();
303        client.execute_json(self.into_spec()).await
304    }
305
306    /// 发送请求并返回原始 `http::Response<Bytes>`。
307    ///
308    /// # Errors
309    ///
310    /// 当请求失败时返回错误。
311    pub async fn send_raw(self) -> Result<http::Response<Bytes>> {
312        let client = self.client.clone();
313        client.execute_raw_http(self.into_spec()).await
314    }
315
316    /// 发送请求并返回原始 SSE 事件流。
317    ///
318    /// 该方法会自动追加 `Accept: text/event-stream` 请求头。
319    ///
320    /// # Errors
321    ///
322    /// 当请求失败时返回错误。
323    pub async fn send_raw_sse(self) -> Result<RawSseStream> {
324        let client = self.client.clone();
325        let mut spec = self.into_spec();
326        spec.options.insert_header("accept", "text/event-stream");
327        client.execute_raw_sse(spec).await
328    }
329}
330
331impl<T> JsonRequestBuilder<T>
332where
333    T: serde::de::DeserializeOwned + Send + 'static,
334{
335    /// 发送请求并把 SSE 数据流解析为指定类型。
336    ///
337    /// 该方法会自动追加 `Accept: text/event-stream` 请求头。
338    ///
339    /// # Errors
340    ///
341    /// 当请求失败或 SSE 事件反序列化失败时返回错误。
342    pub async fn send_sse(self) -> Result<SseStream<T>> {
343        let client = self.client.clone();
344        let mut spec = self.into_spec();
345        spec.options.insert_header("accept", "text/event-stream");
346        client.execute_sse(spec).await
347    }
348}
349
350/// 表示二进制响应请求构建器。
351#[derive(Debug, Clone)]
352pub struct BytesRequestBuilder {
353    pub(crate) inner: JsonRequestBuilder<Bytes>,
354}
355
356/// 表示不关心响应体的请求构建器。
357#[derive(Debug, Clone)]
358pub struct NoContentRequestBuilder {
359    pub(crate) inner: JsonRequestBuilder<Bytes>,
360}
361
362impl BytesRequestBuilder {
363    pub(crate) fn new(
364        client: Client,
365        endpoint_id: &'static str,
366        method: Method,
367        path: impl Into<String>,
368    ) -> Self {
369        Self {
370            inner: JsonRequestBuilder::new(client, endpoint_id, method, path),
371        }
372    }
373
374    /// 设置 JSON 请求体。
375    pub fn body_value(mut self, body: impl Into<JsonPayload>) -> Self {
376        self.inner = self.inner.body_value(body);
377        self
378    }
379
380    /// 设置可序列化请求体。
381    ///
382    /// # Errors
383    ///
384    /// 当序列化失败时返回错误。
385    pub fn json_body<U>(mut self, body: &U) -> Result<Self>
386    where
387        U: Serialize,
388    {
389        self.inner = self.inner.json_body(body)?;
390        Ok(self)
391    }
392
393    /// 追加请求头。
394    pub fn extra_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
395        self.inner = self.inner.extra_header(key, value);
396        self
397    }
398
399    /// 追加查询参数。
400    pub fn extra_query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
401        self.inner = self.inner.extra_query(key, value);
402        self
403    }
404
405    /// 在 provider 对应的 `provider_options` 下追加字段。
406    pub fn provider_option(
407        mut self,
408        key: impl Into<String>,
409        value: impl Into<JsonPayload>,
410    ) -> Self {
411        self.inner = self.inner.provider_option(key, value);
412        self
413    }
414
415    /// 覆盖请求超时时间。
416    pub fn timeout(mut self, timeout: Duration) -> Self {
417        self.inner = self.inner.timeout(timeout);
418        self
419    }
420
421    /// 覆盖最大重试次数。
422    pub fn max_retries(mut self, max_retries: u32) -> Self {
423        self.inner = self.inner.max_retries(max_retries);
424        self
425    }
426
427    /// 设置取消令牌。
428    pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
429        self.inner = self.inner.cancellation_token(token);
430        self
431    }
432
433    /// 添加 Multipart 文本字段。
434    pub fn multipart_text(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
435        self.inner = self.inner.multipart_text(name, value);
436        self
437    }
438
439    /// 添加 Multipart 文件字段。
440    pub fn multipart_file(mut self, name: impl Into<String>, file: UploadSource) -> Self {
441        self.inner = self.inner.multipart_file(name, file);
442        self
443    }
444
445    /// 在 JSON 根对象中追加字段。
446    pub fn extra_body(mut self, key: impl Into<String>, value: impl Into<JsonPayload>) -> Self {
447        self.inner = self.inner.extra_body(key, value);
448        self
449    }
450
451    /// 发送请求并返回原始字节。
452    ///
453    /// # Errors
454    ///
455    /// 当请求失败时返回错误。
456    pub async fn send(self) -> Result<Bytes> {
457        Ok(self.send_with_meta().await?.data)
458    }
459
460    /// 发送请求并保留响应元信息。
461    ///
462    /// # Errors
463    ///
464    /// 当请求失败时返回错误。
465    pub async fn send_with_meta(self) -> Result<ApiResponse<Bytes>> {
466        let client = self.inner.client.clone();
467        client.execute_bytes(self.inner.into_spec()).await
468    }
469
470    /// 发送请求并返回原始 HTTP 响应。
471    ///
472    /// # Errors
473    ///
474    /// 当请求失败时返回错误。
475    pub async fn send_raw(self) -> Result<http::Response<Bytes>> {
476        let client = self.inner.client.clone();
477        client.execute_raw_http(self.inner.into_spec()).await
478    }
479
480    /// 发送请求并返回原始 SSE 事件流。
481    ///
482    /// 该方法会自动追加 `Accept: text/event-stream` 请求头。
483    ///
484    /// # Errors
485    ///
486    /// 当请求失败时返回错误。
487    pub async fn send_raw_sse(self) -> Result<RawSseStream> {
488        let client = self.inner.client.clone();
489        let mut spec = self.inner.into_spec();
490        spec.options.insert_header("accept", "text/event-stream");
491        client.execute_raw_sse(spec).await
492    }
493
494    /// 发送请求并把 SSE 数据流解析为指定类型。
495    ///
496    /// 该方法会自动追加 `Accept: text/event-stream` 请求头。
497    ///
498    /// # Errors
499    ///
500    /// 当请求失败或 SSE 事件反序列化失败时返回错误。
501    pub async fn send_sse<T>(self) -> Result<SseStream<T>>
502    where
503        T: serde::de::DeserializeOwned + Send + 'static,
504    {
505        let client = self.inner.client.clone();
506        let mut spec = self.inner.into_spec();
507        spec.options.insert_header("accept", "text/event-stream");
508        client.execute_sse(spec).await
509    }
510}
511
512impl NoContentRequestBuilder {
513    pub(crate) fn new(
514        client: Client,
515        endpoint_id: &'static str,
516        method: Method,
517        path: impl Into<String>,
518    ) -> Self {
519        Self {
520            inner: JsonRequestBuilder::new(client, endpoint_id, method, path),
521        }
522    }
523
524    /// 设置 JSON 请求体。
525    pub fn body_value(mut self, body: impl Into<JsonPayload>) -> Self {
526        self.inner = self.inner.body_value(body);
527        self
528    }
529
530    /// 设置可序列化请求体。
531    ///
532    /// # Errors
533    ///
534    /// 当序列化失败时返回错误。
535    pub fn json_body<U>(mut self, body: &U) -> Result<Self>
536    where
537        U: Serialize,
538    {
539        self.inner = self.inner.json_body(body)?;
540        Ok(self)
541    }
542
543    /// 追加请求头。
544    pub fn extra_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
545        self.inner = self.inner.extra_header(key, value);
546        self
547    }
548
549    /// 删除一个默认请求头。
550    pub fn remove_header(mut self, key: impl Into<String>) -> Self {
551        self.inner = self.inner.remove_header(key);
552        self
553    }
554
555    /// 追加查询参数。
556    pub fn extra_query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
557        self.inner = self.inner.extra_query(key, value);
558        self
559    }
560
561    /// 在 JSON 根对象中追加字段。
562    pub fn extra_body(mut self, key: impl Into<String>, value: impl Into<JsonPayload>) -> Self {
563        self.inner = self.inner.extra_body(key, value);
564        self
565    }
566
567    /// 在 provider 对应的 `provider_options` 下追加字段。
568    pub fn provider_option(
569        mut self,
570        key: impl Into<String>,
571        value: impl Into<JsonPayload>,
572    ) -> Self {
573        self.inner = self.inner.provider_option(key, value);
574        self
575    }
576
577    /// 覆盖请求超时时间。
578    pub fn timeout(mut self, timeout: Duration) -> Self {
579        self.inner = self.inner.timeout(timeout);
580        self
581    }
582
583    /// 覆盖最大重试次数。
584    pub fn max_retries(mut self, max_retries: u32) -> Self {
585        self.inner = self.inner.max_retries(max_retries);
586        self
587    }
588
589    /// 设置取消令牌。
590    pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
591        self.inner = self.inner.cancellation_token(token);
592        self
593    }
594
595    /// 发送请求并忽略响应体。
596    ///
597    /// # Errors
598    ///
599    /// 当请求失败时返回错误。
600    pub async fn send(self) -> Result<()> {
601        self.send_with_meta().await.map(|_| ())
602    }
603
604    /// 发送请求并保留响应元信息。
605    ///
606    /// # Errors
607    ///
608    /// 当请求失败时返回错误。
609    pub async fn send_with_meta(self) -> Result<ApiResponse<()>> {
610        let client = self.inner.client.clone();
611        let response = client.execute_bytes(self.inner.into_spec()).await?;
612        let (_, meta) = response.into_parts();
613        Ok(ApiResponse::new((), meta))
614    }
615
616    /// 发送请求并返回原始 HTTP 响应。
617    ///
618    /// # Errors
619    ///
620    /// 当请求失败时返回错误。
621    pub async fn send_raw(self) -> Result<http::Response<Bytes>> {
622        let client = self.inner.client.clone();
623        client.execute_raw_http(self.inner.into_spec()).await
624    }
625}
626
627/// 表示列表请求构建器。
628#[derive(Debug, Clone)]
629pub struct ListRequestBuilder<T> {
630    pub(crate) inner: JsonRequestBuilder<ListEnvelope<T>>,
631}
632
633impl<T> ListRequestBuilder<T> {
634    pub(crate) fn new(client: Client, endpoint_id: &'static str, path: impl Into<String>) -> Self {
635        Self {
636            inner: JsonRequestBuilder::new(client, endpoint_id, Method::GET, path),
637        }
638    }
639
640    /// 设置 `after` 游标。
641    pub fn after(mut self, cursor: impl Into<String>) -> Self {
642        self.inner = self.inner.extra_query("after", cursor);
643        self
644    }
645
646    /// 设置 `before` 游标。
647    pub fn before(mut self, cursor: impl Into<String>) -> Self {
648        self.inner = self.inner.extra_query("before", cursor);
649        self
650    }
651
652    /// 设置分页大小。
653    pub fn limit(mut self, limit: u32) -> Self {
654        self.inner = self.inner.extra_query("limit", limit.to_string());
655        self
656    }
657
658    /// 追加请求头。
659    pub fn extra_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
660        self.inner = self.inner.extra_header(key, value);
661        self
662    }
663
664    /// 在根对象追加额外字段。
665    pub fn extra_body(mut self, key: impl Into<String>, value: impl Into<JsonPayload>) -> Self {
666        self.inner = self.inner.extra_body(key, value);
667        self
668    }
669}
670
671impl<T> ListRequestBuilder<T>
672where
673    T: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
674{
675    /// 发送列表请求并返回游标分页对象。
676    ///
677    /// # Errors
678    ///
679    /// 当请求失败或反序列化失败时返回错误。
680    pub async fn send(self) -> Result<CursorPage<T>> {
681        let client = self.inner.client.clone();
682        let path = self.inner.spec.path.clone();
683        let endpoint_id = self.inner.spec.endpoint_id;
684        let response = client
685            .execute_json::<ListEnvelope<T>>(self.inner.into_spec())
686            .await?;
687        let ListEnvelope {
688            object,
689            data,
690            first_id,
691            last_id,
692            has_more,
693            extra,
694        } = response.data;
695        let mut next_query = BTreeMap::new();
696        if let Some(last_id) = &last_id {
697            next_query.insert("after".into(), Some(last_id.clone()));
698        }
699        Ok(CursorPage::from(ListEnvelope {
700            object,
701            data,
702            first_id,
703            last_id,
704            has_more,
705            extra,
706        })
707        .with_next_request(if has_more {
708            Some(crate::client::PageRequestSpec {
709                client,
710                endpoint_id,
711                method: Method::GET,
712                path,
713                query: next_query,
714            })
715        } else {
716            None
717        }))
718    }
719}
720
721#[cfg(test)]
722mod tests {
723    use percent_encoding::percent_decode_str;
724    use proptest::prelude::*;
725
726    use super::encode_path_segment;
727
728    proptest! {
729        #[test]
730        fn encoded_path_segment_roundtrips_through_percent_decode(segment in any::<String>()) {
731            let encoded = encode_path_segment(&segment);
732            let decoded = percent_decode_str(&encoded).decode_utf8().unwrap();
733            prop_assert_eq!(decoded, segment);
734        }
735    }
736}