bep/providers/xai/
embedding.rs1use serde::Deserialize;
7use serde_json::json;
8
9use crate::embeddings::{self, EmbeddingError};
10
11use super::{
12 client::xai_api_types::{ApiErrorResponse, ApiResponse},
13 Client,
14};
15
16pub const EMBEDDING_V1: &str = "v1";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24 pub object: String,
25 pub data: Vec<EmbeddingData>,
26 pub model: String,
27 pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31 fn from(err: ApiErrorResponse) -> Self {
32 EmbeddingError::ProviderError(err.message())
33 }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38 match value {
39 ApiResponse::Ok(response) => Ok(response),
40 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
41 }
42 }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47 pub object: String,
48 pub embedding: Vec<f64>,
49 pub index: usize,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct Usage {
54 pub prompt_tokens: usize,
55 pub total_tokens: usize,
56}
57
58#[derive(Clone)]
59pub struct EmbeddingModel {
60 client: Client,
61 pub model: String,
62 ndims: usize,
63}
64
65impl embeddings::EmbeddingModel for EmbeddingModel {
66 const MAX_DOCUMENTS: usize = 1024;
67
68 fn ndims(&self) -> usize {
69 self.ndims
70 }
71
72 async fn embed_texts(
73 &self,
74 documents: impl IntoIterator<Item = String>,
75 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
76 let documents = documents.into_iter().collect::<Vec<_>>();
77
78 let response = self
79 .client
80 .post("/v1/embeddings")
81 .json(&json!({
82 "model": self.model,
83 "input": documents,
84 }))
85 .send()
86 .await?;
87
88 if response.status().is_success() {
89 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
90 ApiResponse::Ok(response) => {
91 if response.data.len() != documents.len() {
92 return Err(EmbeddingError::ResponseError(
93 "Response data length does not match input length".into(),
94 ));
95 }
96
97 Ok(response
98 .data
99 .into_iter()
100 .zip(documents.into_iter())
101 .map(|(embedding, document)| embeddings::Embedding {
102 document,
103 vec: embedding.embedding,
104 })
105 .collect())
106 }
107 ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
108 }
109 } else {
110 Err(EmbeddingError::ProviderError(response.text().await?))
111 }
112 }
113}
114
115impl EmbeddingModel {
116 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
117 Self {
118 client,
119 model: model.to_string(),
120 ndims,
121 }
122 }
123}