1use serde::{Deserialize, Serialize};
2use std::fs::File;
3use std::io::{BufRead, BufReader, Lines};
4use std::path::Path;
5
6pub type TokenizerFn = Box<dyn Fn(&[Message]) -> Result<Vec<u32>, String> + Send + Sync>;
11
12pub type BatchTokenizerFn = Box<dyn Fn(&[&[Message]]) -> Result<Vec<Vec<u32>>, String> + Send>;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchRequest {
19 pub custom_id: String,
20 pub method: String,
21 pub url: String,
22 pub body: RequestBody,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct RequestBody {
27 pub model: String,
28 pub messages: Vec<Message>,
29 #[serde(default)]
30 pub max_tokens: Option<u32>,
31 #[serde(default)]
32 pub temperature: Option<f32>,
33 #[serde(default)]
34 pub top_p: Option<f32>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Message {
39 pub role: String,
40 pub content: String,
41}
42
43#[derive(Debug, Clone)]
45pub struct DatasetEntry {
46 pub request_id: String,
47 pub prompt_tokens: Vec<u32>,
48 pub max_output_tokens: Option<u32>,
49}
50
51impl DatasetEntry {
52 pub fn num_prompt_tokens(&self) -> u32 {
53 self.prompt_tokens.len() as u32
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct UnparsedEntry {
60 pub request_id: String,
61 pub messages: Vec<Message>,
62 pub max_output_tokens: Option<u32>,
63}
64
65pub struct DatasetIterator<R: BufRead> {
68 lines: Lines<R>,
69 line_num: usize,
70 sent_end_signal: bool,
71}
72
73impl<R: BufRead> DatasetIterator<R> {
74 pub fn new(reader: R) -> Self {
75 Self {
76 lines: reader.lines(),
77 line_num: 0,
78 sent_end_signal: false,
79 }
80 }
81}
82
83impl<R: BufRead> Iterator for DatasetIterator<R> {
84 type Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>;
85
86 fn next(&mut self) -> Option<Self::Item> {
87 loop {
88 self.line_num += 1;
89 let line = match self.lines.next() {
90 Some(Ok(line)) => line,
91 Some(Err(e)) => return Some(Err(Box::new(e))),
92 None => {
93 if !self.sent_end_signal {
95 self.sent_end_signal = true;
96 return Some(Ok(None));
97 } else {
98 return None;
99 }
100 }
101 };
102
103 if line.trim().is_empty() {
105 continue;
106 }
107
108 let batch_request: BatchRequest = match serde_json::from_str(&line) {
110 Ok(req) => req,
111 Err(e) => {
112 return Some(Err(format!(
113 "Failed to parse line {}: {}",
114 self.line_num, e
115 )
116 .into()))
117 }
118 };
119
120 return Some(Ok(Some(UnparsedEntry {
121 request_id: batch_request.custom_id,
122 messages: batch_request.body.messages,
123 max_output_tokens: batch_request.body.max_tokens,
124 })));
125 }
126 }
127}
128
129pub struct DatasetLoader;
131
132impl DatasetLoader {
133 pub fn count_entries<P: AsRef<Path>>(path: P) -> Result<usize, std::io::Error> {
135 let file = File::open(path)?;
136 let reader = BufReader::new(file);
137 let count = reader
138 .lines()
139 .filter_map(|line| line.ok())
140 .filter(|line| !line.trim().is_empty())
141 .count();
142 Ok(count)
143 }
144
145 pub fn from_file<P: AsRef<Path>>(
148 path: P,
149 ) -> Result<DatasetIterator<BufReader<File>>, std::io::Error> {
150 let file = File::open(path)?;
151 let reader = BufReader::new(file);
152 Ok(DatasetIterator::new(reader))
153 }
154
155 pub fn from_string(data: String) -> DatasetIterator<BufReader<std::io::Cursor<String>>> {
158 let cursor = std::io::Cursor::new(data);
159 let reader = BufReader::new(cursor);
160 DatasetIterator::new(reader)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_parse_batch_request() {
170 let json = r#"{
171 "custom_id": "request-1",
172 "method": "POST",
173 "url": "/v1/chat/completions",
174 "body": {
175 "model": "gpt-3.5-turbo",
176 "messages": [
177 {"role": "system", "content": "You are a helpful assistant."},
178 {"role": "user", "content": "Hello, how are you?"}
179 ],
180 "max_tokens": 100
181 }
182 }"#;
183
184 let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
185 assert_eq!(batch_request.custom_id, "request-1");
186 assert_eq!(batch_request.body.messages.len(), 2);
187 assert_eq!(batch_request.body.max_tokens, Some(100));
188 }
189
190 #[test]
191 fn test_dataset_iterator() {
192 let test_data = r#"{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 10}}
193{"custom_id": "req-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "World"}], "max_tokens": 20}}"#;
194
195 let mut iter = DatasetLoader::from_string(test_data.to_string());
196
197 let entry1 = iter.next().unwrap().unwrap().unwrap();
198 assert_eq!(entry1.request_id, "req-1");
199 assert_eq!(entry1.messages.len(), 1);
200 assert_eq!(entry1.messages[0].role, "user");
201 assert_eq!(entry1.messages[0].content, "Hello");
202 assert_eq!(entry1.max_output_tokens, Some(10));
203
204 let entry2 = iter.next().unwrap().unwrap().unwrap();
205 assert_eq!(entry2.request_id, "req-2");
206 assert_eq!(entry2.max_output_tokens, Some(20));
207
208 let end_signal = iter.next().unwrap().unwrap();
210 assert!(end_signal.is_none());
211
212 assert!(iter.next().is_none());
214 }
215}