use std::io::{BufRead, Write};
use crate::sdk::{SDK_SCHEMA_VERSION, SdkCommand, SdkResponse};
use serde_json::json;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
pub enum ProxyEvent {
Agent(serde_json::Value),
}
pub trait ProxyHandler: Send + Sync {
fn handle_command(&self, command: SdkCommand, event_sink: &dyn Fn(ProxyEvent)) -> SdkResponse;
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub event_channel_capacity: usize,
pub redact_secrets: bool,
}
impl Default for ProxyConfig {
fn default() -> Self {
Self {
event_channel_capacity: 256,
redact_secrets: true,
}
}
}
pub struct StreamingProxy<H> {
handler: H,
config: ProxyConfig,
}
impl<H: ProxyHandler> StreamingProxy<H> {
pub fn new(handler: H, config: ProxyConfig) -> Self {
Self { handler, config }
}
pub fn run<R: BufRead, W: Write>(
self,
reader: R,
writer: W,
cancel: CancellationToken,
) -> Result<W, StreamingProxyError> {
let redactor = if self.config.redact_secrets {
Some(SecretRedactor::default())
} else {
None
};
let (event_tx, event_rx) =
std::sync::mpsc::sync_channel::<ProxyEvent>(self.config.event_channel_capacity);
let mut engine = ProxyEngine {
reader,
writer,
handler: self.handler,
redactor,
event_rx,
event_tx,
cancel,
line_number: 0,
};
let ready = json!({
"type": "proxy_ready",
"schema_version": SDK_SCHEMA_VERSION,
});
engine.write_json(&ready)?;
engine.run_loop()
}
}
struct ProxyEngine<R, W, H> {
reader: R,
writer: W,
handler: H,
redactor: Option<SecretRedactor>,
event_rx: std::sync::mpsc::Receiver<ProxyEvent>,
event_tx: std::sync::mpsc::SyncSender<ProxyEvent>,
cancel: CancellationToken,
line_number: usize,
}
impl<R: BufRead, W: Write, H: ProxyHandler> ProxyEngine<R, W, H> {
fn run_loop(mut self) -> Result<W, StreamingProxyError> {
let mut line = String::new();
loop {
if self.cancel.is_cancelled() {
let _ = self.write_json(&json!({"type": "proxy_cancelled"}));
self.drain_events()?;
return Ok(self.writer);
}
line.clear();
match self.reader.read_line(&mut line) {
Ok(0) => {
self.drain_events()?;
return Ok(self.writer);
}
Ok(_n) => {}
Err(e) => return Err(StreamingProxyError::Io(e.to_string())),
};
self.line_number += 1;
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let command: SdkCommand = match serde_json::from_str(trimmed) {
Ok(cmd) => cmd,
Err(e) => {
let error_resp = json!({
"type": "proxy_error",
"line_number": self.line_number,
"error": format!("parse error: {e}"),
"raw": trimmed,
});
self.write_json(&error_resp)?;
continue;
}
};
let tx = self.event_tx.clone();
let sink = move |event: ProxyEvent| {
if tx.try_send(event).is_err() {
tracing::warn!("proxy event channel full, dropping event");
}
};
let response = self.handler.handle_command(command, &sink);
self.write_json(
&serde_json::to_value(&response)
.unwrap_or(json!({"type":"response","success":false})),
)?;
self.drain_events()?;
if response.command == "quit" {
return Ok(self.writer);
}
}
}
fn drain_events(&mut self) -> Result<(), StreamingProxyError> {
while let Ok(event) = self.event_rx.try_recv() {
match event {
ProxyEvent::Agent(mut value) => {
if let Some(ref redactor) = self.redactor {
value = redactor.redact(&value);
}
self.write_json(&value)?;
}
}
}
Ok(())
}
fn write_json(&mut self, value: &serde_json::Value) -> Result<(), StreamingProxyError> {
let mut line = serde_json::to_string(value).unwrap_or_else(|_| {
r#"{"type":"proxy_error","error":"serialization failed"}"#.to_owned()
});
line.push('\n');
self.writer
.write_all(line.as_bytes())
.map_err(|e| StreamingProxyError::Io(e.to_string()))?;
self.writer
.flush()
.map_err(|e| StreamingProxyError::Io(e.to_string()))?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SecretRedactor {
value_patterns: Vec<String>,
value_regexes: Vec<regex::Regex>,
sensitive_fields: Vec<String>,
}
impl Default for SecretRedactor {
fn default() -> Self {
let value_patterns = vec![
r"sk-ant-[a-zA-Z0-9]{20,}".to_owned(),
r"sk-[a-zA-Z0-9]{20,}".to_owned(),
r"eyJ[a-zA-Z0-9_-]{10,}\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+".to_owned(),
];
Self {
value_regexes: compile_value_patterns(&value_patterns),
value_patterns,
sensitive_fields: vec![
"password".to_owned(),
"secret".to_owned(),
"token".to_owned(),
"api_key".to_owned(),
"apikey".to_owned(),
"private_key".to_owned(),
"access_token".to_owned(),
"refresh_token".to_owned(),
],
}
}
}
fn compile_value_patterns(patterns: &[String]) -> Vec<regex::Regex> {
patterns
.iter()
.filter_map(|pattern| match regex::Regex::new(pattern) {
Ok(regex) => Some(regex),
Err(err) => {
tracing::warn!(pattern, error = %err, "invalid secret redaction pattern ignored");
None
}
})
.collect()
}
impl SecretRedactor {
pub fn new(value_patterns: Vec<String>) -> Self {
Self {
value_regexes: compile_value_patterns(&value_patterns),
value_patterns,
sensitive_fields: Self::default().sensitive_fields,
}
}
pub fn patterns(&self) -> &[String] {
&self.value_patterns
}
pub fn redact(&self, value: &serde_json::Value) -> serde_json::Value {
self.redact_value(value)
}
fn redact_value(&self, value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut new_map = serde_json::Map::new();
for (k, v) in map {
if self.is_sensitive_field(k) {
new_map.insert(
k.clone(),
serde_json::Value::String("[REDACTED]".to_owned()),
);
} else {
new_map.insert(k.clone(), self.redact_value(v));
}
}
serde_json::Value::Object(new_map)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(|v| self.redact_value(v)).collect())
}
serde_json::Value::String(s) => {
if self.matches_value_pattern(s) {
serde_json::Value::String("[REDACTED]".to_owned())
} else {
serde_json::Value::String(s.clone())
}
}
other => other.clone(),
}
}
fn is_sensitive_field(&self, name: &str) -> bool {
let lower = name.to_lowercase();
self.sensitive_fields.iter().any(|f| f == &lower)
}
fn matches_value_pattern(&self, value: &str) -> bool {
self.value_regexes.iter().any(|regex| regex.is_match(value))
}
}
#[derive(Debug)]
pub enum StreamingProxyError {
Io(String),
}
impl std::fmt::Display for StreamingProxyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(msg) => write!(f, "proxy I/O error: {msg}"),
}
}
}
impl std::error::Error for StreamingProxyError {}