use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IpcMessage {
pub topic: String,
pub payload: IpcPayload,
pub signature: Option<Vec<u8>>,
pub source_id: Uuid,
pub timestamp: DateTime<Utc>,
}
impl IpcMessage {
#[must_use]
pub fn new(topic: impl Into<String>, payload: IpcPayload, source_id: Uuid) -> Self {
Self {
topic: topic.into(),
payload,
signature: None,
source_id,
timestamp: Utc::now(),
}
}
#[must_use]
pub fn with_signature(mut self, signature: Vec<u8>) -> Self {
self.signature = Some(signature);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum IpcPayload {
UserInput {
text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
context: Option<Value>,
},
AgentResponse {
text: String,
is_final: bool,
},
ApprovalRequired {
action: String,
resource: String,
reason: String,
},
Custom {
data: Value,
},
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum QuotaError {
#[error("Rate limit exceeded")]
RateLimited,
#[error("Payload too large")]
PayloadTooLarge,
}
#[derive(Debug)]
pub struct IpcRateLimiter {
state: dashmap::DashMap<Uuid, (std::time::Instant, usize)>,
last_prune: std::sync::Mutex<std::time::Instant>,
}
impl IpcRateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
state: dashmap::DashMap::new(),
last_prune: std::sync::Mutex::new(std::time::Instant::now()),
}
}
#[allow(clippy::collapsible_if)]
pub fn check_quota(&self, source_id: Uuid, size_bytes: usize) -> Result<(), QuotaError> {
if size_bytes > 5 * 1024 * 1024 {
return Err(QuotaError::PayloadTooLarge);
}
let now = std::time::Instant::now();
if self.state.len() > 1000 {
if let Ok(mut last) = self.last_prune.try_lock() {
if now.saturating_duration_since(*last).as_secs() > 60 {
*last = now;
self.state
.retain(|_, v| now.saturating_duration_since(v.0).as_secs() < 1);
}
}
}
let mut entry = self.state.entry(source_id).or_insert((now, 0));
if now.saturating_duration_since(entry.0).as_secs() >= 1 {
entry.0 = now;
entry.1 = 0;
}
if entry.1.saturating_add(size_bytes) > 10 * 1024 * 1024 {
return Err(QuotaError::RateLimited);
}
entry.1 = entry.1.saturating_add(size_bytes);
Ok(())
}
}
impl Default for IpcRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ipc_rate_limiter_size() {
let limiter = IpcRateLimiter::new();
let source_id = Uuid::new_v4();
assert_eq!(limiter.check_quota(source_id, 1024 * 1024), Ok(()));
assert_eq!(
limiter.check_quota(source_id, 6 * 1024 * 1024),
Err(QuotaError::PayloadTooLarge)
);
}
#[test]
fn test_ipc_rate_limiter_frequency() {
let limiter = IpcRateLimiter::new();
let source_id = Uuid::new_v4();
assert_eq!(limiter.check_quota(source_id, 4 * 1024 * 1024), Ok(()));
assert_eq!(limiter.check_quota(source_id, 4 * 1024 * 1024), Ok(()));
assert_eq!(
limiter.check_quota(source_id, 4 * 1024 * 1024),
Err(QuotaError::RateLimited)
);
}
#[test]
fn test_ipc_message_signature() {
let msg = IpcMessage::new(
"test.topic",
IpcPayload::AgentResponse {
text: "hello".into(),
is_final: true,
},
Uuid::new_v4(),
);
assert!(msg.signature.is_none());
let signed = msg.with_signature(vec![1, 2, 3]);
assert_eq!(signed.signature, Some(vec![1, 2, 3]));
}
}