use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use tokio::sync::{RwLock, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::engine::approval::{ApprovalDecision, UserInputDecision};
use crate::engine::context_snapshot::ThreadContextSnapshot;
use crate::engine::op::Op;
use crate::engine::start_turn::StartTurnParams;
use crate::engine::turn_port::TurnEnginePort;
use crate::events::Event;
use crate::turn::TurnLoopMode;
#[derive(Clone)]
pub struct EngineHandle<P, R> {
pub tx_op: mpsc::Sender<Op>,
pub rx_event: Arc<RwLock<mpsc::Receiver<Event>>>,
cancel_token: Arc<StdMutex<CancellationToken>>,
tx_approval: mpsc::Sender<ApprovalDecision<P>>,
tx_user_input: mpsc::Sender<UserInputDecision<R>>,
tx_steer: mpsc::Sender<String>,
}
impl<P, R> EngineHandle<P, R>
where
P: Send + Sync + 'static,
R: Send + Sync + 'static,
{
#[must_use]
pub fn new(
tx_op: mpsc::Sender<Op>,
rx_event: Arc<RwLock<mpsc::Receiver<Event>>>,
cancel_token: Arc<StdMutex<CancellationToken>>,
tx_approval: mpsc::Sender<ApprovalDecision<P>>,
tx_user_input: mpsc::Sender<UserInputDecision<R>>,
tx_steer: mpsc::Sender<String>,
) -> Self {
Self {
tx_op,
rx_event,
cancel_token,
tx_approval,
tx_user_input,
tx_steer,
}
}
pub async fn send(&self, op: Op) -> Result<()> {
self.tx_op.send(op).await?;
Ok(())
}
pub fn cancel(&self) {
match self.cancel_token.lock() {
Ok(token) => token.cancel(),
Err(poisoned) => poisoned.into_inner().cancel(),
}
}
#[must_use]
#[allow(dead_code)]
pub fn is_cancelled(&self) -> bool {
match self.cancel_token.lock() {
Ok(token) => token.is_cancelled(),
Err(poisoned) => poisoned.into_inner().is_cancelled(),
}
}
pub async fn approve_tool_call(&self, id: impl Into<String>) -> Result<()> {
self.approve_tool_call_with_options(id, None, false).await
}
pub async fn approve_tool_call_with_options(
&self,
id: impl Into<String>,
cache_key: Option<String>,
remember_for_session: bool,
) -> Result<()> {
self.tx_approval
.send(ApprovalDecision::Approved {
id: id.into(),
cache_key,
remember_for_session,
})
.await?;
Ok(())
}
pub async fn deny_tool_call(&self, id: impl Into<String>) -> Result<()> {
self.tx_approval
.send(ApprovalDecision::Denied { id: id.into() })
.await?;
Ok(())
}
pub async fn retry_tool_with_policy(&self, id: impl Into<String>, policy: P) -> Result<()> {
self.tx_approval
.send(ApprovalDecision::RetryWithPolicy {
id: id.into(),
policy,
})
.await?;
Ok(())
}
pub async fn submit_user_input(&self, id: impl Into<String>, response: R) -> Result<()> {
self.tx_user_input
.send(UserInputDecision::Submitted {
id: id.into(),
response,
})
.await?;
Ok(())
}
pub async fn cancel_user_input(&self, id: impl Into<String>) -> Result<()> {
self.tx_user_input
.send(UserInputDecision::Cancelled { id: id.into() })
.await?;
Ok(())
}
pub async fn steer(&self, content: impl Into<String>) -> Result<()> {
self.tx_steer.send(content.into()).await?;
Ok(())
}
pub async fn query_context_snapshot(&self) -> Result<ThreadContextSnapshot> {
let (tx, rx) = oneshot::channel();
self.send(Op::QueryContext { reply: tx }).await?;
tokio::time::timeout(Duration::from_secs(5), rx)
.await
.map_err(|_| anyhow::anyhow!("context query timed out"))?
.map_err(|_| anyhow::anyhow!("engine dropped context query"))
}
pub async fn query_harness_task_graph(&self) -> Result<serde_json::Value> {
let (tx, rx) = oneshot::channel();
self.send(Op::QueryHarnessTaskGraph { reply: tx }).await?;
tokio::time::timeout(Duration::from_secs(5), rx)
.await
.map_err(|_| anyhow::anyhow!("harness task-graph query timed out"))?
.map_err(|_| anyhow::anyhow!("engine dropped harness task-graph query"))
}
pub async fn query_harness_cycles(&self) -> Result<serde_json::Value> {
let (tx, rx) = oneshot::channel();
self.send(Op::QueryHarnessCycles { reply: tx }).await?;
tokio::time::timeout(Duration::from_secs(5), rx)
.await
.map_err(|_| anyhow::anyhow!("harness cycles query timed out"))?
.map_err(|_| anyhow::anyhow!("engine dropped harness cycles query"))
}
pub async fn truncate_before_last_user_message(&self) -> Result<bool> {
let (tx, rx) = oneshot::channel();
self.send(Op::TruncateBeforeLastUserMessage { reply: tx })
.await?;
rx.await
.map_err(|_| anyhow::anyhow!("engine dropped truncate-before-last-user reply"))
}
}
#[async_trait]
impl<P, R> TurnEnginePort for EngineHandle<P, R>
where
P: Send + Sync + 'static,
R: Send + Sync + 'static,
{
async fn start_turn(&self, params: StartTurnParams) -> Result<()> {
params.validate().map_err(anyhow::Error::msg)?;
self.send(Op::SendMessage {
content: params.prompt,
mode: TurnLoopMode::from_setting(¶ms.mode),
model: params.model,
goal_objective: None,
reasoning_effort: params.reasoning_effort,
reasoning_effort_auto: params.reasoning_effort_auto,
auto_model: params.auto_model,
allow_shell: params.allow_shell,
trust_mode: params.trust_mode,
auto_approve: params.auto_approve,
approval_mode: params.approval_mode,
temperature: params.temperature,
top_p: params.top_p,
max_output_tokens: params.max_output_tokens,
})
.await
}
fn cancel_active_turn(&self) {
self.cancel();
}
}