Skip to main content

inference_lab/
dataset.rs

1use serde::{Deserialize, Serialize};
2use std::fs::File;
3use std::io::{BufRead, BufReader, Lines};
4use std::path::Path;
5
6/// A tokenizer function that takes either chat messages or a raw prompt and returns tokenized output.
7/// This allows different implementations (tiktoken, transformers.js, etc.)
8/// to be passed in from the CLI or WASM interface.
9/// The tokenizer should apply the appropriate chat template for chat-style requests.
10pub type TokenizerFn = Box<dyn Fn(&PromptInput) -> Result<Vec<u32>, String> + Send + Sync>;
11
12/// A batch tokenizer function that takes multiple prompt inputs and returns multiple token vectors.
13/// This is much faster than tokenizing one at a time.
14pub type BatchTokenizerFn = Box<dyn Fn(&[PromptInput]) -> Result<Vec<Vec<u32>>, String> + Send>;
15
16/// OpenAI Batch API format - JSONL entries
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchRequest {
19    pub custom_id: String,
20    pub method: String,
21    pub url: String,
22    pub body: RequestBody,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct RequestBody {
27    pub model: String,
28    #[serde(default)]
29    pub max_tokens: Option<u32>,
30    #[serde(default)]
31    pub temperature: Option<f32>,
32    #[serde(default)]
33    pub top_p: Option<f32>,
34    #[serde(flatten)]
35    pub input: RequestInput,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum RequestInput {
41    Chat { messages: Vec<Message> },
42    Completion { prompt: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Message {
47    pub role: String,
48    pub content: String,
49}
50
51#[derive(Debug, Clone)]
52pub enum PromptInput {
53    Messages(Vec<Message>),
54    Prompt(String),
55}
56
57/// A processed dataset entry ready for simulation
58#[derive(Debug, Clone)]
59pub struct DatasetEntry {
60    pub request_id: String,
61    pub prompt_tokens: Vec<u32>,
62    pub max_output_tokens: Option<u32>,
63}
64
65impl DatasetEntry {
66    pub fn num_prompt_tokens(&self) -> u32 {
67        self.prompt_tokens.len() as u32
68    }
69}
70
71/// Unparsed entry from dataset (before tokenization)
72#[derive(Debug, Clone)]
73pub struct UnparsedEntry {
74    pub request_id: String,
75    pub prompt_input: PromptInput,
76    pub max_output_tokens: Option<u32>,
77}
78
79/// Iterator over dataset entries, parsing JSON but NOT tokenizing
80/// Tokenization happens in batches in the background thread for performance
81pub struct DatasetIterator<R: BufRead> {
82    lines: Lines<R>,
83    line_num: usize,
84    sent_end_signal: bool,
85}
86
87impl<R: BufRead> DatasetIterator<R> {
88    pub fn new(reader: R) -> Self {
89        Self {
90            lines: reader.lines(),
91            line_num: 0,
92            sent_end_signal: false,
93        }
94    }
95}
96
97impl<R: BufRead> Iterator for DatasetIterator<R> {
98    type Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>;
99
100    fn next(&mut self) -> Option<Self::Item> {
101        loop {
102            self.line_num += 1;
103            let line = match self.lines.next() {
104                Some(Ok(line)) => line,
105                Some(Err(e)) => return Some(Err(Box::new(e))),
106                None => {
107                    // End of dataset - signal completion with Ok(None) once, then end iterator
108                    if !self.sent_end_signal {
109                        self.sent_end_signal = true;
110                        return Some(Ok(None));
111                    } else {
112                        return None;
113                    }
114                }
115            };
116
117            // Skip empty lines
118            if line.trim().is_empty() {
119                continue;
120            }
121
122            // Parse the batch request (but don't tokenize yet - that happens in batches)
123            let batch_request: BatchRequest = match serde_json::from_str(&line) {
124                Ok(req) => req,
125                Err(e) => {
126                    return Some(Err(format!(
127                        "Failed to parse line {}: {}",
128                        self.line_num, e
129                    )
130                    .into()))
131                }
132            };
133
134            return Some(Ok(Some(UnparsedEntry {
135                request_id: batch_request.custom_id,
136                prompt_input: match batch_request.body.input {
137                    RequestInput::Chat { messages } => PromptInput::Messages(messages),
138                    RequestInput::Completion { prompt } => PromptInput::Prompt(prompt),
139                },
140                max_output_tokens: batch_request.body.max_tokens,
141            })));
142        }
143    }
144}
145
146/// Dataset loader that provides lazy iteration over entries
147pub struct DatasetLoader;
148
149impl DatasetLoader {
150    fn count_non_empty_lines<R: BufRead>(reader: R) -> Result<usize, std::io::Error> {
151        let mut count = 0;
152        for line in reader.lines() {
153            if !line?.trim().is_empty() {
154                count += 1;
155            }
156        }
157        Ok(count)
158    }
159
160    /// Count the number of non-empty lines in a JSONL file (fast approximation of entry count)
161    pub fn count_entries<P: AsRef<Path>>(path: P) -> Result<usize, std::io::Error> {
162        let file = File::open(path)?;
163        let reader = BufReader::new(file);
164        Self::count_non_empty_lines(reader)
165    }
166
167    /// Create an iterator from a JSONL file in OpenAI batch API format
168    /// Returns unparsed entries (without tokenization)
169    pub fn from_file<P: AsRef<Path>>(
170        path: P,
171    ) -> Result<DatasetIterator<BufReader<File>>, std::io::Error> {
172        let file = File::open(path)?;
173        let reader = BufReader::new(file);
174        Ok(DatasetIterator::new(reader))
175    }
176
177    /// Create an iterator from a string (useful for testing or WASM)
178    /// Returns unparsed entries (without tokenization)
179    pub fn from_string(data: String) -> DatasetIterator<BufReader<std::io::Cursor<String>>> {
180        let cursor = std::io::Cursor::new(data);
181        let reader = BufReader::new(cursor);
182        DatasetIterator::new(reader)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::io::{self, Read};
190
191    #[test]
192    fn test_parse_batch_request() {
193        let json = r#"{
194            "custom_id": "request-1",
195            "method": "POST",
196            "url": "/v1/chat/completions",
197            "body": {
198                "model": "gpt-3.5-turbo",
199                "messages": [
200                    {"role": "system", "content": "You are a helpful assistant."},
201                    {"role": "user", "content": "Hello, how are you?"}
202                ],
203                "max_tokens": 100
204            }
205        }"#;
206
207        let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
208        assert_eq!(batch_request.custom_id, "request-1");
209        match batch_request.body.input {
210            RequestInput::Chat { messages } => assert_eq!(messages.len(), 2),
211            RequestInput::Completion { .. } => panic!("expected chat request"),
212        }
213        assert_eq!(batch_request.body.max_tokens, Some(100));
214    }
215
216    #[test]
217    fn test_dataset_iterator() {
218        let test_data = r#"{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 10}}
219{"custom_id": "req-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "World"}], "max_tokens": 20}}"#;
220
221        let mut iter = DatasetLoader::from_string(test_data.to_string());
222
223        let entry1 = iter.next().unwrap().unwrap().unwrap();
224        assert_eq!(entry1.request_id, "req-1");
225        match entry1.prompt_input {
226            PromptInput::Messages(messages) => {
227                assert_eq!(messages.len(), 1);
228                assert_eq!(messages[0].role, "user");
229                assert_eq!(messages[0].content, "Hello");
230            }
231            PromptInput::Prompt(_) => panic!("expected chat prompt input"),
232        }
233        assert_eq!(entry1.max_output_tokens, Some(10));
234
235        let entry2 = iter.next().unwrap().unwrap().unwrap();
236        assert_eq!(entry2.request_id, "req-2");
237        assert_eq!(entry2.max_output_tokens, Some(20));
238
239        // Should get Ok(None) signaling end of dataset
240        let end_signal = iter.next().unwrap().unwrap();
241        assert!(end_signal.is_none());
242
243        // After that, iterator itself should be exhausted
244        assert!(iter.next().is_none());
245    }
246
247    #[test]
248    fn test_parse_completion_batch_request() {
249        let json = r#"{
250            "custom_id": "request-2",
251            "method": "POST",
252            "url": "/v1/completions",
253            "body": {
254                "model": "gpt-3.5-turbo-instruct",
255                "prompt": "Hello, world",
256                "max_tokens": 32
257            }
258        }"#;
259
260        let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
261        assert_eq!(batch_request.custom_id, "request-2");
262        match batch_request.body.input {
263            RequestInput::Completion { prompt } => assert_eq!(prompt, "Hello, world"),
264            RequestInput::Chat { .. } => panic!("expected completion request"),
265        }
266        assert_eq!(batch_request.body.max_tokens, Some(32));
267    }
268
269    #[test]
270    fn test_count_entries_propagates_read_errors() {
271        struct FailingReader {
272            bytes: Vec<u8>,
273            pos: usize,
274            fail_after: usize,
275        }
276
277        impl Read for FailingReader {
278            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
279                if self.pos >= self.fail_after {
280                    return Err(io::Error::other("forced read failure"));
281                }
282
283                let remaining_until_failure = self.fail_after.saturating_sub(self.pos);
284                let remaining_bytes = self.bytes.len().saturating_sub(self.pos);
285                let to_copy = buf.len().min(remaining_until_failure).min(remaining_bytes);
286
287                if to_copy == 0 {
288                    return Ok(0);
289                }
290
291                buf[..to_copy].copy_from_slice(&self.bytes[self.pos..self.pos + to_copy]);
292                self.pos += to_copy;
293                Ok(to_copy)
294            }
295        }
296
297        let reader = BufReader::new(FailingReader {
298            bytes: b"first\nsecond\n".to_vec(),
299            pos: 0,
300            fail_after: 7,
301        });
302
303        let err = DatasetLoader::count_non_empty_lines(reader).unwrap_err();
304        assert_eq!(err.kind(), io::ErrorKind::Other);
305    }
306}