mod types;
pub use types::*;
use std::collections::{HashSet, VecDeque};
use std::sync::{Arc, Mutex};
use crate::error::{Result, TinyAgentsError};
use crate::harness::context::RunContext;
use crate::harness::events::AgentEvent;
use crate::harness::message::Message;
impl SteeringPolicy {
pub fn new() -> Self {
Self {
allowed: HashSet::new(),
}
}
pub fn allow_all() -> Self {
Self {
allowed: SteeringCommandKind::ALL.into_iter().collect(),
}
}
pub fn allow(mut self, kind: SteeringCommandKind) -> Self {
self.allowed.insert(kind);
self
}
pub fn is_allowed(&self, kind: SteeringCommandKind) -> bool {
self.allowed.contains(&kind)
}
}
impl SteeringHandle {
pub fn new(policy: SteeringPolicy) -> Self {
Self {
inner: Arc::new(SteeringInner {
queue: Mutex::new(VecDeque::new()),
policy,
}),
}
}
pub fn allow_all() -> Self {
Self::new(SteeringPolicy::allow_all())
}
pub fn send(&self, command: SteeringCommand) {
self.inner
.queue
.lock()
.expect("steering queue mutex poisoned")
.push_back(command);
}
pub fn drain(&self) -> Vec<SteeringCommand> {
let mut queue = self
.inner
.queue
.lock()
.expect("steering queue mutex poisoned");
queue.drain(..).collect()
}
pub fn is_empty(&self) -> bool {
self.inner
.queue
.lock()
.expect("steering queue mutex poisoned")
.is_empty()
}
pub fn pending(&self) -> usize {
self.inner
.queue
.lock()
.expect("steering queue mutex poisoned")
.len()
}
pub fn policy(&self) -> &SteeringPolicy {
&self.inner.policy
}
}
pub fn apply_pending_steering<Ctx>(
ctx: &mut RunContext<Ctx>,
messages: &mut Vec<Message>,
) -> Result<SteeringOutcome> {
let Some(handle) = ctx.steering.clone() else {
return Ok(SteeringOutcome::Continue);
};
let commands = handle.drain();
if commands.is_empty() {
return Ok(SteeringOutcome::Continue);
}
let mut outcome = SteeringOutcome::Continue;
for command in commands {
let kind = command.kind();
if !handle.policy().is_allowed(kind) {
ctx.emit(AgentEvent::Steered {
command_kind: kind.as_str().to_string(),
accepted: false,
});
return Err(TinyAgentsError::Steering(format!(
"steering command `{}` is not permitted by the run policy",
kind.as_str()
)));
}
match command {
SteeringCommand::Pause => outcome = SteeringOutcome::Pause,
SteeringCommand::Resume => {
if outcome == SteeringOutcome::Pause {
outcome = SteeringOutcome::Continue;
}
}
SteeringCommand::Cancel => {
ctx.emit(AgentEvent::Steered {
command_kind: kind.as_str().to_string(),
accepted: true,
});
return Ok(SteeringOutcome::Cancel);
}
SteeringCommand::InjectMessage(message) => messages.push(message),
SteeringCommand::Redirect { instruction } => {
messages.push(Message::system(format!(
"[steering:redirect] {instruction}"
)));
}
SteeringCommand::SetMetadata { metadata } => {
ctx.config.metadata = metadata;
}
}
ctx.emit(AgentEvent::Steered {
command_kind: kind.as_str().to_string(),
accepted: true,
});
}
Ok(outcome)
}
#[cfg(test)]
mod test;