use std::pin::Pin;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use crate::domain::core::task::TaskStateExt;
use crate::domain::{
A2AError, ListTasksParams, ListTasksResult, Message, RetryPolicy, Task,
TaskPushNotificationConfig,
};
use crate::port::{StreamEvent, StreamItem, Transport};
type EventStream = Pin<Box<dyn Stream<Item = Result<StreamEvent, A2AError>> + Send>>;
pub fn subscribe_resilient(
transport: Arc<dyn Transport>,
task_id: impl Into<String>,
history_length: Option<u32>,
last_event_id: Option<u64>,
policy: RetryPolicy,
) -> EventStream {
let task_id = task_id.into();
let seed = seed_for(&task_id);
struct State {
transport: Arc<dyn Transport>,
task_id: String,
history_length: Option<u32>,
policy: RetryPolicy,
seed: u64,
last_event_id: Option<u64>,
attempt: u32,
inner: Option<EventStream>,
done: bool,
}
let state = State {
transport,
task_id,
history_length,
policy,
seed,
last_event_id,
attempt: 0,
inner: None,
done: false,
};
Box::pin(futures::stream::unfold(state, |mut st| async move {
loop {
if st.done {
return None;
}
if st.inner.is_none() {
if st.attempt > st.policy.max_retries {
st.done = true;
return Some((
Err(A2AError::Internal(format!(
"subscription to '{}' failed after {} retries",
st.task_id, st.policy.max_retries
))),
st,
));
}
if st.attempt > 0 {
let delay = st.policy.backoff(st.attempt, st.seed);
tokio::time::sleep(delay).await;
}
let resume = st.last_event_id.map(|n| n.to_string());
match st
.transport
.subscribe_to_task(&st.task_id, st.history_length, resume.as_deref())
.await
{
Ok(stream) => st.inner = Some(stream),
Err(_) => {
st.attempt += 1;
continue;
}
}
}
match st.inner.as_mut().unwrap().next().await {
Some(Ok(event)) => {
st.attempt = 0;
if let Some(id) = event.event_id {
st.last_event_id = Some(id);
}
if is_terminal(&event.item) {
st.done = true;
}
return Some((Ok(event), st));
}
Some(Err(_)) | None => {
st.inner = None;
st.attempt += 1;
continue;
}
}
}
}))
}
fn is_terminal(item: &StreamItem) -> bool {
match item {
StreamItem::Task(task) => task
.status
.as_option()
.map(|s| s.state.is_terminal())
.unwrap_or(false),
StreamItem::StatusUpdate(event) => event.status.state.is_terminal(),
StreamItem::ArtifactUpdate(_) => false,
}
}
fn seed_for(task_id: &str) -> u64 {
let mut state = 0u64;
for &b in task_id.as_bytes() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(b as u64);
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
state.wrapping_mul(6364136223846793005).wrapping_add(now)
}
pub struct RetryingTransport {
inner: Arc<dyn Transport>,
policy: RetryPolicy,
}
impl RetryingTransport {
pub fn new(inner: Arc<dyn Transport>, policy: RetryPolicy) -> Self {
Self { inner, policy }
}
pub fn wrap(inner: Box<dyn Transport>, policy: RetryPolicy) -> Self {
Self {
inner: Arc::from(inner),
policy,
}
}
}
#[async_trait]
impl Transport for RetryingTransport {
fn protocol(&self) -> &str {
self.inner.protocol()
}
async fn send_task_message(
&self,
task_id: &str,
message: &Message,
session_id: Option<&str>,
history_length: Option<u32>,
) -> Result<Task, A2AError> {
self.inner
.send_task_message(task_id, message, session_id, history_length)
.await
}
async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
self.inner.get_task(task_id, history_length).await
}
async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
self.inner.cancel_task(task_id).await
}
async fn set_task_push_notification(
&self,
config: &TaskPushNotificationConfig,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.inner.set_task_push_notification(config).await
}
async fn get_task_push_notification(
&self,
task_id: &str,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.inner.get_task_push_notification(task_id).await
}
async fn list_tasks(&self, params: &ListTasksParams) -> Result<ListTasksResult, A2AError> {
self.inner.list_tasks(params).await
}
async fn list_push_notification_configs(
&self,
task_id: &str,
) -> Result<Vec<TaskPushNotificationConfig>, A2AError> {
self.inner.list_push_notification_configs(task_id).await
}
async fn get_push_notification_config(
&self,
task_id: &str,
config_id: &str,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.inner
.get_push_notification_config(task_id, config_id)
.await
}
async fn delete_push_notification_config(
&self,
task_id: &str,
config_id: &str,
) -> Result<(), A2AError> {
self.inner
.delete_push_notification_config(task_id, config_id)
.await
}
async fn subscribe_to_task(
&self,
task_id: &str,
history_length: Option<u32>,
last_event_id: Option<&str>,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, A2AError>> + Send>>, A2AError> {
let resume = last_event_id.and_then(|s| s.trim().parse::<u64>().ok());
Ok(subscribe_resilient(
self.inner.clone(),
task_id.to_string(),
history_length,
resume,
self.policy,
))
}
}