1use serde::{Deserialize, Serialize};
2use std::fs::File;
3use std::io::{BufRead, BufReader, Lines};
4use std::path::Path;
5
6pub type TokenizerFn = Box<dyn Fn(&PromptInput) -> Result<Vec<u32>, String> + Send + Sync>;
11
12pub type BatchTokenizerFn = Box<dyn Fn(&[PromptInput]) -> Result<Vec<Vec<u32>>, String> + Send>;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchRequest {
19 pub custom_id: String,
20 pub method: String,
21 pub url: String,
22 pub body: RequestBody,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct RequestBody {
27 pub model: String,
28 #[serde(default)]
29 pub max_tokens: Option<u32>,
30 #[serde(default)]
31 pub temperature: Option<f32>,
32 #[serde(default)]
33 pub top_p: Option<f32>,
34 #[serde(flatten)]
35 pub input: RequestInput,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum RequestInput {
41 Chat { messages: Vec<Message> },
42 Completion { prompt: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Message {
47 pub role: String,
48 pub content: String,
49}
50
51#[derive(Debug, Clone)]
52pub enum PromptInput {
53 Messages(Vec<Message>),
54 Prompt(String),
55}
56
57#[derive(Debug, Clone)]
59pub struct DatasetEntry {
60 pub request_id: String,
61 pub prompt_tokens: Vec<u32>,
62 pub max_output_tokens: Option<u32>,
63}
64
65impl DatasetEntry {
66 pub fn num_prompt_tokens(&self) -> u32 {
67 self.prompt_tokens.len() as u32
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct UnparsedEntry {
74 pub request_id: String,
75 pub prompt_input: PromptInput,
76 pub max_output_tokens: Option<u32>,
77}
78
79pub struct DatasetIterator<R: BufRead> {
82 lines: Lines<R>,
83 line_num: usize,
84 sent_end_signal: bool,
85}
86
87impl<R: BufRead> DatasetIterator<R> {
88 pub fn new(reader: R) -> Self {
89 Self {
90 lines: reader.lines(),
91 line_num: 0,
92 sent_end_signal: false,
93 }
94 }
95}
96
97impl<R: BufRead> Iterator for DatasetIterator<R> {
98 type Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>;
99
100 fn next(&mut self) -> Option<Self::Item> {
101 loop {
102 self.line_num += 1;
103 let line = match self.lines.next() {
104 Some(Ok(line)) => line,
105 Some(Err(e)) => return Some(Err(Box::new(e))),
106 None => {
107 if !self.sent_end_signal {
109 self.sent_end_signal = true;
110 return Some(Ok(None));
111 } else {
112 return None;
113 }
114 }
115 };
116
117 if line.trim().is_empty() {
119 continue;
120 }
121
122 let batch_request: BatchRequest = match serde_json::from_str(&line) {
124 Ok(req) => req,
125 Err(e) => {
126 return Some(Err(format!(
127 "Failed to parse line {}: {}",
128 self.line_num, e
129 )
130 .into()))
131 }
132 };
133
134 return Some(Ok(Some(UnparsedEntry {
135 request_id: batch_request.custom_id,
136 prompt_input: match batch_request.body.input {
137 RequestInput::Chat { messages } => PromptInput::Messages(messages),
138 RequestInput::Completion { prompt } => PromptInput::Prompt(prompt),
139 },
140 max_output_tokens: batch_request.body.max_tokens,
141 })));
142 }
143 }
144}
145
146pub struct DatasetLoader;
148
149impl DatasetLoader {
150 fn count_non_empty_lines<R: BufRead>(reader: R) -> Result<usize, std::io::Error> {
151 let mut count = 0;
152 for line in reader.lines() {
153 if !line?.trim().is_empty() {
154 count += 1;
155 }
156 }
157 Ok(count)
158 }
159
160 pub fn count_entries<P: AsRef<Path>>(path: P) -> Result<usize, std::io::Error> {
162 let file = File::open(path)?;
163 let reader = BufReader::new(file);
164 Self::count_non_empty_lines(reader)
165 }
166
167 pub fn from_file<P: AsRef<Path>>(
170 path: P,
171 ) -> Result<DatasetIterator<BufReader<File>>, std::io::Error> {
172 let file = File::open(path)?;
173 let reader = BufReader::new(file);
174 Ok(DatasetIterator::new(reader))
175 }
176
177 pub fn from_string(data: String) -> DatasetIterator<BufReader<std::io::Cursor<String>>> {
180 let cursor = std::io::Cursor::new(data);
181 let reader = BufReader::new(cursor);
182 DatasetIterator::new(reader)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::io::{self, Read};
190
191 #[test]
192 fn test_parse_batch_request() {
193 let json = r#"{
194 "custom_id": "request-1",
195 "method": "POST",
196 "url": "/v1/chat/completions",
197 "body": {
198 "model": "gpt-3.5-turbo",
199 "messages": [
200 {"role": "system", "content": "You are a helpful assistant."},
201 {"role": "user", "content": "Hello, how are you?"}
202 ],
203 "max_tokens": 100
204 }
205 }"#;
206
207 let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
208 assert_eq!(batch_request.custom_id, "request-1");
209 match batch_request.body.input {
210 RequestInput::Chat { messages } => assert_eq!(messages.len(), 2),
211 RequestInput::Completion { .. } => panic!("expected chat request"),
212 }
213 assert_eq!(batch_request.body.max_tokens, Some(100));
214 }
215
216 #[test]
217 fn test_dataset_iterator() {
218 let test_data = r#"{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 10}}
219{"custom_id": "req-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "World"}], "max_tokens": 20}}"#;
220
221 let mut iter = DatasetLoader::from_string(test_data.to_string());
222
223 let entry1 = iter.next().unwrap().unwrap().unwrap();
224 assert_eq!(entry1.request_id, "req-1");
225 match entry1.prompt_input {
226 PromptInput::Messages(messages) => {
227 assert_eq!(messages.len(), 1);
228 assert_eq!(messages[0].role, "user");
229 assert_eq!(messages[0].content, "Hello");
230 }
231 PromptInput::Prompt(_) => panic!("expected chat prompt input"),
232 }
233 assert_eq!(entry1.max_output_tokens, Some(10));
234
235 let entry2 = iter.next().unwrap().unwrap().unwrap();
236 assert_eq!(entry2.request_id, "req-2");
237 assert_eq!(entry2.max_output_tokens, Some(20));
238
239 let end_signal = iter.next().unwrap().unwrap();
241 assert!(end_signal.is_none());
242
243 assert!(iter.next().is_none());
245 }
246
247 #[test]
248 fn test_parse_completion_batch_request() {
249 let json = r#"{
250 "custom_id": "request-2",
251 "method": "POST",
252 "url": "/v1/completions",
253 "body": {
254 "model": "gpt-3.5-turbo-instruct",
255 "prompt": "Hello, world",
256 "max_tokens": 32
257 }
258 }"#;
259
260 let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
261 assert_eq!(batch_request.custom_id, "request-2");
262 match batch_request.body.input {
263 RequestInput::Completion { prompt } => assert_eq!(prompt, "Hello, world"),
264 RequestInput::Chat { .. } => panic!("expected completion request"),
265 }
266 assert_eq!(batch_request.body.max_tokens, Some(32));
267 }
268
269 #[test]
270 fn test_count_entries_propagates_read_errors() {
271 struct FailingReader {
272 bytes: Vec<u8>,
273 pos: usize,
274 fail_after: usize,
275 }
276
277 impl Read for FailingReader {
278 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
279 if self.pos >= self.fail_after {
280 return Err(io::Error::other("forced read failure"));
281 }
282
283 let remaining_until_failure = self.fail_after.saturating_sub(self.pos);
284 let remaining_bytes = self.bytes.len().saturating_sub(self.pos);
285 let to_copy = buf.len().min(remaining_until_failure).min(remaining_bytes);
286
287 if to_copy == 0 {
288 return Ok(0);
289 }
290
291 buf[..to_copy].copy_from_slice(&self.bytes[self.pos..self.pos + to_copy]);
292 self.pos += to_copy;
293 Ok(to_copy)
294 }
295 }
296
297 let reader = BufReader::new(FailingReader {
298 bytes: b"first\nsecond\n".to_vec(),
299 pos: 0,
300 fail_after: 7,
301 });
302
303 let err = DatasetLoader::count_non_empty_lines(reader).unwrap_err();
304 assert_eq!(err.kind(), io::ErrorKind::Other);
305 }
306}