use super::{AgentId, AgentTask, AgentResult};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct AgentContext {
pub agent_id: AgentId,
pub task: AgentTask,
shared_state: Arc<RwLock<SharedState>>,
}
#[derive(Default)]
struct SharedState {
intermediate_results: Vec<AgentResult>,
checkpoint_data: serde_json::Value,
}
impl AgentContext {
pub fn new(agent_id: AgentId, task: AgentTask) -> Self {
Self {
agent_id,
task,
shared_state: Arc::new(RwLock::new(SharedState::default())),
}
}
pub async fn store_intermediate(&self, result: AgentResult) {
let mut state = self.shared_state.write().await;
state.intermediate_results.push(result);
}
pub async fn get_intermediates(&self) -> Vec<AgentResult> {
let state = self.shared_state.read().await;
state.intermediate_results.clone()
}
pub async fn checkpoint(&self, data: impl serde::Serialize) -> anyhow::Result<()> {
let mut state = self.shared_state.write().await;
state.checkpoint_data = serde_json::to_value(data)?;
Ok(())
}
pub async fn load_checkpoint<T: serde::de::DeserializeOwned>(&self) -> anyhow::Result<Option<T>> {
let state = self.shared_state.read().await;
if state.checkpoint_data.is_null() {
Ok(None)
} else {
Ok(Some(serde_json::from_value(state.checkpoint_data.clone())?))
}
}
}