use crate::{TokenizedInput, Tokenizer};
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyTorchConfig {
pub device: String,
pub dtype: TensorDType,
pub max_length: Option<usize>,
pub padding: PaddingStrategy,
pub truncation: TruncationStrategy,
pub return_attention_mask: bool,
pub return_token_type_ids: bool,
pub batch_size: usize,
}
impl Default for PyTorchConfig {
fn default() -> Self {
Self {
device: "cpu".to_string(),
dtype: TensorDType::Int64,
max_length: Some(512),
padding: PaddingStrategy::LongestFirst,
truncation: TruncationStrategy::LongestFirst,
return_attention_mask: true,
return_token_type_ids: false,
batch_size: 32,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TensorDType {
Int32,
Int64,
Float32,
Float64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum PaddingStrategy {
False,
LongestFirst,
MaxLength,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TruncationStrategy {
False,
LongestFirst,
MaxLength,
OnlyFirst,
OnlySecond,
}
#[derive(Debug, Clone)]
pub struct PyTorchTensor {
pub data: Vec<i64>,
pub shape: Vec<usize>,
pub device: String,
pub dtype: TensorDType,
}
impl PyTorchTensor {
pub fn new(data: Vec<i64>, shape: Vec<usize>, device: String, dtype: TensorDType) -> Self {
Self {
data,
shape,
device,
dtype,
}
}
pub fn size(&self) -> Vec<usize> {
self.shape.clone()
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self> {
let new_size: usize = new_shape.iter().product();
if new_size != self.numel() {
return Err(anyhow!("Cannot reshape tensor: size mismatch"));
}
Ok(Self {
data: self.data.clone(),
shape: new_shape,
device: self.device.clone(),
dtype: self.dtype,
})
}
pub fn to_device(&self, device: &str) -> Self {
Self {
data: self.data.clone(),
shape: self.shape.clone(),
device: device.to_string(),
dtype: self.dtype,
}
}
pub fn to_dtype(&self, dtype: TensorDType) -> Self {
Self {
data: self.data.clone(),
shape: self.shape.clone(),
device: self.device.clone(),
dtype,
}
}
}
#[derive(Debug, Clone)]
pub struct PyTorchBatch {
pub input_ids: PyTorchTensor,
pub attention_mask: Option<PyTorchTensor>,
pub token_type_ids: Option<PyTorchTensor>,
pub special_tokens_mask: Option<PyTorchTensor>,
pub sequence_lengths: Vec<usize>,
}
impl PyTorchBatch {
pub fn new(
input_ids: PyTorchTensor,
attention_mask: Option<PyTorchTensor>,
token_type_ids: Option<PyTorchTensor>,
special_tokens_mask: Option<PyTorchTensor>,
sequence_lengths: Vec<usize>,
) -> Self {
Self {
input_ids,
attention_mask,
token_type_ids,
special_tokens_mask,
sequence_lengths,
}
}
pub fn batch_size(&self) -> usize {
self.input_ids.shape[0]
}
pub fn sequence_length(&self) -> usize {
self.input_ids.shape[1]
}
pub fn to_device(&self, device: &str) -> Self {
Self {
input_ids: self.input_ids.to_device(device),
attention_mask: self.attention_mask.as_ref().map(|t| t.to_device(device)),
token_type_ids: self.token_type_ids.as_ref().map(|t| t.to_device(device)),
special_tokens_mask: self.special_tokens_mask.as_ref().map(|t| t.to_device(device)),
sequence_lengths: self.sequence_lengths.clone(),
}
}
pub fn to_dtype(&self, dtype: TensorDType) -> Self {
Self {
input_ids: self.input_ids.to_dtype(dtype),
attention_mask: self.attention_mask.as_ref().map(|t| t.to_dtype(dtype)),
token_type_ids: self.token_type_ids.as_ref().map(|t| t.to_dtype(dtype)),
special_tokens_mask: self.special_tokens_mask.as_ref().map(|t| t.to_dtype(dtype)),
sequence_lengths: self.sequence_lengths.clone(),
}
}
pub fn pin_memory(&self) -> Self {
self.clone()
}
}
pub struct PyTorchTokenizer<T: Tokenizer> {
tokenizer: Arc<T>,
config: PyTorchConfig,
}
impl<T: Tokenizer> PyTorchTokenizer<T> {
pub fn new(tokenizer: T, config: PyTorchConfig) -> Self {
Self {
tokenizer: Arc::new(tokenizer),
config,
}
}
pub fn from_tokenizer(tokenizer: T) -> Self {
Self::new(tokenizer, PyTorchConfig::default())
}
pub fn with_config(mut self, config: PyTorchConfig) -> Self {
self.config = config;
self
}
pub fn encode_to_tensors(&self, text: &str) -> Result<PyTorchBatch> {
let tokenized = self.tokenizer.encode(text)?;
self.convert_to_batch(vec![tokenized])
}
pub fn encode_pair_to_tensors(&self, text_a: &str, text_b: &str) -> Result<PyTorchBatch> {
let tokenized = self.tokenizer.encode_pair(text_a, text_b)?;
self.convert_to_batch(vec![tokenized])
}
pub fn encode_batch_to_tensors(&self, texts: &[String]) -> Result<PyTorchBatch> {
let mut tokenized_batch = Vec::new();
for text in texts {
let tokenized = self.tokenizer.encode(text)?;
tokenized_batch.push(tokenized);
}
self.convert_to_batch(tokenized_batch)
}
pub fn encode_pair_batch_to_tensors(
&self,
text_pairs: &[(String, String)],
) -> Result<PyTorchBatch> {
let mut tokenized_batch = Vec::new();
for (text_a, text_b) in text_pairs {
let tokenized = self.tokenizer.encode_pair(text_a, text_b)?;
tokenized_batch.push(tokenized);
}
self.convert_to_batch(tokenized_batch)
}
fn convert_to_batch(&self, tokenized_inputs: Vec<TokenizedInput>) -> Result<PyTorchBatch> {
if tokenized_inputs.is_empty() {
return Err(anyhow!("Cannot create batch from empty input"));
}
let batch_size = tokenized_inputs.len();
let sequence_lengths: Vec<usize> =
tokenized_inputs.iter().map(|t| t.input_ids.len()).collect();
let seq_length = match self.config.padding {
PaddingStrategy::False => {
let first_len = sequence_lengths[0];
if !sequence_lengths.iter().all(|&len| len == first_len) {
return Err(anyhow!(
"All sequences must be same length when padding is disabled"
));
}
first_len
},
PaddingStrategy::LongestFirst => sequence_lengths.iter().copied().max().unwrap_or(0),
PaddingStrategy::MaxLength => self.config.max_length.unwrap_or(512),
};
let final_seq_length = if let Some(max_len) = self.config.max_length {
match self.config.truncation {
TruncationStrategy::False => seq_length,
_ => seq_length.min(max_len),
}
} else {
seq_length
};
let mut input_ids_data = Vec::with_capacity(batch_size * final_seq_length);
let mut attention_mask_data = Vec::with_capacity(batch_size * final_seq_length);
let mut token_type_ids_data = Vec::with_capacity(batch_size * final_seq_length);
let mut special_tokens_mask_data = Vec::with_capacity(batch_size * final_seq_length);
let pad_token_id = 0i64;
for tokenized in &tokenized_inputs {
let mut seq_input_ids = tokenized.input_ids.clone();
if seq_input_ids.len() > final_seq_length {
seq_input_ids.truncate(final_seq_length);
}
while seq_input_ids.len() < final_seq_length {
seq_input_ids.push(pad_token_id as u32);
}
input_ids_data.extend(seq_input_ids.into_iter().map(|id| id as i64));
if self.config.return_attention_mask {
let actual_length = tokenized.input_ids.len().min(final_seq_length);
for i in 0..final_seq_length {
attention_mask_data.push(if i < actual_length { 1 } else { 0 });
}
}
if self.config.return_token_type_ids {
let token_type_ids = tokenized
.token_type_ids
.clone()
.unwrap_or_else(|| vec![0; tokenized.input_ids.len()]);
let mut seq_token_type_ids = token_type_ids;
if seq_token_type_ids.len() > final_seq_length {
seq_token_type_ids.truncate(final_seq_length);
}
while seq_token_type_ids.len() < final_seq_length {
seq_token_type_ids.push(0);
}
token_type_ids_data.extend(seq_token_type_ids.into_iter().map(|id| id as i64));
}
let special_tokens_mask = tokenized
.special_tokens_mask
.clone()
.unwrap_or_else(|| vec![0; tokenized.input_ids.len()]);
let mut seq_special_tokens_mask = special_tokens_mask;
if seq_special_tokens_mask.len() > final_seq_length {
seq_special_tokens_mask.truncate(final_seq_length);
}
while seq_special_tokens_mask.len() < final_seq_length {
seq_special_tokens_mask.push(0);
}
special_tokens_mask_data
.extend(seq_special_tokens_mask.into_iter().map(|mask| mask as i64));
}
let input_ids = PyTorchTensor::new(
input_ids_data,
vec![batch_size, final_seq_length],
self.config.device.clone(),
self.config.dtype,
);
let attention_mask = if self.config.return_attention_mask {
Some(PyTorchTensor::new(
attention_mask_data,
vec![batch_size, final_seq_length],
self.config.device.clone(),
self.config.dtype,
))
} else {
None
};
let token_type_ids = if self.config.return_token_type_ids {
Some(PyTorchTensor::new(
token_type_ids_data,
vec![batch_size, final_seq_length],
self.config.device.clone(),
self.config.dtype,
))
} else {
None
};
let special_tokens_mask = if special_tokens_mask_data.iter().any(|&mask| mask != 0) {
Some(PyTorchTensor::new(
special_tokens_mask_data,
vec![batch_size, final_seq_length],
self.config.device.clone(),
self.config.dtype,
))
} else {
None
};
Ok(PyTorchBatch::new(
input_ids,
attention_mask,
token_type_ids,
special_tokens_mask,
sequence_lengths,
))
}
pub fn tokenizer(&self) -> &T {
&self.tokenizer
}
pub fn config(&self) -> &PyTorchConfig {
&self.config
}
pub fn set_device(&mut self, device: String) {
self.config.device = device;
}
pub fn set_max_length(&mut self, max_length: Option<usize>) {
self.config.max_length = max_length;
}
pub fn set_padding(&mut self, padding: PaddingStrategy) {
self.config.padding = padding;
}
pub fn set_truncation(&mut self, truncation: TruncationStrategy) {
self.config.truncation = truncation;
}
}
pub struct PyTorchDataset {
texts: Vec<String>,
#[allow(dead_code)]
tokenizer_config: PyTorchConfig,
}
impl PyTorchDataset {
pub fn new(texts: Vec<String>, config: PyTorchConfig) -> Self {
Self {
texts,
tokenizer_config: config,
}
}
pub fn len(&self) -> usize {
self.texts.len()
}
pub fn is_empty(&self) -> bool {
self.texts.is_empty()
}
pub fn get_item(&self, index: usize) -> Option<&str> {
self.texts.get(index).map(|s| s.as_str())
}
pub fn batch_iter(&self, batch_size: usize) -> BatchIterator<'_> {
BatchIterator::new(&self.texts, batch_size)
}
}
pub struct BatchIterator<'a> {
texts: &'a [String],
batch_size: usize,
current_index: usize,
}
impl<'a> BatchIterator<'a> {
fn new(texts: &'a [String], batch_size: usize) -> Self {
Self {
texts,
batch_size,
current_index: 0,
}
}
}
impl<'a> Iterator for BatchIterator<'a> {
type Item = &'a [String];
fn next(&mut self) -> Option<Self::Item> {
if self.current_index >= self.texts.len() {
return None;
}
let end_index = (self.current_index + self.batch_size).min(self.texts.len());
let batch = &self.texts[self.current_index..end_index];
self.current_index = end_index;
Some(batch)
}
}
pub struct PyTorchUtils;
impl PyTorchUtils {
pub fn tensor_to_debug_string(tensor: &PyTorchTensor) -> String {
format!(
"PyTorchTensor(shape={:?}, device={}, dtype={:?}, data={:?})",
tensor.shape,
tensor.device,
tensor.dtype,
&tensor.data[..tensor.data.len().min(10)] )
}
pub fn tensor_memory_usage(tensor: &PyTorchTensor) -> usize {
let element_size = match tensor.dtype {
TensorDType::Int32 | TensorDType::Float32 => 4,
TensorDType::Int64 | TensorDType::Float64 => 8,
};
tensor.numel() * element_size
}
pub fn collate_fn<T: Tokenizer>(
tokenizer: &PyTorchTokenizer<T>,
texts: Vec<String>,
) -> Result<PyTorchBatch> {
tokenizer.encode_batch_to_tensors(&texts)
}
pub fn validate_model_inputs(batch: &PyTorchBatch) -> Result<()> {
let batch_size = batch.batch_size();
let seq_length = batch.sequence_length();
if batch.input_ids.shape != vec![batch_size, seq_length] {
return Err(anyhow!("Invalid input_ids shape"));
}
if let Some(ref mask) = batch.attention_mask {
if mask.shape != vec![batch_size, seq_length] {
return Err(anyhow!("Invalid attention_mask shape"));
}
}
if let Some(ref type_ids) = batch.token_type_ids {
if type_ids.shape != vec![batch_size, seq_length] {
return Err(anyhow!("Invalid token_type_ids shape"));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::char::CharTokenizer;
use std::collections::HashMap;
fn create_test_char_tokenizer() -> CharTokenizer {
let mut vocab = HashMap::new();
vocab.insert("[PAD]".to_string(), 0);
vocab.insert("[UNK]".to_string(), 1);
vocab.insert("[CLS]".to_string(), 2);
vocab.insert("[SEP]".to_string(), 3);
vocab.insert("h".to_string(), 4);
vocab.insert("e".to_string(), 5);
vocab.insert("l".to_string(), 6);
vocab.insert("o".to_string(), 7);
vocab.insert("w".to_string(), 8);
vocab.insert("r".to_string(), 9);
vocab.insert("d".to_string(), 10);
vocab.insert(" ".to_string(), 11);
vocab.insert("t".to_string(), 12);
vocab.insert("s".to_string(), 13);
CharTokenizer::new(vocab)
}
#[test]
fn test_pytorch_config() {
let config = PyTorchConfig::default();
assert_eq!(config.device, "cpu");
assert_eq!(config.max_length, Some(512));
assert!(config.return_attention_mask);
assert!(!config.return_token_type_ids);
}
#[test]
fn test_pytorch_tensor() {
let data = vec![1, 2, 3, 4];
let shape = vec![2, 2];
let tensor = PyTorchTensor::new(
data.clone(),
shape.clone(),
"cpu".to_string(),
TensorDType::Int64,
);
assert_eq!(tensor.data, data);
assert_eq!(tensor.shape, shape);
assert_eq!(tensor.device, "cpu");
assert_eq!(tensor.numel(), 4);
}
#[test]
fn test_tensor_reshape() {
let data = vec![1, 2, 3, 4, 5, 6];
let tensor = PyTorchTensor::new(data, vec![2, 3], "cpu".to_string(), TensorDType::Int64);
let reshaped = tensor.reshape(vec![3, 2]).expect("Operation failed in test");
assert_eq!(reshaped.shape, vec![3, 2]);
assert_eq!(reshaped.numel(), 6);
}
#[test]
fn test_pytorch_tokenizer() {
let tokenizer = create_test_char_tokenizer();
let pytorch_tokenizer = PyTorchTokenizer::from_tokenizer(tokenizer);
let batch = pytorch_tokenizer.encode_to_tensors("hello").expect("Operation failed in test");
assert_eq!(batch.batch_size(), 1);
assert!(batch.attention_mask.is_some());
}
#[test]
fn test_batch_encoding() {
let tokenizer = create_test_char_tokenizer();
let pytorch_tokenizer = PyTorchTokenizer::from_tokenizer(tokenizer);
let texts = vec!["hello".to_string(), "world".to_string()];
let batch = pytorch_tokenizer
.encode_batch_to_tensors(&texts)
.expect("Operation failed in test");
assert_eq!(batch.batch_size(), 2);
assert!(batch.attention_mask.is_some());
assert_eq!(batch.sequence_lengths.len(), 2);
}
#[test]
fn test_pytorch_dataset() {
let texts = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
let config = PyTorchConfig::default();
let dataset = PyTorchDataset::new(texts, config);
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get_item(0), Some("hello"));
let batches: Vec<_> = dataset.batch_iter(2).collect();
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 2);
assert_eq!(batches[1].len(), 1);
}
#[test]
fn test_tensor_utilities() {
let tensor = PyTorchTensor::new(
vec![1, 2, 3, 4],
vec![2, 2],
"cpu".to_string(),
TensorDType::Int64,
);
let debug_str = PyTorchUtils::tensor_to_debug_string(&tensor);
assert!(debug_str.contains("shape=[2, 2]"));
assert!(debug_str.contains("device=cpu"));
let memory_usage = PyTorchUtils::tensor_memory_usage(&tensor);
assert_eq!(memory_usage, 4 * 8); }
#[test]
fn test_padding_strategies() {
let tokenizer = create_test_char_tokenizer();
let mut config = PyTorchConfig::default();
config.padding = PaddingStrategy::MaxLength;
config.max_length = Some(10);
let pytorch_tokenizer = PyTorchTokenizer::new(tokenizer, config);
let texts = vec!["hi".to_string(), "hello world".to_string()];
let batch = pytorch_tokenizer
.encode_batch_to_tensors(&texts)
.expect("Operation failed in test");
assert_eq!(batch.sequence_length(), 10);
assert_eq!(batch.batch_size(), 2);
}
}