1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use crate::IntoRequest;
use derive_builder::Builder;
use reqwest::{Client, RequestBuilder};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Builder)]
#[builder(pattern = "mutable")]
pub struct EmbeddingRequest {
    /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), cannot be an empty string, and any array must be 2048 dimensions or less.
    input: EmbeddingInput,
    /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
    #[builder(default)]
    model: EmbeddingModel,
    /// The format to return the embeddings in. Can be either float or base64.
    #[builder(default, setter(strip_option))]
    #[serde(skip_serializing_if = "Option::is_none")]
    encoding_format: Option<EmbeddingEncodingFormat>,
    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.
    #[builder(default, setter(strip_option, into))]
    #[serde(skip_serializing_if = "Option::is_none")]
    user: Option<String>,
}

// currently we don't support array of integers, or array of array of integers
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
    String(String),
    StringArray(Vec<String>),
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum EmbeddingModel {
    #[default]
    #[serde(rename = "text-embedding-ada-002")]
    TextEmbeddingAda002,
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbeddingEncodingFormat {
    #[default]
    Float,
    Base64,
}

#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
    pub object: String,
    pub data: Vec<EmbeddingData>,
    pub model: String,
    pub usage: EmbeddingUsage,
}

#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingUsage {
    pub prompt_tokens: usize,
    pub total_tokens: usize,
}

#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingData {
    /// The index of the embedding in the list of embeddings.
    pub index: usize,
    /// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide.
    pub embedding: Vec<f32>,
    /// The object type, which is always "embedding".
    pub object: String,
}

impl IntoRequest for EmbeddingRequest {
    fn into_request(self, base_url: &str, client: Client) -> RequestBuilder {
        let url = format!("{}/embeddings", base_url);
        client.post(url).json(&self)
    }
}

impl EmbeddingRequest {
    pub fn new(input: impl Into<EmbeddingInput>) -> Self {
        EmbeddingRequestBuilder::default()
            .input(input.into())
            .build()
            .unwrap()
    }

    pub fn new_array(input: Vec<String>) -> Self {
        EmbeddingRequestBuilder::default()
            .input(input.into())
            .build()
            .unwrap()
    }
}

impl From<String> for EmbeddingInput {
    fn from(s: String) -> Self {
        Self::String(s)
    }
}

impl From<Vec<String>> for EmbeddingInput {
    fn from(s: Vec<String>) -> Self {
        Self::StringArray(s)
    }
}

impl From<&[String]> for EmbeddingInput {
    fn from(s: &[String]) -> Self {
        Self::StringArray(s.to_vec())
    }
}

impl From<&str> for EmbeddingInput {
    fn from(s: &str) -> Self {
        Self::String(s.to_owned())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::SDK;
    use anyhow::Result;

    #[tokio::test]
    async fn string_embedding_should_work() -> Result<()> {
        let req = EmbeddingRequest::new("The quick brown fox jumped over the lazy dog.");
        let res = SDK.embedding(req).await?;
        assert_eq!(res.data.len(), 1);
        assert_eq!(res.object, "list");
        // response model id is different
        assert_eq!(res.model, "text-embedding-ada-002-v2");
        let data = &res.data[0];
        assert_eq!(data.embedding.len(), 1536);
        assert_eq!(data.index, 0);
        assert_eq!(data.object, "embedding");
        Ok(())
    }

    #[tokio::test]
    async fn array_string_embedding_should_work() -> Result<()> {
        let req = EmbeddingRequest::new_array(vec![
            "The quick brown fox jumped over the lazy dog.".into(),
            "我是谁?宇宙有没有尽头?".into(),
        ]);
        let res = SDK.embedding(req).await?;
        assert_eq!(res.data.len(), 2);
        assert_eq!(res.object, "list");
        // response model id is different
        assert_eq!(res.model, "text-embedding-ada-002-v2");
        let data = &res.data[1];
        assert_eq!(data.embedding.len(), 1536);
        assert_eq!(data.index, 1);
        assert_eq!(data.object, "embedding");
        Ok(())
    }
}