use futures::StreamExt;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use crate::client::Client;
use crate::error::Error;
use crate::response::{Responses, ToolUseResponse};
pub struct Conversation<'a> {
client: &'a Client,
history: Vec<Turn>,
}
#[derive(Debug, Clone)]
pub struct Turn {
pub prompt: String,
pub responses: Responses,
}
impl Turn {
pub fn text(&self) -> String {
self.responses.text_content()
}
}
type TextCallback<'a> = Box<dyn FnMut(&str) + Send + 'a>;
type ThinkingCallback<'a> = Box<dyn FnMut(&str) + Send + 'a>;
type ToolUseCallback<'a> = Box<dyn FnMut(&ToolUseResponse) + Send + 'a>;
pub struct TurnBuilder<'a, 'c> {
conversation: &'a mut Conversation<'c>,
prompt: String,
on_text: Option<TextCallback<'a>>,
on_thinking: Option<ThinkingCallback<'a>>,
on_tool_use: Option<ToolUseCallback<'a>>,
collect: bool,
}
impl<'a> Conversation<'a> {
pub(crate) fn new(client: &'a Client) -> Self {
Self {
client,
history: Vec::new(),
}
}
pub fn turn(&mut self, prompt: impl Into<String>) -> TurnBuilder<'_, 'a> {
TurnBuilder {
conversation: self,
prompt: prompt.into(),
on_text: None,
on_thinking: None,
on_tool_use: None,
collect: true,
}
}
pub async fn say(&mut self, prompt: &str) -> Result<String, Error> {
self.turn(prompt).send_text().await
}
pub fn history(&self) -> &[Turn] {
&self.history
}
pub fn last(&self) -> Option<&Turn> {
self.history.last()
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn client(&self) -> &Client {
self.client
}
}
impl<'a, 'c> TurnBuilder<'a, 'c> {
pub fn on_text<F>(mut self, f: F) -> Self
where
F: FnMut(&str) + Send + 'a,
{
self.on_text = Some(Box::new(f));
self
}
pub fn on_thinking<F>(mut self, f: F) -> Self
where
F: FnMut(&str) + Send + 'a,
{
self.on_thinking = Some(Box::new(f));
self
}
pub fn on_tool_use<F>(mut self, f: F) -> Self
where
F: FnMut(&ToolUseResponse) + Send + 'a,
{
self.on_tool_use = Some(Box::new(f));
self
}
pub fn collect(mut self, collect: bool) -> Self {
self.collect = collect;
self
}
pub async fn send(self) -> Result<Responses, Error> {
let TurnBuilder {
conversation,
prompt,
mut on_text,
mut on_thinking,
mut on_tool_use,
collect,
} = self;
conversation.client.query(&prompt).await?;
let mut responses = Responses::new();
let mut stream = std::pin::pin!(conversation.client.receive());
while let Some(result) = stream.next().await {
let response = result?;
if let Some(text) = response.as_text()
&& let Some(ref mut cb) = on_text
{
cb(text.content());
}
if let Some(thinking) = response.as_thinking()
&& let Some(ref mut cb) = on_thinking
{
cb(thinking.content());
}
if let Some(tool_use) = response.as_tool_use()
&& let Some(ref mut cb) = on_tool_use
{
cb(tool_use);
}
if collect {
responses.push(response);
}
}
conversation.history.push(Turn {
prompt,
responses: responses.clone(),
});
Ok(responses)
}
pub async fn send_text(self) -> Result<String, Error> {
let responses = self.send().await?;
Ok(responses.text_content())
}
pub async fn send_as<T>(self) -> Result<T, Error>
where
T: DeserializeOwned + JsonSchema,
{
let responses = self.send().await?;
let structured_output = responses
.completion()
.and_then(|c| c.structured_output())
.cloned()
.ok_or_else(|| Error::ProtocolError("no structured output in response".to_owned()))?;
let result = serde_json::from_value::<T>(structured_output)?;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_turn_text() {
let turn = Turn {
prompt: "Hello".to_string(),
responses: Responses::new(),
};
assert_eq!(turn.text(), "");
assert_eq!(turn.prompt, "Hello");
}
}