1use serde::{Deserialize, Serialize};
2use std::fs::File;
3use std::io::{BufRead, BufReader, Lines};
4use std::path::Path;
5
6pub type TokenizerFn = Box<dyn Fn(&PromptInput) -> Result<Vec<u32>, String> + Send + Sync>;
11
12pub type BatchTokenizerFn = Box<dyn Fn(&[PromptInput]) -> Result<Vec<Vec<u32>>, String> + Send>;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchRequest {
19 pub custom_id: String,
20 pub method: String,
21 pub url: String,
22 pub body: RequestBody,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct RequestBody {
27 pub model: String,
28 #[serde(default)]
29 pub max_tokens: Option<u32>,
30 #[serde(default)]
31 pub temperature: Option<f32>,
32 #[serde(default)]
33 pub top_p: Option<f32>,
34 #[serde(flatten)]
35 pub input: RequestInput,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum RequestInput {
41 Chat { messages: Vec<Message> },
42 Completion { prompt: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Message {
47 pub role: String,
48 pub content: String,
49}
50
51#[derive(Debug, Clone)]
52pub enum PromptInput {
53 Messages(Vec<Message>),
54 Prompt(String),
55}
56
57#[derive(Debug, Clone)]
59pub struct DatasetEntry {
60 pub request_id: String,
61 pub prompt_tokens: Vec<u32>,
62 pub max_output_tokens: Option<u32>,
63}
64
65impl DatasetEntry {
66 pub fn num_prompt_tokens(&self) -> u32 {
67 self.prompt_tokens.len() as u32
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct UnparsedEntry {
74 pub request_id: String,
75 pub prompt_input: PromptInput,
76 pub max_output_tokens: Option<u32>,
77}
78
79pub struct DatasetIterator<R: BufRead> {
82 lines: Lines<R>,
83 line_num: usize,
84 sent_end_signal: bool,
85}
86
87impl<R: BufRead> DatasetIterator<R> {
88 pub fn new(reader: R) -> Self {
89 Self {
90 lines: reader.lines(),
91 line_num: 0,
92 sent_end_signal: false,
93 }
94 }
95}
96
97impl<R: BufRead> Iterator for DatasetIterator<R> {
98 type Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>;
99
100 fn next(&mut self) -> Option<Self::Item> {
101 loop {
102 self.line_num += 1;
103 let line = match self.lines.next() {
104 Some(Ok(line)) => line,
105 Some(Err(e)) => return Some(Err(Box::new(e))),
106 None => {
107 if !self.sent_end_signal {
109 self.sent_end_signal = true;
110 return Some(Ok(None));
111 } else {
112 return None;
113 }
114 }
115 };
116
117 if line.trim().is_empty() {
119 continue;
120 }
121
122 let batch_request: BatchRequest = match serde_json::from_str(&line) {
124 Ok(req) => req,
125 Err(e) => {
126 return Some(Err(format!(
127 "Failed to parse line {}: {}",
128 self.line_num, e
129 )
130 .into()))
131 }
132 };
133
134 return Some(Ok(Some(UnparsedEntry {
135 request_id: batch_request.custom_id,
136 prompt_input: match batch_request.body.input {
137 RequestInput::Chat { messages } => PromptInput::Messages(messages),
138 RequestInput::Completion { prompt } => PromptInput::Prompt(prompt),
139 },
140 max_output_tokens: batch_request.body.max_tokens,
141 })));
142 }
143 }
144}
145
146pub struct DatasetLoader;
148
149impl DatasetLoader {
150 pub fn count_entries<P: AsRef<Path>>(path: P) -> Result<usize, std::io::Error> {
152 let file = File::open(path)?;
153 let reader = BufReader::new(file);
154 let count = reader
155 .lines()
156 .filter_map(|line| line.ok())
157 .filter(|line| !line.trim().is_empty())
158 .count();
159 Ok(count)
160 }
161
162 pub fn from_file<P: AsRef<Path>>(
165 path: P,
166 ) -> Result<DatasetIterator<BufReader<File>>, std::io::Error> {
167 let file = File::open(path)?;
168 let reader = BufReader::new(file);
169 Ok(DatasetIterator::new(reader))
170 }
171
172 pub fn from_string(data: String) -> DatasetIterator<BufReader<std::io::Cursor<String>>> {
175 let cursor = std::io::Cursor::new(data);
176 let reader = BufReader::new(cursor);
177 DatasetIterator::new(reader)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_parse_batch_request() {
187 let json = r#"{
188 "custom_id": "request-1",
189 "method": "POST",
190 "url": "/v1/chat/completions",
191 "body": {
192 "model": "gpt-3.5-turbo",
193 "messages": [
194 {"role": "system", "content": "You are a helpful assistant."},
195 {"role": "user", "content": "Hello, how are you?"}
196 ],
197 "max_tokens": 100
198 }
199 }"#;
200
201 let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
202 assert_eq!(batch_request.custom_id, "request-1");
203 match batch_request.body.input {
204 RequestInput::Chat { messages } => assert_eq!(messages.len(), 2),
205 RequestInput::Completion { .. } => panic!("expected chat request"),
206 }
207 assert_eq!(batch_request.body.max_tokens, Some(100));
208 }
209
210 #[test]
211 fn test_dataset_iterator() {
212 let test_data = r#"{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 10}}
213{"custom_id": "req-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "World"}], "max_tokens": 20}}"#;
214
215 let mut iter = DatasetLoader::from_string(test_data.to_string());
216
217 let entry1 = iter.next().unwrap().unwrap().unwrap();
218 assert_eq!(entry1.request_id, "req-1");
219 match entry1.prompt_input {
220 PromptInput::Messages(messages) => {
221 assert_eq!(messages.len(), 1);
222 assert_eq!(messages[0].role, "user");
223 assert_eq!(messages[0].content, "Hello");
224 }
225 PromptInput::Prompt(_) => panic!("expected chat prompt input"),
226 }
227 assert_eq!(entry1.max_output_tokens, Some(10));
228
229 let entry2 = iter.next().unwrap().unwrap().unwrap();
230 assert_eq!(entry2.request_id, "req-2");
231 assert_eq!(entry2.max_output_tokens, Some(20));
232
233 let end_signal = iter.next().unwrap().unwrap();
235 assert!(end_signal.is_none());
236
237 assert!(iter.next().is_none());
239 }
240
241 #[test]
242 fn test_parse_completion_batch_request() {
243 let json = r#"{
244 "custom_id": "request-2",
245 "method": "POST",
246 "url": "/v1/completions",
247 "body": {
248 "model": "gpt-3.5-turbo-instruct",
249 "prompt": "Hello, world",
250 "max_tokens": 32
251 }
252 }"#;
253
254 let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
255 assert_eq!(batch_request.custom_id, "request-2");
256 match batch_request.body.input {
257 RequestInput::Completion { prompt } => assert_eq!(prompt, "Hello, world"),
258 RequestInput::Chat { .. } => panic!("expected completion request"),
259 }
260 assert_eq!(batch_request.body.max_tokens, Some(32));
261 }
262}