use crate::types::{QuestionHookResult, Step, ToolResult};
#[cfg(test)]
use futures_util::StreamExt;
use futures_util::stream::BoxStream;
pub trait Connection: Send + Sync {
fn conversation_id(&self) -> &str;
fn is_idle(&self) -> bool;
fn receive_steps(&self) -> BoxStream<'static, Result<Step, anyhow::Error>>;
fn send(
&self,
content: &str,
) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
fn send_trigger_notification(
&self,
content: &str,
) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
fn send_halt_request(
&self,
) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
fn send_tool_confirmation(
&self,
trajectory_id: &str,
step_index: u32,
accepted: bool,
) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
fn send_tool_response(
&self,
id: &str,
result: ToolResult,
) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
fn send_question_response(
&self,
trajectory_id: &str,
step_index: u32,
answers: QuestionHookResult,
) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
fn disconnect(&self) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send;
}
#[derive(Clone)]
pub enum AnyConnection {
#[cfg(not(target_arch = "wasm32"))]
Local(std::sync::Arc<crate::local::LocalConnection>),
#[cfg(target_arch = "wasm32")]
Wasm(std::sync::Arc<crate::wasm::WasmConnection>),
#[cfg(test)]
Mock(std::sync::Arc<MockConnection>),
}
impl std::fmt::Debug for AnyConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(_) => f.write_str("AnyConnection::Local"),
#[cfg(target_arch = "wasm32")]
Self::Wasm(_) => f.write_str("AnyConnection::Wasm"),
#[cfg(test)]
Self::Mock(_) => f.write_str("AnyConnection::Mock"),
}
}
}
impl Connection for AnyConnection {
fn conversation_id(&self) -> &str {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.conversation_id(),
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.conversation_id(),
#[cfg(test)]
Self::Mock(c) => c.conversation_id(),
}
}
fn is_idle(&self) -> bool {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.is_idle(),
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.is_idle(),
#[cfg(test)]
Self::Mock(c) => c.is_idle(),
}
}
fn receive_steps(&self) -> BoxStream<'static, Result<Step, anyhow::Error>> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.receive_steps(),
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.receive_steps(),
#[cfg(test)]
Self::Mock(c) => c.receive_steps(),
}
}
async fn send(&self, content: &str) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.send(content).await,
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.send(content).await,
#[cfg(test)]
Self::Mock(c) => c.send(content).await,
}
}
async fn send_trigger_notification(&self, content: &str) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.send_trigger_notification(content).await,
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.send_trigger_notification(content).await,
#[cfg(test)]
Self::Mock(c) => c.send_trigger_notification(content).await,
}
}
async fn send_halt_request(&self) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.send_halt_request().await,
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.send_halt_request().await,
#[cfg(test)]
Self::Mock(c) => c.send_halt_request().await,
}
}
async fn send_tool_confirmation(
&self,
trajectory_id: &str,
step_index: u32,
accepted: bool,
) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => {
c.send_tool_confirmation(trajectory_id, step_index, accepted)
.await
}
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => {
c.send_tool_confirmation(trajectory_id, step_index, accepted)
.await
}
#[cfg(test)]
Self::Mock(c) => {
c.send_tool_confirmation(trajectory_id, step_index, accepted)
.await
}
}
}
async fn send_tool_response(&self, id: &str, result: ToolResult) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.send_tool_response(id, result).await,
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.send_tool_response(id, result).await,
#[cfg(test)]
Self::Mock(c) => c.send_tool_response(id, result).await,
}
}
async fn send_question_response(
&self,
trajectory_id: &str,
step_index: u32,
answers: QuestionHookResult,
) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => {
c.send_question_response(trajectory_id, step_index, answers)
.await
}
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => {
c.send_question_response(trajectory_id, step_index, answers)
.await
}
#[cfg(test)]
Self::Mock(c) => {
c.send_question_response(trajectory_id, step_index, answers)
.await
}
}
}
async fn disconnect(&self) -> Result<(), anyhow::Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Local(c) => c.disconnect().await,
#[cfg(target_arch = "wasm32")]
Self::Wasm(c) => c.disconnect().await,
#[cfg(test)]
Self::Mock(c) => c.disconnect().await,
}
}
}
#[cfg(test)]
pub struct MockConnection {
pub id: String,
pub is_idle: std::sync::atomic::AtomicBool,
pub steps_to_yield: std::sync::Mutex<Vec<Step>>,
pub sent_prompts: std::sync::Mutex<Vec<String>>,
}
#[cfg(test)]
impl std::fmt::Debug for MockConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockConnection")
.field("id", &self.id)
.field(
"is_idle",
&self.is_idle.load(std::sync::atomic::Ordering::Relaxed),
)
.finish_non_exhaustive()
}
}
#[cfg(test)]
impl MockConnection {
pub fn new(id: &str) -> Self {
Self {
id: id.to_string(),
is_idle: std::sync::atomic::AtomicBool::new(true),
steps_to_yield: std::sync::Mutex::new(Vec::new()),
sent_prompts: std::sync::Mutex::new(Vec::new()),
}
}
}
#[cfg(test)]
impl Connection for MockConnection {
fn conversation_id(&self) -> &str {
&self.id
}
fn is_idle(&self) -> bool {
self.is_idle.load(std::sync::atomic::Ordering::SeqCst)
}
fn receive_steps(&self) -> BoxStream<'static, Result<Step, anyhow::Error>> {
let steps = self
.steps_to_yield
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
futures_util::stream::iter(steps).map(Ok).boxed()
}
async fn send(&self, content: &str) -> Result<(), anyhow::Error> {
self.sent_prompts
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(content.to_string());
Ok(())
}
async fn send_trigger_notification(&self, _content: &str) -> Result<(), anyhow::Error> {
Ok(())
}
async fn send_halt_request(&self) -> Result<(), anyhow::Error> {
Ok(())
}
async fn send_tool_confirmation(
&self,
_trajectory_id: &str,
_step_index: u32,
_accepted: bool,
) -> Result<(), anyhow::Error> {
Ok(())
}
async fn send_tool_response(
&self,
_id: &str,
_result: ToolResult,
) -> Result<(), anyhow::Error> {
Ok(())
}
async fn send_question_response(
&self,
_trajectory_id: &str,
_step_index: u32,
_answers: QuestionHookResult,
) -> Result<(), anyhow::Error> {
Ok(())
}
async fn disconnect(&self) -> Result<(), anyhow::Error> {
Ok(())
}
}