Skip to main content

inference_lab/
dataset.rs

1use serde::{Deserialize, Serialize};
2use std::fs::File;
3use std::io::{BufRead, BufReader, Lines};
4use std::path::Path;
5
6/// A tokenizer function that takes either chat messages or a raw prompt and returns tokenized output.
7/// This allows different implementations (tiktoken, transformers.js, etc.)
8/// to be passed in from the CLI or WASM interface.
9/// The tokenizer should apply the appropriate chat template for chat-style requests.
10pub type TokenizerFn = Box<dyn Fn(&PromptInput) -> Result<Vec<u32>, String> + Send + Sync>;
11
12/// A batch tokenizer function that takes multiple prompt inputs and returns multiple token vectors.
13/// This is much faster than tokenizing one at a time.
14pub type BatchTokenizerFn = Box<dyn Fn(&[PromptInput]) -> Result<Vec<Vec<u32>>, String> + Send>;
15
16/// OpenAI Batch API format - JSONL entries
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchRequest {
19    pub custom_id: String,
20    pub method: String,
21    pub url: String,
22    pub body: RequestBody,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct RequestBody {
27    pub model: String,
28    #[serde(default)]
29    pub max_tokens: Option<u32>,
30    #[serde(default)]
31    pub temperature: Option<f32>,
32    #[serde(default)]
33    pub top_p: Option<f32>,
34    #[serde(flatten)]
35    pub input: RequestInput,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum RequestInput {
41    Chat { messages: Vec<Message> },
42    Completion { prompt: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Message {
47    pub role: String,
48    pub content: String,
49}
50
51#[derive(Debug, Clone)]
52pub enum PromptInput {
53    Messages(Vec<Message>),
54    Prompt(String),
55}
56
57/// A processed dataset entry ready for simulation
58#[derive(Debug, Clone)]
59pub struct DatasetEntry {
60    pub request_id: String,
61    pub prompt_tokens: Vec<u32>,
62    pub max_output_tokens: Option<u32>,
63}
64
65impl DatasetEntry {
66    pub fn num_prompt_tokens(&self) -> u32 {
67        self.prompt_tokens.len() as u32
68    }
69}
70
71/// Unparsed entry from dataset (before tokenization)
72#[derive(Debug, Clone)]
73pub struct UnparsedEntry {
74    pub request_id: String,
75    pub prompt_input: PromptInput,
76    pub max_output_tokens: Option<u32>,
77}
78
79/// Iterator over dataset entries, parsing JSON but NOT tokenizing
80/// Tokenization happens in batches in the background thread for performance
81pub struct DatasetIterator<R: BufRead> {
82    lines: Lines<R>,
83    line_num: usize,
84    sent_end_signal: bool,
85}
86
87impl<R: BufRead> DatasetIterator<R> {
88    pub fn new(reader: R) -> Self {
89        Self {
90            lines: reader.lines(),
91            line_num: 0,
92            sent_end_signal: false,
93        }
94    }
95}
96
97impl<R: BufRead> Iterator for DatasetIterator<R> {
98    type Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>;
99
100    fn next(&mut self) -> Option<Self::Item> {
101        loop {
102            self.line_num += 1;
103            let line = match self.lines.next() {
104                Some(Ok(line)) => line,
105                Some(Err(e)) => return Some(Err(Box::new(e))),
106                None => {
107                    // End of dataset - signal completion with Ok(None) once, then end iterator
108                    if !self.sent_end_signal {
109                        self.sent_end_signal = true;
110                        return Some(Ok(None));
111                    } else {
112                        return None;
113                    }
114                }
115            };
116
117            // Skip empty lines
118            if line.trim().is_empty() {
119                continue;
120            }
121
122            // Parse the batch request (but don't tokenize yet - that happens in batches)
123            let batch_request: BatchRequest = match serde_json::from_str(&line) {
124                Ok(req) => req,
125                Err(e) => {
126                    return Some(Err(format!(
127                        "Failed to parse line {}: {}",
128                        self.line_num, e
129                    )
130                    .into()))
131                }
132            };
133
134            return Some(Ok(Some(UnparsedEntry {
135                request_id: batch_request.custom_id,
136                prompt_input: match batch_request.body.input {
137                    RequestInput::Chat { messages } => PromptInput::Messages(messages),
138                    RequestInput::Completion { prompt } => PromptInput::Prompt(prompt),
139                },
140                max_output_tokens: batch_request.body.max_tokens,
141            })));
142        }
143    }
144}
145
146/// Dataset loader that provides lazy iteration over entries
147pub struct DatasetLoader;
148
149impl DatasetLoader {
150    /// Count the number of non-empty lines in a JSONL file (fast approximation of entry count)
151    pub fn count_entries<P: AsRef<Path>>(path: P) -> Result<usize, std::io::Error> {
152        let file = File::open(path)?;
153        let reader = BufReader::new(file);
154        let count = reader
155            .lines()
156            .filter_map(|line| line.ok())
157            .filter(|line| !line.trim().is_empty())
158            .count();
159        Ok(count)
160    }
161
162    /// Create an iterator from a JSONL file in OpenAI batch API format
163    /// Returns unparsed entries (without tokenization)
164    pub fn from_file<P: AsRef<Path>>(
165        path: P,
166    ) -> Result<DatasetIterator<BufReader<File>>, std::io::Error> {
167        let file = File::open(path)?;
168        let reader = BufReader::new(file);
169        Ok(DatasetIterator::new(reader))
170    }
171
172    /// Create an iterator from a string (useful for testing or WASM)
173    /// Returns unparsed entries (without tokenization)
174    pub fn from_string(data: String) -> DatasetIterator<BufReader<std::io::Cursor<String>>> {
175        let cursor = std::io::Cursor::new(data);
176        let reader = BufReader::new(cursor);
177        DatasetIterator::new(reader)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_parse_batch_request() {
187        let json = r#"{
188            "custom_id": "request-1",
189            "method": "POST",
190            "url": "/v1/chat/completions",
191            "body": {
192                "model": "gpt-3.5-turbo",
193                "messages": [
194                    {"role": "system", "content": "You are a helpful assistant."},
195                    {"role": "user", "content": "Hello, how are you?"}
196                ],
197                "max_tokens": 100
198            }
199        }"#;
200
201        let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
202        assert_eq!(batch_request.custom_id, "request-1");
203        match batch_request.body.input {
204            RequestInput::Chat { messages } => assert_eq!(messages.len(), 2),
205            RequestInput::Completion { .. } => panic!("expected chat request"),
206        }
207        assert_eq!(batch_request.body.max_tokens, Some(100));
208    }
209
210    #[test]
211    fn test_dataset_iterator() {
212        let test_data = r#"{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 10}}
213{"custom_id": "req-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "World"}], "max_tokens": 20}}"#;
214
215        let mut iter = DatasetLoader::from_string(test_data.to_string());
216
217        let entry1 = iter.next().unwrap().unwrap().unwrap();
218        assert_eq!(entry1.request_id, "req-1");
219        match entry1.prompt_input {
220            PromptInput::Messages(messages) => {
221                assert_eq!(messages.len(), 1);
222                assert_eq!(messages[0].role, "user");
223                assert_eq!(messages[0].content, "Hello");
224            }
225            PromptInput::Prompt(_) => panic!("expected chat prompt input"),
226        }
227        assert_eq!(entry1.max_output_tokens, Some(10));
228
229        let entry2 = iter.next().unwrap().unwrap().unwrap();
230        assert_eq!(entry2.request_id, "req-2");
231        assert_eq!(entry2.max_output_tokens, Some(20));
232
233        // Should get Ok(None) signaling end of dataset
234        let end_signal = iter.next().unwrap().unwrap();
235        assert!(end_signal.is_none());
236
237        // After that, iterator itself should be exhausted
238        assert!(iter.next().is_none());
239    }
240
241    #[test]
242    fn test_parse_completion_batch_request() {
243        let json = r#"{
244            "custom_id": "request-2",
245            "method": "POST",
246            "url": "/v1/completions",
247            "body": {
248                "model": "gpt-3.5-turbo-instruct",
249                "prompt": "Hello, world",
250                "max_tokens": 32
251            }
252        }"#;
253
254        let batch_request: BatchRequest = serde_json::from_str(json).unwrap();
255        assert_eq!(batch_request.custom_id, "request-2");
256        match batch_request.body.input {
257            RequestInput::Completion { prompt } => assert_eq!(prompt, "Hello, world"),
258            RequestInput::Chat { .. } => panic!("expected completion request"),
259        }
260        assert_eq!(batch_request.body.max_tokens, Some(32));
261    }
262}