use crate::error::ActonAIError;
use crate::facade::ActonAI;
use crate::messages::{Message, ToolDefinition};
use crate::prompt::PromptBuilder;
use crate::stream::CollectedResponse;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
type InputMapperFn = Box<dyn FnMut(&str) -> String + Send>;
pub const DEFAULT_SYSTEM_PROMPT: &str = "\
You are a helpful assistant with access to various tools. \
Use tools when appropriate to help the user. \
When the user wants to end the conversation (says goodbye, bye, quit, exit, etc.), \
use the exit_conversation tool.";
pub struct ChatConfig {
user_prompt: String,
assistant_prompt: String,
input_mapper: Option<InputMapperFn>,
}
impl Default for ChatConfig {
fn default() -> Self {
Self {
user_prompt: "You: ".to_string(),
assistant_prompt: "Assistant: ".to_string(),
input_mapper: None,
}
}
}
impl ChatConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn user_prompt(mut self, prompt: impl Into<String>) -> Self {
self.user_prompt = prompt.into();
self
}
#[must_use]
pub fn assistant_prompt(mut self, prompt: impl Into<String>) -> Self {
self.assistant_prompt = prompt.into();
self
}
#[must_use]
pub fn map_input<F>(mut self, f: F) -> Self
where
F: FnMut(&str) -> String + Send + 'static,
{
self.input_mapper = Some(Box::new(f));
self
}
}
impl std::fmt::Debug for ChatConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatConfig")
.field("user_prompt", &self.user_prompt)
.field("assistant_prompt", &self.assistant_prompt)
.field("has_input_mapper", &self.input_mapper.is_some())
.finish()
}
}
fn exit_tool_definition() -> ToolDefinition {
ToolDefinition {
name: "exit_conversation".to_string(),
description: "Call this tool when the user wants to end the conversation, \
say goodbye, or leave. Examples: 'bye', 'goodbye', 'I'm done', \
'quit', 'exit', 'see ya', 'thanks, that's all'."
.to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"farewell": {
"type": "string",
"description": "A friendly farewell message to the user"
}
},
"required": ["farewell"]
}),
}
}
pub struct Conversation {
runtime: ActonAI,
history: Vec<Message>,
system_prompt: Option<String>,
exit_requested: Arc<AtomicBool>,
exit_tool_enabled: bool,
}
impl Conversation {
fn new(
runtime: ActonAI,
system_prompt: Option<String>,
history: Vec<Message>,
exit_tool_enabled: bool,
) -> Self {
Self {
runtime,
history,
system_prompt,
exit_requested: Arc::new(AtomicBool::new(false)),
exit_tool_enabled,
}
}
pub async fn send(
&mut self,
content: impl Into<String>,
) -> Result<CollectedResponse, ActonAIError> {
self.send_with(content, |b| b).await
}
pub async fn send_with<F>(
&mut self,
content: impl Into<String>,
configure: F,
) -> Result<CollectedResponse, ActonAIError>
where
F: FnOnce(PromptBuilder) -> PromptBuilder,
{
let user_content = content.into();
self.history.push(Message::user(&user_content));
let mut builder = self.runtime.continue_with(self.history.clone());
if let Some(ref system) = self.system_prompt {
builder = builder.system(system);
}
if self.exit_tool_enabled {
let exit_flag = Arc::clone(&self.exit_requested);
builder = builder.with_tool_callback(
exit_tool_definition(),
move |_args| {
let flag = Arc::clone(&exit_flag);
async move {
flag.store(true, Ordering::SeqCst);
Ok(serde_json::json!({"status": "goodbye"}))
}
},
|_result| {},
);
}
builder = configure(builder);
let response = builder.collect().await?;
self.history.push(Message::assistant(&response.text));
Ok(response)
}
pub async fn send_streaming<F>(
&mut self,
content: impl Into<String>,
on_token: F,
) -> Result<CollectedResponse, ActonAIError>
where
F: FnMut(&str) + Send + 'static,
{
self.send_with(content, |b| b.on_token(on_token)).await
}
#[must_use]
pub fn history(&self) -> &[Message] {
&self.history
}
#[must_use]
pub fn history_mut(&mut self) -> &mut Vec<Message> {
&mut self.history
}
pub fn clear(&mut self) {
self.history.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.history.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.history.is_empty()
}
#[must_use]
pub fn system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.system_prompt = Some(prompt.into());
}
pub fn clear_system_prompt(&mut self) {
self.system_prompt = None;
}
#[must_use]
pub fn should_exit(&self) -> bool {
self.exit_requested.load(Ordering::SeqCst)
}
pub fn clear_exit(&mut self) {
self.exit_requested.store(false, Ordering::SeqCst);
}
#[must_use]
pub fn exit_requested(&self) -> Arc<AtomicBool> {
Arc::clone(&self.exit_requested)
}
#[must_use]
pub fn is_exit_tool_enabled(&self) -> bool {
self.exit_tool_enabled
}
pub async fn run_chat(&mut self) -> Result<(), ActonAIError> {
self.run_chat_with(ChatConfig::default()).await
}
pub async fn run_chat_with(&mut self, mut config: ChatConfig) -> Result<(), ActonAIError> {
use std::io::{BufRead, Write};
if !self.exit_tool_enabled {
self.exit_tool_enabled = true;
}
if self.system_prompt.is_none() {
self.system_prompt = Some(DEFAULT_SYSTEM_PROMPT.to_string());
}
let stdin = std::io::stdin();
loop {
print!("{}", config.user_prompt);
std::io::stdout().flush().ok();
let mut input = String::new();
if stdin.lock().read_line(&mut input).unwrap_or(0) == 0 {
break; }
let input = input.trim();
if input.is_empty() {
continue;
}
let content = match config.input_mapper.as_mut() {
Some(mapper) => mapper(input),
None => input.to_string(),
};
print!("{}", config.assistant_prompt);
std::io::stdout().flush().ok();
self.send_streaming(&content, |token| {
print!("{token}");
std::io::stdout().flush().ok();
})
.await?;
println!();
if self.should_exit() {
break;
}
}
Ok(())
}
}
impl std::fmt::Debug for Conversation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Conversation")
.field("history_len", &self.history.len())
.field("has_system_prompt", &self.system_prompt.is_some())
.field("exit_tool_enabled", &self.exit_tool_enabled)
.field(
"exit_requested",
&self.exit_requested.load(Ordering::SeqCst),
)
.finish_non_exhaustive()
}
}
pub struct ConversationBuilder {
runtime: ActonAI,
system_prompt: Option<String>,
history: Vec<Message>,
exit_tool_enabled: bool,
}
impl ConversationBuilder {
pub(crate) fn new(runtime: ActonAI) -> Self {
Self {
runtime,
system_prompt: None,
history: Vec::new(),
exit_tool_enabled: false,
}
}
#[must_use]
pub fn system(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn restore(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.history = messages.into_iter().collect();
self
}
#[must_use]
pub fn with_exit_tool(mut self) -> Self {
self.exit_tool_enabled = true;
self
}
#[must_use]
pub fn without_exit_tool(mut self) -> Self {
self.exit_tool_enabled = false;
self
}
#[must_use]
pub fn build(self) -> Conversation {
Conversation::new(
self.runtime,
self.system_prompt,
self.history,
self.exit_tool_enabled,
)
}
pub async fn run_chat(self) -> Result<(), ActonAIError> {
self.build().run_chat().await
}
pub async fn run_chat_with(self, config: ChatConfig) -> Result<(), ActonAIError> {
self.build().run_chat_with(config).await
}
}
impl std::fmt::Debug for ConversationBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConversationBuilder")
.field("has_system_prompt", &self.system_prompt.is_some())
.field("history_len", &self.history.len())
.field("exit_tool_enabled", &self.exit_tool_enabled)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exit_tool_definition_has_required_fields() {
let def = exit_tool_definition();
assert_eq!(def.name, "exit_conversation");
assert!(def.description.contains("goodbye"));
let props = def.input_schema.get("properties").unwrap();
assert!(props.get("farewell").is_some());
let required = def.input_schema.get("required").unwrap();
let required_arr = required.as_array().unwrap();
assert!(required_arr.iter().any(|v| v.as_str() == Some("farewell")));
}
#[test]
fn exit_flag_atomic_operations() {
let flag = Arc::new(AtomicBool::new(false));
assert!(!flag.load(Ordering::SeqCst));
flag.store(true, Ordering::SeqCst);
assert!(flag.load(Ordering::SeqCst));
flag.store(false, Ordering::SeqCst);
assert!(!flag.load(Ordering::SeqCst));
}
#[test]
fn chat_config_default_values() {
let config = ChatConfig::new();
assert_eq!(config.user_prompt, "You: ");
assert_eq!(config.assistant_prompt, "Assistant: ");
assert!(config.input_mapper.is_none());
}
#[test]
fn chat_config_custom_user_prompt() {
let config = ChatConfig::new().user_prompt(">>> ");
assert_eq!(config.user_prompt, ">>> ");
}
#[test]
fn chat_config_custom_assistant_prompt() {
let config = ChatConfig::new().assistant_prompt("AI: ");
assert_eq!(config.assistant_prompt, "AI: ");
}
#[test]
fn chat_config_with_input_mapper() {
let mut config = ChatConfig::new().map_input(|s| format!("[test] {}", s));
let mapper = config.input_mapper.as_mut().unwrap();
assert_eq!(mapper("hello"), "[test] hello");
}
#[test]
fn chat_config_debug_impl() {
let config = ChatConfig::new()
.user_prompt("test> ")
.map_input(|s| s.to_string());
let debug = format!("{:?}", config);
assert!(debug.contains("test> "));
assert!(debug.contains("has_input_mapper"));
}
#[test]
fn chat_config_chaining() {
let config = ChatConfig::new()
.user_prompt("U> ")
.assistant_prompt("A> ")
.map_input(|s| s.to_uppercase());
assert_eq!(config.user_prompt, "U> ");
assert_eq!(config.assistant_prompt, "A> ");
assert!(config.input_mapper.is_some());
}
#[test]
fn default_system_prompt_is_sensible() {
assert!(DEFAULT_SYSTEM_PROMPT.contains("helpful"));
assert!(DEFAULT_SYSTEM_PROMPT.contains("exit_conversation"));
}
}