use async_trait::async_trait;
use futures::future::BoxFuture;
use oharness_core::{
CompletionRequest, CompletionResponse, Content, LlmCapabilities, ModelId, StopReason, Task,
Usage,
};
use oharness_llm::{ChunkStream, FullLayer, Llm, LlmError, LlmExt, RequestLayer, ResponseLayer};
use oharness_loop::{Agent, ReactLoop};
use oharness_tools::fs::FsToolSet;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
struct RequestIdStamp {
counter: AtomicU64,
}
impl RequestIdStamp {
fn new() -> Self {
Self {
counter: AtomicU64::new(0),
}
}
}
impl RequestLayer for RequestIdStamp {
fn on_request(&self, req: &mut CompletionRequest) {
let id = self.counter.fetch_add(1, Ordering::SeqCst);
req.extensions.insert(
"example.request_id".into(),
serde_json::json!(format!("req-{id}")),
);
eprintln!("[RequestIdStamp] stamped req-{id}");
}
}
struct RedactSecrets;
impl ResponseLayer for RedactSecrets {
fn on_response(&self, res: &mut CompletionResponse) {
for block in res.content.iter_mut() {
if let Content::Text { text } = block {
if text.contains("sk-live-") {
*text = text.replace("sk-live-", "sk-live-REDACTED-");
eprintln!("[RedactSecrets] redacted secret in response");
}
}
}
}
}
struct Timer;
#[async_trait]
impl FullLayer for Timer {
async fn around_complete<'a>(
&'a self,
_req: CompletionRequest,
call: BoxFuture<'a, Result<CompletionResponse, LlmError>>,
) -> Result<CompletionResponse, LlmError> {
let start = Instant::now();
eprintln!("[Timer] complete() starting");
let result = call.await;
eprintln!("[Timer] complete() finished in {:?}", start.elapsed());
result
}
async fn around_stream<'a>(
&'a self,
_req: CompletionRequest,
call: BoxFuture<'a, Result<ChunkStream, LlmError>>,
) -> Result<ChunkStream, LlmError> {
call.await
}
}
struct ScriptedLlm;
#[async_trait]
impl Llm for ScriptedLlm {
fn name(&self) -> &str {
"scripted"
}
fn capabilities(&self) -> LlmCapabilities {
LlmCapabilities::default()
}
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
Ok(CompletionResponse {
id: "msg_1".into(),
model: ModelId::new("middleware-example"),
content: vec![Content::text(
"All set. My API key is sk-live-1234567890abc — please keep it safe.",
)],
stop_reason: StopReason::EndTurn,
usage: Usage {
tokens_input: 7,
tokens_output: 20,
..Default::default()
},
})
}
async fn stream(&self, _req: CompletionRequest) -> Result<ChunkStream, LlmError> {
Err(LlmError::Unsupported("stream"))
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let llm = ScriptedLlm
.with_request_layer(RequestIdStamp::new())
.with_response_layer(RedactSecrets)
.with_full_layer(Timer);
let agent = Agent::builder()
.with_llm(Arc::new(llm))
.with_tools(Arc::new(FsToolSet::new()))
.with_loop(Box::new(ReactLoop::new()))
.with_max_turns(2)
.build()?;
let outcome = agent.run(Task::new("test middleware composition")).await?;
if let Some(oharness_core::Message::Assistant { content, .. }) = outcome.final_messages.last() {
for c in content {
if let Content::Text { text } = c {
println!("Assistant (post-redaction): {text}");
}
}
}
println!("Termination: {:?}", outcome.termination);
Ok(())
}