use std::future::Future;
use std::pin::Pin;
use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
use zeph_tools::ToolError;
use zeph_tools::executor::{ToolCall, ToolOutput};
pub type BeforeToolResult = Option<Result<Option<ToolOutput>, ToolError>>;
#[derive(Debug)]
pub struct LayerContext<'a> {
pub conversation_id: Option<&'a str>,
pub turn_number: u32,
}
pub trait RuntimeLayer: Send + Sync {
fn before_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_messages: &'a [Message],
_tools: &'a [ToolDefinition],
) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
Box::pin(std::future::ready(None))
}
fn after_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_response: &'a ChatResponse,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(std::future::ready(()))
}
fn before_tool<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_call: &'a ToolCall,
) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
Box::pin(std::future::ready(None))
}
fn after_tool<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_call: &'a ToolCall,
_result: &'a Result<Option<ToolOutput>, ToolError>,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(std::future::ready(()))
}
}
pub struct NoopLayer;
impl RuntimeLayer for NoopLayer {}
#[cfg(test)]
mod tests {
use super::*;
use zeph_llm::provider::Role;
struct CountingLayer {
before_chat_calls: std::sync::atomic::AtomicU32,
after_chat_calls: std::sync::atomic::AtomicU32,
}
impl CountingLayer {
fn new() -> Self {
Self {
before_chat_calls: std::sync::atomic::AtomicU32::new(0),
after_chat_calls: std::sync::atomic::AtomicU32::new(0),
}
}
}
impl RuntimeLayer for CountingLayer {
fn before_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_messages: &'a [Message],
_tools: &'a [ToolDefinition],
) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
self.before_chat_calls
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Box::pin(std::future::ready(None))
}
fn after_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_response: &'a ChatResponse,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.after_chat_calls
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Box::pin(std::future::ready(()))
}
}
#[test]
fn noop_layer_compiles_and_is_runtime_layer() {
fn assert_runtime_layer<T: RuntimeLayer>() {}
assert_runtime_layer::<NoopLayer>();
}
#[tokio::test]
async fn noop_layer_before_chat_returns_none() {
let layer = NoopLayer;
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let result = layer.before_chat(&ctx, &[], &[]).await;
assert!(result.is_none());
}
#[tokio::test]
async fn noop_layer_before_tool_returns_none() {
let layer = NoopLayer;
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let call = ToolCall {
tool_id: "shell".into(),
params: serde_json::Map::new(),
caller_id: None,
};
let result = layer.before_tool(&ctx, &call).await;
assert!(result.is_none());
}
#[tokio::test]
async fn layer_hooks_are_called() {
use std::sync::Arc;
let layer = Arc::new(CountingLayer::new());
let ctx = LayerContext {
conversation_id: Some("conv-1"),
turn_number: 3,
};
let resp = ChatResponse::Text("hello".into());
let _ = layer.before_chat(&ctx, &[], &[]).await;
layer.after_chat(&ctx, &resp).await;
assert_eq!(
layer
.before_chat_calls
.load(std::sync::atomic::Ordering::Relaxed),
1
);
assert_eq!(
layer
.after_chat_calls
.load(std::sync::atomic::Ordering::Relaxed),
1
);
}
#[tokio::test]
async fn short_circuit_layer_returns_response() {
struct ShortCircuitLayer;
impl RuntimeLayer for ShortCircuitLayer {
fn before_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_messages: &'a [Message],
_tools: &'a [ToolDefinition],
) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
Box::pin(std::future::ready(Some(ChatResponse::Text(
"short-circuited".into(),
))))
}
}
let layer = ShortCircuitLayer;
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let result = layer.before_chat(&ctx, &[], &[]).await;
assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
}
#[test]
fn message_from_legacy_compiles() {
let _msg = Message::from_legacy(Role::User, "hello");
}
#[tokio::test]
async fn multiple_layers_called_in_registration_order() {
use std::sync::{Arc, Mutex};
struct OrderLayer {
id: u32,
log: Arc<Mutex<Vec<String>>>,
}
impl RuntimeLayer for OrderLayer {
fn before_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_messages: &'a [Message],
_tools: &'a [ToolDefinition],
) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
let entry = format!("before_{}", self.id);
self.log.lock().unwrap().push(entry);
Box::pin(std::future::ready(None))
}
fn after_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_response: &'a ChatResponse,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let entry = format!("after_{}", self.id);
self.log.lock().unwrap().push(entry);
Box::pin(std::future::ready(()))
}
}
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let layer_a = OrderLayer {
id: 1,
log: Arc::clone(&log),
};
let layer_b = OrderLayer {
id: 2,
log: Arc::clone(&log),
};
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let resp = ChatResponse::Text("ok".into());
layer_a.before_chat(&ctx, &[], &[]).await;
layer_b.before_chat(&ctx, &[], &[]).await;
layer_a.after_chat(&ctx, &resp).await;
layer_b.after_chat(&ctx, &resp).await;
let events = log.lock().unwrap().clone();
assert_eq!(
events,
vec!["before_1", "before_2", "after_1", "after_2"],
"hooks must fire in registration order"
);
}
#[tokio::test]
async fn after_chat_receives_short_circuit_response() {
use std::sync::{Arc, Mutex};
struct CapturingAfter {
captured: Arc<Mutex<Option<String>>>,
}
impl RuntimeLayer for CapturingAfter {
fn after_chat<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
response: &'a ChatResponse,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
if let ChatResponse::Text(t) = response {
*self.captured.lock().unwrap() = Some(t.clone());
}
Box::pin(std::future::ready(()))
}
}
let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let layer = CapturingAfter {
captured: Arc::clone(&captured),
};
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let sc_response = ChatResponse::Text("short-circuit".into());
layer.after_chat(&ctx, &sc_response).await;
let got = captured.lock().unwrap().clone();
assert_eq!(
got.as_deref(),
Some("short-circuit"),
"after_chat must receive the short-circuit response"
);
}
#[tokio::test]
async fn multi_layer_before_after_tool_ordering() {
use std::sync::{Arc, Mutex};
struct ToolOrderLayer {
id: u32,
log: Arc<Mutex<Vec<String>>>,
}
impl RuntimeLayer for ToolOrderLayer {
fn before_tool<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_call: &'a ToolCall,
) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
self.log
.lock()
.unwrap()
.push(format!("before_tool_{}", self.id));
Box::pin(std::future::ready(None))
}
fn after_tool<'a>(
&'a self,
_ctx: &'a LayerContext<'_>,
_call: &'a ToolCall,
_result: &'a Result<Option<ToolOutput>, ToolError>,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.log
.lock()
.unwrap()
.push(format!("after_tool_{}", self.id));
Box::pin(std::future::ready(()))
}
}
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let layer_a = ToolOrderLayer {
id: 1,
log: Arc::clone(&log),
};
let layer_b = ToolOrderLayer {
id: 2,
log: Arc::clone(&log),
};
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let call = ToolCall {
tool_id: "shell".into(),
params: serde_json::Map::new(),
caller_id: None,
};
let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
layer_a.before_tool(&ctx, &call).await;
layer_b.before_tool(&ctx, &call).await;
layer_a.after_tool(&ctx, &call, &result).await;
layer_b.after_tool(&ctx, &call, &result).await;
let events = log.lock().unwrap().clone();
assert_eq!(
events,
vec![
"before_tool_1",
"before_tool_2",
"after_tool_1",
"after_tool_2"
],
"tool hooks must fire in registration order"
);
}
#[tokio::test]
async fn noop_layer_after_tool_returns_unit() {
use zeph_tools::executor::ToolOutput;
let layer = NoopLayer;
let ctx = LayerContext {
conversation_id: None,
turn_number: 0,
};
let call = ToolCall {
tool_id: "shell".into(),
params: serde_json::Map::new(),
caller_id: None,
};
let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
layer.after_tool(&ctx, &call, &result).await;
}
}