Skip to main content

inference_lab/
dataset.rs

1use serde::{Deserialize, Serialize};
2use std::fs::File;
3use std::io::{BufRead, BufReader, Lines};
4use std::path::Path;
5
6/// A tokenizer function that takes messages and returns tokenized output.
7/// This allows different implementations (tiktoken, transformers.js, etc.)
8/// to be passed in from the CLI or WASM interface.
9/// The tokenizer should apply the appropriate chat template.
10pub type TokenizerFn = Box<dyn Fn(&[Message]) -> Result<Vec<u32>, String> + Send + Sync>;
11
12/// A batch tokenizer function that takes multiple message arrays and returns multiple token vectors.
13/// This is much faster than tokenizing one at a time.
14pub type BatchTokenizerFn = Box<dyn Fn(&[&[Message]]) -> Result<Vec<Vec<u32>>, String> + Send>;
15
16/// OpenAI Batch API format - JSONL entries
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchRequest {
19    pub custom_id: String,
20    pub method: String,
21    pub url: String,
22    pub body: RequestBody,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct RequestBody {
27    pub model: String,
28    pub messages: Vec<Message>,
29    #[serde(default)]
30    pub max_tokens: Option<u32>,
31    #[serde(default)]
32    pub temperature: Option<f32>,
33    #[serde(default)]
34    pub top_p: Option<f32>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Message {
39    pub role: String,
40    pub content: String,
41}
42
43/// A processed dataset entry ready for simulation
44#[derive(Debug, Clone)]
45pub struct DatasetEntry {
46    pub request_id: String,
47    pub prompt_tokens: Vec<u32>,
48    pub max_output_tokens: Option<u32>,
49}
50
51impl DatasetEntry {
52    pub fn num_prompt_tokens(&self) -> u32 {
53        self.prompt_tokens.len() as u32
54    }
55}
56
57/// Unparsed entry from dataset (before tokenization)
58#[derive(Debug, Clone)]
59pub struct UnparsedEntry {
60    pub request_id: String,
61    pub messages: Vec<Message>,
62    pub max_output_tokens: Option<u32>,
63}
64
65/// Iterator over dataset entries, parsing JSON but NOT tokenizing
66/// Tokenization happens in batches in the background thread for performance
67pub struct DatasetIterator<R: BufRead> {
68    lines: Lines<R>,
69    line_num: usize,
70    sent_end_signal: bool,
71}
72
73impl<R: BufRead> DatasetIterator<R> {
74    pub fn new(reader: R) -> Self {
75        Self {
76            lines: reader.lines(),
77            line_num: 0,
78            sent_end_signal: false,
79        }
80    }
81}
82
83impl<R: BufRead> Iterator for DatasetIterator<R> {
84    type Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>;
85
86    fn next(&mut self) -> Option<Self::Item> {
87        loop {
88            self.line_num += 1;
89            let line = match self.lines.next() {
90                Some(Ok(line)) => line,
91                Some(Err(e)) => return Some(Err(Box::new(e))),
92                None => {
93                    // End of dataset - signal completion with Ok(None) once, then end iterator
94                    if !self.sent_end_signal {
95                        self.sent_end_signal = true;
96                        return Some(Ok(None));
97                    } else {
98                        return None;
99                    }
100                }
101            };
102
103            // Skip empty lines
104            if line.trim().is_empty() {
105                continue;
106            }
107
108            // Parse the batch request (but don't tokenize yet - that happens in batches)
109            let batch_request: BatchRequest = match serde_json::from_str(&line) {
110                Ok(req) => req,
111                Err(e) => {
112                    return Some(Err(format!(
113                        "Failed to parse line {}: {}",
114                        self.line_num, e
115                    )
116                    .into()))
117                }
118            };
119
120            return Some(Ok(Some(UnparsedEntry {
121                request_id: batch_request.custom_id,
122                messages: batch_request.body.messages,
123                max_output_tokens: batch_request.body.max_tokens,
124            })));
125        }
126    }
127}
128
129/// Dataset loader that provides lazy iteration over entries
130pub struct DatasetLoader;
131
132impl DatasetLoader {
133    /// Count the number of non-empty lines in a JSONL file (fast approximation of entry count)
134    pub fn count_entries<P: AsRef<Path>>(path: P) -> Result<usize, std::io::Error> {
135        let file = File::open(path)?;
136        let reader = BufReader::new(file);
137        let count = reader
138            .lines()
139            .filter_map(|line| line.ok())
140            .filter(|line| !line.trim().is_empty())
141            .count();
142        Ok(count)
143    }
144
145    /// Create an iterator from a JSONL file in OpenAI batch API format
146    /// Returns unparsed entries (without tokenization)
147    pub fn from_file<P: AsRef<Path>>(
148        path: P,
149    ) -> Result<DatasetIterator<BufReader<File>>, std::io::Error> {
150        let file = File::open(path)?;
151        let reader = BufReader::new(file);
152        Ok(DatasetIterator::new(reader))
153    }
154
155    /// Create an iterator from a string (useful for testing or WASM)
156    /// Returns unparsed entries (without tokenization)
157    pub fn from_string(data: String) -> DatasetIterator<BufReader<std::io::Cursor<String>>> {
158        let cursor = std::io::Cursor::new(data);
159        let reader = BufReader::new(cursor);
160        DatasetIterator::new(reader)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_parse_batch_request() {
170        let json = r#"{
171            "custom_id": "request-1",
172            "method": "POST",
173            "url": "/v1/chat/completions",
174            "body": {
175                "model": "gpt-3.5-turbo",
176                "messages": [
177                    {"role": "system", "content": "You are a helpful assistant."},
178                    {"role": "user", "content": "Hello, how are you?"}
179                ],
180                "max_tokens": 100
181            }
182        }"#;
183
184        let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
185        assert_eq!(batch_request.custom_id, "request-1");
186        assert_eq!(batch_request.body.messages.len(), 2);
187        assert_eq!(batch_request.body.max_tokens, Some(100));
188    }
189
190    #[test]
191    fn test_dataset_iterator() {
192        let test_data = r#"{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 10}}
193{"custom_id": "req-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "World"}], "max_tokens": 20}}"#;
194
195        let mut iter = DatasetLoader::from_string(test_data.to_string());
196
197        let entry1 = iter.next().unwrap().unwrap().unwrap();
198        assert_eq!(entry1.request_id, "req-1");
199        assert_eq!(entry1.messages.len(), 1);
200        assert_eq!(entry1.messages[0].role, "user");
201        assert_eq!(entry1.messages[0].content, "Hello");
202        assert_eq!(entry1.max_output_tokens, Some(10));
203
204        let entry2 = iter.next().unwrap().unwrap().unwrap();
205        assert_eq!(entry2.request_id, "req-2");
206        assert_eq!(entry2.max_output_tokens, Some(20));
207
208        // Should get Ok(None) signaling end of dataset
209        let end_signal = iter.next().unwrap().unwrap();
210        assert!(end_signal.is_none());
211
212        // After that, iterator itself should be exhausted
213        assert!(iter.next().is_none());
214    }
215}