alith_interface/requests/embeddings/
request.rs

1use crate::{llms::LLMBackend, requests::req_components::RequestConfig};
2use std::sync::Arc;
3
4use super::{EmbeddingsError, response::EmbeddingsResponse};
5
6pub struct EmbeddingsRequest {
7    pub model: String,
8    pub input: Vec<String>,
9    backend: Arc<LLMBackend>,
10    config: RequestConfig,
11    llm_interface_errors: Vec<EmbeddingsError>,
12}
13
14impl Clone for EmbeddingsRequest {
15    fn clone(&self) -> Self {
16        Self {
17            model: self.model.clone(),
18            input: self.input.clone(),
19            backend: self.backend.clone(),
20            config: self.config.clone(),
21            llm_interface_errors: Vec::new(),
22        }
23    }
24}
25
26impl EmbeddingsRequest {
27    pub fn new(backend: Arc<LLMBackend>) -> EmbeddingsRequest {
28        EmbeddingsRequest {
29            model: String::new(),
30            input: Vec::new(),
31            backend: Arc::clone(&backend),
32            config: RequestConfig::new(backend.model_ctx_size(), backend.inference_ctx_size()),
33            llm_interface_errors: Vec::new(),
34        }
35    }
36
37    pub fn reset_embedding_request(&mut self) {
38        self.input = Vec::new();
39    }
40
41    pub async fn request(&mut self) -> crate::Result<EmbeddingsResponse, EmbeddingsError> {
42        self.llm_interface_errors.clear();
43        let mut retry_count: u8 = 0;
44        loop {
45            if retry_count >= self.config.retry_after_fail_n_times {
46                let llm_interface_error = EmbeddingsError::ExceededRetryCount {
47                    message: format!("Request failed after {retry_count} attempts."),
48                    errors: std::mem::take(&mut self.llm_interface_errors),
49                };
50                tracing::error!(?llm_interface_error);
51                eprintln!("{}", llm_interface_error);
52                return Err(llm_interface_error);
53            }
54            tracing::info!("{}", self);
55            match self.backend.embeddings_request(self).await {
56                Err(e) => {
57                    tracing::warn!(?e);
58                    retry_count += 1;
59                    match e {
60                        EmbeddingsError::RequestBuilderError { .. }
61                        | EmbeddingsError::ClientError { .. } => {
62                            return Err(e);
63                        }
64
65                        _ => (),
66                    }
67                    self.llm_interface_errors.push(e);
68                    continue;
69                }
70                Ok(res) => {
71                    tracing::info!("{:?}", res);
72                    return Ok(res);
73                }
74            };
75        }
76    }
77
78    pub fn set_input(&mut self, input: Vec<String>) {
79        self.input = input;
80    }
81}
82
83impl std::fmt::Display for EmbeddingsRequest {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        writeln!(f)?;
86        writeln!(f, "CompletionRequest:")?;
87        writeln!(f, "  input: {:?}", self.input)
88    }
89}