llmservice_flows/
embeddings.rs

1use serde::Serialize;
2use urlencoding::encode;
3
4use crate::LLMApi;
5use crate::Retry;
6
7/// The input type for the embeddings.
8///
9/// For more detail about parameters, please refer to
10/// [OpenAI docs](https://platform.openai.com/docs/api-reference/embeddings/create)
11///
12#[derive(Debug, Serialize)]
13pub enum EmbeddingsInput {
14    String(String),
15    Vec(Vec<String>),
16}
17
18impl LLMApi for (Option<&str>, EmbeddingsInput) {
19    type Output = Vec<Vec<f64>>;
20    async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
21        create_embeddings_inner(endpoint, api_key, self.0, &self.1).await
22    }
23}
24
25impl<'a> crate::LLMServiceFlows<'a> {
26    /// Create embeddings from the provided input.
27    ///
28    /// `params` is an [EmbeddingsInput] object.
29    ///
30    ///```rust
31    ///   // This code snippet computes embeddings for `text`, the question created in previous step.
32    ///   // Wrap the `text` in EmbeddingsInput struct.
33    ///   let input = EmbeddingsInput::String(text.to_string());
34    ///   // Call the create_embeddings function.
35    ///   let question_vector = match llm.create_embeddings(Some("text-embedding-ada-002"), input).await {
36    ///       Ok(r) => r[0],
37    ///       Err(e) => {your error handling},
38    ///   };
39    /// ```
40
41    pub async fn create_embeddings(
42        &self,
43        model: Option<&str>,
44        input: EmbeddingsInput,
45    ) -> Result<Vec<Vec<f64>>, String> {
46        self.keep_trying((model, input)).await
47    }
48}
49
50async fn create_embeddings_inner(
51    endpoint: &str,
52    api_key: &str,
53    model: Option<&str>,
54    input: &EmbeddingsInput,
55) -> Retry<Vec<Vec<f64>>> {
56    let flows_user = unsafe { crate::_get_flows_user() };
57
58    let uri = format!(
59        "{}/{}/create_embeddings?endpoint={}&api_key={}&model={}",
60        crate::LLM_API_PREFIX.as_str(),
61        flows_user,
62        encode(endpoint),
63        encode(api_key),
64        encode(model.unwrap_or_default())
65    );
66    let body = match input {
67        EmbeddingsInput::String(ref s) => serde_json::to_vec(&s).unwrap_or_default(),
68        EmbeddingsInput::Vec(ref v) => serde_json::to_vec(&v).unwrap_or_default(),
69    };
70    match reqwest::Client::new()
71        .post(uri)
72        .header("Content-Type", "application/json")
73        .header("Content-Length", body.len())
74        .body(body)
75        .send()
76        .await
77    {
78        Ok(res) => {
79            let status = res.status();
80            let body = res.bytes().await.unwrap();
81            match status.is_success() {
82                true => Retry::No(
83                    serde_json::from_slice::<Vec<Vec<f64>>>(&body.as_ref())
84                        .or(Err(String::from("Unexpected error"))),
85                ),
86                false => {
87                    match status.into() {
88                        409 | 429 | 503 => {
89                            // 409 TryAgain 429 RateLimitError
90                            // 503 ServiceUnavailable
91                            Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
92                        }
93                        _ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
94                    }
95                }
96            }
97        }
98        Err(e) => Retry::No(Err(e.to_string())),
99    }
100}