rig/providers/xai/
embedding.rs

1// ================================================================
2//! xAI Embeddings Integration
3//! From [xAI Reference](https://docs.x.ai/api/endpoints#create-embeddings)
4// ================================================================
5
6use serde::Deserialize;
7use serde_json::json;
8
9use crate::embeddings::{self, EmbeddingError};
10
11use super::{
12    client::xai_api_types::{ApiErrorResponse, ApiResponse},
13    Client,
14};
15
16// ================================================================
17// xAI Embedding API
18// ================================================================
19/// `v1` embedding model
20pub const EMBEDDING_V1: &str = "v1";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24    pub object: String,
25    pub data: Vec<EmbeddingData>,
26    pub model: String,
27    pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31    fn from(err: ApiErrorResponse) -> Self {
32        EmbeddingError::ProviderError(err.message())
33    }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38        match value {
39            ApiResponse::Ok(response) => Ok(response),
40            ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
41        }
42    }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47    pub object: String,
48    pub embedding: Vec<f64>,
49    pub index: usize,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct Usage {
54    pub prompt_tokens: usize,
55    pub total_tokens: usize,
56}
57
58#[derive(Clone)]
59pub struct EmbeddingModel {
60    client: Client,
61    pub model: String,
62    ndims: usize,
63}
64
65impl embeddings::EmbeddingModel for EmbeddingModel {
66    const MAX_DOCUMENTS: usize = 1024;
67
68    fn ndims(&self) -> usize {
69        self.ndims
70    }
71
72    #[cfg_attr(feature = "worker", worker::send)]
73    async fn embed_texts(
74        &self,
75        documents: impl IntoIterator<Item = String>,
76    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
77        let documents = documents.into_iter().collect::<Vec<_>>();
78
79        let response = self
80            .client
81            .post("/v1/embeddings")
82            .json(&json!({
83                "model": self.model,
84                "input": documents,
85            }))
86            .send()
87            .await?;
88
89        if response.status().is_success() {
90            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
91                ApiResponse::Ok(response) => {
92                    if response.data.len() != documents.len() {
93                        return Err(EmbeddingError::ResponseError(
94                            "Response data length does not match input length".into(),
95                        ));
96                    }
97
98                    Ok(response
99                        .data
100                        .into_iter()
101                        .zip(documents.into_iter())
102                        .map(|(embedding, document)| embeddings::Embedding {
103                            document,
104                            vec: embedding.embedding,
105                        })
106                        .collect())
107                }
108                ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
109            }
110        } else {
111            Err(EmbeddingError::ProviderError(response.text().await?))
112        }
113    }
114}
115
116impl EmbeddingModel {
117    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
118        Self {
119            client,
120            model: model.to_string(),
121            ndims,
122        }
123    }
124}