use std::sync::Arc;
use std::time::Duration;
use rusqlite::params;
use tokio::sync::Notify;
use crate::types::{IpcError, IpcResult, MessageInfo};
use convergio_db::pool::ConnPool;
pub const DEFAULT_RATE_LIMIT: u32 = 100;
pub struct SendParams<'a> {
pub from: &'a str,
pub to: &'a str,
pub content: &'a str,
pub msg_type: &'a str,
pub priority: i32,
pub rate_limit: u32,
}
pub fn send(pool: &ConnPool, notify: &Arc<Notify>, p: &SendParams<'_>) -> IpcResult<String> {
let conn = pool.get()?;
check_rate_limit(&conn, p.from, p.rate_limit)?;
let id = generate_msg_id();
conn.execute(
"INSERT INTO ipc_messages (id, from_agent, to_agent, content, msg_type, priority)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![id, p.from, p.to, p.content, p.msg_type, p.priority],
)?;
notify.notify_waiters();
Ok(id)
}
pub fn broadcast(
pool: &ConnPool,
notify: &Arc<Notify>,
from: &str,
content: &str,
msg_type: &str,
channel: Option<&str>,
rate_limit: u32,
) -> IpcResult<String> {
let conn = pool.get()?;
check_rate_limit(&conn, from, rate_limit)?;
let id = generate_msg_id();
conn.execute(
"INSERT INTO ipc_messages (id, from_agent, to_agent, channel, content, msg_type)
VALUES (?1, ?2, NULL, ?3, ?4, ?5)",
params![id, from, channel, content, msg_type],
)?;
notify.notify_waiters();
Ok(id)
}
pub fn receive(
pool: &ConnPool,
agent: &str,
from_filter: Option<&str>,
channel_filter: Option<&str>,
limit: u32,
peek: bool,
) -> IpcResult<Vec<MessageInfo>> {
let conn = pool.get()?;
let tx = conn.unchecked_transaction()?;
let mut conditions = vec!["(to_agent = ?1 OR to_agent IS NULL)".to_string()];
let mut p: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(agent.to_string())];
if let Some(from) = from_filter {
p.push(Box::new(from.to_string()));
conditions.push(format!("from_agent = ?{}", p.len()));
}
if let Some(ch) = channel_filter {
p.push(Box::new(ch.to_string()));
conditions.push(format!("channel = ?{}", p.len()));
}
let sql = format!(
"SELECT id, from_agent, to_agent, channel, content, msg_type, created_at
FROM ipc_messages WHERE {} AND read_at IS NULL
ORDER BY created_at ASC LIMIT {limit}",
conditions.join(" AND ")
);
let refs: Vec<&dyn rusqlite::types::ToSql> = p.iter().map(|v| v.as_ref()).collect();
let mut stmt = tx.prepare(&sql)?;
let rows: Vec<(String, MessageInfo)> = stmt
.query_map(refs.as_slice(), |row| {
Ok((row.get::<_, String>(0)?, map_message(row)?))
})?
.filter_map(|r| r.ok())
.collect();
drop(stmt);
if !peek {
for (id, _) in &rows {
tx.execute(
"UPDATE ipc_messages SET read_at = strftime('%Y-%m-%dT%H:%M:%f','now')
WHERE id = ?1",
params![id],
)?;
}
}
tx.commit()?;
Ok(rows.into_iter().map(|(_, m)| m).collect())
}
pub async fn receive_wait(
pool: &ConnPool,
notify: &Arc<Notify>,
agent: &str,
from_filter: Option<&str>,
channel_filter: Option<&str>,
limit: u32,
timeout_secs: u64,
) -> IpcResult<Vec<MessageInfo>> {
let msgs = receive(pool, agent, from_filter, channel_filter, limit, false)?;
if !msgs.is_empty() {
return Ok(msgs);
}
let n = notify.clone();
let a = agent.to_string();
let f = from_filter.map(String::from);
let c = channel_filter.map(String::from);
match tokio::time::timeout(Duration::from_secs(timeout_secs), async {
loop {
n.notified().await;
let msgs = receive(pool, &a, f.as_deref(), c.as_deref(), limit, false)?;
if !msgs.is_empty() {
return Ok(msgs);
}
}
})
.await
{
Ok(result) => result,
Err(_) => Ok(Vec::new()),
}
}
pub fn history(
pool: &ConnPool,
agent_filter: Option<&str>,
channel_filter: Option<&str>,
limit: u32,
since: Option<&str>,
) -> IpcResult<Vec<MessageInfo>> {
let conn = pool.get()?;
let mut conds: Vec<String> = Vec::new();
let mut p: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
if let Some(agent) = agent_filter {
p.push(Box::new(agent.to_string()));
let i = p.len();
conds.push(format!("(from_agent = ?{i} OR to_agent = ?{i})"));
}
if let Some(ch) = channel_filter {
p.push(Box::new(ch.to_string()));
conds.push(format!("channel = ?{}", p.len()));
}
if let Some(ts) = since {
p.push(Box::new(ts.to_string()));
conds.push(format!("created_at >= ?{}", p.len()));
}
let where_cl = if conds.is_empty() {
String::new()
} else {
format!("WHERE {}", conds.join(" AND "))
};
let sql = format!(
"SELECT id, from_agent, to_agent, channel, content, msg_type, created_at
FROM ipc_messages {where_cl} ORDER BY created_at DESC LIMIT {limit}"
);
let refs: Vec<&dyn rusqlite::types::ToSql> = p.iter().map(|v| v.as_ref()).collect();
let mut stmt = conn.prepare(&sql)?;
let msgs = stmt
.query_map(refs.as_slice(), map_message)?
.filter_map(|r| r.ok())
.collect();
Ok(msgs)
}
fn check_rate_limit(conn: &rusqlite::Connection, from: &str, limit: u32) -> IpcResult<()> {
let count: u32 = conn.query_row(
"SELECT COUNT(*) FROM ipc_messages
WHERE from_agent = ?1 AND created_at > datetime('now', '-1 minute')",
params![from],
|r| r.get(0),
)?;
if count >= limit {
return Err(IpcError::RateLimited(format!(
"agent '{from}' exceeded {limit} msgs/min"
)));
}
Ok(())
}
fn generate_msg_id() -> String {
uuid::Uuid::new_v4().to_string()
}
fn map_message(row: &rusqlite::Row<'_>) -> rusqlite::Result<MessageInfo> {
Ok(MessageInfo {
id: row.get(0)?,
from_agent: row.get(1)?,
to_agent: row.get(2)?,
channel: row.get(3)?,
content: row.get(4)?,
msg_type: row.get(5)?,
created_at: row.get(6)?,
})
}
#[cfg(test)]
#[path = "messaging_tests.rs"]
mod tests;