rig/providers/together/
embedding.rs

1// ================================================================
2//! Together AI Embeddings Integration
3//! From [Together AI Reference](https://docs.together.ai/docs/embeddings-overview)
4// ================================================================
5
6use serde::Deserialize;
7use serde_json::json;
8
9use crate::{
10    embeddings::{self, EmbeddingError},
11    http_client::{self, HttpClientExt},
12};
13
14use super::{
15    Client,
16    client::together_ai_api_types::{ApiErrorResponse, ApiResponse},
17};
18
19// ================================================================
20// Together AI Embedding API
21// ================================================================
22pub const BGE_BASE_EN_V1_5: &str = "BAAI/bge-base-en-v1.5";
23pub const BGE_LARGE_EN_V1_5: &str = "BAAI/bge-large-en-v1.5";
24pub const BERT_BASE_UNCASED: &str = "bert-base-uncased";
25pub const M2_BERT_2K_RETRIEVAL_ENCODER_V1: &str = "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1";
26pub const M2_BERT_80M_32K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-32k-retrieval";
27pub const M2_BERT_80M_2K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-2k-retrieval";
28pub const M2_BERT_80M_8K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-8k-retrieval";
29pub const SENTENCE_BERT: &str = "sentence-transformers/msmarco-bert-base-dot-v5";
30pub const UAE_LARGE_V1: &str = "WhereIsAI/UAE-Large-V1";
31
32#[derive(Debug, Deserialize)]
33pub struct EmbeddingResponse {
34    pub model: String,
35    pub object: String,
36    pub data: Vec<EmbeddingData>,
37}
38
39impl From<ApiErrorResponse> for EmbeddingError {
40    fn from(err: ApiErrorResponse) -> Self {
41        EmbeddingError::ProviderError(err.message())
42    }
43}
44
45impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
46    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
47        match value {
48            ApiResponse::Ok(response) => Ok(response),
49            ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
50        }
51    }
52}
53
54#[derive(Debug, Deserialize)]
55pub struct EmbeddingData {
56    pub object: String,
57    pub embedding: Vec<f64>,
58    pub index: usize,
59}
60
61#[derive(Debug, Deserialize)]
62pub struct Usage {
63    pub prompt_tokens: usize,
64    pub total_tokens: usize,
65}
66
67#[derive(Clone)]
68pub struct EmbeddingModel<T = reqwest::Client> {
69    client: Client<T>,
70    pub model: String,
71    ndims: usize,
72}
73
74impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
75where
76    T: HttpClientExt + Default + Clone + Send + 'static,
77{
78    const MAX_DOCUMENTS: usize = 1024; // This might need to be adjusted based on Together AI's actual limit
79
80    type Client = Client<T>;
81
82    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
83        Self::new(client.clone(), model, dims.unwrap_or_default())
84    }
85
86    fn ndims(&self) -> usize {
87        self.ndims
88    }
89
90    #[cfg_attr(feature = "worker", worker::send)]
91    async fn embed_texts(
92        &self,
93        documents: impl IntoIterator<Item = String>,
94    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
95        let documents = documents.into_iter().collect::<Vec<_>>();
96
97        let body = serde_json::to_vec(&json!({
98            "model": self.model,
99            "input": documents,
100        }))?;
101
102        let req = self
103            .client
104            .post("/v1/embeddings")?
105            .body(body)
106            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
107
108        let response = self.client.send(req).await?;
109
110        if response.status().is_success() {
111            let body: Vec<u8> = response.into_body().await?;
112            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
113
114            match body {
115                ApiResponse::Ok(response) => {
116                    if response.data.len() != documents.len() {
117                        return Err(EmbeddingError::ResponseError(
118                            "Response data length does not match input length".into(),
119                        ));
120                    }
121
122                    Ok(response
123                        .data
124                        .into_iter()
125                        .zip(documents.into_iter())
126                        .map(|(embedding, document)| embeddings::Embedding {
127                            document,
128                            vec: embedding.embedding,
129                        })
130                        .collect())
131                }
132                ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
133            }
134        } else {
135            let text = http_client::text(response).await?;
136            Err(EmbeddingError::ProviderError(text))
137        }
138    }
139}
140
141impl<T> EmbeddingModel<T>
142where
143    T: Default,
144{
145    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
146        Self {
147            client,
148            model: model.into(),
149            ndims,
150        }
151    }
152}