Skip to main content

linger_openai_sdk/
transport.rs

1use crate::error::{HeaderMap, IntoHeaderPair, LingerError};
2use bytes::Bytes;
3use futures_core::Stream;
4use futures_util::StreamExt;
5use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10/// EN: Boxed byte stream used by response bodies.
11/// 中文:响应体使用的装箱字节流。
12pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes, LingerError>> + Send>>;
13
14/// EN: Runtime-neutral HTTP request body.
15/// 中文:运行时无关的 HTTP 请求体。
16pub enum HttpRequestBody {
17    /// EN: In-memory byte body.
18    /// 中文:内存字节请求体。
19    Bytes(Bytes),
20    /// EN: Incremental byte stream body.
21    /// 中文:增量字节流请求体。
22    Stream(BodyStream),
23}
24
25impl HttpRequestBody {
26    /// EN: Returns true when this body is a stream.
27    /// 中文:当请求体为流时返回 true。
28    pub fn is_stream(&self) -> bool {
29        matches!(self, Self::Stream(_))
30    }
31
32    /// EN: Returns byte body contents without buffering streams.
33    /// 中文:返回字节请求体内容,但不会缓冲流式请求体。
34    pub fn as_bytes(&self) -> Option<&[u8]> {
35        match self {
36            Self::Bytes(bytes) => Some(bytes),
37            Self::Stream(_) => None,
38        }
39    }
40
41    /// EN: Converts the body into an incremental byte stream.
42    /// 中文:将请求体转换为增量字节流。
43    pub fn into_stream(self) -> BodyStream {
44        match self {
45            Self::Bytes(bytes) => {
46                Box::pin(futures_util::stream::once(async move { Ok(bytes) })) as BodyStream
47            }
48            Self::Stream(stream) => stream,
49        }
50    }
51}
52
53impl fmt::Debug for HttpRequestBody {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        match self {
56            Self::Bytes(bytes) => write!(f, "<{} bytes>", bytes.len()),
57            Self::Stream(_) => f.write_str("<stream>"),
58        }
59    }
60}
61
62/// EN: HTTP method supported by SDK requests.
63/// 中文:SDK 请求支持的 HTTP 方法。
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65#[non_exhaustive]
66pub enum HttpMethod {
67    /// EN: HTTP GET.
68    /// 中文:HTTP GET。
69    Get,
70    /// EN: HTTP POST.
71    /// 中文:HTTP POST。
72    Post,
73    /// EN: HTTP DELETE.
74    /// 中文:HTTP DELETE。
75    Delete,
76}
77
78/// EN: Runtime-neutral HTTP request passed to transports.
79/// 中文:传给传输层的运行时无关 HTTP 请求。
80pub struct HttpRequest {
81    method: HttpMethod,
82    url: String,
83    path: String,
84    headers: HeaderMap,
85    body: Option<HttpRequestBody>,
86}
87
88impl HttpRequest {
89    /// EN: Creates an SDK HTTP request.
90    /// 中文:创建 SDK HTTP 请求。
91    pub fn new(method: HttpMethod, base_url: impl AsRef<str>, path: impl Into<String>) -> Self {
92        let path = path.into();
93        let url = format!("{}{}", base_url.as_ref().trim_end_matches('/'), path);
94        Self {
95            method,
96            url,
97            path,
98            headers: HeaderMap::new(),
99            body: None,
100        }
101    }
102
103    /// EN: Returns the HTTP method.
104    /// 中文:返回 HTTP 方法。
105    pub fn method(&self) -> HttpMethod {
106        self.method
107    }
108
109    /// EN: Returns the full URL.
110    /// 中文:返回完整 URL。
111    pub fn url(&self) -> &str {
112        &self.url
113    }
114
115    /// EN: Returns the API path.
116    /// 中文:返回 API 路径。
117    pub fn path(&self) -> &str {
118        &self.path
119    }
120
121    /// EN: Adds or replaces a header.
122    /// 中文:添加或替换请求头。
123    pub fn insert_header(&mut self, name: impl Into<String>, value: impl Into<String>) {
124        self.headers.insert(name, value);
125    }
126
127    /// EN: Returns a header value by case-insensitive name.
128    /// 中文:按大小写不敏感名称返回请求头值。
129    pub fn header(&self, name: &str) -> Option<&str> {
130        self.headers.get(name)
131    }
132
133    /// EN: Returns all request headers.
134    /// 中文:返回所有请求头。
135    pub fn headers(&self) -> &HeaderMap {
136        &self.headers
137    }
138
139    /// EN: Sets the request body.
140    /// 中文:设置请求体。
141    pub fn set_body(&mut self, body: impl Into<Bytes>) {
142        self.body = Some(HttpRequestBody::Bytes(body.into()));
143    }
144
145    /// EN: Sets an incremental request body stream.
146    /// 中文:设置增量请求体流。
147    pub fn set_body_stream<S>(&mut self, body: S)
148    where
149        S: Stream<Item = Result<Bytes, LingerError>> + Send + 'static,
150    {
151        self.body = Some(HttpRequestBody::Stream(Box::pin(body)));
152    }
153
154    /// EN: Returns the request body bytes, when present.
155    /// 中文:返回请求体字节,如存在。
156    pub fn body(&self) -> Option<&[u8]> {
157        self.body.as_ref().and_then(HttpRequestBody::as_bytes)
158    }
159
160    /// EN: Returns true when the request body is streamed.
161    /// 中文:当请求体为流式传输时返回 true。
162    pub fn body_is_stream(&self) -> bool {
163        self.body.as_ref().is_some_and(HttpRequestBody::is_stream)
164    }
165
166    /// EN: Consumes the request and returns its body.
167    /// 中文:消耗请求并返回请求体。
168    pub fn into_body(self) -> Option<HttpRequestBody> {
169        self.body
170    }
171}
172
173impl fmt::Debug for HttpRequest {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        f.debug_struct("HttpRequest")
176            .field("method", &self.method)
177            .field("url", &self.url)
178            .field("path", &self.path)
179            .field("headers", &self.headers)
180            .field("body", &self.body.as_ref())
181            .finish()
182    }
183}
184
185enum HttpBody {
186    Bytes(Bytes),
187    Stream(BodyStream),
188}
189
190/// EN: Runtime-neutral HTTP response returned by transports.
191/// 中文:传输层返回的运行时无关 HTTP 响应。
192pub struct HttpResponse {
193    status: u16,
194    headers: HeaderMap,
195    body: HttpBody,
196}
197
198impl HttpResponse {
199    /// EN: Creates a response with an in-memory byte body.
200    /// 中文:创建包含内存字节响应体的响应。
201    pub fn from_bytes<I, P>(status: u16, headers: I, body: impl Into<Bytes>) -> Self
202    where
203        I: IntoIterator<Item = P>,
204        P: IntoHeaderPair,
205    {
206        Self {
207            status,
208            headers: HeaderMap::from_pairs(headers),
209            body: HttpBody::Bytes(body.into()),
210        }
211    }
212
213    /// EN: Creates a response with an incremental byte stream body.
214    /// 中文:创建包含增量字节流响应体的响应。
215    pub fn from_stream<I, P, S>(status: u16, headers: I, body: S) -> Self
216    where
217        I: IntoIterator<Item = P>,
218        P: IntoHeaderPair,
219        S: Stream<Item = Result<Bytes, LingerError>> + Send + 'static,
220    {
221        Self {
222            status,
223            headers: HeaderMap::from_pairs(headers),
224            body: HttpBody::Stream(Box::pin(body)),
225        }
226    }
227
228    /// EN: Returns the HTTP status.
229    /// 中文:返回 HTTP 状态码。
230    pub fn status(&self) -> u16 {
231        self.status
232    }
233
234    /// EN: Returns response headers.
235    /// 中文:返回响应头。
236    pub fn headers(&self) -> &HeaderMap {
237        &self.headers
238    }
239
240    /// EN: Consumes the response and returns headers.
241    /// 中文:消耗响应并返回响应头。
242    pub fn into_parts(self) -> (u16, HeaderMap, BodyStream) {
243        let body = match self.body {
244            HttpBody::Bytes(bytes) => {
245                Box::pin(futures_util::stream::once(async move { Ok(bytes) })) as BodyStream
246            }
247            HttpBody::Stream(stream) => stream,
248        };
249        (self.status, self.headers, body)
250    }
251
252    /// EN: Consumes the response and returns an incremental body stream.
253    /// 中文:消耗响应并返回增量响应体流。
254    pub fn into_body_stream(self) -> BodyStream {
255        self.into_parts().2
256    }
257
258    /// EN: Consumes the response and collects the full body.
259    /// 中文:消耗响应并收集完整响应体。
260    pub async fn into_bytes(self) -> Result<(u16, HeaderMap, Bytes), LingerError> {
261        let (status, headers, mut stream) = self.into_parts();
262        let mut body = Vec::new();
263        while let Some(chunk) = stream.next().await {
264            body.extend_from_slice(&chunk?);
265        }
266        Ok((status, headers, Bytes::from(body)))
267    }
268}
269
270/// EN: Runtime-neutral HTTP transport boundary.
271/// 中文:运行时无关的 HTTP 传输边界。
272pub trait Transport: Send + Sync {
273    /// EN: Sends a request and asynchronously returns a response.
274    /// 中文:发送请求并异步返回响应。
275    fn send(
276        &self,
277        request: HttpRequest,
278    ) -> Pin<Box<dyn Future<Output = Result<HttpResponse, LingerError>> + Send + '_>>;
279}
280
281/// EN: Cloneable shared transport handle.
282/// 中文:可克隆的共享传输句柄。
283#[derive(Clone)]
284pub struct SharedTransport {
285    inner: Arc<dyn Transport>,
286}
287
288impl SharedTransport {
289    /// EN: Wraps a concrete transport.
290    /// 中文:包装具体传输实现。
291    pub fn new<T>(transport: T) -> Self
292    where
293        T: Transport + 'static,
294    {
295        Self {
296            inner: Arc::new(transport),
297        }
298    }
299
300    /// EN: Sends a request through the shared transport.
301    /// 中文:通过共享传输发送请求。
302    pub fn send(
303        &self,
304        request: HttpRequest,
305    ) -> Pin<Box<dyn Future<Output = Result<HttpResponse, LingerError>> + Send + '_>> {
306        self.inner.send(request)
307    }
308}
309
310impl fmt::Debug for SharedTransport {
311    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        f.debug_struct("SharedTransport").finish_non_exhaustive()
313    }
314}
315
316/// EN: Reqwest-backed default HTTP transport.
317/// 中文:基于 reqwest 的默认 HTTP 传输实现。
318#[cfg(feature = "reqwest-transport")]
319#[derive(Clone, Debug)]
320pub struct ReqwestTransport {
321    client: reqwest::Client,
322}
323
324#[cfg(feature = "reqwest-transport")]
325impl Default for ReqwestTransport {
326    fn default() -> Self {
327        Self {
328            client: reqwest::Client::new(),
329        }
330    }
331}
332
333#[cfg(feature = "reqwest-transport")]
334impl ReqwestTransport {
335    /// EN: Creates a reqwest transport from an existing client.
336    /// 中文:通过已有 reqwest 客户端创建传输实现。
337    pub fn new(client: reqwest::Client) -> Self {
338        Self { client }
339    }
340}
341
342#[cfg(feature = "reqwest-transport")]
343impl Transport for ReqwestTransport {
344    fn send(
345        &self,
346        request: HttpRequest,
347    ) -> Pin<Box<dyn Future<Output = Result<HttpResponse, LingerError>> + Send + '_>> {
348        Box::pin(async move {
349            let method = match request.method {
350                HttpMethod::Get => reqwest::Method::GET,
351                HttpMethod::Post => reqwest::Method::POST,
352                HttpMethod::Delete => reqwest::Method::DELETE,
353            };
354            let mut builder = self.client.request(method, request.url);
355            for (name, value) in request.headers.iter() {
356                builder = builder.header(name, value);
357            }
358            if let Some(body) = request.body {
359                builder = match body {
360                    HttpRequestBody::Bytes(bytes) => builder.body(bytes),
361                    HttpRequestBody::Stream(stream) => {
362                        builder.body(reqwest::Body::wrap_stream(stream))
363                    }
364                };
365            }
366            let response = builder
367                .send()
368                .await
369                .map_err(|error| LingerError::transport(error.to_string()))?;
370            let status = response.status().as_u16();
371            let headers =
372                HeaderMap::from_pairs(response.headers().iter().filter_map(|(name, value)| {
373                    value.to_str().ok().map(|value| (name.as_str(), value))
374                }));
375            let stream = response
376                .bytes_stream()
377                .map(|chunk| chunk.map_err(|error| LingerError::transport(error.to_string())));
378            Ok(HttpResponse::from_stream(status, headers.iter(), stream))
379        })
380    }
381}