1use super::{openai_post, ApiResponseOrError, Credentials};
6use serde::{Deserialize, Serialize};
7
8#[derive(Serialize, Clone)]
9struct CreateEmbeddingsRequestBody<'a> {
10 model: &'a str,
11 input: Vec<&'a str>,
12 #[serde(skip_serializing_if = "str::is_empty")]
13 user: &'a str,
14}
15
16#[derive(Deserialize, Clone)]
17pub struct Embeddings {
18 pub data: Vec<Embedding>,
19 pub model: String,
20 pub usage: EmbeddingsUsage,
21}
22
23#[derive(Deserialize, Clone, Copy)]
24pub struct EmbeddingsUsage {
25 pub prompt_tokens: u32,
26 pub total_tokens: u32,
27}
28
29#[derive(Deserialize, Clone)]
30pub struct Embedding {
31 #[serde(rename = "embedding")]
32 pub vec: Vec<f64>,
33}
34
35impl Embeddings {
36 pub async fn create(
51 model: &str,
52 input: Vec<&str>,
53 user: &str,
54 credentials: Credentials,
55 ) -> ApiResponseOrError<Self> {
56 openai_post(
57 "embeddings",
58 &CreateEmbeddingsRequestBody { model, input, user },
59 Some(credentials),
60 )
61 .await
62 }
63
64 pub fn distances(&self) -> Vec<f64> {
65 let mut distances = Vec::new();
66 let mut last_embedding: Option<&Embedding> = None;
67
68 for embedding in &self.data {
69 if let Some(other) = last_embedding {
70 distances.push(embedding.distance(other));
71 }
72
73 last_embedding = Some(embedding);
74 }
75
76 distances
77 }
78}
79
80impl Embedding {
81 pub async fn create(
82 model: &str,
83 input: &str,
84 user: &str,
85 credentials: Credentials,
86 ) -> ApiResponseOrError<Self> {
87 let mut embeddings = Embeddings::create(model, vec![input], user, credentials).await?;
88 Ok(embeddings.data.swap_remove(0))
89 }
90
91 pub fn magnitude(&self) -> f64 {
92 self.vec.iter().map(|x| x * x).sum::<f64>().sqrt()
93 }
94
95 pub fn distance(&self, other: &Self) -> f64 {
96 let dot_product: f64 = self
97 .vec
98 .iter()
99 .zip(other.vec.iter())
100 .map(|(x, y)| x * y)
101 .sum();
102 let product_of_magnitudes = self.magnitude() * other.magnitude();
103
104 1.0 - dot_product / product_of_magnitudes
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use dotenvy::dotenv;
112
113 #[tokio::test]
114 async fn embeddings() {
115 dotenv().ok();
116 let credentials = Credentials::from_env();
117
118 let embeddings = Embeddings::create(
119 "text-embedding-ada-002",
120 vec!["The food was delicious and the waiter..."],
121 "",
122 credentials,
123 )
124 .await
125 .unwrap();
126
127 assert!(!embeddings.data.first().unwrap().vec.is_empty());
128 }
129
130 #[tokio::test]
131 async fn embedding() {
132 dotenv().ok();
133 let credentials = Credentials::from_env();
134
135 let embedding = Embedding::create(
136 "text-embedding-ada-002",
137 "The food was delicious and the waiter...",
138 "",
139 credentials,
140 )
141 .await
142 .unwrap();
143
144 assert!(!embedding.vec.is_empty());
145 }
146
147 #[test]
148 fn right_angle() {
149 let embeddings = Embeddings {
150 data: vec![
151 Embedding {
152 vec: vec![1.0, 0.0, 0.0],
153 },
154 Embedding {
155 vec: vec![0.0, 1.0, 0.0],
156 },
157 ],
158 model: "text-embedding-ada-002".to_string(),
159 usage: EmbeddingsUsage {
160 prompt_tokens: 0,
161 total_tokens: 0,
162 },
163 };
164 assert_eq!(embeddings.distances()[0], 1.0);
165 }
166
167 #[test]
168 fn non_right_angle() {
169 let embeddings = Embeddings {
170 data: vec![
171 Embedding {
172 vec: vec![1.0, 1.0, 0.0],
173 },
174 Embedding {
175 vec: vec![0.0, 1.0, 0.0],
176 },
177 ],
178 model: "text-embedding-ada-002".to_string(),
179 usage: EmbeddingsUsage {
180 prompt_tokens: 0,
181 total_tokens: 0,
182 },
183 };
184
185 assert_eq!(embeddings.distances()[0], 0.29289321881345254);
186 }
187}