mod lattice;
mod model;
mod processor;
mod task;
mod tokenizer;
mod trie;
use base64::{engine::general_purpose::STANDARD_NO_PAD as BASE64_STANDARD, Engine};
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
pub use lattice::*;
pub use model::*;
pub use processor::*;
pub use task::*;
pub use tokenizer::*;
pub use trie::*;
pub type TokenID = u32;
pub type Token = Vec<u8>;
#[derive(Clone, PartialEq)]
pub struct ScoredToken {
pub value: Token,
pub score: f64,
pub keep: bool,
}
impl ScoredToken {
pub fn new(value: Token, score: f64, keep: bool) -> Self {
Self { value, score, keep }
}
pub fn from_str(value: &str, score: f64, keep: bool) -> Self {
Self {
value: value.as_bytes().to_vec(),
score,
keep,
}
}
pub fn from_u8(value: u8, score: f64, keep: bool) -> Self {
Self {
value: vec![value],
score,
keep,
}
}
pub fn clone_with_score(&self, score: f64) -> Self {
Self {
value: self.value.clone(),
score,
keep: self.keep,
}
}
pub fn clone_with_keep(&self, keep: bool) -> Self {
Self {
value: self.value.clone(),
score: self.score,
keep,
}
}
pub fn len(&self) -> usize {
self.value.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl std::fmt::Debug for ScoredToken {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if let Ok(s) = std::str::from_utf8(&self.value) {
write!(f, "ScoredToken({:?}, {}, {})", s, self.score, self.keep)
} else {
write!(
f,
"ScoredToken({:?}, {}, keep={})",
&self.value, self.score, self.keep
)
}
}
}
impl std::fmt::Display for ScoredToken {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if let Ok(s) = std::str::from_utf8(&self.value) {
write!(f, "{:?}", s)
} else {
write!(f, "{:?}", &self.value)
}
}
}
impl PartialOrd for ScoredToken {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.score.partial_cmp(&other.score)
}
}
impl Serialize for ScoredToken {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("ScoredToken", 2)?;
let mut encoded = false;
let value = String::from_utf8(self.value.clone()).unwrap_or_else(|_| {
encoded = true;
BASE64_STANDARD.encode(&self.value)
});
state.serialize_field("value", &value)?;
state.serialize_field("score", &self.score)?;
if encoded {
state.serialize_field("encoded", &true)?;
}
if self.keep {
state.serialize_field("keep", &true)?;
}
state.end()
}
}
impl<'de> Deserialize<'de> for ScoredToken {
fn deserialize<D>(deserializer: D) -> std::result::Result<ScoredToken, D::Error>
where
D: Deserializer<'de>,
{
struct ScoredTokenVisitor;
impl<'de> serde::de::Visitor<'de> for ScoredTokenVisitor {
type Value = ScoredToken;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("ScoredToken")
}
fn visit_map<V>(self, mut map: V) -> std::result::Result<ScoredToken, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut value: Option<String> = None;
let mut score = None;
let mut encoded = false;
let mut keep = false;
while let Some(key) = map.next_key()? {
match key {
"value" => {
value = map.next_value()?;
}
"score" => {
score = map.next_value()?;
}
"encoded" => {
encoded = map.next_value()?;
}
"keep" => {
keep = map.next_value()?;
}
_ => {
return Err(serde::de::Error::unknown_field(key, FIELDS));
}
}
}
let value = match value {
Some(token) => {
if encoded {
BASE64_STANDARD
.decode(token.as_bytes())
.map_err(serde::de::Error::custom)?
} else {
token.into_bytes()
}
}
None => return Err(serde::de::Error::missing_field("token")),
};
let score = match score {
Some(score) => score,
None => return Err(serde::de::Error::missing_field("score")),
};
Ok(ScoredToken { value, score, keep })
}
}
const FIELDS: &[&str] = &["value", "score", "encoded", "keep"];
deserializer.deserialize_struct("ScoredToken", FIELDS, ScoredTokenVisitor)
}
}
pub fn new_default_vocab() -> Vec<ScoredToken> {
(0..=255)
.map(|id| ScoredToken::new(vec![id as u8], 1.0 / 256.0, false))
.collect()
}
pub fn make_vocab(tokens: &[(&[u8], f64)]) -> Vec<ScoredToken> {
tokens
.iter()
.map(|(token, score)| ScoredToken::new(token.to_vec(), *score, false))
.collect()
}
pub enum Error {
IO(std::io::Error),
SerdeJSON(serde_json::Error),
TokenIdOutOfBounds(TokenID),
NoPath(usize, usize),
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Error::IO(err)
}
}
impl From<serde_json::Error> for Error {
fn from(err: serde_json::Error) -> Self {
Error::SerdeJSON(err)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::IO(err) => write!(f, "{}", err),
Error::SerdeJSON(err) => write!(f, "{}", err),
Error::NoPath(pos, len) => {
write!(f, "no path to position {}/{}", pos, len)
}
Error::TokenIdOutOfBounds(id) => write!(f, "token id {} is out of bounds", id),
}
}
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::IO(err) => write!(f, "IO({:?})", err),
Error::SerdeJSON(err) => write!(f, "SerdeJSON({:?})", err),
Error::NoPath(pos, len) => write!(f, "NoPath({}, {})", pos, len),
Error::TokenIdOutOfBounds(id) => write!(f, "TokenIdOutOfBounds({})", id),
}
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_scored_token() {
let scored_token = ScoredToken::new(b"hello".to_vec(), 0.5, false);
let serialized = serde_json::to_string(&scored_token).unwrap();
let deserialized: ScoredToken = serde_json::from_str(&serialized).unwrap();
assert_eq!(scored_token.value, deserialized.value);
assert_eq!(scored_token.score, deserialized.score);
}
}