use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::AgentState;
pub use super::context::ToolContext;
pub use super::file_backend::FileBackend;
pub use super::store::ToolStore;
pub use super::stream::StreamWriter;
pub struct ToolRuntime {
pub state: Arc<Mutex<AgentState>>,
pub context: Arc<dyn ToolContext>,
pub store: Arc<dyn ToolStore>,
pub stream_writer: Option<Arc<dyn StreamWriter>>,
pub file_backend: Option<Arc<dyn FileBackend>>,
pub tool_call_id: String,
}
impl ToolRuntime {
pub fn new(
state: Arc<Mutex<AgentState>>,
context: Arc<dyn ToolContext>,
store: Arc<dyn ToolStore>,
tool_call_id: String,
) -> Self {
Self {
state,
context,
store,
stream_writer: None,
file_backend: None,
tool_call_id,
}
}
pub fn with_stream_writer(mut self, writer: Arc<dyn StreamWriter>) -> Self {
self.stream_writer = Some(writer);
self
}
pub fn with_file_backend(mut self, backend: Arc<dyn FileBackend>) -> Self {
self.file_backend = Some(backend);
self
}
pub fn file_backend(&self) -> Option<&Arc<dyn FileBackend>> {
self.file_backend.as_ref()
}
pub async fn state(&self) -> tokio::sync::MutexGuard<'_, AgentState> {
self.state.lock().await
}
pub fn context(&self) -> &dyn ToolContext {
self.context.as_ref()
}
pub fn store(&self) -> &dyn ToolStore {
self.store.as_ref()
}
pub fn enhanced_store(&self) -> Option<&dyn crate::tools::long_term_memory::EnhancedToolStore> {
None
}
pub fn stream(&self, message: &str) {
if let Some(writer) = &self.stream_writer {
writer.write(message);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::{EmptyContext, InMemoryStore};
#[tokio::test]
async fn test_tool_runtime_creation() {
let state = Arc::new(Mutex::new(AgentState::new()));
let context = Arc::new(EmptyContext);
let store = Arc::new(InMemoryStore::new());
let runtime = ToolRuntime::new(state, context, store, "test_call_1".to_string());
assert_eq!(runtime.tool_call_id, "test_call_1");
}
#[tokio::test]
async fn test_tool_runtime_state_access() {
let state = Arc::new(Mutex::new(AgentState::new()));
let context = Arc::new(EmptyContext);
let store = Arc::new(InMemoryStore::new());
let runtime = ToolRuntime::new(
Arc::clone(&state),
context,
store,
"test_call_1".to_string(),
);
{
let mut state_guard = runtime.state().await;
state_guard.set_field("test_key".to_string(), serde_json::json!("test_value"));
}
let state_guard = runtime.state().await;
assert_eq!(
state_guard.get_field("test_key"),
Some(&serde_json::json!("test_value"))
);
}
#[tokio::test]
async fn test_tool_runtime_context() {
let state = Arc::new(Mutex::new(AgentState::new()));
let context = Arc::new(EmptyContext);
let store = Arc::new(InMemoryStore::new());
let runtime = ToolRuntime::new(state, context, store, "test_call_1".to_string());
let context_ref = runtime.context();
assert_eq!(context_ref.user_id(), None);
}
#[tokio::test]
async fn test_tool_runtime_store() {
let state = Arc::new(Mutex::new(AgentState::new()));
let context = Arc::new(EmptyContext);
let store: Arc<dyn ToolStore> = Arc::new(InMemoryStore::new());
let runtime = ToolRuntime::new(
state,
context,
Arc::clone(&store),
"test_call_1".to_string(),
);
runtime
.store()
.put(&["test"], "key1", serde_json::json!("value1"))
.await;
let value = runtime.store().get(&["test"], "key1").await;
assert_eq!(value, Some(serde_json::json!("value1")));
}
}