Skip to main content

linger_openai_sdk/
embeddings.rs

1use crate::error::LingerError;
2use crate::RequestId;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::BTreeMap;
6
7/// EN: Request body for `POST /v1/embeddings`.
8/// 中文:`POST /v1/embeddings` 的请求体。
9#[derive(Clone, Debug, Serialize, PartialEq)]
10#[non_exhaustive]
11pub struct CreateEmbeddingRequest {
12    /// EN: Model id used to create embeddings.
13    /// 中文:用于创建 embeddings 的模型 ID。
14    pub model: String,
15    /// EN: Input text or text array.
16    /// 中文:输入文本或文本数组。
17    pub input: EmbeddingInput,
18    /// EN: Optional embedding encoding format.
19    /// 中文:可选的 embedding 编码格式。
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub encoding_format: Option<EmbeddingEncodingFormat>,
22    /// EN: Optional output vector dimension count for supported models.
23    /// 中文:支持的模型可选输出向量维度数。
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub dimensions: Option<u32>,
26    /// EN: Optional end-user identifier.
27    /// 中文:可选的终端用户标识。
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub user: Option<String>,
30    /// EN: Forward-compatible optional fields not yet covered by handwritten types.
31    /// 中文:手写类型尚未覆盖的前向兼容可选字段。
32    #[serde(flatten)]
33    pub extra: BTreeMap<String, Value>,
34}
35
36impl CreateEmbeddingRequest {
37    /// EN: Starts building an embeddings request.
38    /// 中文:开始构建 embeddings 请求。
39    pub fn builder() -> CreateEmbeddingRequestBuilder {
40        CreateEmbeddingRequestBuilder::default()
41    }
42}
43
44/// EN: Builder for create-embedding requests.
45/// 中文:创建 embedding 请求的构建器。
46#[derive(Clone, Debug, Default)]
47#[non_exhaustive]
48pub struct CreateEmbeddingRequestBuilder {
49    model: Option<String>,
50    input: Option<EmbeddingInput>,
51    encoding_format: Option<EmbeddingEncodingFormat>,
52    dimensions: Option<u32>,
53    user: Option<String>,
54    extra: BTreeMap<String, Value>,
55}
56
57impl CreateEmbeddingRequestBuilder {
58    /// EN: Sets the embedding model id.
59    /// 中文:设置 embedding 模型 ID。
60    pub fn model(mut self, model: impl Into<String>) -> Self {
61        self.model = Some(model.into());
62        self
63    }
64
65    /// EN: Sets the embedding input.
66    /// 中文:设置 embedding 输入。
67    pub fn input(mut self, input: impl Into<EmbeddingInput>) -> Self {
68        self.input = Some(input.into());
69        self
70    }
71
72    /// EN: Sets the embedding encoding format.
73    /// 中文:设置 embedding 编码格式。
74    pub fn encoding_format(mut self, encoding_format: EmbeddingEncodingFormat) -> Self {
75        self.encoding_format = Some(encoding_format);
76        self
77    }
78
79    /// EN: Sets the output vector dimensions.
80    /// 中文:设置输出向量维度。
81    pub fn dimensions(mut self, dimensions: u32) -> Self {
82        self.dimensions = Some(dimensions);
83        self
84    }
85
86    /// EN: Sets the optional end-user identifier.
87    /// 中文:设置可选的终端用户标识。
88    pub fn user(mut self, user: impl Into<String>) -> Self {
89        self.user = Some(user.into());
90        self
91    }
92
93    /// EN: Adds a forward-compatible JSON field.
94    /// 中文:添加前向兼容的 JSON 字段。
95    pub fn extra(mut self, name: impl Into<String>, value: Value) -> Self {
96        self.extra.insert(name.into(), value);
97        self
98    }
99
100    /// EN: Builds and validates the request.
101    /// 中文:构建并校验请求。
102    pub fn build(self) -> Result<CreateEmbeddingRequest, LingerError> {
103        let model = self
104            .model
105            .filter(|value| !value.trim().is_empty())
106            .ok_or_else(|| LingerError::invalid_config("model is required"))?;
107        let input = self
108            .input
109            .ok_or_else(|| LingerError::invalid_config("input is required"))?;
110        if input.is_empty() {
111            return Err(LingerError::invalid_config("input must not be empty"));
112        }
113        Ok(CreateEmbeddingRequest {
114            model,
115            input,
116            encoding_format: self.encoding_format,
117            dimensions: self.dimensions,
118            user: self.user,
119            extra: self.extra,
120        })
121    }
122}
123
124/// EN: Embeddings API input value.
125/// 中文:Embeddings API 输入值。
126#[derive(Clone, Debug, Serialize, PartialEq, Eq)]
127#[serde(untagged)]
128#[non_exhaustive]
129pub enum EmbeddingInput {
130    /// EN: Single text input.
131    /// 中文:单条文本输入。
132    Text(String),
133    /// EN: Multiple text inputs.
134    /// 中文:多条文本输入。
135    Texts(Vec<String>),
136}
137
138impl EmbeddingInput {
139    fn is_empty(&self) -> bool {
140        match self {
141            Self::Text(value) => value.is_empty(),
142            Self::Texts(values) => values.is_empty() || values.iter().any(String::is_empty),
143        }
144    }
145}
146
147impl From<&str> for EmbeddingInput {
148    fn from(value: &str) -> Self {
149        Self::Text(value.to_string())
150    }
151}
152
153impl From<String> for EmbeddingInput {
154    fn from(value: String) -> Self {
155        Self::Text(value)
156    }
157}
158
159impl From<Vec<String>> for EmbeddingInput {
160    fn from(value: Vec<String>) -> Self {
161        Self::Texts(value)
162    }
163}
164
165/// EN: Embedding encoding format.
166/// 中文:Embedding 编码格式。
167#[derive(Clone, Copy, Debug, Serialize, PartialEq, Eq)]
168#[serde(rename_all = "snake_case")]
169#[non_exhaustive]
170pub enum EmbeddingEncodingFormat {
171    /// EN: Floating-point vector output.
172    /// 中文:浮点向量输出。
173    Float,
174    /// EN: Base64-encoded vector output.
175    /// 中文:Base64 编码的向量输出。
176    Base64,
177}
178
179/// EN: Response object returned by the Embeddings API.
180/// 中文:Embeddings API 返回的响应对象。
181#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
182#[non_exhaustive]
183pub struct EmbeddingResponse {
184    /// EN: API list object type.
185    /// 中文:API 列表对象类型。
186    pub object: String,
187    /// EN: Embedding items.
188    /// 中文:Embedding 项。
189    #[serde(default)]
190    pub data: Vec<Embedding>,
191    /// EN: Model used to create the embeddings.
192    /// 中文:用于创建 embeddings 的模型。
193    pub model: String,
194    /// EN: Token usage.
195    /// 中文:Token 用量。
196    pub usage: EmbeddingUsage,
197    /// EN: OpenAI request id from response headers.
198    /// 中文:响应头中的 OpenAI 请求 ID。
199    #[serde(skip)]
200    request_id: Option<RequestId>,
201}
202
203impl EmbeddingResponse {
204    pub(crate) fn with_request_id(mut self, request_id: Option<RequestId>) -> Self {
205        self.request_id = request_id;
206        self
207    }
208
209    /// EN: Returns the OpenAI request id, when present.
210    /// 中文:返回 OpenAI 请求 ID,如存在。
211    pub fn request_id(&self) -> Option<&RequestId> {
212        self.request_id.as_ref()
213    }
214}
215
216/// EN: Single embedding item.
217/// 中文:单个 embedding 项。
218#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
219#[non_exhaustive]
220pub struct Embedding {
221    /// EN: API object type.
222    /// 中文:API 对象类型。
223    pub object: String,
224    /// EN: Embedding vector.
225    /// 中文:Embedding 向量。
226    pub embedding: Vec<f32>,
227    /// EN: Index of this embedding in the response.
228    /// 中文:该 embedding 在响应中的索引。
229    pub index: u32,
230}
231
232/// EN: Token usage for an embeddings request.
233/// 中文:Embeddings 请求的 token 用量。
234#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
235#[non_exhaustive]
236pub struct EmbeddingUsage {
237    /// EN: Prompt token count.
238    /// 中文:Prompt token 数量。
239    pub prompt_tokens: u64,
240    /// EN: Total token count.
241    /// 中文:总 token 数量。
242    pub total_tokens: u64,
243}