#[cfg(feature = "std")]
use std::fs::File;
#[cfg(feature = "std")]
use std::io::{Read, Result as IOResult, Write};
use alloc::string::{String, ToString};
use alloc::vec::Vec;
#[cfg(feature = "convert-detect")]
use crate::convert::ConversionError;
use crate::{Definition, InitializationError, Kitoken};
const MAGIC: &[u8] = b"kitoken";
const VERSION: &[u8] = &[0, 1];
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum DeserializationError {
#[error("{0}")]
InvalidData(String),
#[error("{0}")]
InitializationError(InitializationError),
#[cfg(feature = "std")]
#[error("{0}")]
IOError(#[from] std::io::Error),
}
impl From<InitializationError> for DeserializationError {
fn from(e: InitializationError) -> Self {
Self::InitializationError(e)
}
}
impl Definition {
#[cfg(feature = "std")]
pub fn from_reader<R: Read>(reader: &mut R) -> Result<Self, DeserializationError> {
let data = {
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
data
};
Self::from_slice(&data)
}
#[cfg(feature = "std")]
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, DeserializationError> {
let mut file = File::open(path)?;
Self::from_reader(&mut file)
}
#[cfg(not(feature = "convert-detect"))]
pub fn from_slice(slice: &[u8]) -> Result<Self, DeserializationError> {
if slice.len() < MAGIC.len() + VERSION.len() {
return Err(DeserializationError::InvalidData("invalid size".to_string()));
}
if &slice[..MAGIC.len()] != MAGIC {
return Err(DeserializationError::InvalidData("invalid magic".to_string()));
}
if &slice[MAGIC.len()..MAGIC.len() + VERSION.len()] != VERSION {
return Err(DeserializationError::InvalidData("invalid version".to_string()));
}
let definition = postcard::from_bytes(&slice[MAGIC.len() + VERSION.len()..])
.map_err(|e| DeserializationError::InvalidData(e.to_string()))?;
Ok(definition)
}
#[cfg(feature = "convert-detect")]
pub fn from_slice(slice: &[u8]) -> Result<Self, DeserializationError> {
let formats = &[
|slice: &[u8]| {
if slice.len() < MAGIC.len() + VERSION.len() {
return Err(ConversionError::InvalidData("invalid size".to_string()));
}
if &slice[..MAGIC.len()] != MAGIC {
return Err(ConversionError::InvalidData("invalid magic".to_string()));
}
if &slice[MAGIC.len()..MAGIC.len() + VERSION.len()] != VERSION {
return Err(ConversionError::InvalidData("invalid version".to_string()));
}
postcard::from_bytes(&slice[MAGIC.len() + VERSION.len()..])
.map_err(|e| ConversionError::InvalidData(e.to_string()))
},
#[cfg(feature = "convert-tiktoken")]
Definition::from_tiktoken_slice,
#[cfg(feature = "convert-sentencepiece")]
Definition::from_sentencepiece_slice,
#[cfg(feature = "convert-tokenizers")]
Definition::from_tokenizers_slice,
#[cfg(feature = "convert-tekken")]
Definition::from_tekken_slice,
];
formats
.iter()
.find_map(|f| f(slice).ok())
.ok_or_else(|| DeserializationError::InvalidData("unknown format".to_string()))
}
#[cfg(feature = "std")]
pub fn to_writer<W: Write>(&self, writer: &mut W) -> IOResult<()> {
writer.write_all(MAGIC)?;
writer.write_all(VERSION)?;
let data = postcard::to_allocvec(self).unwrap();
writer.write_all(&data)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn to_file<P: AsRef<std::path::Path>>(&self, path: P) -> IOResult<()> {
let mut file = File::create(path)?;
self.to_writer(&mut file)
}
pub fn to_vec(&self) -> Vec<u8> {
let data = postcard::to_allocvec(self).unwrap();
let mut vec = Vec::with_capacity(MAGIC.len() + VERSION.len() + data.len());
vec.extend_from_slice(MAGIC);
vec.extend_from_slice(VERSION);
vec.extend_from_slice(&data);
vec
}
}
impl Kitoken {
#[cfg(feature = "std")]
pub fn from_reader<R: Read>(reader: &mut R) -> Result<Self, DeserializationError> {
let definition = Definition::from_reader(reader)?;
Ok(Self::from_definition(definition)?)
}
#[cfg(feature = "std")]
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, DeserializationError> {
let definition = Definition::from_file(path)?;
Ok(Self::from_definition(definition)?)
}
pub fn from_slice(slice: &[u8]) -> Result<Self, DeserializationError> {
let definition = Definition::from_slice(slice)?;
Ok(Self::from_definition(definition)?)
}
#[cfg(feature = "std")]
pub fn to_writer<W: Write>(&self, writer: &mut W) -> IOResult<()> {
let definition = self.to_definition();
definition.to_writer(writer)
}
#[cfg(feature = "std")]
pub fn to_file<P: AsRef<std::path::Path>>(&self, path: P) -> IOResult<()> {
let definition = self.to_definition();
definition.to_file(path)
}
pub fn to_vec(&self) -> Vec<u8> {
let definition = self.to_definition();
definition.to_vec()
}
}