use std::sync::Arc;
use tokio::sync::Mutex;
use crate::{
agent::AgentState,
prompt::PromptArgs,
tools::{StreamWriter, ToolContext, ToolStore},
};
pub struct Runtime {
pub context: Arc<dyn ToolContext>,
pub store: Arc<dyn ToolStore>,
pub stream_writer: Option<Arc<dyn StreamWriter>>,
}
impl Runtime {
pub fn new(context: Arc<dyn ToolContext>, store: Arc<dyn ToolStore>) -> Self {
Self {
context,
store,
stream_writer: None,
}
}
pub fn with_stream_writer(mut self, stream_writer: Arc<dyn StreamWriter>) -> Self {
self.stream_writer = Some(stream_writer);
self
}
pub fn context(&self) -> &dyn ToolContext {
self.context.as_ref()
}
pub fn store(&self) -> &dyn ToolStore {
self.store.as_ref()
}
pub fn stream_writer(&self) -> Option<&Arc<dyn StreamWriter>> {
self.stream_writer.as_ref()
}
}
pub struct RuntimeRequest {
pub input: PromptArgs,
pub state: Arc<Mutex<AgentState>>,
pub runtime: Option<Arc<Runtime>>,
}
impl RuntimeRequest {
pub fn new(input: PromptArgs, state: Arc<Mutex<AgentState>>) -> Self {
Self {
input,
state,
runtime: None,
}
}
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime = Some(runtime);
self
}
pub fn runtime(&self) -> Option<&Arc<Runtime>> {
self.runtime.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::{EmptyContext, InMemoryStore};
#[test]
fn test_runtime_creation() {
let context = Arc::new(EmptyContext);
let store = Arc::new(InMemoryStore::new());
let runtime = Runtime::new(context, store);
assert!(runtime.stream_writer().is_none());
}
#[test]
fn test_runtime_request_creation() {
let state = Arc::new(Mutex::new(AgentState::new()));
let input = PromptArgs::new();
let request = RuntimeRequest::new(input, state);
assert!(request.runtime().is_none());
}
#[test]
fn test_runtime_request_with_runtime() {
let state = Arc::new(Mutex::new(AgentState::new()));
let input = PromptArgs::new();
let context = Arc::new(EmptyContext);
let store = Arc::new(InMemoryStore::new());
let runtime = Arc::new(Runtime::new(context, store));
let request = RuntimeRequest::new(input, state).with_runtime(runtime);
assert!(request.runtime().is_some());
}
}