Skip to main content

openai_core/files/
mod.rs

1//! 文件上传相关抽象。
2
3use std::fmt;
4use std::io::Read;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7
8use bytes::Bytes;
9use reqwest::header::CONTENT_TYPE;
10use reqwest::multipart::Part;
11use tokio::io::{AsyncRead, AsyncReadExt};
12
13use crate::error::{Error, Result};
14
15/// 表示 Multipart 文本字段。
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct MultipartField {
18    /// 字段名称。
19    pub name: String,
20    /// 字段值。
21    pub value: String,
22}
23
24/// 统一的文件输入类型别名。
25pub type FileLike = UploadSource;
26
27/// 表示 `to_file()` 可接受的统一输入。
28pub enum ToFileInput {
29    /// 来自文件路径。
30    Path(PathBuf),
31    /// 来自内存字节。
32    Bytes(Bytes),
33    /// 来自已有上传源。
34    UploadSource(UploadSource),
35    /// 来自读取器。
36    Reader(Box<dyn Read + Send>),
37    /// 来自异步读取器。
38    AsyncReader(Pin<Box<dyn AsyncRead + Send>>),
39    /// 来自 HTTP 响应。
40    Response(reqwest::Response),
41}
42
43impl fmt::Debug for ToFileInput {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Self::Path(path) => f.debug_tuple("ToFileInput::Path").field(path).finish(),
47            Self::Bytes(bytes) => f
48                .debug_struct("ToFileInput::Bytes")
49                .field("size", &bytes.len())
50                .finish(),
51            Self::UploadSource(source) => f
52                .debug_tuple("ToFileInput::UploadSource")
53                .field(source)
54                .finish(),
55            Self::Reader(_) => f.write_str("ToFileInput::Reader(..)"),
56            Self::AsyncReader(_) => f.write_str("ToFileInput::AsyncReader(..)"),
57            Self::Response(response) => f
58                .debug_struct("ToFileInput::Response")
59                .field("url", response.url())
60                .finish(),
61        }
62    }
63}
64
65impl ToFileInput {
66    /// 从路径构造输入。
67    pub fn path(path: impl Into<PathBuf>) -> Self {
68        Self::Path(path.into())
69    }
70
71    /// 从读取器构造输入。
72    pub fn reader<R>(reader: R) -> Self
73    where
74        R: Read + Send + 'static,
75    {
76        Self::Reader(Box::new(reader))
77    }
78
79    /// 从异步读取器构造输入。
80    pub fn async_reader<R>(reader: R) -> Self
81    where
82        R: AsyncRead + Send + 'static,
83    {
84        Self::AsyncReader(Box::pin(reader))
85    }
86
87    /// 从已有上传源构造输入。
88    pub fn upload(source: UploadSource) -> Self {
89        Self::UploadSource(source)
90    }
91}
92
93impl From<PathBuf> for ToFileInput {
94    fn from(value: PathBuf) -> Self {
95        Self::Path(value)
96    }
97}
98
99impl From<Vec<u8>> for ToFileInput {
100    fn from(value: Vec<u8>) -> Self {
101        Self::Bytes(Bytes::from(value))
102    }
103}
104
105impl From<Bytes> for ToFileInput {
106    fn from(value: Bytes) -> Self {
107        Self::Bytes(value)
108    }
109}
110
111impl From<UploadSource> for ToFileInput {
112    fn from(value: UploadSource) -> Self {
113        Self::UploadSource(value)
114    }
115}
116
117impl From<reqwest::Response> for ToFileInput {
118    fn from(value: reqwest::Response) -> Self {
119        Self::Response(value)
120    }
121}
122
123/// 从统一输入构造上传文件对象。
124///
125/// 当输入本身无法推导文件名时,调用方应显式提供 `filename`。
126///
127/// # Errors
128///
129/// 当读取输入失败,或字节/读取器输入未提供文件名时返回错误。
130pub async fn to_file(
131    input: impl Into<ToFileInput>,
132    filename: Option<impl Into<String>>,
133) -> Result<UploadSource> {
134    let filename = filename.map(Into::into);
135    match input.into() {
136        ToFileInput::Path(path) => {
137            let mut source = UploadSource::from_path(path)?;
138            if let Some(filename) = filename {
139                source = source.with_filename(filename);
140            }
141            Ok(source)
142        }
143        ToFileInput::Bytes(bytes) => {
144            let filename = filename.ok_or_else(|| {
145                Error::InvalidConfig("字节输入无法自动推导文件名,请显式提供 filename".into())
146            })?;
147            Ok(UploadSource::from_bytes(bytes, filename))
148        }
149        ToFileInput::UploadSource(source) => Ok(match filename {
150            Some(filename) => source.with_filename(filename),
151            None => source,
152        }),
153        ToFileInput::Reader(reader) => {
154            let filename = filename.ok_or_else(|| {
155                Error::InvalidConfig("读取器输入无法自动推导文件名,请显式提供 filename".into())
156            })?;
157            UploadSource::from_reader(reader, filename)
158        }
159        ToFileInput::AsyncReader(mut reader) => {
160            let filename = filename.ok_or_else(|| {
161                Error::InvalidConfig("异步读取器输入无法自动推导文件名,请显式提供 filename".into())
162            })?;
163            let mut buffer = Vec::new();
164            reader
165                .read_to_end(&mut buffer)
166                .await
167                .map_err(|error| Error::InvalidConfig(format!("读取异步上传流失败: {error}")))?;
168            Ok(UploadSource::from_bytes(buffer, filename))
169        }
170        ToFileInput::Response(response) => {
171            let mut source = UploadSource::from_response(response).await?;
172            if let Some(filename) = filename {
173                source = source.with_filename(filename);
174            }
175            Ok(source)
176        }
177    }
178}
179
180/// 表示一个可上传的文件来源。
181#[derive(Clone)]
182pub enum UploadSource {
183    /// 直接由内存字节构成。
184    Bytes {
185        /// 文件字节。
186        data: Bytes,
187        /// 文件名。
188        filename: String,
189        /// 可选 MIME 类型。
190        mime_type: Option<String>,
191    },
192    /// 由文件路径读取得到。
193    Path {
194        /// 原始路径。
195        path: PathBuf,
196        /// 文件字节。
197        data: Bytes,
198        /// 文件名。
199        filename: String,
200        /// 可选 MIME 类型。
201        mime_type: Option<String>,
202    },
203    /// 由通用读取器读取得到。
204    Reader {
205        /// 文件字节。
206        data: Bytes,
207        /// 文件名。
208        filename: String,
209        /// 可选 MIME 类型。
210        mime_type: Option<String>,
211    },
212}
213
214impl UploadSource {
215    /// 从文件路径创建上传源。
216    ///
217    /// # Errors
218    ///
219    /// 当文件不存在、无法读取或无法推导文件名时返回错误。
220    pub fn from_path<P>(path: P) -> Result<Self>
221    where
222        P: AsRef<Path>,
223    {
224        let path = path.as_ref();
225        let data = std::fs::read(path)
226            .map(Bytes::from)
227            .map_err(|error| Error::InvalidConfig(format!("读取上传文件失败: {error}")))?;
228        let filename = path
229            .file_name()
230            .and_then(|value| value.to_str())
231            .ok_or_else(|| Error::InvalidConfig("无法从路径推导文件名".into()))?
232            .to_owned();
233        let mime_type = mime_guess::from_path(path).first_raw().map(str::to_owned);
234
235        Ok(Self::Path {
236            path: path.to_path_buf(),
237            data,
238            filename,
239            mime_type,
240        })
241    }
242
243    /// 从内存字节创建上传源。
244    pub fn from_bytes<T, U>(bytes: T, filename: U) -> Self
245    where
246        T: Into<Bytes>,
247        U: Into<String>,
248    {
249        Self::Bytes {
250            data: bytes.into(),
251            filename: filename.into(),
252            mime_type: None,
253        }
254    }
255
256    /// 从通用读取器读取字节并创建上传源。
257    ///
258    /// # Errors
259    ///
260    /// 当读取器读取失败时返回错误。
261    pub fn from_reader<R, U>(mut reader: R, filename: U) -> Result<Self>
262    where
263        R: Read,
264        U: Into<String>,
265    {
266        let mut buffer = Vec::new();
267        reader
268            .read_to_end(&mut buffer)
269            .map_err(|error| Error::InvalidConfig(format!("读取上传流失败: {error}")))?;
270
271        Ok(Self::Reader {
272            data: Bytes::from(buffer),
273            filename: filename.into(),
274            mime_type: None,
275        })
276    }
277
278    /// 从 HTTP 响应中读取字节并创建上传源。
279    ///
280    /// 该方法会优先使用响应 URL 的最后一个路径段作为文件名。
281    /// 如果无法推导,则回退为 `upload.bin`。
282    ///
283    /// # Errors
284    ///
285    /// 当响应体读取失败时返回错误。
286    pub async fn from_response(response: reqwest::Response) -> Result<Self> {
287        let filename = response
288            .url()
289            .path_segments()
290            .and_then(|mut segments| segments.rfind(|segment| !segment.is_empty()))
291            .map(str::to_owned)
292            .unwrap_or_else(|| "upload.bin".into());
293        let mime_type = response
294            .headers()
295            .get(CONTENT_TYPE)
296            .and_then(|value| value.to_str().ok())
297            .map(str::to_owned);
298        let data = response
299            .bytes()
300            .await
301            .map_err(|error| Error::InvalidConfig(format!("读取上传响应失败: {error}")))?;
302
303        Ok(Self::Bytes {
304            data,
305            filename,
306            mime_type,
307        })
308    }
309
310    /// 覆盖 MIME 类型。
311    pub fn with_mime_type<T>(mut self, mime_type: T) -> Self
312    where
313        T: Into<String>,
314    {
315        let mime_type = Some(mime_type.into());
316        match &mut self {
317            Self::Bytes {
318                mime_type: target, ..
319            }
320            | Self::Path {
321                mime_type: target, ..
322            }
323            | Self::Reader {
324                mime_type: target, ..
325            } => {
326                *target = mime_type;
327            }
328        }
329        self
330    }
331
332    /// 覆盖文件名。
333    pub fn with_filename<T>(mut self, filename: T) -> Self
334    where
335        T: Into<String>,
336    {
337        let filename = filename.into();
338        match &mut self {
339            Self::Bytes {
340                filename: target, ..
341            }
342            | Self::Path {
343                filename: target, ..
344            }
345            | Self::Reader {
346                filename: target, ..
347            } => {
348                *target = filename;
349            }
350        }
351        self
352    }
353
354    /// 返回文件名。
355    pub fn filename(&self) -> &str {
356        match self {
357            Self::Bytes { filename, .. }
358            | Self::Path { filename, .. }
359            | Self::Reader { filename, .. } => filename,
360        }
361    }
362
363    /// 返回 MIME 类型。
364    pub fn mime_type(&self) -> Option<&str> {
365        match self {
366            Self::Bytes { mime_type, .. }
367            | Self::Path { mime_type, .. }
368            | Self::Reader { mime_type, .. } => mime_type.as_deref(),
369        }
370    }
371
372    /// 返回原始字节。
373    pub fn bytes(&self) -> &Bytes {
374        match self {
375            Self::Bytes { data, .. } | Self::Path { data, .. } | Self::Reader { data, .. } => data,
376        }
377    }
378
379    /// 把上传源转换为 `reqwest::multipart::Part`。
380    ///
381    /// # Errors
382    ///
383    /// 当 MIME 类型非法时返回错误。
384    pub fn to_part(&self) -> Result<Part> {
385        let mut part = Part::bytes(self.bytes().to_vec()).file_name(self.filename().to_owned());
386
387        if let Some(mime_type) = self.mime_type() {
388            part = part
389                .mime_str(mime_type)
390                .map_err(|error| Error::InvalidConfig(format!("非法 MIME 类型: {error}")))?;
391        }
392
393        Ok(part)
394    }
395}
396
397impl fmt::Debug for UploadSource {
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        let mut debug = f.debug_struct("UploadSource");
400        debug.field("filename", &self.filename());
401        debug.field("mime_type", &self.mime_type());
402        debug.field("size", &self.bytes().len());
403
404        if let Self::Path { path, .. } = self {
405            debug.field("path", path);
406        }
407
408        debug.finish()
409    }
410}