use crate::conversation::StreamToken;
use crate::error::ActonAIError;
use crate::facade::ActonAI;
use crate::llm::SamplingParams;
use crate::messages::{
LLMRequest, LLMStreamEnd, LLMStreamStart, LLMStreamToken, LLMStreamToolCall, Message,
StopReason, ToolCall, ToolDefinition,
};
use crate::stream::{CollectedResponse, ExecutedToolCall};
use crate::tools::ToolError;
use crate::types::{AgentId, CorrelationId};
use acton_reactive::prelude::*;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Notify;
type StartCallback = Box<dyn FnMut() + Send + 'static>;
type TokenCallback = Box<dyn FnMut(&str) + Send + 'static>;
type EndCallback = Box<dyn FnMut(StopReason) + Send + 'static>;
type ToolResultCallback = Box<dyn FnMut(Result<&serde_json::Value, &str>) + Send + 'static>;
type ToolFuture = Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send>>;
pub trait ToolExecutorFn: Send + Sync {
fn call(&self, args: serde_json::Value) -> ToolFuture;
}
struct ClosureToolExecutor<F> {
func: F,
}
impl<F, Fut> ToolExecutorFn for ClosureToolExecutor<F>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
fn call(&self, args: serde_json::Value) -> ToolFuture {
Box::pin((self.func)(args))
}
}
struct BuiltinToolExecutorAdapter {
executor: Arc<crate::tools::BoxedToolExecutor>,
}
impl ToolExecutorFn for BuiltinToolExecutorAdapter {
fn call(&self, args: serde_json::Value) -> ToolFuture {
let executor = Arc::clone(&self.executor);
Box::pin(async move { executor.execute(args).await })
}
}
pub struct ToolSpec {
pub definition: ToolDefinition,
executor: Arc<dyn ToolExecutorFn>,
on_result: Option<ToolResultCallback>,
}
impl std::fmt::Debug for ToolSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolSpec")
.field("definition", &self.definition)
.finish_non_exhaustive()
}
}
impl Clone for ToolSpec {
fn clone(&self) -> Self {
Self {
definition: self.definition.clone(),
executor: self.executor.clone(),
on_result: None,
}
}
}
type WrappedStartCallback = Arc<std::sync::Mutex<StartCallback>>;
type WrappedTokenCallback = Arc<std::sync::Mutex<TokenCallback>>;
type WrappedEndCallback = Arc<std::sync::Mutex<EndCallback>>;
pub struct PromptBuilder {
runtime: ActonAI,
user_content: String,
system_prompt: Option<String>,
conversation_history: Option<Vec<Message>>,
on_start: Option<StartCallback>,
on_token: Option<TokenCallback>,
on_end: Option<EndCallback>,
tools: Vec<ToolSpec>,
max_tool_rounds: usize,
provider_name: Option<String>,
token_target: Option<ActorHandle>,
sampling: Option<SamplingParams>,
}
impl PromptBuilder {
#[must_use]
pub(crate) fn new(runtime: ActonAI, user_content: String) -> Self {
Self {
runtime,
user_content,
system_prompt: None,
conversation_history: None,
on_start: None,
on_token: None,
on_end: None,
tools: Vec::new(),
max_tool_rounds: 10,
provider_name: None,
token_target: None,
sampling: None,
}
}
#[must_use]
pub fn system(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.conversation_history = Some(messages.into_iter().collect());
self
}
#[must_use]
pub fn on_start<F>(mut self, f: F) -> Self
where
F: FnMut() + Send + 'static,
{
self.on_start = Some(Box::new(f));
self
}
#[must_use]
pub fn on_token<F>(mut self, f: F) -> Self
where
F: FnMut(&str) + Send + 'static,
{
self.on_token = Some(Box::new(f));
self
}
#[must_use]
pub fn on_end<F>(mut self, f: F) -> Self
where
F: FnMut(StopReason) + Send + 'static,
{
self.on_end = Some(Box::new(f));
self
}
#[must_use]
pub fn tool<F, Fut>(
mut self,
name: impl Into<String>,
description: impl Into<String>,
input_schema: serde_json::Value,
executor: F,
) -> Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
let definition = ToolDefinition {
name: name.into(),
description: description.into(),
input_schema,
};
let spec = ToolSpec {
definition,
executor: Arc::new(ClosureToolExecutor { func: executor }),
on_result: None,
};
self.tools.push(spec);
self
}
#[must_use]
pub fn with_tool<F, Fut>(mut self, definition: ToolDefinition, executor: F) -> Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
let spec = ToolSpec {
definition,
executor: Arc::new(ClosureToolExecutor { func: executor }),
on_result: None,
};
self.tools.push(spec);
self
}
#[must_use]
pub fn with_tool_callback<F, Fut, C>(
mut self,
definition: ToolDefinition,
executor: F,
on_result: C,
) -> Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
C: FnMut(Result<&serde_json::Value, &str>) + Send + 'static,
{
let spec = ToolSpec {
definition,
executor: Arc::new(ClosureToolExecutor { func: executor }),
on_result: Some(Box::new(on_result)),
};
self.tools.push(spec);
self
}
#[must_use]
pub fn max_tool_rounds(mut self, max: usize) -> Self {
self.max_tool_rounds = max;
self
}
#[must_use]
pub fn provider(mut self, name: impl Into<String>) -> Self {
self.provider_name = Some(name.into());
self
}
#[must_use]
pub fn sampling(mut self, params: SamplingParams) -> Self {
self.sampling = Some(params);
self
}
#[must_use]
pub fn temperature(mut self, temperature: f64) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.temperature = Some(temperature);
self
}
#[must_use]
pub fn top_p(mut self, top_p: f64) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.top_p = Some(top_p);
self
}
#[must_use]
pub fn top_k(mut self, top_k: u32) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.top_k = Some(top_k);
self
}
#[must_use]
pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.stop_sequences = Some(sequences);
self
}
#[must_use]
pub fn frequency_penalty(mut self, penalty: f64) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.frequency_penalty = Some(penalty);
self
}
#[must_use]
pub fn presence_penalty(mut self, penalty: f64) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.presence_penalty = Some(penalty);
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.sampling
.get_or_insert_with(SamplingParams::default)
.seed = Some(seed);
self
}
#[must_use]
pub fn token_target(mut self, handle: ActorHandle) -> Self {
self.token_target = Some(handle);
self
}
#[must_use]
pub fn use_builtins(mut self) -> Self {
if let Some(builtins) = self.runtime.builtins() {
for (name, config) in builtins.configs() {
if let Some(executor) = builtins.get_executor(name) {
let adapter = BuiltinToolExecutorAdapter { executor };
self.tools.push(ToolSpec {
definition: config.definition.clone(),
executor: Arc::new(adapter),
on_result: None,
});
}
}
}
self
}
pub async fn collect(self) -> Result<CollectedResponse, ActonAIError> {
if self.runtime.is_shutdown() {
return Err(ActonAIError::runtime_shutdown());
}
let PromptBuilder {
runtime,
user_content,
system_prompt,
conversation_history,
on_start,
on_token,
on_end,
mut tools,
max_tool_rounds,
provider_name,
token_target,
sampling,
} = self;
let provider_handle = if let Some(ref name) = provider_name {
runtime.provider_handle_named(name).ok_or_else(|| {
ActonAIError::configuration(
"provider",
format!(
"provider '{}' not found; available: {}",
name,
runtime.provider_names().collect::<Vec<_>>().join(", ")
),
)
})?
} else {
runtime.provider_handle()
};
let mut messages = Vec::new();
if let Some(ref system) = system_prompt {
messages.push(Message::system(system));
}
if let Some(history) = conversation_history {
messages.extend(history);
} else {
messages.push(Message::user(&user_content));
}
let tool_definitions: Vec<ToolDefinition> =
tools.iter().map(|t| t.definition.clone()).collect();
let has_tools = !tool_definitions.is_empty();
let mut executed_tool_calls = Vec::new();
let mut total_token_count = 0;
let mut final_text;
let mut rounds = 0;
let on_start: Option<WrappedStartCallback> =
on_start.map(|f| Arc::new(std::sync::Mutex::new(f)));
let on_token: Option<WrappedTokenCallback> =
on_token.map(|f| Arc::new(std::sync::Mutex::new(f)));
let on_end: Option<WrappedEndCallback> = on_end.map(|f| Arc::new(std::sync::Mutex::new(f)));
loop {
rounds += 1;
if rounds > max_tool_rounds {
return Err(ActonAIError::prompt_failed(format!(
"exceeded maximum tool rounds ({max_tool_rounds})",
)));
}
let correlation_id = CorrelationId::new();
let agent_id = AgentId::new();
let request = LLMRequest {
correlation_id: correlation_id.clone(),
agent_id,
messages: messages.clone(),
tools: if has_tools {
Some(tool_definitions.clone())
} else {
None
},
sampling: sampling.clone(),
};
let (text, stop_reason, token_count, tool_calls) = collect_stream_round(
&runtime,
&provider_handle,
&request,
correlation_id,
StreamRoundCallbacks {
on_start: on_start.clone(),
on_token: on_token.clone(),
on_end: on_end.clone(),
token_target: token_target.clone(),
},
)
.await?;
final_text = text.clone();
total_token_count += token_count;
match stop_reason {
StopReason::EndTurn | StopReason::MaxTokens | StopReason::StopSequence => {
break;
}
StopReason::ToolUse => {
if tool_calls.is_empty() {
break;
}
let mut tool_results = Vec::new();
for tool_call in &tool_calls {
let result = execute_tool_with_callback(&mut tools, tool_call).await;
let executed = match &result {
Ok(value) => ExecutedToolCall::success(
&tool_call.id,
&tool_call.name,
tool_call.arguments.clone(),
value.clone(),
),
Err(e) => ExecutedToolCall::error(
&tool_call.id,
&tool_call.name,
tool_call.arguments.clone(),
e.to_string(),
),
};
executed_tool_calls.push(executed);
tool_results.push(result);
}
messages.push(Message::assistant_with_tools(text, tool_calls.clone()));
for (tool_call, result) in tool_calls.iter().zip(tool_results.iter()) {
let result_str = match result {
Ok(v) => serde_json::to_string(v).unwrap_or_default(),
Err(e) => format!("Error: {e}"),
};
messages.push(Message::tool(&tool_call.id, result_str));
}
}
}
}
Ok(CollectedResponse::with_tool_calls(
final_text,
StopReason::EndTurn,
total_token_count,
executed_tool_calls,
))
}
}
struct StreamRoundCallbacks {
on_start: Option<WrappedStartCallback>,
on_token: Option<WrappedTokenCallback>,
on_end: Option<WrappedEndCallback>,
token_target: Option<ActorHandle>,
}
async fn collect_stream_round(
runtime: &ActonAI,
provider_handle: &ActorHandle,
request: &LLMRequest,
correlation_id: CorrelationId,
callbacks: StreamRoundCallbacks,
) -> Result<(String, StopReason, usize, Vec<ToolCall>), ActonAIError> {
let StreamRoundCallbacks {
on_start,
on_token,
on_end,
token_target,
} = callbacks;
let stream_done = Arc::new(Notify::new());
let stream_done_signal = stream_done.clone();
let result_container: Arc<std::sync::Mutex<Option<CollectorResultData>>> =
Arc::new(std::sync::Mutex::new(None));
let result_container_clone = result_container.clone();
let mut actor_runtime = runtime.runtime().clone();
let mut collector = actor_runtime.new_actor::<StreamCollector>();
let on_start_clone = on_start.clone();
let expected_id = correlation_id.clone();
collector.mutate_on::<LLMStreamStart>(move |_actor, envelope| {
if envelope.message().correlation_id == expected_id {
if let Some(ref callback) = on_start_clone {
if let Ok(mut f) = callback.lock() {
f();
}
}
}
Reply::ready()
});
let on_token_clone = on_token.clone();
let token_target_clone = token_target.clone();
let expected_id = correlation_id.clone();
collector.mutate_on::<LLMStreamToken>(move |actor, envelope| {
if envelope.message().correlation_id == expected_id {
let token = &envelope.message().token;
actor.model.buffer.push_str(token);
actor.model.token_count += 1;
if let Some(ref callback) = on_token_clone {
if let Ok(mut f) = callback.lock() {
f(token);
}
}
if let Some(ref target) = token_target_clone {
let target = target.clone();
let text = token.to_string();
return Reply::pending(async move {
target.send(StreamToken { text }).await;
});
}
}
Reply::ready()
});
let expected_id = correlation_id.clone();
collector.mutate_on::<LLMStreamToolCall>(move |actor, envelope| {
if envelope.message().correlation_id == expected_id {
actor
.model
.tool_calls
.push(envelope.message().tool_call.clone());
}
Reply::ready()
});
let on_end_clone = on_end.clone();
let expected_id = correlation_id.clone();
collector.mutate_on::<LLMStreamEnd>(move |actor, envelope| {
if envelope.message().correlation_id == expected_id {
actor.model.stop_reason = Some(envelope.message().stop_reason);
if let Some(ref callback) = on_end_clone {
if let Ok(mut f) = callback.lock() {
f(envelope.message().stop_reason);
}
}
if let Ok(mut container) = result_container_clone.lock() {
*container = Some(CollectorResultData {
buffer: std::mem::take(&mut actor.model.buffer),
stop_reason: actor.model.stop_reason,
token_count: actor.model.token_count,
tool_calls: std::mem::take(&mut actor.model.tool_calls),
});
}
stream_done_signal.notify_one();
}
Reply::ready()
});
collector.handle().subscribe::<LLMStreamStart>().await;
collector.handle().subscribe::<LLMStreamToken>().await;
collector.handle().subscribe::<LLMStreamToolCall>().await;
collector.handle().subscribe::<LLMStreamEnd>().await;
let collector_handle = collector.start().await;
provider_handle.send(request.clone()).await;
stream_done.notified().await;
let _ = collector_handle.stop().await;
let result = result_container
.lock()
.ok()
.and_then(|mut guard| guard.take())
.ok_or_else(|| {
ActonAIError::prompt_failed("failed to retrieve collected stream data".to_string())
})?;
Ok((
result.buffer,
result.stop_reason.unwrap_or(StopReason::EndTurn),
result.token_count,
result.tool_calls,
))
}
async fn execute_tool_with_callback(
tools: &mut [ToolSpec],
tool_call: &ToolCall,
) -> Result<serde_json::Value, ToolError> {
for spec in tools.iter_mut() {
if spec.definition.name == tool_call.name {
let result = spec.executor.call(tool_call.arguments.clone()).await;
if let Some(ref mut callback) = spec.on_result {
match &result {
Ok(value) => callback(Ok(value)),
Err(e) => {
let error_str = e.to_string();
callback(Err(&error_str));
}
}
}
return result;
}
}
Err(ToolError::not_found(&tool_call.name))
}
#[acton_actor]
struct StreamCollector {
buffer: String,
token_count: usize,
stop_reason: Option<StopReason>,
tool_calls: Vec<ToolCall>,
}
#[derive(Debug, Clone, Default)]
struct CollectorResultData {
buffer: String,
stop_reason: Option<StopReason>,
token_count: usize,
tool_calls: Vec<ToolCall>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_spec_debug_impl() {
let spec = ToolSpec {
definition: ToolDefinition {
name: "test".to_string(),
description: "Test tool".to_string(),
input_schema: serde_json::json!({}),
},
executor: Arc::new(ClosureToolExecutor {
func: |_args: serde_json::Value| async { Ok(serde_json::json!({})) },
}),
on_result: None,
};
let debug = format!("{:?}", spec);
assert!(debug.contains("test"));
}
#[test]
fn tool_spec_clone() {
let spec = ToolSpec {
definition: ToolDefinition {
name: "test".to_string(),
description: "Test tool".to_string(),
input_schema: serde_json::json!({}),
},
executor: Arc::new(ClosureToolExecutor {
func: |_args: serde_json::Value| async { Ok(serde_json::json!({})) },
}),
on_result: Some(Box::new(|_result| {})),
};
let cloned = spec.clone();
assert_eq!(cloned.definition.name, "test");
assert!(cloned.on_result.is_none());
}
#[test]
fn collected_response_new_creates_correctly() {
let response = CollectedResponse::new("Hello world".to_string(), StopReason::EndTurn, 2);
assert_eq!(response.text, "Hello world");
assert_eq!(response.stop_reason, StopReason::EndTurn);
assert_eq!(response.token_count, 2);
assert!(response.tool_calls.is_empty());
}
#[test]
fn collected_response_is_complete() {
let complete = CollectedResponse::new("test".to_string(), StopReason::EndTurn, 1);
assert!(complete.is_complete());
let incomplete = CollectedResponse::new("test".to_string(), StopReason::MaxTokens, 1);
assert!(!incomplete.is_complete());
}
#[test]
fn collected_response_is_truncated() {
let truncated = CollectedResponse::new("test".to_string(), StopReason::MaxTokens, 1);
assert!(truncated.is_truncated());
let complete = CollectedResponse::new("test".to_string(), StopReason::EndTurn, 1);
assert!(!complete.is_truncated());
}
}