use futures::stream::BoxStream;
use crate::api::{ApiClient, ApiRequest};
use crate::error::{ApiError, Result};
use crate::raw::request::message::{Message, Role};
use crate::conversation::{LlmSummarizer, Summarizer};
pub struct Conversation {
pub(crate) client: ApiClient,
pub(crate) history: Vec<Message>,
summarizer: Box<dyn Summarizer + Send + Sync>,
auto_summary: bool,
}
impl Conversation {
pub fn new(client: ApiClient) -> Self {
let summarizer = LlmSummarizer::new(client.clone());
Self {
client,
history: vec![],
summarizer: Box::new(summarizer),
auto_summary: true,
}
}
pub fn with_summarizer(mut self, s: impl Summarizer + 'static) -> Self {
self.summarizer = Box::new(s);
self
}
pub fn enable_auto_summary(mut self, v: bool) -> Self {
self.auto_summary = v;
self
}
pub fn with_history(mut self, history: Vec<Message>) -> Self {
self.history = history;
self
}
pub fn history(&self) -> &[Message] {
&self.history
}
pub fn history_mut(&mut self) -> &mut Vec<Message> {
&mut self.history
}
pub fn add_message(&mut self, message: Message) {
self.history.push(message);
}
pub fn push_user_input(&mut self, text: impl Into<String>) {
self.history.push(Message::new(Role::User, &text.into()));
}
pub async fn maybe_summarize(&mut self) {
if !self.auto_summary {
return;
}
if !self.summarizer.should_summarize(&self.history) {
return;
}
let _ = self.summarizer.summarize(&mut self.history).await;
}
pub async fn send_once(&mut self) -> Result<Option<String>> {
self.maybe_summarize().await;
let req = ApiRequest::builder().messages(self.history.clone());
let resp = self.client.send(req).await?;
let choice = resp
.choices
.into_iter()
.next()
.ok_or_else(|| ApiError::Other("empty choices from API".to_string()))?;
let assistant_msg = choice.message;
let content = assistant_msg.content.clone();
self.history.push(assistant_msg);
self.maybe_summarize().await;
Ok(content)
}
pub async fn stream_text(
&mut self,
) -> Result<BoxStream<'_, std::result::Result<String, ApiError>>> {
let req = ApiRequest::builder()
.messages(self.history.clone())
.stream(true);
self.client.stream_text(req).await
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fake() -> Conversation {
Conversation::new(ApiClient::new("fake-token"))
}
#[test]
fn new_has_empty_history() {
assert!(fake().history().is_empty());
}
#[test]
fn with_history_seeds_messages() {
let msgs = vec![Message::new(Role::User, "hi")];
let conv = fake().with_history(msgs);
assert_eq!(conv.history().len(), 1);
}
#[test]
fn push_user_input_appends_user_role() {
let mut conv = fake();
conv.push_user_input("hello");
assert_eq!(conv.history().len(), 1);
assert!(matches!(conv.history()[0].role, Role::User));
}
#[test]
fn add_message_appends() {
let mut conv = fake();
conv.add_message(Message::new(Role::Assistant, "hi"));
assert_eq!(conv.history().len(), 1);
assert!(matches!(conv.history()[0].role, Role::Assistant));
}
#[test]
fn enable_auto_summary_false() {
let conv = fake().enable_auto_summary(false);
assert!(!conv.auto_summary);
}
}