rig/providers/together/
embedding.rs

1// ================================================================
2//! Together AI Embeddings Integration
3//! From [Together AI Reference](https://docs.together.ai/docs/embeddings-overview)
4// ================================================================
5
6use serde::Deserialize;
7use serde_json::json;
8
9use crate::{
10    embeddings::{self, EmbeddingError},
11    http_client::{self, HttpClientExt},
12};
13
14use super::{
15    Client,
16    client::together_ai_api_types::{ApiErrorResponse, ApiResponse},
17};
18
19// ================================================================
20// Together AI Embedding API
21// ================================================================
22
23pub const BGE_BASE_EN_V1_5: &str = "BAAI/bge-base-en-v1.5";
24pub const BGE_LARGE_EN_V1_5: &str = "BAAI/bge-large-en-v1.5";
25pub const BERT_BASE_UNCASED: &str = "bert-base-uncased";
26pub const M2_BERT_2K_RETRIEVAL_ENCODER_V1: &str = "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1";
27pub const M2_BERT_80M_32K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-32k-retrieval";
28pub const M2_BERT_80M_2K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-2k-retrieval";
29pub const M2_BERT_80M_8K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-8k-retrieval";
30pub const SENTENCE_BERT: &str = "sentence-transformers/msmarco-bert-base-dot-v5";
31pub const UAE_LARGE_V1: &str = "WhereIsAI/UAE-Large-V1";
32
33#[derive(Debug, Deserialize)]
34pub struct EmbeddingResponse {
35    pub model: String,
36    pub object: String,
37    pub data: Vec<EmbeddingData>,
38}
39
40impl From<ApiErrorResponse> for EmbeddingError {
41    fn from(err: ApiErrorResponse) -> Self {
42        EmbeddingError::ProviderError(err.message())
43    }
44}
45
46impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
47    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
48        match value {
49            ApiResponse::Ok(response) => Ok(response),
50            ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
51        }
52    }
53}
54
55#[derive(Debug, Deserialize)]
56pub struct EmbeddingData {
57    pub object: String,
58    pub embedding: Vec<f64>,
59    pub index: usize,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct Usage {
64    pub prompt_tokens: usize,
65    pub total_tokens: usize,
66}
67
68#[derive(Clone)]
69pub struct EmbeddingModel<T = reqwest::Client> {
70    client: Client<T>,
71    pub model: String,
72    ndims: usize,
73}
74
75impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
76where
77    T: HttpClientExt + Default + Clone + Send + 'static,
78{
79    const MAX_DOCUMENTS: usize = 1024; // This might need to be adjusted based on Together AI's actual limit
80
81    fn ndims(&self) -> usize {
82        self.ndims
83    }
84
85    #[cfg_attr(feature = "worker", worker::send)]
86    async fn embed_texts(
87        &self,
88        documents: impl IntoIterator<Item = String>,
89    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
90        let documents = documents.into_iter().collect::<Vec<_>>();
91
92        let body = serde_json::to_vec(&json!({
93            "model": self.model,
94            "input": documents,
95        }))?;
96
97        let req = self
98            .client
99            .post("/v1/embeddings")?
100            .header("Content-Type", "application/json")
101            .body(body)
102            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
103
104        let response = self.client.send(req).await?;
105
106        if response.status().is_success() {
107            let body: Vec<u8> = response.into_body().await?;
108            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
109
110            match body {
111                ApiResponse::Ok(response) => {
112                    if response.data.len() != documents.len() {
113                        return Err(EmbeddingError::ResponseError(
114                            "Response data length does not match input length".into(),
115                        ));
116                    }
117
118                    Ok(response
119                        .data
120                        .into_iter()
121                        .zip(documents.into_iter())
122                        .map(|(embedding, document)| embeddings::Embedding {
123                            document,
124                            vec: embedding.embedding,
125                        })
126                        .collect())
127                }
128                ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
129            }
130        } else {
131            let text = http_client::text(response).await?;
132            Err(EmbeddingError::ProviderError(text))
133        }
134    }
135}
136
137impl<T> EmbeddingModel<T>
138where
139    T: Default,
140{
141    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
142        Self {
143            client,
144            model: model.to_string(),
145            ndims,
146        }
147    }
148}