use crate::GenerationParameters;
use crate::ModelConstraints;
use futures_util::Future;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
mod ext;
pub use ext::*;
mod task;
pub use task::*;
mod chat_builder;
pub use chat_builder::*;
mod boxed;
pub use boxed::*;
pub trait CreateChatSession {
type Error: Send + Sync + 'static;
type ChatSession: ChatSession;
fn new_chat_session(&self) -> Result<Self::ChatSession, Self::Error>;
}
pub trait ChatModel<Sampler = GenerationParameters>: CreateChatSession {
fn add_messages_with_callback<'a>(
&'a self,
session: &'a mut Self::ChatSession,
messages: &[ChatMessage],
sampler: Sampler,
on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a;
}
pub trait StructuredChatModel<Constraints: ModelConstraints, Sampler = GenerationParameters>:
ChatModel<Sampler>
{
fn add_message_with_callback_and_constraints<'a>(
&'a self,
session: &'a mut Self::ChatSession,
messages: &[ChatMessage],
sampler: Sampler,
constraints: Constraints,
on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
) -> impl Future<Output = Result<Constraints::Output, Self::Error>> + Send + 'a;
}
pub trait CreateDefaultChatConstraintsForType<T>:
StructuredChatModel<Self::DefaultConstraints>
{
type DefaultConstraints: ModelConstraints<Output = T>;
fn create_default_constraints() -> Self::DefaultConstraints;
}
#[doc = include_str!("../../docs/chat_session.md")]
pub trait ChatSession {
type Error: Send + Sync + 'static;
fn write_to(&self, into: &mut Vec<u8>) -> Result<(), Self::Error>;
fn to_bytes(&self) -> Result<Vec<u8>, Self::Error> {
let mut bytes = Vec::new();
self.write_to(&mut bytes)?;
Ok(bytes)
}
fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>
where
Self: std::marker::Sized;
fn history(&self) -> Vec<ChatMessage>;
fn try_clone(&self) -> Result<Self, Self::Error>
where
Self: std::marker::Sized;
}
pub fn prompt_input(prompt: impl Display) -> Result<String, std::io::Error> {
use std::io::Write;
print!("{}", prompt);
std::io::stdout().flush()?;
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;
input.pop();
Ok(input)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageType {
#[serde(rename = "developer")]
SystemPrompt,
#[serde(rename = "user")]
UserMessage,
#[serde(rename = "assistant")]
ModelAnswer,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChatMessage {
role: MessageType,
content: String,
}
impl ChatMessage {
pub fn new(role: MessageType, contents: impl ToString) -> Self {
Self {
role,
content: contents.to_string(),
}
}
pub fn role(&self) -> MessageType {
self.role
}
pub fn content(&self) -> &str {
&self.content
}
}
pub trait IntoChatMessage {
fn into_chat_message(self) -> ChatMessage;
}
impl<S: ToString> IntoChatMessage for S {
fn into_chat_message(self) -> ChatMessage {
ChatMessage::new(MessageType::UserMessage, self.to_string())
}
}
impl IntoChatMessage for ChatMessage {
fn into_chat_message(self) -> ChatMessage {
self
}
}