rig/providers/together/
embedding.rs

1// ================================================================
2//! Together AI Embeddings Integration
3//! From [Together AI Reference](https://docs.together.ai/docs/embeddings-overview)
4// ================================================================
5
6use serde::Deserialize;
7use serde_json::json;
8
9use crate::embeddings::{self, EmbeddingError};
10
11use super::{
12    client::together_ai_api_types::{ApiErrorResponse, ApiResponse},
13    Client,
14};
15
16// ================================================================
17// Together AI Embedding API
18// ================================================================
19
20pub const BGE_BASE_EN_V1_5: &str = "BAAI/bge-base-en-v1.5";
21pub const BGE_LARGE_EN_V1_5: &str = "BAAI/bge-large-en-v1.5";
22pub const BERT_BASE_UNCASED: &str = "bert-base-uncased";
23pub const M2_BERT_2K_RETRIEVAL_ENCODER_V1: &str = "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1";
24pub const M2_BERT_80M_32K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-32k-retrieval";
25pub const M2_BERT_80M_2K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-2k-retrieval";
26pub const M2_BERT_80M_8K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-8k-retrieval";
27pub const SENTENCE_BERT: &str = "sentence-transformers/msmarco-bert-base-dot-v5";
28pub const UAE_LARGE_V1: &str = "WhereIsAI/UAE-Large-V1";
29
30#[derive(Debug, Deserialize)]
31pub struct EmbeddingResponse {
32    pub model: String,
33    pub object: String,
34    pub data: Vec<EmbeddingData>,
35}
36
37impl From<ApiErrorResponse> for EmbeddingError {
38    fn from(err: ApiErrorResponse) -> Self {
39        EmbeddingError::ProviderError(err.message())
40    }
41}
42
43impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
44    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
45        match value {
46            ApiResponse::Ok(response) => Ok(response),
47            ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
48        }
49    }
50}
51
52#[derive(Debug, Deserialize)]
53pub struct EmbeddingData {
54    pub object: String,
55    pub embedding: Vec<f64>,
56    pub index: usize,
57}
58
59#[derive(Debug, Deserialize)]
60pub struct Usage {
61    pub prompt_tokens: usize,
62    pub total_tokens: usize,
63}
64
65#[derive(Clone)]
66pub struct EmbeddingModel {
67    client: Client,
68    pub model: String,
69    ndims: usize,
70}
71
72impl embeddings::EmbeddingModel for EmbeddingModel {
73    const MAX_DOCUMENTS: usize = 1024; // This might need to be adjusted based on Together AI's actual limit
74
75    fn ndims(&self) -> usize {
76        self.ndims
77    }
78
79    #[cfg_attr(feature = "worker", worker::send)]
80    async fn embed_texts(
81        &self,
82        documents: impl IntoIterator<Item = String>,
83    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
84        let documents = documents.into_iter().collect::<Vec<_>>();
85
86        let response = self
87            .client
88            .post("/v1/embeddings")
89            .json(&json!({
90                "model": self.model,
91                "input": documents,
92            }))
93            .send()
94            .await?;
95
96        if response.status().is_success() {
97            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
98                ApiResponse::Ok(response) => {
99                    if response.data.len() != documents.len() {
100                        return Err(EmbeddingError::ResponseError(
101                            "Response data length does not match input length".into(),
102                        ));
103                    }
104
105                    Ok(response
106                        .data
107                        .into_iter()
108                        .zip(documents.into_iter())
109                        .map(|(embedding, document)| embeddings::Embedding {
110                            document,
111                            vec: embedding.embedding,
112                        })
113                        .collect())
114                }
115                ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
116            }
117        } else {
118            Err(EmbeddingError::ProviderError(response.text().await?))
119        }
120    }
121}
122
123impl EmbeddingModel {
124    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
125        Self {
126            client,
127            model: model.to_string(),
128            ndims,
129        }
130    }
131}