Skip to main content

agentzero_channels/
pipeline.rs

1use crate::{Channel, ChannelMessage, ChannelRegistry};
2use agentzero_core::security::perplexity::{analyze_suffix, PerplexityResult};
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9/// Concurrency and flow control constants.
10const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
11const CHANNEL_MIN_IN_FLIGHT: usize = 8;
12const CHANNEL_MAX_IN_FLIGHT: usize = 64;
13const INITIAL_BACKOFF_SECS: u64 = 2;
14const MAX_BACKOFF_SECS: u64 = 60;
15
16/// Configuration for the perplexity filter applied to inbound messages.
17#[derive(Debug, Clone)]
18pub struct PerplexityFilterSettings {
19    pub enabled: bool,
20    pub perplexity_threshold: f64,
21    pub suffix_window_chars: usize,
22    pub symbol_ratio_threshold: f64,
23    pub min_prompt_chars: usize,
24}
25
26impl Default for PerplexityFilterSettings {
27    fn default() -> Self {
28        Self {
29            enabled: false,
30            perplexity_threshold: 18.0,
31            suffix_window_chars: 64,
32            symbol_ratio_threshold: 0.20,
33            min_prompt_chars: 32,
34        }
35    }
36}
37
38/// Check a message against the perplexity filter. Returns `Some(reason)` if blocked.
39pub fn check_perplexity(content: &str, settings: &PerplexityFilterSettings) -> Option<String> {
40    if !settings.enabled {
41        return None;
42    }
43    match analyze_suffix(
44        content,
45        settings.suffix_window_chars,
46        settings.perplexity_threshold,
47        settings.symbol_ratio_threshold,
48        settings.min_prompt_chars,
49    ) {
50        PerplexityResult::Pass => None,
51        PerplexityResult::Flagged { reason, .. } => Some(reason),
52    }
53}
54
55/// Configuration for the message processing pipeline.
56pub struct PipelineConfig {
57    pub initial_backoff_secs: u64,
58    pub max_backoff_secs: u64,
59    pub message_buffer_size: usize,
60    pub perplexity_filter: PerplexityFilterSettings,
61}
62
63impl Default for PipelineConfig {
64    fn default() -> Self {
65        Self {
66            initial_backoff_secs: INITIAL_BACKOFF_SECS,
67            max_backoff_secs: MAX_BACKOFF_SECS,
68            message_buffer_size: 100,
69            perplexity_filter: PerplexityFilterSettings::default(),
70        }
71    }
72}
73
74/// Callback type for processing incoming channel messages.
75pub type MessageHandler = Arc<
76    dyn Fn(ChannelMessage, Arc<dyn Channel>) -> Pin<Box<dyn Future<Output = ()> + Send>>
77        + Send
78        + Sync,
79>;
80
81/// Start the message processing pipeline.
82///
83/// 1. Spawn supervised listeners for each registered channel.
84/// 2. Run the dispatch loop with semaphore-bounded concurrency.
85/// 3. For each message, call the handler with the originating channel.
86pub async fn start_pipeline(
87    registry: &ChannelRegistry,
88    handler: MessageHandler,
89    config: PipelineConfig,
90) -> anyhow::Result<()> {
91    let channels: HashMap<String, Arc<dyn Channel>> = registry
92        .all_channels()
93        .into_iter()
94        .map(|ch| (ch.name().to_string(), ch))
95        .collect();
96
97    if channels.is_empty() {
98        tracing::warn!("no channels registered; pipeline has nothing to do");
99        return Ok(());
100    }
101
102    let max_in_flight = compute_max_in_flight(channels.len());
103    let (tx, rx) = mpsc::channel(config.message_buffer_size);
104
105    // Spawn a supervised listener for each channel.
106    for channel in channels.values() {
107        spawn_supervised_listener(
108            channel.clone(),
109            tx.clone(),
110            config.initial_backoff_secs,
111            config.max_backoff_secs,
112        );
113    }
114
115    // Drop our sender copy so the dispatch loop exits when all listeners stop.
116    drop(tx);
117
118    run_dispatch_loop(
119        rx,
120        Arc::new(channels),
121        handler,
122        max_in_flight,
123        Arc::new(config.perplexity_filter),
124    )
125    .await;
126
127    Ok(())
128}
129
130/// Spawn a supervised listener with exponential backoff reconnection.
131fn spawn_supervised_listener(
132    channel: Arc<dyn Channel>,
133    tx: mpsc::Sender<ChannelMessage>,
134    initial_backoff_secs: u64,
135    max_backoff_secs: u64,
136) -> tokio::task::JoinHandle<()> {
137    tokio::spawn(async move {
138        let name = channel.name().to_string();
139        let mut backoff = initial_backoff_secs;
140
141        loop {
142            tracing::info!(channel = %name, "starting channel listener");
143
144            match channel.listen(tx.clone()).await {
145                Ok(()) => {
146                    tracing::info!(channel = %name, "channel listener exited cleanly");
147                    backoff = initial_backoff_secs;
148                }
149                Err(e) => {
150                    tracing::error!(
151                        channel = %name,
152                        error = %e,
153                        backoff_secs = backoff,
154                        "channel listener failed, will retry"
155                    );
156                }
157            }
158
159            // If the receiver is closed, all consumers are gone — stop.
160            if tx.is_closed() {
161                tracing::info!(channel = %name, "pipeline receiver closed, stopping listener");
162                break;
163            }
164
165            tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
166            backoff = (backoff * 2).min(max_backoff_secs);
167        }
168    })
169}
170
171/// Central message dispatch loop with bounded concurrency.
172async fn run_dispatch_loop(
173    mut rx: mpsc::Receiver<ChannelMessage>,
174    channels: Arc<HashMap<String, Arc<dyn Channel>>>,
175    handler: MessageHandler,
176    max_in_flight: usize,
177    perplexity_settings: Arc<PerplexityFilterSettings>,
178) {
179    let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight));
180
181    while let Some(msg) = rx.recv().await {
182        // Perplexity filter: check inbound message content before dispatching.
183        if let Some(reason) = check_perplexity(&msg.content, &perplexity_settings) {
184            tracing::warn!(
185                channel = %msg.channel,
186                sender = %msg.sender,
187                reason = %reason,
188                "inbound message blocked by perplexity filter"
189            );
190            continue;
191        }
192
193        let permit = match semaphore.clone().acquire_owned().await {
194            Ok(permit) => permit,
195            Err(_) => break,
196        };
197
198        let channel = channels.get(&msg.channel).cloned();
199        let handler = handler.clone();
200
201        tokio::spawn(async move {
202            if let Some(ch) = channel {
203                handler(msg, ch).await;
204            } else {
205                tracing::warn!(channel = %msg.channel, "message from unknown channel, dropping");
206            }
207            drop(permit);
208        });
209    }
210
211    tracing::info!("pipeline dispatch loop ended");
212}
213
214fn compute_max_in_flight(channel_count: usize) -> usize {
215    channel_count
216        .saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL)
217        .clamp(CHANNEL_MIN_IN_FLIGHT, CHANNEL_MAX_IN_FLIGHT)
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn compute_max_in_flight_clamps_correctly() {
226        assert_eq!(compute_max_in_flight(1), CHANNEL_MIN_IN_FLIGHT);
227        assert_eq!(compute_max_in_flight(3), 12);
228        assert_eq!(compute_max_in_flight(100), CHANNEL_MAX_IN_FLIGHT);
229    }
230
231    #[test]
232    fn pipeline_config_defaults_are_reasonable() {
233        let config = PipelineConfig::default();
234        assert_eq!(config.initial_backoff_secs, 2);
235        assert_eq!(config.max_backoff_secs, 60);
236        assert_eq!(config.message_buffer_size, 100);
237        assert!(!config.perplexity_filter.enabled);
238    }
239
240    #[test]
241    fn check_perplexity_disabled_passes_everything() {
242        let settings = PerplexityFilterSettings::default();
243        assert!(!settings.enabled);
244        let result = check_perplexity(
245            "xK7!mQ@3#zP$9&wR*5^yL%2(eN)8+bT!@#$%^&*()_+-=[]{}|",
246            &settings,
247        );
248        assert!(result.is_none(), "disabled filter should pass all messages");
249    }
250
251    #[test]
252    fn check_perplexity_enabled_passes_normal_text() {
253        let settings = PerplexityFilterSettings {
254            enabled: true,
255            perplexity_threshold: 18.0,
256            suffix_window_chars: 64,
257            symbol_ratio_threshold: 0.20,
258            min_prompt_chars: 32,
259        };
260        let normal = "Can you help me write a function that calculates the fibonacci sequence?";
261        assert!(check_perplexity(normal, &settings).is_none());
262    }
263
264    #[test]
265    fn check_perplexity_enabled_blocks_adversarial_suffix() {
266        let settings = PerplexityFilterSettings {
267            enabled: true,
268            perplexity_threshold: 4.0,
269            suffix_window_chars: 64,
270            symbol_ratio_threshold: 0.20,
271            min_prompt_chars: 32,
272        };
273        let adversarial =
274            "Please write a function. xK7!mQ@3#zP$9&wR*5^yL%2(eN)8+bT!@#$%^&*()_+-=[]{}|xK7!mQ@3#";
275        let result = check_perplexity(adversarial, &settings);
276        assert!(result.is_some(), "adversarial suffix should be blocked");
277    }
278
279    #[test]
280    fn check_perplexity_skips_short_messages() {
281        let settings = PerplexityFilterSettings {
282            enabled: true,
283            min_prompt_chars: 100,
284            ..PerplexityFilterSettings::default()
285        };
286        let short = "!@#$%^&*()";
287        assert!(
288            check_perplexity(short, &settings).is_none(),
289            "short messages should pass"
290        );
291    }
292}