bep/providers/xai/
embedding.rs

1// ================================================================
2//! xAI Embeddings Integration
3//! From [xAI Reference](https://docs.x.ai/api/endpoints#create-embeddings)
4// ================================================================
5
6use serde::Deserialize;
7use serde_json::json;
8
9use crate::embeddings::{self, EmbeddingError};
10
11use super::{
12    client::xai_api_types::{ApiErrorResponse, ApiResponse},
13    Client,
14};
15
16// ================================================================
17// xAI Embedding API
18// ================================================================
19/// `v1` embedding model
20pub const EMBEDDING_V1: &str = "v1";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24    pub object: String,
25    pub data: Vec<EmbeddingData>,
26    pub model: String,
27    pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31    fn from(err: ApiErrorResponse) -> Self {
32        EmbeddingError::ProviderError(err.message())
33    }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38        match value {
39            ApiResponse::Ok(response) => Ok(response),
40            ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
41        }
42    }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47    pub object: String,
48    pub embedding: Vec<f64>,
49    pub index: usize,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct Usage {
54    pub prompt_tokens: usize,
55    pub total_tokens: usize,
56}
57
58#[derive(Clone)]
59pub struct EmbeddingModel {
60    client: Client,
61    pub model: String,
62    ndims: usize,
63}
64
65impl embeddings::EmbeddingModel for EmbeddingModel {
66    const MAX_DOCUMENTS: usize = 1024;
67
68    fn ndims(&self) -> usize {
69        self.ndims
70    }
71
72    async fn embed_texts(
73        &self,
74        documents: impl IntoIterator<Item = String>,
75    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
76        let documents = documents.into_iter().collect::<Vec<_>>();
77
78        let response = self
79            .client
80            .post("/v1/embeddings")
81            .json(&json!({
82                "model": self.model,
83                "input": documents,
84            }))
85            .send()
86            .await?;
87
88        if response.status().is_success() {
89            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
90                ApiResponse::Ok(response) => {
91                    if response.data.len() != documents.len() {
92                        return Err(EmbeddingError::ResponseError(
93                            "Response data length does not match input length".into(),
94                        ));
95                    }
96
97                    Ok(response
98                        .data
99                        .into_iter()
100                        .zip(documents.into_iter())
101                        .map(|(embedding, document)| embeddings::Embedding {
102                            document,
103                            vec: embedding.embedding,
104                        })
105                        .collect())
106                }
107                ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
108            }
109        } else {
110            Err(EmbeddingError::ProviderError(response.text().await?))
111        }
112    }
113}
114
115impl EmbeddingModel {
116    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
117        Self {
118            client,
119            model: model.to_string(),
120            ndims,
121        }
122    }
123}