use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::future::Future;
use std::pin::Pin;
use indexmap::IndexMap;
use objectiveai_mcp_proxy::{QueueDelegate, QueueRead};
use objectiveai_sdk::client_objectiveai_mcp::server_response::ReadMessageQueueResult;
use objectiveai_sdk::mcp::tool::ContentBlock;
use tokio::sync::Mutex;
use crate::objectiveai_mcp::ReverseAttachHandle;
const AIH_HEADER: &str = "X-OBJECTIVEAI-AGENT-INSTANCE-HIERARCHY";
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
pub struct ApiQueueDelegate {
states: Mutex<HashMap<String, Arc<PerLoopState>>>,
}
struct PerLoopState {
reverse_attach: Arc<ReverseAttachHandle>,
inner: Mutex<PerLoopInner>,
}
struct PerLoopInner {
confirmed: HashSet<i64>,
pending: HashMap<String, Vec<i64>>,
undelivered: Vec<i64>,
}
impl ApiQueueDelegate {
pub fn new() -> Self {
Self { states: Mutex::new(HashMap::new()) }
}
pub async fn register(
&self,
aih: String,
reverse_attach: Arc<ReverseAttachHandle>,
initial_banned: Vec<i64>,
) {
let mut confirmed: HashSet<i64> = HashSet::with_capacity(initial_banned.len());
confirmed.extend(initial_banned);
let state = PerLoopState {
reverse_attach,
inner: Mutex::new(PerLoopInner {
confirmed,
pending: HashMap::new(),
undelivered: Vec::new(),
}),
};
self.states.lock().await.insert(aih, Arc::new(state));
}
pub async fn confirm(&self, aih: &str, token: &str) -> Vec<i64> {
let Some(state) = self.states.lock().await.get(aih).cloned() else {
return Vec::new();
};
let mut inner = state.inner.lock().await;
let Some(ids) = inner.pending.remove(token) else {
return Vec::new();
};
inner.confirmed.extend(ids.iter().copied());
inner.undelivered.extend(ids.iter().copied());
ids
}
pub async fn drain_undelivered(&self, aih: &str) -> Vec<i64> {
let Some(state) = self.states.lock().await.get(aih).cloned() else {
return Vec::new();
};
let mut inner = state.inner.lock().await;
std::mem::take(&mut inner.undelivered)
}
pub async fn unregister(&self, aih: &str) {
self.states.lock().await.remove(aih);
}
}
impl Default for ApiQueueDelegate {
fn default() -> Self {
Self::new()
}
}
impl QueueDelegate for ApiQueueDelegate {
fn read_pending_blocks<'a>(
&'a self,
agent_arguments: &'a IndexMap<String, String>,
_mcp_session_id: &'a str,
) -> Pin<Box<dyn Future<Output = Option<QueueRead>> + Send + 'a>> {
Box::pin(async move {
let aih = agent_arguments
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(AIH_HEADER))
.map(|(_, v)| v.as_str())?;
let state = self.states.lock().await.get(aih).cloned()?;
let (filter, reverse_attach) = {
let inner = state.inner.lock().await;
let mut filter: HashSet<i64> = inner.confirmed.clone();
for ids in inner.pending.values() {
filter.extend(ids.iter().copied());
}
(filter, state.reverse_attach.clone())
};
let result = match read_message_queue_via_ws(&reverse_attach, aih).await {
Ok(r) => r,
Err(_) => return None,
};
let mut new_ids: Vec<i64> = Vec::new();
let mut blocks: Vec<Vec<ContentBlock>> = Vec::new();
for row in result.rows {
if row.content_ids.iter().all(|id| filter.contains(id)) {
continue;
}
new_ids.extend(row.content_ids.iter().copied());
let row_blocks: Vec<ContentBlock> = row.rich_content.into();
blocks.push(row_blocks);
}
if new_ids.is_empty() {
return None;
}
let token = uuid::Uuid::new_v4().to_string();
{
let mut inner = state.inner.lock().await;
inner.pending.insert(token.clone(), new_ids);
}
Some(QueueRead { token, blocks })
})
}
}
async fn read_message_queue_via_ws(
handle: &Arc<ReverseAttachHandle>,
agent_instance_hierarchy: &str,
) -> Result<ReadMessageQueueResult, ReverseAttachError> {
use objectiveai_sdk::client_objectiveai_mcp::{server_request, server_response};
let rc = handle.channel();
let request = server_request::Request {
id: uuid::Uuid::new_v4().to_string(),
headers: IndexMap::new(),
payload: server_request::Payload::ReadMessageQueue(
server_request::ReadMessageQueueRequest {
agent_instance_hierarchy: agent_instance_hierarchy.to_string(),
},
),
};
let rx = crate::objectiveai_mcp::send_server_request(&rc.sink, &rc.pending, request)
.await
.map_err(|()| ReverseAttachError::Closed)?;
let response = match tokio::time::timeout(READ_TIMEOUT, rx).await {
Ok(Ok(response)) => response,
Ok(Err(_)) => return Err(ReverseAttachError::Dropped),
Err(_) => return Err(ReverseAttachError::Timeout),
};
match response.payload {
server_response::Payload::ReadMessageQueue(server_response::JsonRpcResult::Ok {
result,
}) => Ok(result),
server_response::Payload::ReadMessageQueue(server_response::JsonRpcResult::Err {
code,
message,
..
}) => Err(ReverseAttachError::CliError { code, message }),
_ => Err(ReverseAttachError::WrongVariant),
}
}
#[derive(Debug)]
enum ReverseAttachError {
Closed,
Dropped,
Timeout,
CliError { code: i64, message: String },
WrongVariant,
}