use async_trait::async_trait;
use std::sync::Arc;
use chrono::Utc;
use klieo_core::agent::Agent;
use klieo_core::checkpoint::{ApprovalDecision, RunCheckpoint};
use klieo_core::error::{Error as KlieoError, ToolError};
use klieo_core::ids::ThreadId;
use klieo_core::llm::{Message, Role};
use klieo_core::runtime::{resume_from_checkpoint, RunOptions};
use klieo_core::tool::{ToolCtx, ToolInvoker};
use klieo_core::ToolDef;
use klieo_hitl::{run_with_hitl, HitlConfig};
use klieo_hitl_client::HitlClient;
use crate::resume_ticket::{ResumeTicketRecord, ResumeTicketStore};
use crate::AgentContextFactory;
const SUSPENDED_REASON_WIRE: &str = "workflow suspended for human review";
#[derive(Clone)]
pub(crate) struct HitlBundle {
pub(crate) client: Arc<HitlClient>,
pub(crate) cfg: Arc<HitlConfig>,
}
pub(crate) struct WorkflowMaterialisation {
pub(crate) invoker: Arc<dyn ToolInvoker>,
pub(crate) resume_handle: Arc<dyn WorkflowResumeHandle>,
pub(crate) name: String,
}
pub(crate) type WorkflowMaterialiser = Box<
dyn FnOnce(
HitlBundle,
Option<Arc<ResumeTicketStore>>,
Option<crate::GovernorBundleHolder>,
) -> WorkflowMaterialisation
+ Send
+ 'static,
>;
pub(crate) struct WorkflowRegistration {
pub(crate) materialise: WorkflowMaterialiser,
}
#[async_trait]
pub(crate) trait WorkflowResumeHandle: Send + Sync {
#[cfg_attr(not(feature = "http"), allow(dead_code))]
async fn resume(
&self,
checkpoint: RunCheckpoint,
decision: ApprovalDecision,
tenant_label: String,
) -> Result<serde_json::Value, ToolError>;
}
pub(crate) struct WorkflowAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
pub(crate) name: String,
pub(crate) system_prompt: String,
pub(crate) input_schema: serde_json::Value,
pub(crate) ctx_factory: AgentContextFactory,
pub(crate) run_options: RunOptions,
pub(crate) hitl: HitlBundle,
pub(crate) ticket_store: Option<Arc<ResumeTicketStore>>,
#[cfg(feature = "governor")]
pub(crate) governor: Option<crate::governor::GovernorBundle>,
_agent: std::marker::PhantomData<fn() -> A>,
}
impl<A> WorkflowAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
name: String,
system_prompt: String,
input_schema: serde_json::Value,
ctx_factory: AgentContextFactory,
run_options: RunOptions,
hitl: HitlBundle,
ticket_store: Option<Arc<ResumeTicketStore>>,
#[cfg(feature = "governor")] governor: Option<crate::governor::GovernorBundle>,
) -> Self {
Self {
name,
system_prompt,
input_schema,
ctx_factory,
run_options,
hitl,
ticket_store,
#[cfg(feature = "governor")]
governor,
_agent: std::marker::PhantomData,
}
}
async fn issue_resume_ticket(
&self,
principal: String,
checkpoint: RunCheckpoint,
) -> serde_json::Value {
let Some(store) = self.ticket_store.as_ref() else {
return suspended_no_ticket();
};
let token = ResumeTicketStore::mint_token();
let record = ResumeTicketRecord {
principal,
workflow_name: self.name.clone(),
checkpoint,
created_at: Utc::now(),
};
if let Err(err) = store.persist(&token, &record).await {
tracing::warn!(
workflow = %self.name,
error = %err,
"persist of resume ticket failed; falling back to no-ticket envelope",
);
return suspended_no_ticket();
}
serde_json::json!({
"status": "suspended",
"ticket": token,
"reason": SUSPENDED_REASON_WIRE,
})
}
}
fn suspended_no_ticket() -> serde_json::Value {
serde_json::json!({
"status": "suspended",
"reason": SUSPENDED_REASON_WIRE,
})
}
#[async_trait]
impl<A> ToolInvoker for WorkflowAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
fn catalogue(&self) -> Vec<ToolDef> {
vec![ToolDef::new(
self.name.clone(),
format!("klieo workflow: {}", self.name),
self.input_schema.clone(),
)]
}
async fn invoke(
&self,
name: &str,
args: serde_json::Value,
tool_ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
if name != self.name {
return Err(ToolError::UnknownTool(name.into()));
}
let _decoded: A::Input = serde_json::from_value(args.clone()).map_err(|err| {
tracing::warn!(workflow = %self.name, error = %err, "decode of MCP tools/call args failed");
ToolError::InvalidArgs("arguments do not match inputSchema".into())
})?;
let caller_principal = tool_ctx.caller_principal.clone();
let mut ctx = (self.ctx_factory)();
ctx.cancel = tool_ctx.cancel.child_token();
ctx.progress = tool_ctx.progress.clone();
if let Some(principal) = caller_principal.as_ref() {
ctx = ctx.with_tenant_label(klieo_core::principal_hash(principal.as_str()));
}
if let Some(anchor) = tool_ctx.parent_anchor.as_ref() {
ctx = ctx.with_parent_anchor(anchor.as_str().to_string());
}
#[cfg(feature = "governor")]
if let Some(bundle) = self.governor.as_ref() {
ctx = crate::governor::wrap_ctx_with_governor(ctx, bundle);
}
let thread = ThreadId::new(format!("{}:{}", self.name, ctx.run_id));
seed_user_message(&ctx, &thread, &args, &self.name).await?;
let result = run_with_hitl(
&ctx,
&self.system_prompt,
thread,
self.run_options.clone(),
&self.hitl.client,
&self.hitl.cfg,
)
.await;
self.map_hitl_result(
caller_principal.map(|principal| principal.as_str().to_string()),
result,
)
.await
}
}
#[async_trait]
impl<A> WorkflowResumeHandle for WorkflowAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
async fn resume(
&self,
checkpoint: RunCheckpoint,
decision: ApprovalDecision,
tenant_label: String,
) -> Result<serde_json::Value, ToolError> {
let ctx = (self.ctx_factory)().with_tenant_label(tenant_label);
#[cfg(feature = "governor")]
let ctx = match self.governor.as_ref() {
Some(bundle) => crate::governor::wrap_ctx_with_governor(ctx, bundle),
None => ctx,
};
let outcome =
resume_from_checkpoint(&ctx, &self.system_prompt, checkpoint, decision, self.run_options.clone()).await;
match outcome {
Ok(text) => Ok(serde_json::Value::String(text)),
Err(err) => {
tracing::warn!(
workflow = %self.name,
error = %err,
"workflow resume failed",
);
Err(ToolError::Permanent("workflow resume failed".into()))
}
}
}
}
impl<A> WorkflowAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
async fn map_hitl_result(
&self,
caller_principal: Option<String>,
result: Result<String, KlieoError>,
) -> Result<serde_json::Value, ToolError> {
match result {
Ok(text) => Ok(serde_json::Value::String(text)),
Err(KlieoError::Suspended { reason, checkpoint }) => {
tracing::info!(
workflow = %self.name,
policy_reason = %reason,
"workflow suspended on ReviewPolicy; not echoed to peer",
);
match (self.ticket_store.as_ref(), caller_principal) {
(Some(_), Some(principal)) => {
Ok(self.issue_resume_ticket(principal, *checkpoint).await)
}
_ => {
tracing::warn!(
workflow = %self.name,
"resume unavailable (no checkpoint KV / no caller principal)",
);
Ok(suspended_no_ticket())
}
}
}
Err(other) => {
tracing::warn!(workflow = %self.name, error = %other, "workflow execution failed");
Err(ToolError::Permanent("workflow execution failed".into()))
}
}
}
}
async fn seed_user_message(
ctx: &klieo_core::AgentContext,
thread: &ThreadId,
args: &serde_json::Value,
workflow: &str,
) -> Result<(), ToolError> {
let body = args.to_string();
ctx.short_term
.append(
thread.clone(),
Message {
role: Role::User,
content: body,
tool_calls: vec![],
tool_call_id: None,
},
)
.await
.map_err(|err| {
tracing::warn!(workflow = %workflow, error = %err, "seed user message failed");
ToolError::Permanent("workflow input persistence failed".into())
})
}