1use crate::{Channel, ChannelMessage, ChannelRegistry};
2use agentzero_core::security::perplexity::{analyze_suffix, PerplexityResult};
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
11const CHANNEL_MIN_IN_FLIGHT: usize = 8;
12const CHANNEL_MAX_IN_FLIGHT: usize = 64;
13const INITIAL_BACKOFF_SECS: u64 = 2;
14const MAX_BACKOFF_SECS: u64 = 60;
15
16#[derive(Debug, Clone)]
18pub struct PerplexityFilterSettings {
19 pub enabled: bool,
20 pub perplexity_threshold: f64,
21 pub suffix_window_chars: usize,
22 pub symbol_ratio_threshold: f64,
23 pub min_prompt_chars: usize,
24}
25
26impl Default for PerplexityFilterSettings {
27 fn default() -> Self {
28 Self {
29 enabled: false,
30 perplexity_threshold: 18.0,
31 suffix_window_chars: 64,
32 symbol_ratio_threshold: 0.20,
33 min_prompt_chars: 32,
34 }
35 }
36}
37
38pub fn check_perplexity(content: &str, settings: &PerplexityFilterSettings) -> Option<String> {
40 if !settings.enabled {
41 return None;
42 }
43 match analyze_suffix(
44 content,
45 settings.suffix_window_chars,
46 settings.perplexity_threshold,
47 settings.symbol_ratio_threshold,
48 settings.min_prompt_chars,
49 ) {
50 PerplexityResult::Pass => None,
51 PerplexityResult::Flagged { reason, .. } => Some(reason),
52 }
53}
54
55pub struct PipelineConfig {
57 pub initial_backoff_secs: u64,
58 pub max_backoff_secs: u64,
59 pub message_buffer_size: usize,
60 pub perplexity_filter: PerplexityFilterSettings,
61}
62
63impl Default for PipelineConfig {
64 fn default() -> Self {
65 Self {
66 initial_backoff_secs: INITIAL_BACKOFF_SECS,
67 max_backoff_secs: MAX_BACKOFF_SECS,
68 message_buffer_size: 100,
69 perplexity_filter: PerplexityFilterSettings::default(),
70 }
71 }
72}
73
74pub type MessageHandler = Arc<
76 dyn Fn(ChannelMessage, Arc<dyn Channel>) -> Pin<Box<dyn Future<Output = ()> + Send>>
77 + Send
78 + Sync,
79>;
80
81pub async fn start_pipeline(
87 registry: &ChannelRegistry,
88 handler: MessageHandler,
89 config: PipelineConfig,
90) -> anyhow::Result<()> {
91 let channels: HashMap<String, Arc<dyn Channel>> = registry
92 .all_channels()
93 .into_iter()
94 .map(|ch| (ch.name().to_string(), ch))
95 .collect();
96
97 if channels.is_empty() {
98 tracing::warn!("no channels registered; pipeline has nothing to do");
99 return Ok(());
100 }
101
102 let max_in_flight = compute_max_in_flight(channels.len());
103 let (tx, rx) = mpsc::channel(config.message_buffer_size);
104
105 for channel in channels.values() {
107 spawn_supervised_listener(
108 channel.clone(),
109 tx.clone(),
110 config.initial_backoff_secs,
111 config.max_backoff_secs,
112 );
113 }
114
115 drop(tx);
117
118 run_dispatch_loop(
119 rx,
120 Arc::new(channels),
121 handler,
122 max_in_flight,
123 Arc::new(config.perplexity_filter),
124 )
125 .await;
126
127 Ok(())
128}
129
130fn spawn_supervised_listener(
132 channel: Arc<dyn Channel>,
133 tx: mpsc::Sender<ChannelMessage>,
134 initial_backoff_secs: u64,
135 max_backoff_secs: u64,
136) -> tokio::task::JoinHandle<()> {
137 tokio::spawn(async move {
138 let name = channel.name().to_string();
139 let mut backoff = initial_backoff_secs;
140
141 loop {
142 tracing::info!(channel = %name, "starting channel listener");
143
144 match channel.listen(tx.clone()).await {
145 Ok(()) => {
146 tracing::info!(channel = %name, "channel listener exited cleanly");
147 backoff = initial_backoff_secs;
148 }
149 Err(e) => {
150 tracing::error!(
151 channel = %name,
152 error = %e,
153 backoff_secs = backoff,
154 "channel listener failed, will retry"
155 );
156 }
157 }
158
159 if tx.is_closed() {
161 tracing::info!(channel = %name, "pipeline receiver closed, stopping listener");
162 break;
163 }
164
165 tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
166 backoff = (backoff * 2).min(max_backoff_secs);
167 }
168 })
169}
170
171async fn run_dispatch_loop(
173 mut rx: mpsc::Receiver<ChannelMessage>,
174 channels: Arc<HashMap<String, Arc<dyn Channel>>>,
175 handler: MessageHandler,
176 max_in_flight: usize,
177 perplexity_settings: Arc<PerplexityFilterSettings>,
178) {
179 let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight));
180
181 while let Some(msg) = rx.recv().await {
182 if let Some(reason) = check_perplexity(&msg.content, &perplexity_settings) {
184 tracing::warn!(
185 channel = %msg.channel,
186 sender = %msg.sender,
187 reason = %reason,
188 "inbound message blocked by perplexity filter"
189 );
190 continue;
191 }
192
193 let permit = match semaphore.clone().acquire_owned().await {
194 Ok(permit) => permit,
195 Err(_) => break,
196 };
197
198 let channel = channels.get(&msg.channel).cloned();
199 let handler = handler.clone();
200
201 tokio::spawn(async move {
202 if let Some(ch) = channel {
203 handler(msg, ch).await;
204 } else {
205 tracing::warn!(channel = %msg.channel, "message from unknown channel, dropping");
206 }
207 drop(permit);
208 });
209 }
210
211 tracing::info!("pipeline dispatch loop ended");
212}
213
214fn compute_max_in_flight(channel_count: usize) -> usize {
215 channel_count
216 .saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL)
217 .clamp(CHANNEL_MIN_IN_FLIGHT, CHANNEL_MAX_IN_FLIGHT)
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn compute_max_in_flight_clamps_correctly() {
226 assert_eq!(compute_max_in_flight(1), CHANNEL_MIN_IN_FLIGHT);
227 assert_eq!(compute_max_in_flight(3), 12);
228 assert_eq!(compute_max_in_flight(100), CHANNEL_MAX_IN_FLIGHT);
229 }
230
231 #[test]
232 fn pipeline_config_defaults_are_reasonable() {
233 let config = PipelineConfig::default();
234 assert_eq!(config.initial_backoff_secs, 2);
235 assert_eq!(config.max_backoff_secs, 60);
236 assert_eq!(config.message_buffer_size, 100);
237 assert!(!config.perplexity_filter.enabled);
238 }
239
240 #[test]
241 fn check_perplexity_disabled_passes_everything() {
242 let settings = PerplexityFilterSettings::default();
243 assert!(!settings.enabled);
244 let result = check_perplexity(
245 "xK7!mQ@3#zP$9&wR*5^yL%2(eN)8+bT!@#$%^&*()_+-=[]{}|",
246 &settings,
247 );
248 assert!(result.is_none(), "disabled filter should pass all messages");
249 }
250
251 #[test]
252 fn check_perplexity_enabled_passes_normal_text() {
253 let settings = PerplexityFilterSettings {
254 enabled: true,
255 perplexity_threshold: 18.0,
256 suffix_window_chars: 64,
257 symbol_ratio_threshold: 0.20,
258 min_prompt_chars: 32,
259 };
260 let normal = "Can you help me write a function that calculates the fibonacci sequence?";
261 assert!(check_perplexity(normal, &settings).is_none());
262 }
263
264 #[test]
265 fn check_perplexity_enabled_blocks_adversarial_suffix() {
266 let settings = PerplexityFilterSettings {
267 enabled: true,
268 perplexity_threshold: 4.0,
269 suffix_window_chars: 64,
270 symbol_ratio_threshold: 0.20,
271 min_prompt_chars: 32,
272 };
273 let adversarial =
274 "Please write a function. xK7!mQ@3#zP$9&wR*5^yL%2(eN)8+bT!@#$%^&*()_+-=[]{}|xK7!mQ@3#";
275 let result = check_perplexity(adversarial, &settings);
276 assert!(result.is_some(), "adversarial suffix should be blocked");
277 }
278
279 #[test]
280 fn check_perplexity_skips_short_messages() {
281 let settings = PerplexityFilterSettings {
282 enabled: true,
283 min_prompt_chars: 100,
284 ..PerplexityFilterSettings::default()
285 };
286 let short = "!@#$%^&*()";
287 assert!(
288 check_perplexity(short, &settings).is_none(),
289 "short messages should pass"
290 );
291 }
292}