openai_orch/
embed.rs

1//! Requests and responses using Embeddings models.
2
3use anyhow::{Error, Result};
4use async_openai::types::CreateEmbeddingRequest;
5use async_trait::async_trait;
6use log::{debug, error};
7use tokio::time::timeout;
8
9use crate::{
10  keys::Keys, policies::Policies, utils::get_openai_client, OrchRequest,
11  ResponseType,
12};
13
14pub const EMBEDDING_SIZE: usize = 1536;
15
16pub struct EmbeddingRequest(pub String);
17pub struct EmbeddingResponse(pub [f32; EMBEDDING_SIZE]);
18
19impl ResponseType for EmbeddingResponse {}
20
21#[async_trait]
22impl OrchRequest for EmbeddingRequest {
23  type Res = EmbeddingResponse;
24  async fn send(
25    &self,
26    policies: Policies,
27    keys: Keys,
28    id: u64,
29  ) -> Result<Self::Res> {
30    debug!("starting request {}", id);
31    let client = get_openai_client(&keys);
32    let mut retry_policy = policies.retry_policy;
33
34    let request = CreateEmbeddingRequest {
35      model: "text-embedding-ada-002".to_string(),
36      input: async_openai::types::EmbeddingInput::String(self.0.to_string()),
37      user:  None,
38    };
39
40    // continue trying until we get a response or we reach max retry
41    loop {
42      let timer = timing::start();
43      let response = timeout(
44        policies.timeout_policy.timeout,
45        client.embeddings().create(request.clone()),
46      )
47      .await;
48
49      let response = match response {
50        Ok(response) => response,
51        Err(err) => {
52          debug!(
53            "request {} timed out after {}s",
54            id,
55            policies.timeout_policy.timeout.as_secs_f32()
56          );
57          if retry_policy.failed_request().await {
58            continue;
59          } else {
60            error!("request {} reached max retry", id);
61            return Err(Error::new(err).context("reached max retry"));
62          }
63        }
64      };
65
66      // if we got a response, we need to check if it's an error
67      let response = match response {
68        Ok(response) => response,
69        Err(err) => {
70          if retry_policy.failed_request().await {
71            continue;
72          } else {
73            return Err(Error::new(err).context("reached max retry"));
74          }
75        }
76      };
77
78      debug!(
79        "got response for {} in {}",
80        id,
81        timer.elapsed().as_secs_f32()
82      );
83      let embedding = response
84        .data
85        .first()
86        .ok_or(Error::msg("response.data is empty"))?
87        .embedding
88        .clone();
89
90      return Ok(EmbeddingResponse(embedding.as_slice().try_into()?));
91    }
92  }
93}