pub mod agents;
pub mod schedulers;
#[cfg(test)]
mod tests;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use crate::chains::agents::ooda::{multistep, one_step};
use crate::chains::schedulers::{MultiAgentScheduler, SingleAgentScheduler};
use crate::context::ContextDump;
use crate::models::Usage;
use crate::tools::invocation::InvocationError;
use crate::tools::toolbox::{invoke_tool, InvokeResult, Toolbox};
use crate::tools::{TerminationMessage, ToolUseError};
use crate::{SapiensConfig, WeakRuntimeObserver};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Outcome {
Success {
result: String,
},
NoValidInvocationsFound {
e: InvocationError,
},
NoInvocationsFound {
e: InvocationError,
},
ToolUseError {
e: ToolUseError,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Message {
Task {
content: String,
},
Observation {
content: String,
usage: Option<Usage>,
},
Orientation {
content: String,
usage: Option<Usage>,
},
Decision {
content: String,
usage: Option<Usage>,
},
Action {
content: String,
usage: Option<Usage>,
},
ActionResult {
invocation_count: usize,
tool_name: Option<String>,
extracted_input: Option<String>,
outcome: Outcome,
},
}
impl Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Message::Task { content } => write!(f, "Task: {}", content),
Message::Observation { content,.. } => write!(f, "Observation: {}", content),
Message::Orientation { content ,..} => write!(f, "Orientation: {}", content),
Message::Decision { content,.. } => write!(f, "Decision: {}", content),
Message::Action { content ,..} => write!(f, "Action: {}", content),
Message::ActionResult {
invocation_count,
tool_name,
extracted_input,
outcome,
} => write!(
f,
"ActionResult: {} invocations found, tool_name: {:?}, extracted_input: {:?}, outcome: {:?}",
invocation_count,
tool_name,
extracted_input,
outcome
),
}
}
}
impl From<InvokeResult> for Message {
fn from(result: InvokeResult) -> Self {
match result {
InvokeResult::NoInvocationsFound { e } => Message::ActionResult {
invocation_count: 0,
tool_name: None,
extracted_input: None,
outcome: Outcome::NoInvocationsFound { e },
},
InvokeResult::NoValidInvocationsFound {
e,
invocation_count,
} => Message::ActionResult {
invocation_count,
tool_name: None,
extracted_input: None,
outcome: Outcome::NoValidInvocationsFound { e },
},
InvokeResult::Success {
invocation_count,
tool_name,
extracted_input,
result,
} => Message::ActionResult {
invocation_count,
tool_name: Some(tool_name),
extracted_input: Some(extracted_input),
outcome: Outcome::Success { result },
},
InvokeResult::Error {
invocation_count,
tool_name,
e,
..
} => Message::ActionResult {
invocation_count,
tool_name: Some(tool_name),
extracted_input: None,
outcome: Outcome::ToolUseError { e },
},
}
}
}
#[derive(Clone, Default)]
pub struct Context {
messages: Vec<Message>,
}
impl Context {
pub fn new() -> Self {
Self::default()
}
pub fn dump(&self) -> ContextDump {
ContextDump {
messages: self.messages.clone(),
}
}
pub fn get_latest_task(&self) -> Option<String> {
self.messages.iter().rev().find_map(|m| match m {
Message::Task { content } => Some(content.clone()),
_ => None,
})
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("No terminal tool")]
NoTerminalTool,
#[error("Max steps reached")]
MaxStepsReached,
#[error("Agent failed: {0}")]
AgentFailed(#[from] agents::Error),
}
#[async_trait::async_trait]
pub trait Agent: Send + Sync {
type Error;
async fn act(&self, context: &Context) -> Result<Message, Self::Error>;
}
#[async_trait::async_trait]
pub trait Scheduler: Send + Sync {
async fn schedule(&mut self, context: &Context) -> Result<Message, Error>;
}
pub struct Runtime {
context: Context,
toolbox: Toolbox,
scheduler: Box<dyn Scheduler>,
observer: WeakRuntimeObserver,
}
pub struct TerminalState {
pub messages: Vec<TerminationMessage>,
}
impl Runtime {
pub async fn new(
toolbox: Toolbox,
scheduler: Box<dyn Scheduler>,
observer: WeakRuntimeObserver,
) -> Result<Self, Error> {
if !toolbox.has_terminal_tools().await {
return Err(Error::NoTerminalTool);
}
Ok(Self {
context: Context::default(),
toolbox,
scheduler,
observer,
})
}
pub async fn run(&mut self) -> Result<TerminalState, Error> {
loop {
let messages = self.step().await?;
if !messages.is_empty() {
return Ok(TerminalState { messages });
}
}
}
pub async fn step(&mut self) -> Result<Vec<TerminationMessage>, Error> {
let message = self.scheduler.schedule(&self.context).await?;
self.context.messages.push(message.clone());
if let Some(observer) = self.observer.upgrade() {
observer
.lock()
.await
.on_message(message.clone().into())
.await;
}
if let Message::Action { content, .. } = message {
let res = invoke_tool(self.toolbox.clone(), &content).await;
if let Some(observer) = self.observer.upgrade() {
observer
.lock()
.await
.on_invocation_result(res.clone().into())
.await;
}
self.context.messages.push(res.into());
}
Ok(self.toolbox.termination_messages().await)
}
}
#[async_trait::async_trait]
pub trait Chain: Send + Sync {
fn dump(&self) -> ContextDump;
async fn step(&mut self) -> Result<Vec<TerminationMessage>, Error>;
}
pub struct SingleStepOODAChain {
runtime: Runtime,
}
impl SingleStepOODAChain {
pub async fn new(
config: SapiensConfig,
toolbox: Toolbox,
observer: WeakRuntimeObserver,
) -> Result<Self, Error> {
let agent = one_step::Agent::new(config.clone(), toolbox.clone(), observer.clone()).await;
let scheduler =
SingleAgentScheduler::new(config.max_steps, Box::new(agent), observer.clone());
Ok(Self {
runtime: Runtime::new(toolbox, Box::new(scheduler), observer).await?,
})
}
pub fn with_task(mut self, task: String) -> Self {
self.runtime
.context
.messages
.push(Message::Task { content: task });
self
}
}
#[async_trait::async_trait]
impl Chain for SingleStepOODAChain {
fn dump(&self) -> ContextDump {
self.runtime.context.dump()
}
async fn step(&mut self) -> Result<Vec<TerminationMessage>, Error> {
self.runtime.step().await
}
}
pub struct MultiStepOODAChain {
runtime: Runtime,
}
impl MultiStepOODAChain {
pub async fn new(
config: SapiensConfig,
toolbox: Toolbox,
observer: WeakRuntimeObserver,
) -> Result<Self, Error> {
let agents = vec![
multistep::Agent::new_observer(config.clone(), toolbox.clone(), observer.clone()).await,
multistep::Agent::new_orienter(config.clone(), toolbox.clone(), observer.clone()).await,
multistep::Agent::new_decider(config.clone(), toolbox.clone(), observer.clone()).await,
multistep::Agent::new_actor(config.clone(), toolbox.clone(), observer.clone()).await,
];
let agents = agents
.into_iter()
.map(|a| Box::new(a) as Box<dyn Agent<Error = agents::Error>>)
.collect();
let scheduler = MultiAgentScheduler::new(config.max_steps, agents, observer.clone());
Ok(Self {
runtime: Runtime::new(toolbox, Box::new(scheduler), observer).await?,
})
}
pub fn with_task(mut self, task: String) -> Self {
self.runtime
.context
.messages
.push(Message::Task { content: task });
self
}
}
#[async_trait::async_trait]
impl Chain for MultiStepOODAChain {
fn dump(&self) -> ContextDump {
self.runtime.context.dump()
}
async fn step(&mut self) -> Result<Vec<TerminationMessage>, Error> {
self.runtime.step().await
}
}