1use crate::proto;
10
11#[derive(Debug, Clone)]
13pub struct SampleRequest {
14 pub prompts: Vec<String>,
16 pub model: String,
18 pub n: Option<i32>,
20 pub max_tokens: Option<i32>,
22 pub seed: Option<i32>,
24 pub stop: Vec<String>,
26 pub temperature: Option<f32>,
28 pub top_p: Option<f32>,
30 pub frequency_penalty: Option<f32>,
32 pub presence_penalty: Option<f32>,
34 pub logprobs: bool,
36 pub top_logprobs: Option<i32>,
38 pub user: Option<String>,
40}
41
42impl SampleRequest {
43 pub fn new(model: impl Into<String>) -> Self {
45 Self {
46 prompts: Vec::new(),
47 model: model.into(),
48 n: None,
49 max_tokens: None,
50 seed: None,
51 stop: Vec::new(),
52 temperature: None,
53 top_p: None,
54 frequency_penalty: None,
55 presence_penalty: None,
56 logprobs: false,
57 top_logprobs: None,
58 user: None,
59 }
60 }
61
62 pub fn add_prompt(mut self, prompt: impl Into<String>) -> Self {
64 self.prompts.push(prompt.into());
65 self
66 }
67
68 pub fn with_n(mut self, n: i32) -> Self {
70 self.n = Some(n);
71 self
72 }
73
74 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
76 self.max_tokens = Some(max_tokens);
77 self
78 }
79
80 pub fn with_temperature(mut self, temperature: f32) -> Self {
82 self.temperature = Some(temperature);
83 self
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct SampleResponse {
90 pub id: String,
92 pub choices: Vec<SampleChoice>,
94 pub model: String,
96 pub total_tokens: i32,
98}
99
100#[derive(Debug, Clone)]
102pub struct SampleChoice {
103 pub index: i32,
105 pub text: String,
107 pub finish_reason: String,
109}
110
111impl From<proto::SampleTextResponse> for SampleResponse {
112 fn from(proto: proto::SampleTextResponse) -> Self {
113 Self {
114 id: proto.id,
115 choices: proto.choices.into_iter().map(Into::into).collect(),
116 model: proto.model,
117 total_tokens: proto.usage.map(|u| u.total_tokens).unwrap_or(0),
118 }
119 }
120}
121
122impl From<proto::SampleChoice> for SampleChoice {
123 fn from(proto: proto::SampleChoice) -> Self {
124 let finish_reason = match proto::FinishReason::try_from(proto.finish_reason) {
125 Ok(proto::FinishReason::ReasonStop) => "stop",
126 Ok(proto::FinishReason::ReasonMaxLen) => "length",
127 Ok(proto::FinishReason::ReasonMaxContext) => "max_context",
128 Ok(proto::FinishReason::ReasonToolCalls) => "tool_calls",
129 Ok(proto::FinishReason::ReasonTimeLimit) => "time_limit",
130 _ => "unknown",
131 };
132
133 Self {
134 index: proto.index,
135 text: proto.text,
136 finish_reason: finish_reason.to_string(),
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_sample_request_builder() {
147 let request = SampleRequest::new("grok-2-1212")
148 .add_prompt("Hello, world!")
149 .add_prompt("How are you?")
150 .with_n(3)
151 .with_max_tokens(100)
152 .with_temperature(0.8);
153
154 assert_eq!(request.model, "grok-2-1212");
155 assert_eq!(request.prompts.len(), 2);
156 assert_eq!(request.prompts[0], "Hello, world!");
157 assert_eq!(request.prompts[1], "How are you?");
158 assert_eq!(request.n, Some(3));
159 assert_eq!(request.max_tokens, Some(100));
160 assert_eq!(request.temperature, Some(0.8));
161 }
162
163 #[test]
164 fn test_sample_request_minimal() {
165 let request = SampleRequest::new("grok-beta");
166
167 assert_eq!(request.model, "grok-beta");
168 assert_eq!(request.prompts.len(), 0);
169 assert_eq!(request.n, None);
170 assert_eq!(request.max_tokens, None);
171 assert_eq!(request.temperature, None);
172 assert!(!request.logprobs);
173 }
174
175 #[test]
176 fn test_sample_choice_from_proto() {
177 let proto_choice = proto::SampleChoice {
178 finish_reason: proto::FinishReason::ReasonStop as i32,
179 index: 0,
180 text: "Hello there!".to_string(),
181 };
182
183 let choice: SampleChoice = proto_choice.into();
184 assert_eq!(choice.index, 0);
185 assert_eq!(choice.text, "Hello there!");
186 assert_eq!(choice.finish_reason, "stop");
187 }
188
189 #[test]
190 fn test_sample_choice_finish_reasons() {
191 let test_cases = vec![
192 (proto::FinishReason::ReasonStop, "stop"),
193 (proto::FinishReason::ReasonMaxLen, "length"),
194 (proto::FinishReason::ReasonMaxContext, "max_context"),
195 (proto::FinishReason::ReasonToolCalls, "tool_calls"),
196 (proto::FinishReason::ReasonTimeLimit, "time_limit"),
197 (proto::FinishReason::ReasonInvalid, "unknown"),
198 ];
199
200 for (proto_reason, expected_str) in test_cases {
201 let proto_choice = proto::SampleChoice {
202 finish_reason: proto_reason as i32,
203 index: 0,
204 text: "test".to_string(),
205 };
206
207 let choice: SampleChoice = proto_choice.into();
208 assert_eq!(choice.finish_reason, expected_str);
209 }
210 }
211
212 #[test]
213 fn test_sample_response_from_proto() {
214 let proto_response = proto::SampleTextResponse {
215 id: "req-123".to_string(),
216 choices: vec![
217 proto::SampleChoice {
218 finish_reason: proto::FinishReason::ReasonStop as i32,
219 index: 0,
220 text: "First choice".to_string(),
221 },
222 proto::SampleChoice {
223 finish_reason: proto::FinishReason::ReasonMaxLen as i32,
224 index: 1,
225 text: "Second choice".to_string(),
226 },
227 ],
228 created: None,
229 model: "grok-2-1212".to_string(),
230 system_fingerprint: "fp_test".to_string(),
231 usage: Some(proto::SamplingUsage {
232 prompt_tokens: 10,
233 completion_tokens: 20,
234 total_tokens: 30,
235 cached_prompt_text_tokens: 0,
236 num_sources_used: 0,
237 prompt_image_tokens: 0,
238 reasoning_tokens: 0,
239 prompt_text_tokens: 10,
240 server_side_tools_used: vec![],
241 }),
242 };
243
244 let response: SampleResponse = proto_response.into();
245 assert_eq!(response.id, "req-123");
246 assert_eq!(response.model, "grok-2-1212");
247 assert_eq!(response.total_tokens, 30);
248 assert_eq!(response.choices.len(), 2);
249 assert_eq!(response.choices[0].text, "First choice");
250 assert_eq!(response.choices[1].text, "Second choice");
251 }
252
253 #[test]
254 fn test_sample_request_clone() {
255 let request = SampleRequest::new("grok-2-1212")
256 .add_prompt("Test")
257 .with_max_tokens(100);
258
259 let cloned = request.clone();
260 assert_eq!(cloned.model, request.model);
261 assert_eq!(cloned.prompts, request.prompts);
262 assert_eq!(cloned.max_tokens, request.max_tokens);
263 }
264}