mod session;
use schemars::JsonSchema;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::agent::error::{AgentError, AgentResult};
pub use session::{
AgentBuilder, AgentEvent, AgentInput, AgentRunInput, AgentRunOutput, ExecutionProfile,
PreparedRequest, SessionAgent,
};
#[allow(async_fn_in_trait)]
pub trait Agent: Send + 'static {
type Input: Clone + Serialize + DeserializeOwned + Send + Sync + 'static;
type ToolCall: Clone + Serialize + DeserializeOwned + Send + Sync + 'static;
type ToolResult: Clone + Serialize + DeserializeOwned + Send + Sync + 'static;
type Output: Clone + Serialize + DeserializeOwned + JsonSchema + Send + Sync + 'static;
async fn send(&mut self, input: AgentInput<Self::Input>) -> AgentResult<()>;
async fn next(
&mut self,
) -> AgentResult<Option<AgentEvent<Self::ToolCall, Self::ToolResult, Self::Output>>>;
async fn cast(&mut self, input: Self::Input) -> AgentResult<()> {
self.send(AgentInput::Message(input)).await
}
async fn call(&mut self, input: Self::Input) -> AgentResult<Self::Output> {
self.send(AgentInput::Message(input)).await?;
loop {
match self.next().await? {
Some(AgentEvent::Completed { reply, .. }) => return Ok(reply),
Some(AgentEvent::Cancelled) => return Err(AgentError::Cancelled),
Some(_) => {}
None => {
return Err(AgentError::Internal {
message: "agent ended turn without a terminal event".to_string(),
});
}
}
}
}
async fn steer(&mut self, input: Self::Input) -> AgentResult<Self::Output> {
self.send(AgentInput::Steer(input)).await?;
loop {
match self.next().await? {
Some(AgentEvent::Completed { reply, .. }) => return Ok(reply),
Some(AgentEvent::Cancelled) => return Err(AgentError::Cancelled),
Some(_) => {}
None => {
return Err(AgentError::Internal {
message: "agent ended steered turn without a terminal event".to_string(),
});
}
}
}
}
async fn cancel(&mut self) -> AgentResult<()> {
self.send(AgentInput::Cancel).await?;
loop {
match self.next().await? {
Some(AgentEvent::Cancelled) => return Ok(()),
Some(AgentEvent::Completed { .. }) => {
return Err(AgentError::Internal {
message: "cancel completed without observing cancellation".to_string(),
});
}
Some(_) => {}
None => {
return Err(AgentError::Internal {
message: "agent ended without observing cancellation".to_string(),
});
}
}
}
}
async fn spawn(
self,
) -> AgentResult<(
AgentRunInput<Self::Input>,
AgentRunOutput<Self::ToolCall, Self::ToolResult, Self::Output>,
)>
where
Self: Sized,
Self: Sized;
}