use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinHandle;
use crate::error::RuntimeError;
use crate::state::AgentState;
#[derive(Debug)]
pub enum SubgraphResult {
Completed {
subgraph_id: String,
state: AgentState,
},
Failed {
subgraph_id: String,
error: RuntimeError,
},
Cancelled {
subgraph_id: String,
},
}
impl SubgraphResult {
pub fn is_completed(&self) -> bool {
matches!(self, SubgraphResult::Completed { .. })
}
pub fn is_failed(&self) -> bool {
matches!(self, SubgraphResult::Failed { .. })
}
pub fn subgraph_id(&self) -> &str {
match self {
SubgraphResult::Completed { subgraph_id, .. } => subgraph_id,
SubgraphResult::Failed { subgraph_id, .. } => subgraph_id,
SubgraphResult::Cancelled { subgraph_id } => subgraph_id,
}
}
pub fn state(&self) -> Option<&AgentState> {
match self {
SubgraphResult::Completed { state, .. } => Some(state),
_ => None,
}
}
pub fn into_state(self) -> Result<AgentState, RuntimeError> {
match self {
SubgraphResult::Completed { state, .. } => Ok(state),
SubgraphResult::Failed { error, .. } => Err(error),
SubgraphResult::Cancelled { subgraph_id } => {
Err(RuntimeError::InvalidState(format!(
"Subgraph '{}' was cancelled",
subgraph_id
)))
}
}
}
pub fn error(&self) -> Option<&RuntimeError> {
match self {
SubgraphResult::Failed { error, .. } => Some(error),
_ => None,
}
}
}
pub struct SubgraphHandle {
pub subgraph_id: String,
handle: JoinHandle<SubgraphResult>,
}
impl SubgraphHandle {
pub(crate) fn new(subgraph_id: String, handle: JoinHandle<SubgraphResult>) -> Self {
Self { subgraph_id, handle }
}
pub fn id(&self) -> &str {
&self.subgraph_id
}
pub fn is_finished(&self) -> bool {
self.handle.is_finished()
}
pub fn abort(&self) {
self.handle.abort();
}
pub async fn join(self) -> SubgraphResult {
match self.handle.await {
Ok(result) => result,
Err(e) if e.is_cancelled() => SubgraphResult::Cancelled {
subgraph_id: self.subgraph_id,
},
Err(e) => SubgraphResult::Failed {
subgraph_id: self.subgraph_id,
error: RuntimeError::InvalidState(format!("Subgraph task panicked: {}", e)),
},
}
}
}
impl Future for SubgraphHandle {
type Output = SubgraphResult;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.handle).poll(cx) {
Poll::Ready(Ok(result)) => Poll::Ready(result),
Poll::Ready(Err(e)) if e.is_cancelled() => Poll::Ready(SubgraphResult::Cancelled {
subgraph_id: self.subgraph_id.clone(),
}),
Poll::Ready(Err(e)) => Poll::Ready(SubgraphResult::Failed {
subgraph_id: self.subgraph_id.clone(),
error: RuntimeError::InvalidState(format!("Subgraph task panicked: {}", e)),
}),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subgraph_result_completed() {
let result = SubgraphResult::Completed {
subgraph_id: "test".to_string(),
state: AgentState::new(),
};
assert!(result.is_completed());
assert!(!result.is_failed());
assert_eq!(result.subgraph_id(), "test");
assert!(result.state().is_some());
}
#[test]
fn test_subgraph_result_failed() {
let result = SubgraphResult::Failed {
subgraph_id: "test".to_string(),
error: RuntimeError::RecursionLimit(100),
};
assert!(!result.is_completed());
assert!(result.is_failed());
assert!(result.error().is_some());
}
#[tokio::test]
async fn test_subgraph_handle() {
let handle = tokio::spawn(async {
SubgraphResult::Completed {
subgraph_id: "test".to_string(),
state: AgentState::new(),
}
});
let subgraph_handle = SubgraphHandle::new("test".to_string(), handle);
let result = subgraph_handle.await;
assert!(result.is_completed());
}
}