pub mod builtin;
pub mod loader;
#[cfg(feature = "python")]
pub mod python;
pub mod register;
pub mod registry;
use std::ops::Deref;
use std::sync::{Arc, RwLock};
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use crate::llm::ToolSpec;
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn spec(&self) -> Result<ToolSpec>;
fn execute(&self, args: Value) -> Result<Value>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolConcurrency {
SerialOnly,
ParallelSafe,
}
#[derive(Clone)]
pub struct ToolContext {
session_id: String,
agent_name: String,
tool_call_id: String,
state_store: Option<Arc<registry::RegistryStateStore>>,
}
impl ToolContext {
pub fn new(
session_id: impl Into<String>,
agent_name: impl Into<String>,
tool_call_id: impl Into<String>,
) -> Self {
Self {
session_id: session_id.into(),
agent_name: agent_name.into(),
tool_call_id: tool_call_id.into(),
state_store: None,
}
}
pub(crate) fn attach_state_store(
mut self,
state_store: Arc<registry::RegistryStateStore>,
) -> Self {
self.state_store = Some(state_store);
self
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn agent_name(&self) -> &str {
&self.agent_name
}
pub fn tool_call_id(&self) -> &str {
&self.tool_call_id
}
pub fn app_state<T>(&self) -> Result<State<T>>
where
T: Send + Sync + 'static,
{
let store = self.require_state_store()?;
store.get_app_state::<T>()
}
pub fn session_state<T>(&self) -> Result<SessionState<T>>
where
T: Send + Sync + 'static,
{
let store = self.require_state_store()?;
store.get_session_state::<T>(&self.session_id)
}
fn require_state_store(&self) -> Result<&Arc<registry::RegistryStateStore>> {
self.state_store.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Tool context for session '{}' does not have a managed-state store",
self.session_id
)
})
}
}
impl std::fmt::Debug for ToolContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolContext")
.field("session_id", &self.session_id)
.field("agent_name", &self.agent_name)
.field("tool_call_id", &self.tool_call_id)
.finish()
}
}
#[derive(Clone)]
pub struct State<T> {
inner: Arc<T>,
}
impl<T> State<T> {
pub(crate) fn from_arc(inner: Arc<T>) -> Self {
Self { inner }
}
pub fn into_inner(self) -> Arc<T> {
self.inner
}
}
impl<T> Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> std::fmt::Debug for State<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("State").field(&self.inner).finish()
}
}
#[derive(Clone)]
pub struct SessionState<T> {
inner: Arc<RwLock<T>>,
}
impl<T> SessionState<T> {
pub(crate) fn from_arc(inner: Arc<RwLock<T>>) -> Self {
Self { inner }
}
pub fn read<R>(&self, reader: impl FnOnce(&T) -> R) -> Result<R> {
let guard = self.inner.read().map_err(|_| {
anyhow::anyhow!("Session state lock was poisoned while acquiring a read guard")
})?;
Ok(reader(&guard))
}
pub fn update<R>(&self, updater: impl FnOnce(&mut T) -> R) -> Result<R> {
let mut guard = self.inner.write().map_err(|_| {
anyhow::anyhow!("Session state lock was poisoned while acquiring a write guard")
})?;
Ok(updater(&mut guard))
}
pub fn get_cloned(&self) -> Result<T>
where
T: Clone,
{
self.read(Clone::clone)
}
}
impl<T> std::fmt::Debug for SessionState<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.inner.read() {
Ok(guard) => f.debug_tuple("SessionState").field(&*guard).finish(),
Err(_) => f.debug_tuple("SessionState").field(&"<poisoned>").finish(),
}
}
}
#[async_trait]
pub trait AsyncTool: Send + Sync {
fn name(&self) -> &str;
fn spec(&self) -> Result<ToolSpec>;
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::SerialOnly
}
async fn execute(&self, ctx: ToolContext, args: Value) -> Result<Value>;
}
pub use registry::ToolRegistry;