1use crate::IntoRequest;
2use derive_builder::Builder;
3use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Builder)]
7#[builder(pattern = "mutable")]
8pub struct EmbeddingRequest {
9 input: EmbeddingInput,
11 #[builder(default)]
13 model: EmbeddingModel,
14 #[builder(default, setter(strip_option))]
16 #[serde(skip_serializing_if = "Option::is_none")]
17 encoding_format: Option<EmbeddingEncodingFormat>,
18 #[builder(default, setter(strip_option, into))]
20 #[serde(skip_serializing_if = "Option::is_none")]
21 user: Option<String>,
22}
23
24#[derive(Debug, Clone, Serialize)]
26#[serde(untagged)]
27pub enum EmbeddingInput {
28 String(String),
29 StringArray(Vec<String>),
30}
31
32#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
33pub enum EmbeddingModel {
34 #[default]
35 #[serde(rename = "text-embedding-ada-002")]
36 TextEmbeddingAda002,
37}
38
39#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
40#[serde(rename_all = "snake_case")]
41pub enum EmbeddingEncodingFormat {
42 #[default]
43 Float,
44 Base64,
45}
46
47#[derive(Debug, Clone, Deserialize)]
48pub struct EmbeddingResponse {
49 pub object: String,
50 pub data: Vec<EmbeddingData>,
51 pub model: String,
52 pub usage: EmbeddingUsage,
53}
54
55#[derive(Debug, Clone, Deserialize)]
56pub struct EmbeddingUsage {
57 pub prompt_tokens: usize,
58 pub total_tokens: usize,
59}
60
61#[derive(Debug, Clone, Deserialize)]
62pub struct EmbeddingData {
63 pub index: usize,
65 pub embedding: Vec<f32>,
67 pub object: String,
69}
70
71impl IntoRequest for EmbeddingRequest {
72 fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
73 let url = format!("{}/embeddings", base_url);
74 client.post(url).json(&self)
75 }
76}
77
78impl EmbeddingRequest {
79 pub fn new(input: impl Into<EmbeddingInput>) -> Self {
80 EmbeddingRequestBuilder::default()
81 .input(input.into())
82 .build()
83 .unwrap()
84 }
85
86 pub fn new_array(input: Vec<String>) -> Self {
87 EmbeddingRequestBuilder::default()
88 .input(input.into())
89 .build()
90 .unwrap()
91 }
92}
93
94impl From<String> for EmbeddingInput {
95 fn from(s: String) -> Self {
96 Self::String(s)
97 }
98}
99
100impl From<Vec<String>> for EmbeddingInput {
101 fn from(s: Vec<String>) -> Self {
102 Self::StringArray(s)
103 }
104}
105
106impl From<&[String]> for EmbeddingInput {
107 fn from(s: &[String]) -> Self {
108 Self::StringArray(s.to_vec())
109 }
110}
111
112impl From<&str> for EmbeddingInput {
113 fn from(s: &str) -> Self {
114 Self::String(s.to_owned())
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::SDK;
122 use anyhow::Result;
123
124 #[tokio::test]
125 async fn string_embedding_should_work() -> Result<()> {
126 let req = EmbeddingRequest::new("The quick brown fox jumped over the lazy dog.");
127 let res = SDK.embedding(req).await?;
128 assert_eq!(res.data.len(), 1);
129 assert_eq!(res.object, "list");
130 assert_eq!(res.model, "text-embedding-ada-002-v2");
132 let data = &res.data[0];
133 assert_eq!(data.embedding.len(), 1536);
134 assert_eq!(data.index, 0);
135 assert_eq!(data.object, "embedding");
136 Ok(())
137 }
138
139 #[tokio::test]
140 async fn array_string_embedding_should_work() -> Result<()> {
141 let req = EmbeddingRequest::new_array(vec![
142 "The quick brown fox jumped over the lazy dog.".into(),
143 "我是谁?宇宙有没有尽头?".into(),
144 ]);
145 let res = SDK.embedding(req).await?;
146 assert_eq!(res.data.len(), 2);
147 assert_eq!(res.object, "list");
148 assert_eq!(res.model, "text-embedding-ada-002-v2");
150 let data = &res.data[1];
151 assert_eq!(data.embedding.len(), 1536);
152 assert_eq!(data.index, 1);
153 assert_eq!(data.object, "embedding");
154 Ok(())
155 }
156}