use crate::bridge::SDKMessage;
use crate::bridge::poll_config_defaults::PollIntervalConfig;
use crate::bridge::repl_bridge_handle::{BridgeControlRequest, BridgeControlResponse, BridgeState};
use crate::bridge::repl_bridge_transport::ReplBridgeTransport;
use crate::error::AgentError;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
const POLL_ERROR_INITIAL_DELAY_MS: u64 = 2_000;
const POLL_ERROR_MAX_DELAY_MS: u64 = 60_000;
const POLL_ERROR_GIVE_UP_MS: u64 = 15 * 60 * 1000;
#[derive(Clone)]
pub struct BridgeCoreParams {
pub dir: String,
pub machine_name: String,
pub branch: String,
pub git_repo_url: Option<String>,
pub title: String,
pub base_url: String,
pub session_ingress_url: String,
pub worker_type: String,
pub get_access_token: Arc<dyn Fn() -> Option<String> + Send + Sync>,
pub create_session: Arc<
dyn Fn(
String,
String,
Option<String>,
String,
) -> future::BoxFuture<'static, Result<Option<String>, AgentError>>
+ Send
+ Sync,
>,
pub archive_session:
Arc<dyn Fn(String) -> future::BoxFuture<'static, Result<(), AgentError>> + Send + Sync>,
pub get_current_title: Option<Arc<dyn Fn() -> String + Send + Sync>>,
pub to_sdk_messages:
Option<Arc<dyn Fn(Vec<crate::types::Message>) -> Vec<SDKMessage> + Send + Sync>>,
pub on_auth_401: Option<
Arc<dyn Fn(String) -> future::BoxFuture<'static, Result<bool, AgentError>> + Send + Sync>,
>,
pub get_poll_interval_config: Option<Arc<dyn Fn() -> PollIntervalConfig + Send + Sync>>,
pub initial_history_cap: Option<u32>,
pub initial_messages: Option<Vec<crate::types::Message>>,
pub previously_flushed_uuids:
Option<Arc<dyn Fn() -> std::collections::HashSet<String> + Send + Sync>>,
pub on_inbound_message: Option<Arc<dyn Fn(SDKMessage) + Send + Sync>>,
pub on_permission_response: Option<Arc<dyn Fn(BridgeControlResponse) + Send + Sync>>,
pub on_interrupt: Option<Arc<dyn Fn() + Send + Sync>>,
pub on_set_model: Option<Arc<dyn Fn(Option<String>) + Send + Sync>>,
pub on_set_max_thinking_tokens: Option<Arc<dyn Fn(Option<u32>) + Send + Sync>>,
pub on_set_permission_mode:
Option<Arc<dyn Fn(crate::permission::PermissionMode) -> Result<(), String> + Send + Sync>>,
pub on_state_change: Option<Arc<dyn Fn(BridgeState, Option<String>) + Send + Sync>>,
pub on_user_message: Option<Arc<dyn Fn(String, String) -> bool + Send + Sync>>,
pub perpetual: Option<bool>,
pub initial_sse_sequence_num: Option<u64>,
}
pub struct BridgeCoreHandle {
pub session_id: RwLock<String>,
pub environment_id: RwLock<String>,
pub session_ingress_url: String,
pub transport: RwLock<Option<Box<dyn ReplBridgeTransport>>>,
pub current_work_id: RwLock<Option<String>>,
pub current_ingress_token: RwLock<Option<String>>,
pub last_sequence_num: RwLock<u64>,
pub poll_abort: tokio::sync::watch::Sender<bool>,
pub teardown_started: RwLock<bool>,
params: BridgeCoreParams,
}
impl BridgeCoreHandle {
pub fn new(
session_id: String,
environment_id: String,
session_ingress_url: String,
params: BridgeCoreParams,
) -> Self {
let (poll_abort, _) = tokio::sync::watch::channel(false);
Self {
session_id: RwLock::new(session_id),
environment_id: RwLock::new(environment_id),
session_ingress_url,
transport: RwLock::new(None),
current_work_id: RwLock::new(None),
current_ingress_token: RwLock::new(None),
last_sequence_num: RwLock::new(0),
poll_abort,
teardown_started: RwLock::new(false),
params,
}
}
pub async fn write_messages(&self, messages: Vec<SDKMessage>) {
let transport = self.transport.read().await;
if let Some(t) = transport.as_ref() {
t.write_batch(messages).await;
}
}
pub async fn write_sdk_messages(&self, messages: Vec<SDKMessage>) {
let transport = self.transport.read().await;
if let Some(t) = transport.as_ref() {
t.write_batch(messages).await;
}
}
pub async fn send_control_request(&self, request: BridgeControlRequest) {
let transport = self.transport.read().await;
if let Some(t) = transport.as_ref() {
let msg =
bridge_message_from_control_request(request, self.session_id.read().await.clone());
t.write(msg).await;
}
}
pub async fn send_control_response(&self, response: BridgeControlResponse) {
let transport = self.transport.read().await;
if let Some(t) = transport.as_ref() {
let msg = bridge_message_from_control_response(
response,
self.session_id.read().await.clone(),
);
t.write(msg).await;
}
}
pub async fn send_control_cancel_request(&self, request_id: &str) {
let transport = self.transport.read().await;
if let Some(t) = transport.as_ref() {
let msg = bridge_control_cancel_request(
request_id.to_string(),
self.session_id.read().await.clone(),
);
t.write(msg).await;
}
}
pub async fn send_result(&self) {
let transport = self.transport.read().await;
if let Some(t) = transport.as_ref() {
let msg = bridge_result_message(self.session_id.read().await.clone());
t.write(msg).await;
}
}
pub fn get_sse_sequence_num(&self) -> u64 {
*self.last_sequence_num.blocking_read()
}
pub async fn teardown(&self) {
let mut started = self.teardown_started.write().await;
if *started {
return;
}
*started = true;
drop(started);
let _ = self.poll_abort.send(true);
let mut transport = self.transport.write().await;
if let Some(t) = transport.take() {
t.close();
}
if let Some(ref callback) = self.params.on_state_change {
callback(BridgeState::Failed, Some("teardown".to_string()));
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BridgeStateInternal {
Ready,
Connecting,
Connected,
Reconnecting,
Failed,
}
pub async fn init_bridge_core(
params: BridgeCoreParams,
) -> Result<Option<BridgeCoreHandle>, AgentError> {
let poll_config = params
.get_poll_interval_config
.as_ref()
.map(|f| f())
.unwrap_or_default();
if let Some(ref callback) = params.on_state_change {
callback(BridgeState::Ready, None);
}
Ok(None)
}
pub fn compute_backoff(consecutive_errors: u32) -> Duration {
let delay = POLL_ERROR_INITIAL_DELAY_MS
* 2u64.saturating_pow(consecutive_errors.saturating_sub(1) as u32);
Duration::from_millis(delay.min(POLL_ERROR_MAX_DELAY_MS))
}
pub fn should_give_up(first_error_time: Option<u64>, give_up_ms: u64) -> bool {
if let Some(start) = first_error_time {
let elapsed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
- start;
return elapsed >= give_up_ms;
}
false
}
pub fn bridge_message_from_control_request(
_request: BridgeControlRequest,
session_id: String,
) -> SDKMessage {
SDKMessage::user_message_with_session(session_id)
}
pub fn bridge_message_from_control_response(
_response: BridgeControlResponse,
session_id: String,
) -> SDKMessage {
SDKMessage::user_message_with_session(session_id)
}
pub fn bridge_control_cancel_request(request_id: String, session_id: String) -> SDKMessage {
SDKMessage::user_message_with_session(session_id)
}
pub fn bridge_result_message(session_id: String) -> SDKMessage {
SDKMessage::user_message_with_session(session_id)
}
mod future {
use crate::error::AgentError;
use core::pin::Pin;
pub type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
}