ai_sdk_core/embed/
single.rs1use super::EmbeddingConfig;
2use crate::error::EmbedError;
3use crate::retry::RetryPolicy;
4use crate::Result;
5use ai_sdk_provider::{EmbedOptions, EmbeddingModel, EmbeddingUsage};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub struct EmbedBuilder<M, V, Val> {
11 model: M,
12 value: V,
13 config: EmbeddingConfig,
14 _marker: PhantomData<Val>,
15}
16
17impl<Val> EmbedBuilder<(), (), Val>
19where
20 Val: Send + Sync + Clone + 'static,
21{
22 pub fn new() -> Self {
24 Self {
25 model: (),
26 value: (),
27 config: EmbeddingConfig::default(),
28 _marker: PhantomData,
29 }
30 }
31}
32
33impl<Val> Default for EmbedBuilder<(), (), Val>
34where
35 Val: Send + Sync + Clone + 'static,
36{
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl<M, V, Val> EmbedBuilder<M, V, Val> {
44 pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
46 self.config.retry_policy = retry_policy;
47 self
48 }
49}
50
51impl<V, Val> EmbedBuilder<(), V, Val>
53where
54 Val: Send + Sync + Clone + 'static,
55{
56 pub fn model<Mod: EmbeddingModel<Val> + 'static>(
58 self,
59 model: Mod,
60 ) -> EmbedBuilder<Arc<dyn EmbeddingModel<Val>>, V, Val> {
61 EmbedBuilder {
62 model: Arc::new(model),
63 value: self.value,
64 config: self.config,
65 _marker: PhantomData,
66 }
67 }
68}
69
70impl<M, Val> EmbedBuilder<M, (), Val> {
72 pub fn value(self, value: Val) -> EmbedBuilder<M, Val, Val> {
74 EmbedBuilder {
75 model: self.model,
76 value,
77 config: self.config,
78 _marker: PhantomData,
79 }
80 }
81}
82
83impl<Val> EmbedBuilder<Arc<dyn EmbeddingModel<Val>>, Val, Val>
85where
86 Val: Send + Sync + Clone + 'static,
87{
88 pub async fn execute(self) -> Result<EmbedResult<Val>> {
90 let model = self.model;
91 let value = self.value;
92 let config = self.config;
93
94 let response = config
96 .retry_policy
97 .retry(|| {
98 let options = EmbedOptions {
99 values: vec![value.clone()],
100 provider_options: None,
101 headers: None,
102 };
103 let model = model.clone();
104 async move { model.do_embed(options).await }
105 })
106 .await
107 .map_err(EmbedError::ProviderError)?;
108
109 let embedding = response
111 .embeddings
112 .into_iter()
113 .next()
114 .ok_or(EmbedError::EmptyResponse)?;
115
116 Ok(EmbedResult {
117 value,
118 embedding,
119 usage: response.usage,
120 })
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct EmbedResult<Val> {
127 pub value: Val,
129 pub embedding: Vec<f32>,
131 pub usage: Option<EmbeddingUsage>,
133}
134
135impl<Val> EmbedResult<Val> {
136 pub fn embedding(&self) -> &[f32] {
138 &self.embedding
139 }
140
141 pub fn value(&self) -> &Val {
143 &self.value
144 }
145
146 pub fn usage(&self) -> Option<&EmbeddingUsage> {
148 self.usage.as_ref()
149 }
150}
151
152pub fn embed<Val>() -> EmbedBuilder<(), (), Val>
154where
155 Val: Send + Sync + Clone + 'static,
156{
157 EmbedBuilder::new()
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_embed_builder_defaults() {
166 let builder = embed::<String>();
167 let _ = builder; }
169
170 #[tokio::test]
173 async fn test_embed_success() {
174 use ai_sdk_provider::{EmbedResponse, Result};
175 use async_trait::async_trait;
176
177 struct DummyModel;
178 #[async_trait]
179 impl EmbeddingModel<String> for DummyModel {
180 fn provider(&self) -> &str {
181 "test"
182 }
183 fn model_id(&self) -> &str {
184 "dummy"
185 }
186 async fn max_embeddings_per_call(&self) -> Option<usize> {
187 Some(100)
188 }
189 async fn supports_parallel_calls(&self) -> bool {
190 true
191 }
192 async fn do_embed(&self, options: EmbedOptions<String>) -> Result<EmbedResponse> {
193 assert_eq!(options.values.len(), 1);
194 assert_eq!(options.values[0], "test value");
195 Ok(EmbedResponse {
196 embeddings: vec![vec![0.1, 0.2, 0.3]],
197 usage: Some(EmbeddingUsage { tokens: 10 }),
198 provider_metadata: None,
199 response: None,
200 })
202 }
203 }
204
205 let result = embed()
206 .model(DummyModel)
207 .value("test value".to_string())
208 .execute()
209 .await
210 .unwrap();
211
212 assert_eq!(result.value(), "test value");
213 assert_eq!(result.embedding().len(), 3);
214 assert_eq!(result.embedding(), &[0.1, 0.2, 0.3]);
215 assert_eq!(result.usage().unwrap().tokens, 10);
216 }
217}