use super::BinaryMiddleware;
use crate::llms::AiModel;
use crate::middleware::{AgentMiddleware, MiddlewareStage, pipeline::MiddlewarePipeline};
use serde::Deserialize;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
fn default_timeout() -> u64 {
30
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct MiddlewareConfig {
#[serde(default)]
pub before_release: Vec<MiddlewareEntry>,
#[serde(default)]
pub on_provider_response: Vec<MiddlewareEntry>,
#[serde(default)]
pub before_prompt: Vec<MiddlewareEntry>,
#[serde(default)]
pub on_completion: Vec<MiddlewareEntry>,
#[serde(skip)]
pub moderation_model: Option<Arc<dyn AiModel>>,
}
impl MiddlewareConfig {
pub fn with_moderation_model(mut self, model: Arc<dyn AiModel>) -> Self {
self.moderation_model = Some(model);
self
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum MiddlewareEntry {
Builtin {
builtin: BuiltinMiddlewareType,
#[serde(default)]
stages: Option<Vec<MiddlewareStage>>,
#[serde(default)]
config: serde_json::Value,
},
Dylib {
dylib: PathBuf,
#[serde(default)]
stages: Option<Vec<MiddlewareStage>>,
},
Binary {
binary: PathBuf,
#[serde(default)]
args: Vec<String>,
#[serde(default = "default_timeout")]
timeout_secs: u64,
#[serde(default)]
stages: Option<Vec<MiddlewareStage>>,
},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BuiltinMiddlewareType {
SignatureVerification,
RuleBased,
LlmModeration,
PromptExposure,
}
impl MiddlewareConfig {
pub fn build_before_release_pipeline(&self) -> MiddlewarePipeline {
self.build_pipeline(
&self.before_release,
&[MiddlewareStage::Edit, MiddlewareStage::Release],
)
}
pub fn build_provider_response_pipeline(&self) -> MiddlewarePipeline {
self.build_pipeline(
&self.on_provider_response,
&[MiddlewareStage::ProviderResponse],
)
}
pub fn build_before_prompt_pipeline(&self) -> MiddlewarePipeline {
self.build_pipeline(&self.before_prompt, &[MiddlewareStage::BeforePrompt])
}
pub fn build_completion_pipeline(&self) -> MiddlewarePipeline {
self.build_pipeline(&self.on_completion, &[MiddlewareStage::Completion])
}
pub fn is_empty(&self) -> bool {
self.before_release.is_empty()
&& self.on_provider_response.is_empty()
&& self.before_prompt.is_empty()
&& self.on_completion.is_empty()
}
fn build_pipeline(
&self,
entries: &[MiddlewareEntry],
default_stages: &[MiddlewareStage],
) -> MiddlewarePipeline {
let middleware: Vec<Box<dyn AgentMiddleware>> = entries
.iter()
.filter_map(|entry| self.build_entry(entry, default_stages))
.collect();
MiddlewarePipeline::new(middleware)
}
fn build_entry(
&self,
entry: &MiddlewareEntry,
default_stages: &[MiddlewareStage],
) -> Option<Box<dyn AgentMiddleware>> {
match entry {
MiddlewareEntry::Builtin {
builtin,
stages,
config,
} => {
let active_stages = stages
.as_ref()
.cloned()
.unwrap_or_else(|| default_stages.to_vec());
match super::builtin::create_builtin_middleware(
builtin,
config,
active_stages,
self.moderation_model.clone(),
) {
Ok(mw) => {
tracing::info!(
builtin_type = ?builtin,
"Loaded builtin middleware"
);
Some(mw)
}
Err(e) => {
tracing::error!(
builtin_type = ?builtin,
error = %e,
"Failed to create builtin middleware"
);
None
}
}
}
MiddlewareEntry::Dylib { dylib, stages } => {
let active_stages = stages
.as_ref()
.cloned()
.unwrap_or_else(|| default_stages.to_vec());
match unsafe { super::DylibMiddleware::load(dylib, active_stages) } {
Ok(mw) => {
tracing::info!(
dylib = ?dylib,
"Loaded dynamic library middleware"
);
Some(Box::new(mw))
}
Err(e) => {
tracing::error!(
dylib = ?dylib,
error = %e,
"Failed to load dynamic library middleware — agent will refuse to start. \
Fix the dylib path or remove from config."
);
None
}
}
}
MiddlewareEntry::Binary {
binary,
args,
timeout_secs,
stages,
} => {
let active_stages = stages
.as_ref()
.cloned()
.unwrap_or_else(|| default_stages.to_vec());
let name = binary
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("binary")
.to_string();
Some(Box::new(BinaryMiddleware {
display_name: name,
path: binary.clone(),
args: args.clone(),
timeout: Duration::from_secs(*timeout_secs),
active_stages,
}))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_deserialize_empty() {
let yaml = "{}";
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
assert!(config.is_empty());
}
#[test]
fn config_deserialize_binary_entry() {
let yaml = r#"
before_release:
- binary: ./hooks/moderate
timeout_secs: 15
stages: [release]
"#;
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.before_release.len(), 1);
match &config.before_release[0] {
MiddlewareEntry::Binary {
binary,
timeout_secs,
stages,
..
} => {
assert_eq!(binary, &PathBuf::from("./hooks/moderate"));
assert_eq!(*timeout_secs, 15);
assert_eq!(stages.as_ref().unwrap(), &[MiddlewareStage::Release]);
}
_ => panic!("Expected Binary entry"),
}
}
#[test]
fn config_deserialize_builtin_entry() {
let yaml = r#"
before_release:
- builtin: rule_based
stages: [edit, release]
config:
max_content_length: 50000
pii_patterns: true
"#;
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.before_release.len(), 1);
match &config.before_release[0] {
MiddlewareEntry::Builtin {
builtin,
stages,
config,
} => {
assert!(matches!(builtin, BuiltinMiddlewareType::RuleBased));
assert_eq!(
stages.as_ref().unwrap(),
&[MiddlewareStage::Edit, MiddlewareStage::Release]
);
assert_eq!(config["max_content_length"], 50000);
assert_eq!(config["pii_patterns"], true);
}
_ => panic!("Expected Builtin entry"),
}
}
#[test]
fn config_deserialize_mixed() {
let yaml = r#"
before_release:
- builtin: signature_verification
stages: [release]
- binary: ./hooks/moderate
timeout_secs: 30
stages: [release]
on_provider_response:
- binary: ./hooks/transform
timeout_secs: 10
"#;
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.before_release.len(), 2);
assert_eq!(config.on_provider_response.len(), 1);
assert!(!config.is_empty());
}
#[test]
fn config_default_timeout() {
let yaml = r#"
before_release:
- binary: ./hooks/check
"#;
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
match &config.before_release[0] {
MiddlewareEntry::Binary { timeout_secs, .. } => {
assert_eq!(*timeout_secs, 30); }
_ => panic!("Expected Binary entry"),
}
}
#[test]
fn build_binary_pipeline() {
let yaml = r#"
before_release:
- binary: /bin/true
timeout_secs: 5
"#;
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
let pipeline = config.build_before_release_pipeline();
assert_eq!(pipeline.len(), 1);
assert!(!pipeline.is_empty());
}
#[test]
fn build_builtin_pipeline_creates_rule_based() {
let yaml = r#"
before_release:
- builtin: rule_based
config: {}
"#;
let config: MiddlewareConfig = serde_yaml::from_str(yaml).unwrap();
let pipeline = config.build_before_release_pipeline();
assert_eq!(pipeline.len(), 1);
}
}