rig/providers/xai/
embedding.rs1use serde::Deserialize;
7use serde_json::json;
8
9use crate::embeddings::{self, EmbeddingError};
10
11use super::{
12 client::xai_api_types::{ApiErrorResponse, ApiResponse},
13 Client,
14};
15
16pub const EMBEDDING_V1: &str = "v1";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24 pub object: String,
25 pub data: Vec<EmbeddingData>,
26 pub model: String,
27 pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31 fn from(err: ApiErrorResponse) -> Self {
32 EmbeddingError::ProviderError(err.message())
33 }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38 match value {
39 ApiResponse::Ok(response) => Ok(response),
40 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
41 }
42 }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47 pub object: String,
48 pub embedding: Vec<f64>,
49 pub index: usize,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct Usage {
54 pub prompt_tokens: usize,
55 pub total_tokens: usize,
56}
57
58#[derive(Clone)]
59pub struct EmbeddingModel {
60 client: Client,
61 pub model: String,
62 ndims: usize,
63}
64
65impl embeddings::EmbeddingModel for EmbeddingModel {
66 const MAX_DOCUMENTS: usize = 1024;
67
68 fn ndims(&self) -> usize {
69 self.ndims
70 }
71
72 #[cfg_attr(feature = "worker", worker::send)]
73 async fn embed_texts(
74 &self,
75 documents: impl IntoIterator<Item = String>,
76 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
77 let documents = documents.into_iter().collect::<Vec<_>>();
78
79 let response = self
80 .client
81 .post("/v1/embeddings")
82 .json(&json!({
83 "model": self.model,
84 "input": documents,
85 }))
86 .send()
87 .await?;
88
89 if response.status().is_success() {
90 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
91 ApiResponse::Ok(response) => {
92 if response.data.len() != documents.len() {
93 return Err(EmbeddingError::ResponseError(
94 "Response data length does not match input length".into(),
95 ));
96 }
97
98 Ok(response
99 .data
100 .into_iter()
101 .zip(documents.into_iter())
102 .map(|(embedding, document)| embeddings::Embedding {
103 document,
104 vec: embedding.embedding,
105 })
106 .collect())
107 }
108 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
109 }
110 } else {
111 Err(EmbeddingError::ProviderError(response.text().await?))
112 }
113 }
114}
115
116impl EmbeddingModel {
117 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
118 Self {
119 client,
120 model: model.to_string(),
121 ndims,
122 }
123 }
124}