use async_trait::async_trait;
use oris_runtime::agent::{create_agent, Middleware, MiddlewareContext, MiddlewareError};
use oris_runtime::prompt::PromptArgs;
use oris_runtime::schemas::agent::AgentAction;
use oris_runtime::schemas::Message;
use std::sync::Arc;
struct LengthGuardrail {
max_length: usize,
}
impl LengthGuardrail {
fn new(max_length: usize) -> Self {
Self { max_length }
}
}
#[async_trait]
impl Middleware for LengthGuardrail {
async fn before_agent_plan(
&self,
input: &PromptArgs,
_steps: &[(AgentAction, String)],
_context: &mut MiddlewareContext,
) -> Result<Option<PromptArgs>, MiddlewareError> {
let input_text = input.get("input").and_then(|v| v.as_str()).unwrap_or("");
if input_text.len() > self.max_length {
return Err(MiddlewareError::ValidationError(format!(
"Input too long: {} characters (max: {})",
input_text.len(),
self.max_length
)));
}
Ok(None)
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
let length_guardrail = LengthGuardrail::new(1000);
let agent = create_agent(
"gpt-4o-mini",
&[],
Some("You are a helpful assistant."),
Some(vec![Arc::new(length_guardrail)]),
)?;
println!("Testing with normal input...");
let result = agent
.invoke_messages(vec![Message::new_human_message("Hello, how are you?")])
.await?;
println!("Result: {}", result);
println!("\nTesting with overly long input...");
let long_input = "x".repeat(2000);
let result = agent
.invoke_messages(vec![Message::new_human_message(long_input)])
.await;
match result {
Ok(_) => println!("Unexpected: Request was not blocked"),
Err(e) => println!("Expected error: {}", e),
}
Ok(())
}