use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use futures::future::BoxFuture;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use entelix_core::CurrentToolInvocation;
use entelix_core::LlmRenderable;
use entelix_core::error::{Error, Result};
use entelix_core::service::ToolInvocation;
use crate::agent::event::AgentEvent;
use crate::agent::sink::AgentEventSink;
pub struct ToolEventLayer<S>
where
S: Clone + Send + Sync + 'static,
{
sink: Arc<dyn AgentEventSink<S>>,
}
impl<S> ToolEventLayer<S>
where
S: Clone + Send + Sync + 'static,
{
pub const NAME: &'static str = "tool_event";
#[must_use]
pub fn new(sink: Arc<dyn AgentEventSink<S>>) -> Self {
Self { sink }
}
}
impl<S> Clone for ToolEventLayer<S>
where
S: Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
sink: Arc::clone(&self.sink),
}
}
}
impl<S> std::fmt::Debug for ToolEventLayer<S>
where
S: Clone + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolEventLayer").finish_non_exhaustive()
}
}
impl<S, Inner> Layer<Inner> for ToolEventLayer<S>
where
S: Clone + Send + Sync + 'static,
{
type Service = ToolEventService<S, Inner>;
fn layer(&self, inner: Inner) -> Self::Service {
ToolEventService {
inner,
sink: Arc::clone(&self.sink),
}
}
}
impl<S> entelix_core::NamedLayer for ToolEventLayer<S>
where
S: Clone + Send + Sync + 'static,
{
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
pub struct ToolEventService<S, Inner>
where
S: Clone + Send + Sync + 'static,
{
inner: Inner,
sink: Arc<dyn AgentEventSink<S>>,
}
impl<S, Inner: Clone> Clone for ToolEventService<S, Inner>
where
S: Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
sink: Arc::clone(&self.sink),
}
}
}
impl<S, Inner> Service<ToolInvocation> for ToolEventService<S, Inner>
where
S: Clone + Send + Sync + 'static,
Inner: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
Inner::Future: Send + 'static,
{
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut invocation: ToolInvocation) -> Self::Future {
let inner = self.inner.clone();
let sink = Arc::clone(&self.sink);
Box::pin(async move {
let run_id = invocation.ctx.run_id().map(str::to_owned);
let tenant_id = invocation.ctx.tenant_id().clone();
let tool = invocation.metadata.name.clone();
let tool_version = invocation.metadata.version.clone();
let tool_use_id = invocation.tool_use_id.clone();
let input = invocation.input.clone();
if let Some(rid) = &run_id {
let _ = sink
.send(AgentEvent::ToolStart {
run_id: rid.clone(),
tenant_id: tenant_id.clone(),
tool_use_id: tool_use_id.clone(),
tool: tool.clone(),
tool_version: tool_version.clone(),
input,
})
.await;
}
let marker_use_id = if tool_use_id.is_empty() {
tool.clone()
} else {
tool_use_id.clone()
};
if let Ok(marker) = CurrentToolInvocation::new(marker_use_id, tool.clone()) {
invocation.ctx = invocation.ctx.clone().add_extension(marker);
}
let started_at = Instant::now();
let result = inner.oneshot(invocation).await;
let duration_ms = u64::try_from(started_at.elapsed().as_millis()).unwrap_or(u64::MAX);
match (&result, run_id) {
(Ok(output), Some(rid)) => {
let _ = sink
.send(AgentEvent::ToolComplete {
run_id: rid,
tenant_id,
tool_use_id,
tool,
tool_version,
duration_ms,
output: output.clone(),
})
.await;
}
(Err(err), Some(rid)) => {
let envelope = err.envelope();
let _ = sink
.send(AgentEvent::ToolError {
run_id: rid,
tenant_id,
tool_use_id,
tool,
tool_version,
error: err.to_string(),
error_for_llm: err.for_llm(),
envelope,
duration_ms,
})
.await;
}
_ => {}
}
result
})
}
}