rig/providers/together/
embedding.rs1use serde::Deserialize;
7use serde_json::json;
8
9use crate::{
10 embeddings::{self, EmbeddingError},
11 http_client::{self, HttpClientExt},
12};
13
14use super::{
15 Client,
16 client::together_ai_api_types::{ApiErrorResponse, ApiResponse},
17};
18
19pub const BGE_BASE_EN_V1_5: &str = "BAAI/bge-base-en-v1.5";
24pub const BGE_LARGE_EN_V1_5: &str = "BAAI/bge-large-en-v1.5";
25pub const BERT_BASE_UNCASED: &str = "bert-base-uncased";
26pub const M2_BERT_2K_RETRIEVAL_ENCODER_V1: &str = "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1";
27pub const M2_BERT_80M_32K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-32k-retrieval";
28pub const M2_BERT_80M_2K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-2k-retrieval";
29pub const M2_BERT_80M_8K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-8k-retrieval";
30pub const SENTENCE_BERT: &str = "sentence-transformers/msmarco-bert-base-dot-v5";
31pub const UAE_LARGE_V1: &str = "WhereIsAI/UAE-Large-V1";
32
33#[derive(Debug, Deserialize)]
34pub struct EmbeddingResponse {
35 pub model: String,
36 pub object: String,
37 pub data: Vec<EmbeddingData>,
38}
39
40impl From<ApiErrorResponse> for EmbeddingError {
41 fn from(err: ApiErrorResponse) -> Self {
42 EmbeddingError::ProviderError(err.message())
43 }
44}
45
46impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
47 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
48 match value {
49 ApiResponse::Ok(response) => Ok(response),
50 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
51 }
52 }
53}
54
55#[derive(Debug, Deserialize)]
56pub struct EmbeddingData {
57 pub object: String,
58 pub embedding: Vec<f64>,
59 pub index: usize,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct Usage {
64 pub prompt_tokens: usize,
65 pub total_tokens: usize,
66}
67
68#[derive(Clone)]
69pub struct EmbeddingModel<T = reqwest::Client> {
70 client: Client<T>,
71 pub model: String,
72 ndims: usize,
73}
74
75impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
76where
77 T: HttpClientExt + Default + Clone + Send + 'static,
78{
79 const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize {
82 self.ndims
83 }
84
85 #[cfg_attr(feature = "worker", worker::send)]
86 async fn embed_texts(
87 &self,
88 documents: impl IntoIterator<Item = String>,
89 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
90 let documents = documents.into_iter().collect::<Vec<_>>();
91
92 let body = serde_json::to_vec(&json!({
93 "model": self.model,
94 "input": documents,
95 }))?;
96
97 let req = self
98 .client
99 .post("/v1/embeddings")?
100 .header("Content-Type", "application/json")
101 .body(body)
102 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
103
104 let response = self.client.send(req).await?;
105
106 if response.status().is_success() {
107 let body: Vec<u8> = response.into_body().await?;
108 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
109
110 match body {
111 ApiResponse::Ok(response) => {
112 if response.data.len() != documents.len() {
113 return Err(EmbeddingError::ResponseError(
114 "Response data length does not match input length".into(),
115 ));
116 }
117
118 Ok(response
119 .data
120 .into_iter()
121 .zip(documents.into_iter())
122 .map(|(embedding, document)| embeddings::Embedding {
123 document,
124 vec: embedding.embedding,
125 })
126 .collect())
127 }
128 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
129 }
130 } else {
131 let text = http_client::text(response).await?;
132 Err(EmbeddingError::ProviderError(text))
133 }
134 }
135}
136
137impl<T> EmbeddingModel<T>
138where
139 T: Default,
140{
141 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
142 Self {
143 client,
144 model: model.to_string(),
145 ndims,
146 }
147 }
148}