openai/
embeddings.rs

1//! Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms.
2//!
3//! Related guide: [Embeddings](https://beta.openai.com/docs/guides/embeddings)
4
5use super::{openai_post, ApiResponseOrError, Credentials};
6use serde::{Deserialize, Serialize};
7
8#[derive(Serialize, Clone)]
9struct CreateEmbeddingsRequestBody<'a> {
10    model: &'a str,
11    input: Vec<&'a str>,
12    #[serde(skip_serializing_if = "str::is_empty")]
13    user: &'a str,
14}
15
16#[derive(Deserialize, Clone)]
17pub struct Embeddings {
18    pub data: Vec<Embedding>,
19    pub model: String,
20    pub usage: EmbeddingsUsage,
21}
22
23#[derive(Deserialize, Clone, Copy)]
24pub struct EmbeddingsUsage {
25    pub prompt_tokens: u32,
26    pub total_tokens: u32,
27}
28
29#[derive(Deserialize, Clone)]
30pub struct Embedding {
31    #[serde(rename = "embedding")]
32    pub vec: Vec<f64>,
33}
34
35impl Embeddings {
36    /// Creates an embedding vector representing the input text.
37    ///
38    /// # Arguments
39    ///
40    /// * `model` - ID of the model to use.
41    ///   You can use the [List models](https://beta.openai.com/docs/api-reference/models/list)
42    ///   API to see all of your available models, or see our [Model overview](https://beta.openai.com/docs/models/overview)
43    ///   for descriptions of them.
44    /// * `input` - Input text to get embeddings for, encoded as a string or array of tokens.
45    ///   To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays.
46    ///   Each input must not exceed 8192 tokens in length.
47    /// * `user` - A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
48    ///   [Learn more](https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids).
49    /// * `credentials` - The OpenAI credentials.
50    pub async fn create(
51        model: &str,
52        input: Vec<&str>,
53        user: &str,
54        credentials: Credentials,
55    ) -> ApiResponseOrError<Self> {
56        openai_post(
57            "embeddings",
58            &CreateEmbeddingsRequestBody { model, input, user },
59            Some(credentials),
60        )
61        .await
62    }
63
64    pub fn distances(&self) -> Vec<f64> {
65        let mut distances = Vec::new();
66        let mut last_embedding: Option<&Embedding> = None;
67
68        for embedding in &self.data {
69            if let Some(other) = last_embedding {
70                distances.push(embedding.distance(other));
71            }
72
73            last_embedding = Some(embedding);
74        }
75
76        distances
77    }
78}
79
80impl Embedding {
81    pub async fn create(
82        model: &str,
83        input: &str,
84        user: &str,
85        credentials: Credentials,
86    ) -> ApiResponseOrError<Self> {
87        let mut embeddings = Embeddings::create(model, vec![input], user, credentials).await?;
88        Ok(embeddings.data.swap_remove(0))
89    }
90
91    pub fn magnitude(&self) -> f64 {
92        self.vec.iter().map(|x| x * x).sum::<f64>().sqrt()
93    }
94
95    pub fn distance(&self, other: &Self) -> f64 {
96        let dot_product: f64 = self
97            .vec
98            .iter()
99            .zip(other.vec.iter())
100            .map(|(x, y)| x * y)
101            .sum();
102        let product_of_magnitudes = self.magnitude() * other.magnitude();
103
104        1.0 - dot_product / product_of_magnitudes
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use dotenvy::dotenv;
112
113    #[tokio::test]
114    async fn embeddings() {
115        dotenv().ok();
116        let credentials = Credentials::from_env();
117
118        let embeddings = Embeddings::create(
119            "text-embedding-ada-002",
120            vec!["The food was delicious and the waiter..."],
121            "",
122            credentials,
123        )
124        .await
125        .unwrap();
126
127        assert!(!embeddings.data.first().unwrap().vec.is_empty());
128    }
129
130    #[tokio::test]
131    async fn embedding() {
132        dotenv().ok();
133        let credentials = Credentials::from_env();
134
135        let embedding = Embedding::create(
136            "text-embedding-ada-002",
137            "The food was delicious and the waiter...",
138            "",
139            credentials,
140        )
141        .await
142        .unwrap();
143
144        assert!(!embedding.vec.is_empty());
145    }
146
147    #[test]
148    fn right_angle() {
149        let embeddings = Embeddings {
150            data: vec![
151                Embedding {
152                    vec: vec![1.0, 0.0, 0.0],
153                },
154                Embedding {
155                    vec: vec![0.0, 1.0, 0.0],
156                },
157            ],
158            model: "text-embedding-ada-002".to_string(),
159            usage: EmbeddingsUsage {
160                prompt_tokens: 0,
161                total_tokens: 0,
162            },
163        };
164        assert_eq!(embeddings.distances()[0], 1.0);
165    }
166
167    #[test]
168    fn non_right_angle() {
169        let embeddings = Embeddings {
170            data: vec![
171                Embedding {
172                    vec: vec![1.0, 1.0, 0.0],
173                },
174                Embedding {
175                    vec: vec![0.0, 1.0, 0.0],
176                },
177            ],
178            model: "text-embedding-ada-002".to_string(),
179            usage: EmbeddingsUsage {
180                prompt_tokens: 0,
181                total_tokens: 0,
182            },
183        };
184
185        assert_eq!(embeddings.distances()[0], 0.29289321881345254);
186    }
187}