Skip to main content

chat_core/chat/
embed.rs

1use crate::{
2    chat::{Chat, state::Embedded},
3    error::{ChatError, ChatFailure},
4    traits::EmbeddingsProvider,
5    types::{
6        callback::CallbackRetryContext, messages::Messages, metadata::Metadata,
7        response::EmbeddingsResponse,
8    },
9};
10
11impl<CP: EmbeddingsProvider> Chat<CP, Embedded> {
12    pub async fn embed(
13        &mut self,
14        messages: &mut Messages,
15    ) -> Result<EmbeddingsResponse, ChatFailure> {
16        if let Some(strategy) = self.before_strategy.as_mut() {
17            strategy(messages, None).await;
18        }
19        let max_retries = self.max_retries.unwrap_or(1);
20        let mut last_metadata: Option<Metadata> = None;
21        let mut last_err: Option<ChatError> = None;
22
23        for idx in 0..max_retries {
24            match self.model.embed(messages).await {
25                Ok(res) => {
26                    return Ok(res);
27                }
28                Err(failure) => {
29                    last_err = Some(failure.err.clone());
30                    if let Some(metadata) = failure.metadata.as_ref() {
31                        match &mut last_metadata {
32                            Some(existing) => {
33                                existing.extend(metadata);
34                            }
35                            None => {
36                                last_metadata = Some(metadata.clone());
37                            }
38                        }
39                    }
40
41                    let ctx = CallbackRetryContext { idx, failure };
42
43                    if let Some(strategy) = self.retry_strategy.as_mut() {
44                        strategy(messages, last_metadata.as_ref(), ctx).await
45                    }
46
47                    continue;
48                }
49            }
50        }
51
52        Err(ChatFailure {
53            err: last_err.unwrap_or(ChatError::RateLimited),
54            metadata: last_metadata,
55        })
56    }
57}