use std::{
borrow::Borrow,
pin::Pin,
task::{Context, Poll},
};
use futures::Stream;
use tokio::sync::mpsc::UnboundedReceiver;
use crate::{LlamaModel, Token};
pub struct CompletionHandle {
pub(super) rx: UnboundedReceiver<Token>,
pub(super) model: LlamaModel,
}
impl CompletionHandle {
pub fn next_token(&mut self) -> Option<Token> {
tokio::task::block_in_place(|| self.rx.blocking_recv())
}
pub async fn next_token_async(&mut self) -> Option<Token> {
self.rx.recv().await
}
pub fn into_bytes(self) -> TokensToBytes<CompletionHandle> {
let model = self.model.clone();
TokensToBytes::new(self, model)
}
pub fn into_strings(self) -> TokensToStrings<CompletionHandle> {
let model = self.model.clone();
TokensToStrings::new(self, model)
}
pub fn into_string(self) -> String {
self.model.clone().decode_tokens(self)
}
pub async fn into_string_async(mut self) -> String {
let mut tokens = Vec::new();
while let Some(token) = self.next_token_async().await {
tokens.push(token);
}
self.model.decode_tokens(tokens)
}
}
impl Iterator for CompletionHandle {
type Item = Token;
fn next(&mut self) -> Option<Self::Item> {
self.next_token()
}
}
impl Stream for CompletionHandle {
type Item = Token;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
pub struct TokensToBytes<I> {
inner: I,
model: LlamaModel,
}
impl<I> TokensToBytes<I> {
pub fn new(inner: I, model: LlamaModel) -> TokensToBytes<I> {
Self { inner, model }
}
}
impl<I: Iterator> Iterator for TokensToBytes<I>
where
I::Item: Borrow<Token>,
{
type Item = Vec<u8>;
fn next(&mut self) -> Option<Self::Item> {
self.inner
.next()
.map(|token| self.model.token_to_byte_piece(*token.borrow()))
}
}
impl<I: Stream + Unpin> Stream for TokensToBytes<I>
where
I::Item: Borrow<Token>,
{
type Item = Vec<u8>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let res = std::pin::pin!(&mut self.inner).poll_next(cx);
res.map(|optional_token| {
optional_token.map(|token| self.model.token_to_byte_piece(*token.borrow()))
})
}
}
struct TokenDecoder {
buf: Vec<u8>,
}
impl TokenDecoder {
fn new() -> TokenDecoder {
TokenDecoder { buf: Vec::new() }
}
fn add_token(&mut self, token: &[u8]) -> String {
let mut token = token;
let mut out = String::new();
if !self.buf.is_empty() {
self.buf.extend_from_slice(token);
token = self.buf.as_slice();
}
loop {
match std::str::from_utf8(token) {
Ok(s) => {
out.push_str(s);
self.buf.clear();
break;
}
Err(err) => {
let valid_len = err.valid_up_to();
out.push_str(unsafe { std::str::from_utf8_unchecked(&token[..valid_len]) });
if let Some(len) = err.error_len() {
out.push(char::REPLACEMENT_CHARACTER);
token = &token[valid_len + len..];
} else {
let mut last_bytes = [0; 4];
let last_part_len = token.len() - valid_len;
last_bytes[..last_part_len].clone_from_slice(&token[valid_len..]);
self.buf.clear();
self.buf.extend_from_slice(&last_bytes[..last_part_len]);
break;
}
}
}
}
out
}
fn last_part(&mut self) -> Option<String> {
(!self.buf.is_empty()).then(|| {
let out = String::from_utf8_lossy(&self.buf).to_string();
self.buf.clear();
out
})
}
}
pub struct TokensToStrings<I> {
completion: TokensToBytes<I>,
decoder: TokenDecoder,
}
impl<I> TokensToStrings<I> {
pub fn new(inner: I, model: LlamaModel) -> Self {
Self {
completion: TokensToBytes::new(inner, model),
decoder: TokenDecoder::new(),
}
}
}
impl<I: Iterator> Iterator for TokensToStrings<I>
where
I::Item: Borrow<Token>,
{
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if let Some(bytes) = self.completion.next() {
Some(self.decoder.add_token(&bytes))
} else {
self.decoder.last_part()
}
}
}
impl<I: Stream + Unpin> Stream for TokensToStrings<I>
where
I::Item: Borrow<Token>,
{
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match std::pin::pin!(&mut self.completion).poll_next(cx) {
Poll::Ready(Some(bytes)) => Poll::Ready(Some(self.decoder.add_token(&bytes))),
Poll::Ready(None) => Poll::Ready(self.decoder.last_part()),
Poll::Pending => Poll::Pending,
}
}
}