xai_grpc_client/
sample.rs

1//! Sample API for raw text sampling
2//!
3//! The Sample service provides a simpler alternative to the Chat service for basic text completion.
4//! It's useful for straightforward text generation without the conversation structure.
5//!
6//! **Note**: For most use cases, the Chat API (`GrokClient::complete_chat`) is recommended
7//! as it provides more features and better conversation management.
8
9use crate::proto;
10
11/// Request for text sampling
12#[derive(Debug, Clone)]
13pub struct SampleRequest {
14    /// Text prompts to sample from
15    pub prompts: Vec<String>,
16    /// Model name
17    pub model: String,
18    /// Number of completions (1-128)
19    pub n: Option<i32>,
20    /// Maximum tokens to generate
21    pub max_tokens: Option<i32>,
22    /// Random seed for determinism
23    pub seed: Option<i32>,
24    /// Stop sequences
25    pub stop: Vec<String>,
26    /// Temperature (0-2)
27    pub temperature: Option<f32>,
28    /// Top-p sampling
29    pub top_p: Option<f32>,
30    /// Frequency penalty (-2 to 2)
31    pub frequency_penalty: Option<f32>,
32    /// Presence penalty (-2 to 2)
33    pub presence_penalty: Option<f32>,
34    /// Return log probabilities
35    pub logprobs: bool,
36    /// Number of top logprobs (0-8)
37    pub top_logprobs: Option<i32>,
38    /// User identifier
39    pub user: Option<String>,
40}
41
42impl SampleRequest {
43    /// Create a new sample request
44    pub fn new(model: impl Into<String>) -> Self {
45        Self {
46            prompts: Vec::new(),
47            model: model.into(),
48            n: None,
49            max_tokens: None,
50            seed: None,
51            stop: Vec::new(),
52            temperature: None,
53            top_p: None,
54            frequency_penalty: None,
55            presence_penalty: None,
56            logprobs: false,
57            top_logprobs: None,
58            user: None,
59        }
60    }
61
62    /// Add a prompt
63    pub fn add_prompt(mut self, prompt: impl Into<String>) -> Self {
64        self.prompts.push(prompt.into());
65        self
66    }
67
68    /// Set number of completions
69    pub fn with_n(mut self, n: i32) -> Self {
70        self.n = Some(n);
71        self
72    }
73
74    /// Set max tokens
75    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
76        self.max_tokens = Some(max_tokens);
77        self
78    }
79
80    /// Set temperature
81    pub fn with_temperature(mut self, temperature: f32) -> Self {
82        self.temperature = Some(temperature);
83        self
84    }
85}
86
87/// Response from sampling
88#[derive(Debug, Clone)]
89pub struct SampleResponse {
90    /// Request ID
91    pub id: String,
92    /// Generated completions
93    pub choices: Vec<SampleChoice>,
94    /// Model used
95    pub model: String,
96    /// Token usage
97    pub total_tokens: i32,
98}
99
100/// A single completion choice
101#[derive(Debug, Clone)]
102pub struct SampleChoice {
103    /// Index of this choice
104    pub index: i32,
105    /// Generated text
106    pub text: String,
107    /// Finish reason
108    pub finish_reason: String,
109}
110
111impl From<proto::SampleTextResponse> for SampleResponse {
112    fn from(proto: proto::SampleTextResponse) -> Self {
113        Self {
114            id: proto.id,
115            choices: proto.choices.into_iter().map(Into::into).collect(),
116            model: proto.model,
117            total_tokens: proto.usage.map(|u| u.total_tokens).unwrap_or(0),
118        }
119    }
120}
121
122impl From<proto::SampleChoice> for SampleChoice {
123    fn from(proto: proto::SampleChoice) -> Self {
124        let finish_reason = match proto::FinishReason::try_from(proto.finish_reason) {
125            Ok(proto::FinishReason::ReasonStop) => "stop",
126            Ok(proto::FinishReason::ReasonMaxLen) => "length",
127            Ok(proto::FinishReason::ReasonMaxContext) => "max_context",
128            Ok(proto::FinishReason::ReasonToolCalls) => "tool_calls",
129            Ok(proto::FinishReason::ReasonTimeLimit) => "time_limit",
130            _ => "unknown",
131        };
132
133        Self {
134            index: proto.index,
135            text: proto.text,
136            finish_reason: finish_reason.to_string(),
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_sample_request_builder() {
147        let request = SampleRequest::new("grok-2-1212")
148            .add_prompt("Hello, world!")
149            .add_prompt("How are you?")
150            .with_n(3)
151            .with_max_tokens(100)
152            .with_temperature(0.8);
153
154        assert_eq!(request.model, "grok-2-1212");
155        assert_eq!(request.prompts.len(), 2);
156        assert_eq!(request.prompts[0], "Hello, world!");
157        assert_eq!(request.prompts[1], "How are you?");
158        assert_eq!(request.n, Some(3));
159        assert_eq!(request.max_tokens, Some(100));
160        assert_eq!(request.temperature, Some(0.8));
161    }
162
163    #[test]
164    fn test_sample_request_minimal() {
165        let request = SampleRequest::new("grok-beta");
166
167        assert_eq!(request.model, "grok-beta");
168        assert_eq!(request.prompts.len(), 0);
169        assert_eq!(request.n, None);
170        assert_eq!(request.max_tokens, None);
171        assert_eq!(request.temperature, None);
172        assert!(!request.logprobs);
173    }
174
175    #[test]
176    fn test_sample_choice_from_proto() {
177        let proto_choice = proto::SampleChoice {
178            finish_reason: proto::FinishReason::ReasonStop as i32,
179            index: 0,
180            text: "Hello there!".to_string(),
181        };
182
183        let choice: SampleChoice = proto_choice.into();
184        assert_eq!(choice.index, 0);
185        assert_eq!(choice.text, "Hello there!");
186        assert_eq!(choice.finish_reason, "stop");
187    }
188
189    #[test]
190    fn test_sample_choice_finish_reasons() {
191        let test_cases = vec![
192            (proto::FinishReason::ReasonStop, "stop"),
193            (proto::FinishReason::ReasonMaxLen, "length"),
194            (proto::FinishReason::ReasonMaxContext, "max_context"),
195            (proto::FinishReason::ReasonToolCalls, "tool_calls"),
196            (proto::FinishReason::ReasonTimeLimit, "time_limit"),
197            (proto::FinishReason::ReasonInvalid, "unknown"),
198        ];
199
200        for (proto_reason, expected_str) in test_cases {
201            let proto_choice = proto::SampleChoice {
202                finish_reason: proto_reason as i32,
203                index: 0,
204                text: "test".to_string(),
205            };
206
207            let choice: SampleChoice = proto_choice.into();
208            assert_eq!(choice.finish_reason, expected_str);
209        }
210    }
211
212    #[test]
213    fn test_sample_response_from_proto() {
214        let proto_response = proto::SampleTextResponse {
215            id: "req-123".to_string(),
216            choices: vec![
217                proto::SampleChoice {
218                    finish_reason: proto::FinishReason::ReasonStop as i32,
219                    index: 0,
220                    text: "First choice".to_string(),
221                },
222                proto::SampleChoice {
223                    finish_reason: proto::FinishReason::ReasonMaxLen as i32,
224                    index: 1,
225                    text: "Second choice".to_string(),
226                },
227            ],
228            created: None,
229            model: "grok-2-1212".to_string(),
230            system_fingerprint: "fp_test".to_string(),
231            usage: Some(proto::SamplingUsage {
232                prompt_tokens: 10,
233                completion_tokens: 20,
234                total_tokens: 30,
235                cached_prompt_text_tokens: 0,
236                num_sources_used: 0,
237                prompt_image_tokens: 0,
238                reasoning_tokens: 0,
239                prompt_text_tokens: 10,
240                server_side_tools_used: vec![],
241            }),
242        };
243
244        let response: SampleResponse = proto_response.into();
245        assert_eq!(response.id, "req-123");
246        assert_eq!(response.model, "grok-2-1212");
247        assert_eq!(response.total_tokens, 30);
248        assert_eq!(response.choices.len(), 2);
249        assert_eq!(response.choices[0].text, "First choice");
250        assert_eq!(response.choices[1].text, "Second choice");
251    }
252
253    #[test]
254    fn test_sample_request_clone() {
255        let request = SampleRequest::new("grok-2-1212")
256            .add_prompt("Test")
257            .with_max_tokens(100);
258
259        let cloned = request.clone();
260        assert_eq!(cloned.model, request.model);
261        assert_eq!(cloned.prompts, request.prompts);
262        assert_eq!(cloned.max_tokens, request.max_tokens);
263    }
264}