rig/providers/together/
embedding.rs1use serde::Deserialize;
7use serde_json::json;
8
9use crate::embeddings::{self, EmbeddingError};
10
11use super::{
12 client::together_ai_api_types::{ApiErrorResponse, ApiResponse},
13 Client,
14};
15
16pub const BGE_BASE_EN_V1_5: &str = "BAAI/bge-base-en-v1.5";
21pub const BGE_LARGE_EN_V1_5: &str = "BAAI/bge-large-en-v1.5";
22pub const BERT_BASE_UNCASED: &str = "bert-base-uncased";
23pub const M2_BERT_2K_RETRIEVAL_ENCODER_V1: &str = "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1";
24pub const M2_BERT_80M_32K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-32k-retrieval";
25pub const M2_BERT_80M_2K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-2k-retrieval";
26pub const M2_BERT_80M_8K_RETRIEVAL: &str = "togethercomputer/m2-bert-80M-8k-retrieval";
27pub const SENTENCE_BERT: &str = "sentence-transformers/msmarco-bert-base-dot-v5";
28pub const UAE_LARGE_V1: &str = "WhereIsAI/UAE-Large-V1";
29
30#[derive(Debug, Deserialize)]
31pub struct EmbeddingResponse {
32 pub model: String,
33 pub object: String,
34 pub data: Vec<EmbeddingData>,
35}
36
37impl From<ApiErrorResponse> for EmbeddingError {
38 fn from(err: ApiErrorResponse) -> Self {
39 EmbeddingError::ProviderError(err.message())
40 }
41}
42
43impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
44 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
45 match value {
46 ApiResponse::Ok(response) => Ok(response),
47 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
48 }
49 }
50}
51
52#[derive(Debug, Deserialize)]
53pub struct EmbeddingData {
54 pub object: String,
55 pub embedding: Vec<f64>,
56 pub index: usize,
57}
58
59#[derive(Debug, Deserialize)]
60pub struct Usage {
61 pub prompt_tokens: usize,
62 pub total_tokens: usize,
63}
64
65#[derive(Clone)]
66pub struct EmbeddingModel {
67 client: Client,
68 pub model: String,
69 ndims: usize,
70}
71
72impl embeddings::EmbeddingModel for EmbeddingModel {
73 const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize {
76 self.ndims
77 }
78
79 #[cfg_attr(feature = "worker", worker::send)]
80 async fn embed_texts(
81 &self,
82 documents: impl IntoIterator<Item = String>,
83 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
84 let documents = documents.into_iter().collect::<Vec<_>>();
85
86 let response = self
87 .client
88 .post("/v1/embeddings")
89 .json(&json!({
90 "model": self.model,
91 "input": documents,
92 }))
93 .send()
94 .await?;
95
96 if response.status().is_success() {
97 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
98 ApiResponse::Ok(response) => {
99 if response.data.len() != documents.len() {
100 return Err(EmbeddingError::ResponseError(
101 "Response data length does not match input length".into(),
102 ));
103 }
104
105 Ok(response
106 .data
107 .into_iter()
108 .zip(documents.into_iter())
109 .map(|(embedding, document)| embeddings::Embedding {
110 document,
111 vec: embedding.embedding,
112 })
113 .collect())
114 }
115 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
116 }
117 } else {
118 Err(EmbeddingError::ProviderError(response.text().await?))
119 }
120 }
121}
122
123impl EmbeddingModel {
124 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
125 Self {
126 client,
127 model: model.to_string(),
128 ndims,
129 }
130 }
131}