xai_grpc_client/
embedding.rs

1//! Embedding API for generating vector representations.
2//!
3//! This module provides access to xAI's embedding models, allowing you to:
4//! - Generate embeddings from text strings
5//! - Generate embeddings from images
6//! - Support for both text-only and multimodal embedding models
7//!
8//! # Examples
9//!
10//! ## Embedding text
11//!
12//! ```no_run
13//! use xai_grpc_client::{GrokClient, EmbedRequest};
14//!
15//! #[tokio::main]
16//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//!     let mut client = GrokClient::from_env().await?;
18//!
19//!     let request = EmbedRequest::new("embed-large-v1")
20//!         .add_text("Hello, world!")
21//!         .add_text("How are you?");
22//!
23//!     let response = client.embed(request).await?;
24//!
25//!     for embedding in response.embeddings {
26//!         println!("Embedding {} has {} dimensions",
27//!             embedding.index, embedding.vector.len());
28//!     }
29//!     Ok(())
30//! }
31//! ```
32//!
33//! ## Embedding images
34//!
35//! ```no_run
36//! use xai_grpc_client::{GrokClient, EmbedRequest};
37//!
38//! #[tokio::main]
39//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
40//!     let mut client = GrokClient::from_env().await?;
41//!
42//!     let request = EmbedRequest::new("embed-vision-v1")
43//!         .add_image("https://example.com/image.jpg");
44//!
45//!     let response = client.embed(request).await?;
46//!     println!("Generated {} embeddings", response.embeddings.len());
47//!     Ok(())
48//! }
49//! ```
50
51use crate::{proto, request::ImageDetail};
52
53/// Request for generating embeddings.
54///
55/// Supports embedding text strings, images, or a mix of both depending on
56/// the model capabilities. You can embed up to 128 inputs in a single request.
57#[derive(Clone, Debug)]
58pub struct EmbedRequest {
59    /// Inputs to embed (text or images).
60    pub inputs: Vec<EmbedInput>,
61    /// Model name or alias to use.
62    pub model: String,
63    /// Encoding format for the embeddings (Float or Base64).
64    pub encoding_format: EmbedEncodingFormat,
65    /// Optional user identifier for tracking.
66    pub user: Option<String>,
67}
68
69impl EmbedRequest {
70    /// Create a new embedding request with the specified model.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use xai_grpc_client::EmbedRequest;
76    ///
77    /// let request = EmbedRequest::new("embed-large-v1");
78    /// ```
79    pub fn new(model: impl Into<String>) -> Self {
80        Self {
81            inputs: Vec::new(),
82            model: model.into(),
83            encoding_format: EmbedEncodingFormat::Float,
84            user: None,
85        }
86    }
87
88    /// Add a text string to embed.
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use xai_grpc_client::EmbedRequest;
94    ///
95    /// let request = EmbedRequest::new("embed-large-v1")
96    ///     .add_text("Hello, world!");
97    /// ```
98    pub fn add_text(mut self, text: impl Into<String>) -> Self {
99        self.inputs.push(EmbedInput::Text(text.into()));
100        self
101    }
102
103    /// Add an image URL to embed.
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// use xai_grpc_client::EmbedRequest;
109    ///
110    /// let request = EmbedRequest::new("embed-vision-v1")
111    ///     .add_image("https://example.com/image.jpg");
112    /// ```
113    pub fn add_image(self, url: impl Into<String>) -> Self {
114        self.add_image_with_detail(url, ImageDetail::Auto)
115    }
116
117    /// Add an image URL with specific detail level.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// use xai_grpc_client::{EmbedRequest, ImageDetail};
123    ///
124    /// let request = EmbedRequest::new("embed-vision-v1")
125    ///     .add_image_with_detail("https://example.com/image.jpg", ImageDetail::High);
126    /// ```
127    pub fn add_image_with_detail(mut self, url: impl Into<String>, detail: ImageDetail) -> Self {
128        self.inputs.push(EmbedInput::Image {
129            url: url.into(),
130            detail,
131        });
132        self
133    }
134
135    /// Set the encoding format for embeddings.
136    ///
137    /// # Examples
138    ///
139    /// ```
140    /// use xai_grpc_client::{EmbedRequest, EmbedEncodingFormat};
141    ///
142    /// let request = EmbedRequest::new("embed-large-v1")
143    ///     .with_encoding_format(EmbedEncodingFormat::Base64);
144    /// ```
145    pub fn with_encoding_format(mut self, format: EmbedEncodingFormat) -> Self {
146        self.encoding_format = format;
147        self
148    }
149
150    /// Set the user identifier for tracking.
151    pub fn with_user(mut self, user: impl Into<String>) -> Self {
152        self.user = Some(user.into());
153        self
154    }
155}
156
157/// Input to be embedded (text or image).
158#[derive(Clone, Debug)]
159pub enum EmbedInput {
160    /// Text string to embed.
161    Text(String),
162    /// Image URL to embed with optional detail level.
163    Image {
164        /// URL of the image.
165        url: String,
166        /// Detail level for processing.
167        detail: ImageDetail,
168    },
169}
170
171/// Encoding format for embedding vectors.
172#[derive(Clone, Debug, PartialEq, Eq)]
173pub enum EmbedEncodingFormat {
174    /// Return embeddings as arrays of floats.
175    Float,
176    /// Return embeddings as base64-encoded strings.
177    Base64,
178}
179
180/// Response from an embedding request.
181#[derive(Clone, Debug)]
182pub struct EmbedResponse {
183    /// Request identifier.
184    pub id: String,
185    /// Generated embeddings (one per input).
186    pub embeddings: Vec<Embedding>,
187    /// Token usage statistics.
188    pub usage: EmbeddingUsage,
189    /// Model name used (may differ from request if alias was used).
190    pub model: String,
191    /// Backend configuration fingerprint.
192    pub system_fingerprint: String,
193}
194
195/// A single embedding vector.
196#[derive(Clone, Debug)]
197pub struct Embedding {
198    /// Index of the input that generated this embedding.
199    pub index: usize,
200    /// The embedding vector.
201    pub vector: Vec<f32>,
202}
203
204/// Usage statistics for an embedding request.
205#[derive(Clone, Debug, Default)]
206pub struct EmbeddingUsage {
207    /// Number of text embeddings generated.
208    pub num_text_embeddings: u32,
209    /// Number of image embeddings generated.
210    pub num_image_embeddings: u32,
211}
212
213impl From<proto::EmbeddingUsage> for EmbeddingUsage {
214    fn from(proto: proto::EmbeddingUsage) -> Self {
215        Self {
216            num_text_embeddings: proto.num_text_embeddings as u32,
217            num_image_embeddings: proto.num_image_embeddings as u32,
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_embed_request_builder() {
228        let request = EmbedRequest::new("embed-large-v1")
229            .add_text("Hello")
230            .add_text("World");
231
232        assert_eq!(request.model, "embed-large-v1");
233        assert_eq!(request.inputs.len(), 2);
234        assert!(matches!(request.inputs[0], EmbedInput::Text(_)));
235    }
236
237    #[test]
238    fn test_embed_request_with_images() {
239        let request = EmbedRequest::new("embed-vision-v1")
240            .add_image("https://example.com/img1.jpg")
241            .add_image_with_detail("https://example.com/img2.jpg", ImageDetail::High);
242
243        assert_eq!(request.inputs.len(), 2);
244        assert!(matches!(request.inputs[0], EmbedInput::Image { .. }));
245        assert!(matches!(request.inputs[1], EmbedInput::Image { .. }));
246    }
247
248    #[test]
249    fn test_embed_request_mixed() {
250        let request = EmbedRequest::new("embed-multimodal-v1")
251            .add_text("Description")
252            .add_image("https://example.com/img.jpg");
253
254        assert_eq!(request.inputs.len(), 2);
255    }
256
257    #[test]
258    fn test_encoding_format() {
259        let request =
260            EmbedRequest::new("embed-large-v1").with_encoding_format(EmbedEncodingFormat::Base64);
261
262        assert_eq!(request.encoding_format, EmbedEncodingFormat::Base64);
263    }
264
265    #[test]
266    fn test_with_user() {
267        let request = EmbedRequest::new("embed-large-v1").with_user("user123");
268
269        assert_eq!(request.user, Some("user123".to_string()));
270    }
271
272    #[test]
273    fn test_embedding_usage_default() {
274        let usage = EmbeddingUsage::default();
275        assert_eq!(usage.num_text_embeddings, 0);
276        assert_eq!(usage.num_image_embeddings, 0);
277    }
278
279    #[test]
280    fn test_embedding_usage_from_proto() {
281        let proto = proto::EmbeddingUsage {
282            num_text_embeddings: 5,
283            num_image_embeddings: 2,
284        };
285
286        let usage: EmbeddingUsage = proto.into();
287        assert_eq!(usage.num_text_embeddings, 5);
288        assert_eq!(usage.num_image_embeddings, 2);
289    }
290}