Skip to main content

xai_rust/api/
tokenizer.rs

1//! Tokenizer API for counting and getting tokens.
2
3use crate::client::XaiClient;
4use crate::models::tokenizer::{TokenizeRequest, TokenizeResponse};
5use crate::{Error, Result};
6
7/// Tokenizer API client.
8#[derive(Debug, Clone)]
9pub struct TokenizerApi {
10    client: XaiClient,
11}
12
13impl TokenizerApi {
14    pub(crate) fn new(client: XaiClient) -> Self {
15        Self { client }
16    }
17
18    /// Tokenize text using a specific model.
19    ///
20    /// # Example
21    ///
22    /// ```rust,no_run
23    /// use xai_rust::{XaiClient, TokenizeRequest};
24    ///
25    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
26    /// let client = XaiClient::from_env()?;
27    ///
28    /// let request = TokenizeRequest::new("grok-4", "Hello, world!");
29    /// let response = client.tokenizer().tokenize(request).await?;
30    ///
31    /// println!("Token count: {}", response.count());
32    /// println!("Tokens: {:?}", response.tokens);
33    /// # Ok(())
34    /// # }
35    /// ```
36    pub async fn tokenize(&self, request: TokenizeRequest) -> Result<TokenizeResponse> {
37        let url = format!("{}/tokenize-text", self.client.base_url());
38
39        let response = self
40            .client
41            .send(self.client.http().post(&url).json(&request))
42            .await?;
43
44        if !response.status().is_success() {
45            return Err(Error::from_response(response).await);
46        }
47
48        Ok(response.json().await?)
49    }
50
51    /// Tokenize text using a model name and text string.
52    ///
53    /// # Example
54    ///
55    /// ```rust,no_run
56    /// use xai_rust::XaiClient;
57    ///
58    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
59    /// let client = XaiClient::from_env()?;
60    ///
61    /// let response = client.tokenizer()
62    ///     .tokenize_text("grok-4", "Hello, world!")
63    ///     .await?;
64    ///
65    /// println!("Tokens: {:?}", response.tokens);
66    /// # Ok(())
67    /// # }
68    /// ```
69    pub async fn tokenize_text(
70        &self,
71        model: impl Into<String>,
72        text: impl Into<String>,
73    ) -> Result<TokenizeResponse> {
74        self.tokenize(TokenizeRequest::new(model, text)).await
75    }
76
77    /// Count tokens in text using a specific model.
78    ///
79    /// # Example
80    ///
81    /// ```rust,no_run
82    /// use xai_rust::XaiClient;
83    ///
84    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
85    /// let client = XaiClient::from_env()?;
86    ///
87    /// let count = client.tokenizer()
88    ///     .count_tokens("grok-4", "Hello, world!")
89    ///     .await?;
90    ///
91    /// println!("Token count: {}", count);
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub async fn count_tokens(
96        &self,
97        model: impl Into<String>,
98        text: impl Into<String>,
99    ) -> Result<usize> {
100        let response = self.tokenize_text(model, text).await?;
101        Ok(response.count())
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use serde_json::json;
109    use wiremock::matchers::{method, path};
110    use wiremock::{Mock, MockServer, ResponseTemplate};
111
112    #[tokio::test]
113    async fn tokenize_forwards_model_and_text_payload() {
114        let server = MockServer::start().await;
115
116        Mock::given(method("POST"))
117            .and(path("/tokenize-text"))
118            .respond_with(move |req: &wiremock::Request| {
119                let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
120                assert_eq!(body["model"], "grok-4");
121                assert_eq!(body["text"], "Hello tokenizer");
122                ResponseTemplate::new(200).set_body_json(json!({
123                    "tokens": [10, 20, 30],
124                    "token_count": 3
125                }))
126            })
127            .mount(&server)
128            .await;
129
130        let client = XaiClient::builder()
131            .api_key("test-key")
132            .base_url(server.uri())
133            .build()
134            .unwrap();
135
136        let response = client
137            .tokenizer()
138            .tokenize(TokenizeRequest::new("grok-4", "Hello tokenizer"))
139            .await
140            .unwrap();
141        assert_eq!(response.tokens, vec![10, 20, 30]);
142        assert_eq!(response.count(), 3);
143    }
144
145    #[tokio::test]
146    async fn count_tokens_prefers_explicit_token_count() {
147        let server = MockServer::start().await;
148
149        Mock::given(method("POST"))
150            .and(path("/tokenize-text"))
151            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
152                "tokens": [1, 2],
153                "token_count": 9
154            })))
155            .mount(&server)
156            .await;
157
158        let client = XaiClient::builder()
159            .api_key("test-key")
160            .base_url(server.uri())
161            .build()
162            .unwrap();
163
164        let count = client
165            .tokenizer()
166            .count_tokens("grok-4", "count this")
167            .await
168            .unwrap();
169        assert_eq!(count, 9);
170    }
171
172    #[tokio::test]
173    async fn count_tokens_falls_back_to_token_vector_length() {
174        let server = MockServer::start().await;
175
176        Mock::given(method("POST"))
177            .and(path("/tokenize-text"))
178            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
179                "tokens": [7, 8, 9, 10]
180            })))
181            .mount(&server)
182            .await;
183
184        let client = XaiClient::builder()
185            .api_key("test-key")
186            .base_url(server.uri())
187            .build()
188            .unwrap();
189
190        let count = client
191            .tokenizer()
192            .count_tokens("grok-4", "fallback count")
193            .await
194            .unwrap();
195        assert_eq!(count, 4);
196    }
197
198    #[tokio::test]
199    async fn tokenize_returns_api_error_for_non_success_response() {
200        let server = MockServer::start().await;
201
202        Mock::given(method("POST"))
203            .and(path("/tokenize-text"))
204            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
205                "error": {
206                    "message": "bad tokenize request",
207                    "type": "invalid_request_error"
208                }
209            })))
210            .mount(&server)
211            .await;
212
213        let client = XaiClient::builder()
214            .api_key("test-key")
215            .base_url(server.uri())
216            .build()
217            .unwrap();
218
219        let err = client
220            .tokenizer()
221            .tokenize_text("grok-4", "bad input")
222            .await
223            .unwrap_err();
224        match err {
225            Error::Api {
226                status,
227                message,
228                error_type,
229            } => {
230                assert_eq!(status, 400);
231                assert_eq!(message, "bad tokenize request");
232                assert_eq!(error_type.as_deref(), Some("invalid_request_error"));
233            }
234            other => panic!("expected api error, got {other:?}"),
235        }
236    }
237}