use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::kernel::event::Event;
use crate::kernel::state::KernelState;
use crate::kernel::step::{Next, StepFn};
use crate::kernel::KernelError;
use crate::prompt::PromptArgs;
use super::unified_agent::{AgentInput, UnifiedAgent};
use crate::graph::RunnableConfig;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct AgentStepState {
#[serde(default)]
pub prompt_args: PromptArgs,
#[serde(default)]
pub last_output: Option<String>,
}
impl AgentStepState {
pub fn new(prompt_args: PromptArgs) -> Self {
Self {
prompt_args,
last_output: None,
}
}
}
impl KernelState for AgentStepState {
fn version(&self) -> u32 {
1
}
}
pub struct AgentStepFnAdapter {
pub agent: Arc<UnifiedAgent>,
pub config: RunnableConfig,
}
impl AgentStepFnAdapter {
pub fn new(agent: Arc<UnifiedAgent>, config: RunnableConfig) -> Self {
Self { agent, config }
}
}
impl StepFn<AgentStepState> for AgentStepFnAdapter {
fn next(&self, state: &AgentStepState) -> Result<Next, KernelError> {
let handle = tokio::runtime::Handle::try_current().map_err(|_| {
KernelError::Driver(
"Tokio runtime required: call from a thread with an entered runtime (e.g. after Runtime::new() and rt.enter()), or use block_in_place from an async task. Do not call from inside an async task without block_in_place.".into(),
)
})?;
let agent = Arc::clone(&self.agent);
let config = self.config.clone();
let prompt_args = state.prompt_args.clone();
let result = handle.block_on(async move {
agent
.invoke_with_config(AgentInput::State(prompt_args), &config)
.await
});
match result {
Ok(super::AgentInvokeResult::Complete(output)) => {
let payload = serde_json::json!({ "output": output });
Ok(Next::Emit(vec![Event::StateUpdated {
step_id: Some("agent".to_string()),
payload,
}]))
}
Ok(super::AgentInvokeResult::Interrupt { interrupt_value }) => {
Ok(Next::Interrupt(crate::kernel::step::InterruptInfo {
value: interrupt_value,
}))
}
Err(e) => Err(KernelError::Driver(e.to_string())),
}
}
}