use crate::error::ActonAIError;
use crate::facade::ActonAI;
use crate::messages::{Message, ToolDefinition};
use crate::stream::CollectedResponse;
use acton_reactive::prelude::*;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, watch};
type InputMapperFn = Box<dyn FnMut(&str) -> String + Send>;
pub const DEFAULT_SYSTEM_PROMPT: &str = "\
You are a helpful assistant with access to various tools. \
Use tools when appropriate to help the user. \
When the user wants to end the conversation (says goodbye, bye, quit, exit, etc.), \
use the exit_conversation tool.";
pub struct ChatConfig {
user_prompt: String,
assistant_prompt: String,
input_mapper: Option<InputMapperFn>,
}
impl Default for ChatConfig {
fn default() -> Self {
Self {
user_prompt: "You: ".to_string(),
assistant_prompt: "Assistant: ".to_string(),
input_mapper: None,
}
}
}
impl ChatConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn user_prompt(mut self, prompt: impl Into<String>) -> Self {
self.user_prompt = prompt.into();
self
}
#[must_use]
pub fn assistant_prompt(mut self, prompt: impl Into<String>) -> Self {
self.assistant_prompt = prompt.into();
self
}
#[must_use]
pub fn map_input<F>(mut self, f: F) -> Self
where
F: FnMut(&str) -> String + Send + 'static,
{
self.input_mapper = Some(Box::new(f));
self
}
}
impl std::fmt::Debug for ChatConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatConfig")
.field("user_prompt", &self.user_prompt)
.field("assistant_prompt", &self.assistant_prompt)
.field("has_input_mapper", &self.input_mapper.is_some())
.finish()
}
}
fn exit_tool_definition() -> ToolDefinition {
ToolDefinition {
name: "exit_conversation".to_string(),
description: "Call this tool when the user wants to end the conversation, \
say goodbye, or leave. Examples: 'bye', 'goodbye', 'I'm done', \
'quit', 'exit', 'see ya', 'thanks, that's all'."
.to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"farewell": {
"type": "string",
"description": "A friendly farewell message to the user"
}
},
"required": ["farewell"]
}),
}
}
#[derive(Clone, Debug)]
struct ConvSend {
content: String,
token_target: Option<ActorHandle>,
result_tx: mpsc::Sender<Result<CollectedResponse, ActonAIError>>,
}
#[derive(Clone, Debug)]
struct ConvAddAssistant {
text: String,
}
#[derive(Clone, Debug)]
pub struct StreamToken {
pub text: String,
}
#[derive(Clone, Debug)]
struct ConvClear;
#[derive(Clone, Debug)]
struct ConvSetSystemPrompt {
prompt: Option<String>,
}
#[acton_actor]
struct ConversationActor {
history: Vec<Message>,
}
#[derive(Default, Debug)]
struct StdoutTokenPrinter;
struct HandlerState {
runtime: ActonAI,
self_handle: ActorHandle,
history_tx: Arc<watch::Sender<Vec<Message>>>,
history_len: Arc<AtomicUsize>,
exit_requested: Arc<AtomicBool>,
exit_tool_enabled: Arc<AtomicBool>,
system_prompt_rx: watch::Receiver<Option<String>>,
system_prompt_tx: watch::Sender<Option<String>>,
}
fn configure_handlers(builder: &mut ManagedActor<Idle, ConversationActor>, state: HandlerState) {
let HandlerState {
runtime,
self_handle,
history_tx,
history_len,
exit_requested,
exit_tool_enabled,
system_prompt_rx,
system_prompt_tx,
} = state;
{
let runtime = runtime.clone();
let self_handle = self_handle.clone();
let history_tx = history_tx.clone();
let history_len = history_len.clone();
let exit_requested = exit_requested.clone();
let exit_tool_enabled = exit_tool_enabled.clone();
let system_prompt_rx = system_prompt_rx.clone();
builder.mutate_on::<ConvSend>(move |actor, ctx| {
let msg = ctx.message().clone();
actor.model.history.push(Message::user(&msg.content));
let _ = history_tx.send(actor.model.history.clone());
history_len.store(actor.model.history.len(), Ordering::SeqCst);
let history = actor.model.history.clone();
let system_prompt = system_prompt_rx.borrow().clone();
let runtime = runtime.clone();
let exit_requested = exit_requested.clone();
let exit_tool_enabled_val = exit_tool_enabled.load(Ordering::SeqCst);
let result_tx = msg.result_tx;
let token_target = msg.token_target;
let self_handle = self_handle.clone();
Reply::pending(async move {
let llm_result = tokio::spawn(async move {
let mut builder = runtime.continue_with(history);
if let Some(ref system) = system_prompt {
builder = builder.system(system);
}
if exit_tool_enabled_val {
let exit_flag = exit_requested.clone();
builder = builder.with_tool_callback(
exit_tool_definition(),
move |_args| {
let flag = exit_flag.clone();
async move {
flag.store(true, Ordering::SeqCst);
Ok(serde_json::json!({"status": "goodbye"}))
}
},
|_result| {},
);
}
if let Some(target) = token_target {
builder = builder.token_target(target);
}
builder.collect().await
})
.await;
let result = match llm_result {
Ok(r) => r,
Err(join_err) => Err(ActonAIError::prompt_failed(join_err.to_string())),
};
if let Ok(ref response) = result {
self_handle
.send(ConvAddAssistant {
text: response.text.clone(),
})
.await;
}
let _ = result_tx.send(result).await;
})
});
}
{
let history_tx = history_tx.clone();
let history_len = history_len.clone();
builder.mutate_on::<ConvAddAssistant>(move |actor, ctx| {
let text = &ctx.message().text;
actor.model.history.push(Message::assistant(text));
let _ = history_tx.send(actor.model.history.clone());
history_len.store(actor.model.history.len(), Ordering::SeqCst);
Reply::ready()
});
}
{
let history_tx = history_tx.clone();
let history_len = history_len.clone();
builder.mutate_on::<ConvClear>(move |actor, _ctx| {
actor.model.history.clear();
let _ = history_tx.send(actor.model.history.clone());
history_len.store(0, Ordering::SeqCst);
Reply::ready()
});
}
builder.mutate_on::<ConvSetSystemPrompt>(move |_actor, ctx| {
let prompt = ctx.message().prompt.clone();
let _ = system_prompt_tx.send(prompt);
Reply::ready()
});
}
pub struct Conversation {
handle: ActorHandle,
runtime: ActonAI,
exit_requested: Arc<AtomicBool>,
exit_tool_enabled: Arc<AtomicBool>,
history_rx: watch::Receiver<Vec<Message>>,
history_len: Arc<AtomicUsize>,
system_prompt_rx: watch::Receiver<Option<String>>,
}
const _: () = {
#[allow(dead_code)]
fn assert_clone_send_static<T: Clone + Send + 'static>() {}
#[allow(dead_code)]
fn assert_conversation() {
assert_clone_send_static::<Conversation>();
}
};
impl Clone for Conversation {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
runtime: self.runtime.clone(),
exit_requested: self.exit_requested.clone(),
exit_tool_enabled: self.exit_tool_enabled.clone(),
history_rx: self.history_rx.clone(),
history_len: self.history_len.clone(),
system_prompt_rx: self.system_prompt_rx.clone(),
}
}
}
impl Conversation {
pub async fn send(
&self,
content: impl Into<String>,
) -> Result<CollectedResponse, ActonAIError> {
let (tx, mut rx) = mpsc::channel(1);
self.handle
.send(ConvSend {
content: content.into(),
token_target: None,
result_tx: tx,
})
.await;
rx.recv().await.unwrap_or_else(|| {
Err(ActonAIError::prompt_failed(
"conversation actor dropped".to_string(),
))
})
}
pub async fn send_streaming(
&self,
content: impl Into<String>,
token_handle: &ActorHandle,
) -> Result<CollectedResponse, ActonAIError> {
let (tx, mut rx) = mpsc::channel(1);
self.handle
.send(ConvSend {
content: content.into(),
token_target: Some(token_handle.clone()),
result_tx: tx,
})
.await;
rx.recv().await.unwrap_or_else(|| {
Err(ActonAIError::prompt_failed(
"conversation actor dropped".to_string(),
))
})
}
#[must_use]
pub fn history(&self) -> Vec<Message> {
self.history_rx.borrow().clone()
}
pub fn clear(&self) {
let handle = self.handle.clone();
tokio::spawn(async move {
handle.send(ConvClear).await;
});
}
#[must_use]
pub fn len(&self) -> usize {
self.history_len.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.history_len.load(Ordering::SeqCst) == 0
}
#[must_use]
pub fn system_prompt(&self) -> Option<String> {
self.system_prompt_rx.borrow().clone()
}
pub fn set_system_prompt(&self, prompt: impl Into<String>) {
let handle = self.handle.clone();
let prompt = prompt.into();
tokio::spawn(async move {
handle
.send(ConvSetSystemPrompt {
prompt: Some(prompt),
})
.await;
});
}
pub fn clear_system_prompt(&self) {
let handle = self.handle.clone();
tokio::spawn(async move {
handle.send(ConvSetSystemPrompt { prompt: None }).await;
});
}
#[must_use]
pub fn should_exit(&self) -> bool {
self.exit_requested.load(Ordering::SeqCst)
}
pub fn clear_exit(&self) {
self.exit_requested.store(false, Ordering::SeqCst);
}
#[must_use]
pub fn exit_requested(&self) -> Arc<AtomicBool> {
Arc::clone(&self.exit_requested)
}
#[must_use]
pub fn is_exit_tool_enabled(&self) -> bool {
self.exit_tool_enabled.load(Ordering::SeqCst)
}
pub async fn run_chat(&self) -> Result<(), ActonAIError> {
self.run_chat_with(ChatConfig::default()).await
}
pub async fn run_chat_with(&self, mut config: ChatConfig) -> Result<(), ActonAIError> {
use std::io::{BufRead, Write};
self.exit_tool_enabled.store(true, Ordering::SeqCst);
if self.system_prompt_rx.borrow().is_none() {
self.handle
.send(ConvSetSystemPrompt {
prompt: Some(DEFAULT_SYSTEM_PROMPT.to_string()),
})
.await;
}
let mut actor_runtime = self.runtime.runtime().clone();
let mut token_actor = actor_runtime.new_actor::<StdoutTokenPrinter>();
token_actor.mutate_on::<StreamToken>(|_actor, ctx| {
print!("{}", ctx.message().text);
std::io::stdout().flush().ok();
Reply::ready()
});
let token_handle = token_actor.start().await;
let stdin = std::io::stdin();
let result = loop {
print!("{}", config.user_prompt);
std::io::stdout().flush().ok();
let mut input = String::new();
if stdin.lock().read_line(&mut input).unwrap_or(0) == 0 {
break Ok(()); }
let input = input.trim();
if input.is_empty() {
continue;
}
let content = match config.input_mapper.as_mut() {
Some(mapper) => mapper(input),
None => input.to_string(),
};
print!("{}", config.assistant_prompt);
std::io::stdout().flush().ok();
match self.send_streaming(&content, &token_handle).await {
Ok(_) => {
println!();
}
Err(e) => {
break Err(e);
}
}
if self.should_exit() {
break Ok(());
}
};
let _ = token_handle.stop().await;
result
}
}
impl std::fmt::Debug for Conversation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Conversation")
.field("history_len", &self.history_len.load(Ordering::SeqCst))
.field(
"has_system_prompt",
&self.system_prompt_rx.borrow().is_some(),
)
.field(
"exit_tool_enabled",
&self.exit_tool_enabled.load(Ordering::SeqCst),
)
.field(
"exit_requested",
&self.exit_requested.load(Ordering::SeqCst),
)
.finish_non_exhaustive()
}
}
pub struct ConversationBuilder {
runtime: ActonAI,
system_prompt: Option<String>,
history: Vec<Message>,
exit_tool_enabled: bool,
}
impl ConversationBuilder {
pub(crate) fn new(runtime: ActonAI) -> Self {
Self {
runtime,
system_prompt: None,
history: Vec::new(),
exit_tool_enabled: false,
}
}
#[must_use]
pub fn system(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn restore(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.history = messages.into_iter().collect();
self
}
#[must_use]
pub fn with_exit_tool(mut self) -> Self {
self.exit_tool_enabled = true;
self
}
#[must_use]
pub fn without_exit_tool(mut self) -> Self {
self.exit_tool_enabled = false;
self
}
pub async fn build(self) -> Conversation {
let initial_history = self.history;
let (history_tx, history_rx) = watch::channel(initial_history.clone());
let (system_prompt_tx, system_prompt_rx) = watch::channel(self.system_prompt);
let history_len = Arc::new(AtomicUsize::new(initial_history.len()));
let exit_requested = Arc::new(AtomicBool::new(false));
let exit_tool_enabled = Arc::new(AtomicBool::new(self.exit_tool_enabled));
let mut actor_runtime = self.runtime.runtime().clone();
let mut actor_builder = actor_runtime.new_actor::<ConversationActor>();
actor_builder.model.history = initial_history;
let actor_handle = actor_builder.handle().clone();
configure_handlers(
&mut actor_builder,
HandlerState {
runtime: self.runtime.clone(),
self_handle: actor_handle.clone(),
history_tx: Arc::new(history_tx),
history_len: history_len.clone(),
exit_requested: exit_requested.clone(),
exit_tool_enabled: exit_tool_enabled.clone(),
system_prompt_rx: system_prompt_rx.clone(),
system_prompt_tx,
},
);
let _started = actor_builder.start().await;
Conversation {
handle: actor_handle,
runtime: self.runtime,
exit_requested,
exit_tool_enabled,
history_rx,
history_len,
system_prompt_rx,
}
}
pub async fn run_chat(self) -> Result<(), ActonAIError> {
self.build().await.run_chat().await
}
pub async fn run_chat_with(self, config: ChatConfig) -> Result<(), ActonAIError> {
self.build().await.run_chat_with(config).await
}
}
impl std::fmt::Debug for ConversationBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConversationBuilder")
.field("has_system_prompt", &self.system_prompt.is_some())
.field("history_len", &self.history.len())
.field("exit_tool_enabled", &self.exit_tool_enabled)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exit_tool_definition_has_required_fields() {
let def = exit_tool_definition();
assert_eq!(def.name, "exit_conversation");
assert!(def.description.contains("goodbye"));
let props = def.input_schema.get("properties").unwrap();
assert!(props.get("farewell").is_some());
let required = def.input_schema.get("required").unwrap();
let required_arr = required.as_array().unwrap();
assert!(required_arr.iter().any(|v| v.as_str() == Some("farewell")));
}
#[test]
fn exit_flag_atomic_operations() {
let flag = Arc::new(AtomicBool::new(false));
assert!(!flag.load(Ordering::SeqCst));
flag.store(true, Ordering::SeqCst);
assert!(flag.load(Ordering::SeqCst));
flag.store(false, Ordering::SeqCst);
assert!(!flag.load(Ordering::SeqCst));
}
#[test]
fn chat_config_default_values() {
let config = ChatConfig::new();
assert_eq!(config.user_prompt, "You: ");
assert_eq!(config.assistant_prompt, "Assistant: ");
assert!(config.input_mapper.is_none());
}
#[test]
fn chat_config_custom_user_prompt() {
let config = ChatConfig::new().user_prompt(">>> ");
assert_eq!(config.user_prompt, ">>> ");
}
#[test]
fn chat_config_custom_assistant_prompt() {
let config = ChatConfig::new().assistant_prompt("AI: ");
assert_eq!(config.assistant_prompt, "AI: ");
}
#[test]
fn chat_config_with_input_mapper() {
let mut config = ChatConfig::new().map_input(|s| format!("[test] {}", s));
let mapper = config.input_mapper.as_mut().unwrap();
assert_eq!(mapper("hello"), "[test] hello");
}
#[test]
fn chat_config_debug_impl() {
let config = ChatConfig::new()
.user_prompt("test> ")
.map_input(|s| s.to_string());
let debug = format!("{:?}", config);
assert!(debug.contains("test> "));
assert!(debug.contains("has_input_mapper"));
}
#[test]
fn chat_config_chaining() {
let config = ChatConfig::new()
.user_prompt("U> ")
.assistant_prompt("A> ")
.map_input(|s| s.to_uppercase());
assert_eq!(config.user_prompt, "U> ");
assert_eq!(config.assistant_prompt, "A> ");
assert!(config.input_mapper.is_some());
}
#[test]
fn default_system_prompt_is_sensible() {
assert!(DEFAULT_SYSTEM_PROMPT.contains("helpful"));
assert!(DEFAULT_SYSTEM_PROMPT.contains("exit_conversation"));
}
}