1use anyhow::{Error, Result};
4use async_openai::types::CreateEmbeddingRequest;
5use async_trait::async_trait;
6use log::{debug, error};
7use tokio::time::timeout;
8
9use crate::{
10 keys::Keys, policies::Policies, utils::get_openai_client, OrchRequest,
11 ResponseType,
12};
13
14pub const EMBEDDING_SIZE: usize = 1536;
15
16pub struct EmbeddingRequest(pub String);
17pub struct EmbeddingResponse(pub [f32; EMBEDDING_SIZE]);
18
19impl ResponseType for EmbeddingResponse {}
20
21#[async_trait]
22impl OrchRequest for EmbeddingRequest {
23 type Res = EmbeddingResponse;
24 async fn send(
25 &self,
26 policies: Policies,
27 keys: Keys,
28 id: u64,
29 ) -> Result<Self::Res> {
30 debug!("starting request {}", id);
31 let client = get_openai_client(&keys);
32 let mut retry_policy = policies.retry_policy;
33
34 let request = CreateEmbeddingRequest {
35 model: "text-embedding-ada-002".to_string(),
36 input: async_openai::types::EmbeddingInput::String(self.0.to_string()),
37 user: None,
38 };
39
40 loop {
42 let timer = timing::start();
43 let response = timeout(
44 policies.timeout_policy.timeout,
45 client.embeddings().create(request.clone()),
46 )
47 .await;
48
49 let response = match response {
50 Ok(response) => response,
51 Err(err) => {
52 debug!(
53 "request {} timed out after {}s",
54 id,
55 policies.timeout_policy.timeout.as_secs_f32()
56 );
57 if retry_policy.failed_request().await {
58 continue;
59 } else {
60 error!("request {} reached max retry", id);
61 return Err(Error::new(err).context("reached max retry"));
62 }
63 }
64 };
65
66 let response = match response {
68 Ok(response) => response,
69 Err(err) => {
70 if retry_policy.failed_request().await {
71 continue;
72 } else {
73 return Err(Error::new(err).context("reached max retry"));
74 }
75 }
76 };
77
78 debug!(
79 "got response for {} in {}",
80 id,
81 timer.elapsed().as_secs_f32()
82 );
83 let embedding = response
84 .data
85 .first()
86 .ok_or(Error::msg("response.data is empty"))?
87 .embedding
88 .clone();
89
90 return Ok(EmbeddingResponse(embedding.as_slice().try_into()?));
91 }
92 }
93}