pub mod output;
pub mod param;
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt, TryStreamExt, stream::SplitSink};
pub use output::*;
pub use param::*;
use reqwest_websocket::{CloseCode, Message, WebSocket};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::time::{Duration, interval};
use crate::{error::DashScopeError, operation::ws_client::WsClient};
pub trait WebsocketCallback {
fn on_open(
&self,
tx: &mut SplitSink<WebSocket, Message>,
) -> impl std::future::Future<Output = ()> + Send;
fn on_event(
&self,
tx: &mut SplitSink<WebSocket, Message>,
event: WebSocketEvent,
) -> impl std::future::Future<Output = ()> + Send;
fn on_data(
&self,
_tx: &mut SplitSink<WebSocket, Message>,
_data: Bytes,
) -> impl std::future::Future<Output = ()> + Send {
async {}
}
fn on_complete(&self) -> impl std::future::Future<Output = ()> + Send {
async {}
}
fn on_pong(&self, _bytes: Bytes) -> impl std::future::Future<Output = ()> + Send {
async {}
}
fn on_error(&self, error: DashScopeError) -> impl std::future::Future<Output = ()> + Send;
fn on_close(
&self,
code: CloseCode,
reason: String,
) -> impl std::future::Future<Output = ()> + Send;
fn heartbeat_interval(&self) -> Option<Duration> {
None
}
}
pub struct WebsocketInference {
ws_client: WsClient,
}
impl WebsocketInference {
pub fn new(ws_client: WsClient) -> Self {
Self { ws_client }
}
pub async fn call(self, callback: impl WebsocketCallback) -> Result<(), DashScopeError> {
let (tx, mut rx) = self.ws_client.0.split();
let tx = Arc::new(Mutex::new(tx));
let tx_for_open = Arc::clone(&tx);
{
let mut tx_guard = tx_for_open.lock().await;
callback.on_open(&mut tx_guard).await;
}
let heartbeat_handle = if let Some(interval_duration) = callback.heartbeat_interval() {
let tx_for_heartbeat = Arc::clone(&tx);
let heartbeat_task = tokio::spawn(async move {
let mut interval_timer = interval(interval_duration);
loop {
interval_timer.tick().await;
{
let mut tx_guard = tx_for_heartbeat.lock().await;
if tx_guard.send(Message::Ping(Bytes::new())).await.is_err() {
break;
}
}
}
});
Some(heartbeat_task)
} else {
None
};
while let Some(message) = rx.try_next().await? {
match message {
Message::Text(t) => {
let mut tx_guard = tx.lock().await;
match WebSocketEvent::try_from(t) {
Ok(event) => callback.on_event(&mut tx_guard, event).await,
Err(e) => callback.on_error(e).await,
}
}
Message::Binary(b) => {
let mut tx_guard = tx.lock().await;
callback.on_data(&mut tx_guard, b).await; }
Message::Ping(bytes) => {
{
let mut tx_guard = tx.lock().await;
let _ = tx_guard.send(Message::Pong(bytes)).await;
}
}
Message::Pong(bytes) => {
callback.on_pong(bytes).await;
}
Message::Close { code, reason } => {
if let Some(handle) = &heartbeat_handle {
handle.abort();
}
callback.on_close(code, reason).await;
break;
}
}
}
if let Some(handle) = heartbeat_handle {
handle.abort();
}
Ok(())
}
}