llm_sdk/api/
embedding.rs

1use crate::IntoRequest;
2use derive_builder::Builder;
3use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Builder)]
7#[builder(pattern = "mutable")]
8pub struct EmbeddingRequest {
9    /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), cannot be an empty string, and any array must be 2048 dimensions or less.
10    input: EmbeddingInput,
11    /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
12    #[builder(default)]
13    model: EmbeddingModel,
14    /// The format to return the embeddings in. Can be either float or base64.
15    #[builder(default, setter(strip_option))]
16    #[serde(skip_serializing_if = "Option::is_none")]
17    encoding_format: Option<EmbeddingEncodingFormat>,
18    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.
19    #[builder(default, setter(strip_option, into))]
20    #[serde(skip_serializing_if = "Option::is_none")]
21    user: Option<String>,
22}
23
24// currently we don't support array of integers, or array of array of integers
25#[derive(Debug, Clone, Serialize)]
26#[serde(untagged)]
27pub enum EmbeddingInput {
28    String(String),
29    StringArray(Vec<String>),
30}
31
32#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
33pub enum EmbeddingModel {
34    #[default]
35    #[serde(rename = "text-embedding-ada-002")]
36    TextEmbeddingAda002,
37}
38
39#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
40#[serde(rename_all = "snake_case")]
41pub enum EmbeddingEncodingFormat {
42    #[default]
43    Float,
44    Base64,
45}
46
47#[derive(Debug, Clone, Deserialize)]
48pub struct EmbeddingResponse {
49    pub object: String,
50    pub data: Vec<EmbeddingData>,
51    pub model: String,
52    pub usage: EmbeddingUsage,
53}
54
55#[derive(Debug, Clone, Deserialize)]
56pub struct EmbeddingUsage {
57    pub prompt_tokens: usize,
58    pub total_tokens: usize,
59}
60
61#[derive(Debug, Clone, Deserialize)]
62pub struct EmbeddingData {
63    /// The index of the embedding in the list of embeddings.
64    pub index: usize,
65    /// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide.
66    pub embedding: Vec<f32>,
67    /// The object type, which is always "embedding".
68    pub object: String,
69}
70
71impl IntoRequest for EmbeddingRequest {
72    fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
73        let url = format!("{}/embeddings", base_url);
74        client.post(url).json(&self)
75    }
76}
77
78impl EmbeddingRequest {
79    pub fn new(input: impl Into<EmbeddingInput>) -> Self {
80        EmbeddingRequestBuilder::default()
81            .input(input.into())
82            .build()
83            .unwrap()
84    }
85
86    pub fn new_array(input: Vec<String>) -> Self {
87        EmbeddingRequestBuilder::default()
88            .input(input.into())
89            .build()
90            .unwrap()
91    }
92}
93
94impl From<String> for EmbeddingInput {
95    fn from(s: String) -> Self {
96        Self::String(s)
97    }
98}
99
100impl From<Vec<String>> for EmbeddingInput {
101    fn from(s: Vec<String>) -> Self {
102        Self::StringArray(s)
103    }
104}
105
106impl From<&[String]> for EmbeddingInput {
107    fn from(s: &[String]) -> Self {
108        Self::StringArray(s.to_vec())
109    }
110}
111
112impl From<&str> for EmbeddingInput {
113    fn from(s: &str) -> Self {
114        Self::String(s.to_owned())
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::SDK;
122    use anyhow::Result;
123
124    #[tokio::test]
125    async fn string_embedding_should_work() -> Result<()> {
126        let req = EmbeddingRequest::new("The quick brown fox jumped over the lazy dog.");
127        let res = SDK.embedding(req).await?;
128        assert_eq!(res.data.len(), 1);
129        assert_eq!(res.object, "list");
130        // response model id is different
131        assert_eq!(res.model, "text-embedding-ada-002-v2");
132        let data = &res.data[0];
133        assert_eq!(data.embedding.len(), 1536);
134        assert_eq!(data.index, 0);
135        assert_eq!(data.object, "embedding");
136        Ok(())
137    }
138
139    #[tokio::test]
140    async fn array_string_embedding_should_work() -> Result<()> {
141        let req = EmbeddingRequest::new_array(vec![
142            "The quick brown fox jumped over the lazy dog.".into(),
143            "我是谁?宇宙有没有尽头?".into(),
144        ]);
145        let res = SDK.embedding(req).await?;
146        assert_eq!(res.data.len(), 2);
147        assert_eq!(res.object, "list");
148        // response model id is different
149        assert_eq!(res.model, "text-embedding-ada-002-v2");
150        let data = &res.data[1];
151        assert_eq!(data.embedding.len(), 1536);
152        assert_eq!(data.index, 1);
153        assert_eq!(data.object, "embedding");
154        Ok(())
155    }
156}