use crate::error::Result;
use crate::types::{Event, Message, Model, Response, ResponseRequest, Role};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Chat {
model: Option<Model>,
max_output_tokens: Option<u32>,
messages: Vec<Message>,
}
impl Chat {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: Some(Model::new(model)),
max_output_tokens: None,
messages: Vec::new(),
}
}
pub fn new_auto() -> Self {
Self {
model: None,
max_output_tokens: None,
messages: Vec::new(),
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(Model::new(model));
self
}
pub fn max_output_tokens(mut self, max: u32) -> Self {
self.max_output_tokens = Some(max);
self
}
pub fn set_max_output_tokens(&mut self, max: u32) -> &mut Self {
self.max_output_tokens = Some(max);
self
}
pub fn push_message(&mut self, message: Message) -> &mut Self {
self.messages.push(message);
self
}
pub fn push_text(&mut self, role: Role, text: impl Into<String>) -> &mut Self {
self.push_message(Message::text(role, text))
}
pub fn push_user(&mut self, text: impl Into<String>) -> &mut Self {
self.push_text(Role::User, text)
}
pub fn push(&mut self, text: impl Into<String>) -> &mut Self {
self.push_user(text)
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn into_messages(self) -> Vec<Message> {
self.messages
}
fn to_request(&self) -> ResponseRequest {
ResponseRequest {
model: self.model.clone(),
messages: self.messages.clone(),
max_output_tokens: self.max_output_tokens,
tools: Vec::new(),
}
}
pub async fn send(&mut self, client: &crate::client::Client) -> Result<Response> {
let resp = client.send(self.to_request()).await?;
self.messages.push(resp.message.clone());
Ok(resp)
}
pub async fn stream<F>(
&mut self,
client: &crate::client::Client,
on_event: F,
) -> Result<Response>
where
F: FnMut(Event),
{
let resp = client.stream(self.to_request(), on_event).await?;
self.messages.push(resp.message.clone());
Ok(resp)
}
}
pub struct ChatSession<'a> {
client: &'a crate::client::Client,
chat: Chat,
}
impl<'a> ChatSession<'a> {
pub(crate) fn new(client: &'a crate::client::Client, chat: Chat) -> Self {
Self { client, chat }
}
pub fn chat(&self) -> &Chat {
&self.chat
}
pub fn chat_mut(&mut self) -> &mut Chat {
&mut self.chat
}
pub fn into_chat(self) -> Chat {
self.chat
}
pub fn push_message(&mut self, message: Message) -> &mut Self {
self.chat.push_message(message);
self
}
pub fn push_text(&mut self, role: Role, text: impl Into<String>) -> &mut Self {
self.chat.push_text(role, text);
self
}
pub fn push_user(&mut self, text: impl Into<String>) -> &mut Self {
self.chat.push_user(text);
self
}
pub fn push(&mut self, text: impl Into<String>) -> &mut Self {
self.push_user(text)
}
pub async fn send(&mut self) -> Result<Response> {
self.chat.send(self.client).await
}
pub async fn stream<F>(&mut self, on_event: F) -> Result<Response>
where
F: FnMut(Event),
{
self.chat.stream(self.client, on_event).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_serde_roundtrip() {
let mut chat = Chat::new("x");
chat.push_text(Role::User, "hi");
chat.push_text(Role::Assistant, "hello");
chat.set_max_output_tokens(123);
let bytes = serde_json::to_vec(&chat).unwrap();
let back: Chat = serde_json::from_slice(&bytes).unwrap();
assert_eq!(back.messages().len(), 2);
}
}