use std::collections::VecDeque;
use std::time::Duration;
use futures_util::{SinkExt, Stream, StreamExt};
use secrecy::{ExposeSecret, SecretString};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tokio_tungstenite::{connect_async, connect_async_tls_with_config, Connector};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue;
use tokio_tungstenite::tungstenite::{Error as WsError, Message};
use tracing::{debug, info, trace, warn};
use crate::exit_api::ExitApiClient;
use crate::stream::proto::{ClientMessage, MirrorConfigMsg, ServerMessage, StrategyConfigMsg, TakeProfitLevelMsg, WatchWalletEntryMsg};
const MIN_RECONNECT_BACKOFF: Duration = Duration::from_millis(100);
const MAX_RECONNECT_BACKOFF: Duration = Duration::from_secs(2);
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
pub const STREAM_ENDPOINT: &str = "wss://stream.lasersell.io/v1/ws";
pub const LOCAL_STREAM_ENDPOINT: &str = "ws://localhost:8082/v1/ws";
#[derive(Clone)]
pub struct StreamClient {
api_key: SecretString,
local: bool,
endpoint_override: Option<String>,
spki_pins: Option<Vec<String>>,
}
impl StreamClient {
pub fn new(api_key: SecretString) -> Self {
Self {
api_key,
local: false,
endpoint_override: None,
spki_pins: None,
}
}
pub fn with_local_mode(mut self, local: bool) -> Self {
self.local = local;
self
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
let endpoint = endpoint.into();
self.endpoint_override = Some(endpoint.trim_end().to_string());
self
}
pub fn with_spki_pins(mut self, spki_sha256_b64: Vec<String>) -> Self {
self.spki_pins = Some(spki_sha256_b64);
self
}
pub async fn connect(
&self,
configure: StreamConfigure,
) -> Result<StreamConnection, StreamClientError> {
validate_strategy_thresholds(&configure.strategy, configure.deadline_timeout_sec)?;
let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
let (status_tx, status_rx) = mpsc::unbounded_channel();
let (ready_tx, ready_rx) = oneshot::channel();
let url = self.endpoint().to_string();
let api_key = self.api_key.clone();
let tls_connector = self.spki_pins.as_ref().map(|pins| {
let config = super::tls::build_pinned_tls_config(pins);
Connector::Rustls(std::sync::Arc::new(config))
});
tokio::spawn(async move {
stream_connection_worker(
url,
api_key,
configure,
outbound_rx,
inbound_tx,
status_tx,
ready_tx,
tls_connector,
)
.await;
});
match ready_rx.await {
Ok(Ok(())) => Ok(StreamConnection {
sender: StreamSender { tx: outbound_tx },
receiver: inbound_rx,
status: Some(status_rx),
}),
Ok(Err(err)) => Err(err),
Err(_) => Err(StreamClientError::Protocol(
"stream worker stopped before initial connect".to_string(),
)),
}
}
pub async fn connect_with_wallets(
&self,
proofs: &[crate::exit_api::WalletProof],
strategy: StrategyConfigMsg,
deadline_timeout_sec: u64,
) -> Result<StreamConnection, StreamClientError> {
let exit_client = ExitApiClient::with_api_key(self.api_key.clone())
.map_err(|e| StreamClientError::Protocol(format!("failed to create exit client: {e}")))?;
for proof in proofs {
exit_client
.register_wallet(proof, None)
.await
.map_err(|e| StreamClientError::Protocol(format!("wallet registration failed: {e}")))?;
}
let wallet_pubkeys = proofs.iter().map(|p| p.wallet_pubkey.clone()).collect();
self.connect(StreamConfigure {
wallet_pubkeys,
strategy,
deadline_timeout_sec,
send_mode: None,
tip_lamports: None,
watch_wallets: Vec::new(),
mirror_config: None,
})
.await
}
fn endpoint(&self) -> &str {
if let Some(endpoint) = self.endpoint_override.as_deref() {
return endpoint;
}
if self.local {
LOCAL_STREAM_ENDPOINT
} else {
STREAM_ENDPOINT
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct StreamConfigure {
pub wallet_pubkeys: Vec<String>,
pub strategy: StrategyConfigMsg,
pub deadline_timeout_sec: u64,
pub send_mode: Option<String>,
pub tip_lamports: Option<u64>,
pub watch_wallets: Vec<WatchWalletEntryMsg>,
pub mirror_config: Option<MirrorConfigMsg>,
}
impl StreamConfigure {
pub fn new(wallet_pubkeys: Vec<String>, strategy: StrategyConfigMsg) -> Self {
Self {
wallet_pubkeys,
strategy,
deadline_timeout_sec: 0,
send_mode: None,
tip_lamports: None,
watch_wallets: Vec::new(),
mirror_config: None,
}
}
pub fn single_wallet(wallet_pubkey: impl Into<String>, strategy: StrategyConfigMsg) -> Self {
Self {
wallet_pubkeys: vec![wallet_pubkey.into()],
strategy,
deadline_timeout_sec: 0,
send_mode: None,
tip_lamports: None,
watch_wallets: Vec::new(),
mirror_config: None,
}
}
pub fn single_wallet_optional(
wallet_pubkey: impl Into<String>,
target_profit_pct: Option<f64>,
stop_loss_pct: Option<f64>,
deadline_timeout_sec: Option<u64>,
) -> Self {
Self {
wallet_pubkeys: vec![wallet_pubkey.into()],
strategy: strategy_config_from_optional(target_profit_pct, stop_loss_pct, None, None),
deadline_timeout_sec: deadline_timeout_sec.unwrap_or(0),
send_mode: None,
tip_lamports: None,
watch_wallets: Vec::new(),
mirror_config: None,
}
}
}
pub fn strategy_config_from_optional(
target_profit_pct: Option<f64>,
stop_loss_pct: Option<f64>,
trailing_stop_pct: Option<f64>,
sell_on_graduation: Option<bool>,
) -> StrategyConfigMsg {
StrategyConfigBuilder::new()
.target_profit_pct(target_profit_pct.unwrap_or(0.0))
.stop_loss_pct(stop_loss_pct.unwrap_or(0.0))
.trailing_stop_pct(trailing_stop_pct.unwrap_or(0.0))
.sell_on_graduation(sell_on_graduation.unwrap_or(false))
.build()
}
pub struct StrategyConfigBuilder {
msg: StrategyConfigMsg,
}
impl StrategyConfigBuilder {
pub fn new() -> Self {
Self {
msg: StrategyConfigMsg {
target_profit_pct: 0.0,
stop_loss_pct: 0.0,
trailing_stop_pct: 0.0,
sell_on_graduation: false,
take_profit_levels: Vec::new(),
liquidity_guard: false,
breakeven_trail_pct: 0.0,
},
}
}
pub fn target_profit_pct(mut self, pct: f64) -> Self {
self.msg.target_profit_pct = pct;
self
}
pub fn stop_loss_pct(mut self, pct: f64) -> Self {
self.msg.stop_loss_pct = pct;
self
}
pub fn trailing_stop_pct(mut self, pct: f64) -> Self {
self.msg.trailing_stop_pct = pct;
self
}
pub fn sell_on_graduation(mut self, enabled: bool) -> Self {
self.msg.sell_on_graduation = enabled;
self
}
pub fn take_profit_levels(mut self, levels: Vec<TakeProfitLevelMsg>) -> Self {
self.msg.take_profit_levels = levels;
self
}
pub fn liquidity_guard(mut self, enabled: bool) -> Self {
self.msg.liquidity_guard = enabled;
self
}
pub fn breakeven_trail_pct(mut self, pct: f64) -> Self {
self.msg.breakeven_trail_pct = pct;
self
}
pub fn build(self) -> StrategyConfigMsg {
self.msg
}
}
impl Default for StrategyConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StreamConnectionStatus {
Connected,
Disconnected,
}
#[derive(Debug)]
pub struct StreamConnection {
sender: StreamSender,
receiver: mpsc::UnboundedReceiver<ServerMessage>,
status: Option<mpsc::UnboundedReceiver<StreamConnectionStatus>>,
}
impl StreamConnection {
pub fn sender(&self) -> StreamSender {
self.sender.clone()
}
pub fn split(self) -> (StreamSender, mpsc::UnboundedReceiver<ServerMessage>) {
(self.sender, self.receiver)
}
pub fn split_with_status(
self,
) -> (
StreamSender,
mpsc::UnboundedReceiver<ServerMessage>,
mpsc::UnboundedReceiver<StreamConnectionStatus>,
) {
(self.sender, self.receiver, self.status.unwrap_or_else(|| mpsc::unbounded_channel().1))
}
pub fn take_status(&mut self) -> Option<mpsc::UnboundedReceiver<StreamConnectionStatus>> {
self.status.take()
}
pub async fn recv(&mut self) -> Option<ServerMessage> {
self.receiver.recv().await
}
pub fn into_lanes(self, low_capacity: usize) -> StreamConnectionLanes {
let (high_tx, high_rx) = mpsc::unbounded_channel();
let (low_tx, low_rx) = mpsc::channel(low_capacity);
let _status = self.status; let mut receiver = self.receiver;
tokio::spawn(async move {
while let Some(message) = receiver.recv().await {
match message {
ServerMessage::PnlUpdate { .. } => {
let _ = low_tx.try_send(message);
}
_ => {
let _ = high_tx.send(message);
}
}
}
});
StreamConnectionLanes {
sender: self.sender,
high: high_rx,
low: low_rx,
}
}
}
#[derive(Debug)]
pub struct StreamConnectionLanes {
sender: StreamSender,
high: mpsc::UnboundedReceiver<ServerMessage>,
low: mpsc::Receiver<ServerMessage>,
}
impl StreamConnectionLanes {
pub fn sender(&self) -> StreamSender {
self.sender.clone()
}
pub fn split(
self,
) -> (
StreamSender,
mpsc::UnboundedReceiver<ServerMessage>,
mpsc::Receiver<ServerMessage>,
) {
(self.sender, self.high, self.low)
}
pub async fn recv_high(&mut self) -> Option<ServerMessage> {
self.high.recv().await
}
pub async fn recv_low(&mut self) -> Option<ServerMessage> {
self.low.recv().await
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PositionSelector {
TokenAccount(String),
PositionId(u64),
}
pub trait IntoPositionSelector {
fn into_position_selector(self) -> PositionSelector;
}
impl IntoPositionSelector for PositionSelector {
fn into_position_selector(self) -> PositionSelector {
self
}
}
impl IntoPositionSelector for String {
fn into_position_selector(self) -> PositionSelector {
PositionSelector::TokenAccount(self)
}
}
impl IntoPositionSelector for &String {
fn into_position_selector(self) -> PositionSelector {
PositionSelector::TokenAccount(self.clone())
}
}
impl IntoPositionSelector for &str {
fn into_position_selector(self) -> PositionSelector {
PositionSelector::TokenAccount(self.to_string())
}
}
impl IntoPositionSelector for u64 {
fn into_position_selector(self) -> PositionSelector {
PositionSelector::PositionId(self)
}
}
#[derive(Clone, Debug)]
pub struct StreamSender {
tx: mpsc::UnboundedSender<ClientMessage>,
}
impl StreamSender {
pub fn send(&self, message: ClientMessage) -> Result<(), StreamClientError> {
self.tx
.send(message)
.map_err(|_| StreamClientError::SendQueueClosed)
}
pub fn ping(&self, client_time_ms: u64) -> Result<(), StreamClientError> {
self.send(ClientMessage::Ping { client_time_ms })
}
pub fn update_strategy(&self, strategy: StrategyConfigMsg) -> Result<(), StreamClientError> {
self.send(ClientMessage::UpdateStrategy { strategy })
}
pub fn update_position_strategy(
&self,
position_id: u64,
strategy: StrategyConfigMsg,
) -> Result<(), StreamClientError> {
self.send(ClientMessage::UpdatePositionStrategy {
position_id,
strategy,
})
}
pub fn update_wallets(&self, wallet_pubkeys: Vec<String>) -> Result<(), StreamClientError> {
self.send(ClientMessage::UpdateWallets { wallet_pubkeys })
}
pub fn update_watch_wallets(&self, watch_wallets: Vec<WatchWalletEntryMsg>) -> Result<(), StreamClientError> {
self.send(ClientMessage::UpdateWatchWallets { watch_wallets })
}
pub fn close_position<S>(&self, selector: S) -> Result<(), StreamClientError>
where
S: IntoPositionSelector,
{
self.send(close_message(selector.into_position_selector()))
}
pub fn close_by_id(&self, position_id: u64) -> Result<(), StreamClientError> {
self.close_position(PositionSelector::PositionId(position_id))
}
pub fn request_exit_signal<S>(
&self,
selector: S,
slippage_bps: Option<u16>,
) -> Result<(), StreamClientError>
where
S: IntoPositionSelector,
{
self.send(request_exit_signal_message(
selector.into_position_selector(),
slippage_bps,
))
}
pub fn request_exit_signal_by_id(
&self,
position_id: u64,
slippage_bps: Option<u16>,
) -> Result<(), StreamClientError> {
self.request_exit_signal(PositionSelector::PositionId(position_id), slippage_bps)
}
pub fn mirror_buy_result(
&self,
mint: impl Into<String>,
success: bool,
) -> Result<(), StreamClientError> {
self.send(ClientMessage::MirrorBuyResult {
mint: mint.into(),
success,
})
}
}
fn close_message(selector: PositionSelector) -> ClientMessage {
match selector {
PositionSelector::TokenAccount(token_account) => ClientMessage::ClosePosition {
position_id: None,
token_account: Some(token_account),
},
PositionSelector::PositionId(position_id) => ClientMessage::ClosePosition {
position_id: Some(position_id),
token_account: None,
},
}
}
fn request_exit_signal_message(
selector: PositionSelector,
slippage_bps: Option<u16>,
) -> ClientMessage {
match selector {
PositionSelector::TokenAccount(token_account) => ClientMessage::RequestExitSignal {
position_id: None,
token_account: Some(token_account),
slippage_bps,
},
PositionSelector::PositionId(position_id) => ClientMessage::RequestExitSignal {
position_id: Some(position_id),
token_account: None,
slippage_bps,
},
}
}
#[derive(Debug, Error)]
pub enum StreamClientError {
#[error("websocket error: {0}")]
WebSocket(#[from] WsError),
#[error("json error: {0}")]
Json(#[from] serde_json::Error),
#[error("invalid api-key header: {0}")]
InvalidApiKeyHeader(#[from] InvalidHeaderValue),
#[error("send queue is closed")]
SendQueueClosed,
#[error("protocol error: {0}")]
Protocol(String),
}
pub(crate) fn validate_strategy_thresholds(
strategy: &StrategyConfigMsg,
deadline_timeout_sec: u64,
) -> Result<(), StreamClientError> {
validate_strategy_value(strategy.target_profit_pct, "strategy.target_profit_pct")?;
validate_strategy_value(strategy.stop_loss_pct, "strategy.stop_loss_pct")?;
validate_strategy_value(strategy.trailing_stop_pct, "strategy.trailing_stop_pct")?;
if strategy.target_profit_pct > 0.0
|| strategy.stop_loss_pct > 0.0
|| strategy.trailing_stop_pct > 0.0
|| deadline_timeout_sec > 0
{
return Ok(());
}
Err(StreamClientError::Protocol(
"at least one of strategy.target_profit_pct, strategy.stop_loss_pct, strategy.trailing_stop_pct, or deadline_timeout_sec must be > 0"
.to_string(),
))
}
fn validate_strategy_value(value: f64, field: &str) -> Result<(), StreamClientError> {
if !value.is_finite() {
return Err(StreamClientError::Protocol(format!(
"{field} must be a finite number"
)));
}
if value < 0.0 {
return Err(StreamClientError::Protocol(format!("{field} must be >= 0")));
}
Ok(())
}
enum SessionOutcome {
GracefulShutdown,
Reconnect,
}
async fn stream_connection_worker(
url: String,
api_key: SecretString,
configure: StreamConfigure,
mut outbound_rx: mpsc::UnboundedReceiver<ClientMessage>,
inbound_tx: mpsc::UnboundedSender<ServerMessage>,
status_tx: mpsc::UnboundedSender<StreamConnectionStatus>,
ready_tx: oneshot::Sender<Result<(), StreamClientError>>,
tls_connector: Option<Connector>,
) {
let mut ready_tx = Some(ready_tx);
let mut pending = VecDeque::new();
let mut backoff = MIN_RECONNECT_BACKOFF;
debug!(event = "stream_connecting");
loop {
match run_connected_session(
&url,
&api_key,
&configure,
&mut outbound_rx,
&inbound_tx,
&status_tx,
&mut pending,
&mut ready_tx,
tls_connector.clone(),
)
.await
{
Ok(SessionOutcome::GracefulShutdown) => {
info!(event = "stream_worker_graceful_shutdown");
let _ = status_tx.send(StreamConnectionStatus::Disconnected);
if let Some(tx) = ready_tx.take() {
let _ = tx.send(Err(StreamClientError::SendQueueClosed));
}
break;
}
Ok(SessionOutcome::Reconnect) => {
warn!(event = "stream_worker_reconnect");
let _ = status_tx.send(StreamConnectionStatus::Disconnected);
backoff = MIN_RECONNECT_BACKOFF;
}
Err(err) => {
warn!(event = "stream_worker_connect_error", error = %err);
if let Some(tx) = ready_tx.take() {
let _ = tx.send(Err(err));
return;
}
}
}
if outbound_rx.is_closed() {
break;
}
debug!(event = "stream_reconnect_backoff", delay_ms = backoff.as_millis() as u64);
if !collect_messages_during_delay(backoff, &mut outbound_rx, &mut pending).await {
break;
}
backoff = std::cmp::min(backoff.saturating_mul(2), MAX_RECONNECT_BACKOFF);
}
}
async fn run_connected_session(
url: &str,
api_key: &SecretString,
configure: &StreamConfigure,
outbound_rx: &mut mpsc::UnboundedReceiver<ClientMessage>,
inbound_tx: &mpsc::UnboundedSender<ServerMessage>,
status_tx: &mpsc::UnboundedSender<StreamConnectionStatus>,
pending: &mut VecDeque<ClientMessage>,
ready_tx: &mut Option<oneshot::Sender<Result<(), StreamClientError>>>,
tls_connector: Option<Connector>,
) -> Result<SessionOutcome, StreamClientError> {
let mut request = url.into_client_request()?;
let api_key_header = api_key.expose_secret().parse()?;
request.headers_mut().insert("x-api-key", api_key_header);
let (mut socket, _) = if let Some(connector) = tls_connector {
connect_async_tls_with_config(request, None, false, Some(connector)).await?
} else {
connect_async(request).await?
};
debug!(event = "stream_ws_connected");
let configure_msg = ClientMessage::Configure {
wallet_pubkeys: configure.wallet_pubkeys.clone(),
strategy: configure.strategy.clone(),
send_mode: configure.send_mode.clone(),
tip_lamports: configure.tip_lamports,
watch_wallets: configure.watch_wallets.clone(),
mirror_config: configure.mirror_config.clone(),
};
send_client_message(&mut socket, &configure_msg).await?;
debug!(event = "stream_configure_sent");
let hello_message = recv_server_message_before_configure(&mut socket).await?;
if !matches!(&hello_message, ServerMessage::HelloOk { .. }) {
let detail = match &hello_message {
ServerMessage::Error { code, message } => format!("server error [{code}]: {message}"),
other => format!("unexpected message: {}", serde_json::to_string(other).unwrap_or_else(|_| format!("{other:?}"))),
};
return Err(StreamClientError::Protocol(
format!("expected hello_ok after configure, got: {detail}"),
));
}
debug!(event = "stream_hello_ok_received");
let _ = inbound_tx.send(hello_message);
let _ = status_tx.send(StreamConnectionStatus::Connected);
info!(event = "stream_configured");
if let Some(tx) = ready_tx.take() {
let _ = tx.send(Ok(()));
}
while let Some(next) = pending.pop_front() {
if send_client_message(&mut socket, &next).await.is_err() {
pending.push_front(next);
return Ok(SessionOutcome::Reconnect);
}
}
let mut keepalive = tokio::time::interval(KEEPALIVE_INTERVAL);
keepalive.reset();
loop {
tokio::select! {
_ = keepalive.tick() => {
if socket.send(Message::Ping(vec![].into())).await.is_err() {
return Ok(SessionOutcome::Reconnect);
}
}
maybe_outbound = outbound_rx.recv() => {
match maybe_outbound {
Some(client_msg) => {
if send_client_message(&mut socket, &client_msg).await.is_err() {
pending.push_front(client_msg);
return Ok(SessionOutcome::Reconnect);
}
}
None => {
let _ = socket.close(None).await;
return Ok(SessionOutcome::GracefulShutdown);
}
}
}
maybe_inbound = socket.next() => {
match maybe_inbound {
Some(Ok(Message::Text(text))) => {
trace!(event = "stream_raw_message", raw = %text);
match parse_server_message(&text) {
Ok(server_msg) => {
debug!(event = "stream_server_msg", msg_type = server_msg_label(&server_msg));
let _ = inbound_tx.send(server_msg);
}
Err(err) => {
warn!(
event = "stream_msg_parse_error",
error = %err,
raw_message = %text,
"skipping unparseable server message"
);
}
}
}
Some(Ok(Message::Ping(payload))) => {
if socket.send(Message::Pong(payload)).await.is_err() {
return Ok(SessionOutcome::Reconnect);
}
}
Some(Ok(Message::Pong(_))) => {}
Some(Ok(Message::Close(_))) => {
debug!(event = "stream_ws_close_received");
return Ok(SessionOutcome::Reconnect);
}
Some(Ok(_)) => return Ok(SessionOutcome::Reconnect),
Some(Err(_)) => {
warn!(event = "stream_ws_error");
return Ok(SessionOutcome::Reconnect);
}
None => {
debug!(event = "stream_ws_ended");
return Ok(SessionOutcome::Reconnect);
}
}
}
}
}
}
async fn recv_server_message_before_configure<S>(
socket: &mut tokio_tungstenite::WebSocketStream<S>,
) -> Result<ServerMessage, StreamClientError>
where
tokio_tungstenite::WebSocketStream<S>: futures_util::Sink<Message, Error = WsError>
+ Stream<Item = Result<Message, WsError>>
+ Unpin,
{
loop {
match socket.next().await {
Some(Ok(Message::Text(text))) => return parse_server_message(&text),
Some(Ok(Message::Ping(payload))) => {
socket.send(Message::Pong(payload)).await?;
}
Some(Ok(Message::Pong(_))) => {}
Some(Ok(Message::Close(_))) => {
return Err(StreamClientError::Protocol(
"socket closed before hello_ok".to_string(),
));
}
Some(Ok(_)) => {
return Err(StreamClientError::Protocol(
"received non-text frame before hello_ok".to_string(),
));
}
Some(Err(err)) => return Err(StreamClientError::WebSocket(err)),
None => {
return Err(StreamClientError::Protocol(
"socket ended before hello_ok".to_string(),
));
}
}
}
}
fn server_msg_label(msg: &ServerMessage) -> &'static str {
match msg {
ServerMessage::HelloOk { .. } => "hello_ok",
ServerMessage::BalanceUpdate { .. } => "balance_update",
ServerMessage::PositionOpened { .. } => "position_opened",
ServerMessage::PositionClosed { .. } => "position_closed",
ServerMessage::ExitSignalWithTx { .. } => "exit_signal_with_tx",
ServerMessage::PnlUpdate { .. } => "pnl_update",
ServerMessage::LiquiditySnapshot { .. } => "liquidity_snapshot",
ServerMessage::TradeTick { .. } => "trade_tick",
ServerMessage::Pong { .. } => "pong",
ServerMessage::Error { .. } => "error",
ServerMessage::MirrorBuySignal { .. } => "mirror_buy_signal",
ServerMessage::MirrorBuyFailed { .. } => "mirror_buy_failed",
ServerMessage::MirrorWalletAutoDisabled { .. } => "mirror_wallet_auto_disabled",
}
}
fn parse_server_message(text: &str) -> Result<ServerMessage, StreamClientError> {
serde_json::from_str(text).map_err(StreamClientError::Json)
}
async fn send_client_message<S>(
socket: &mut tokio_tungstenite::WebSocketStream<S>,
message: &ClientMessage,
) -> Result<(), StreamClientError>
where
tokio_tungstenite::WebSocketStream<S>: futures_util::Sink<Message, Error = WsError> + Unpin,
{
let text = serde_json::to_string(message)?;
socket.send(Message::Text(text)).await?;
Ok(())
}
async fn collect_messages_during_delay(
delay: Duration,
outbound_rx: &mut mpsc::UnboundedReceiver<ClientMessage>,
pending: &mut VecDeque<ClientMessage>,
) -> bool {
let sleep = tokio::time::sleep(delay);
tokio::pin!(sleep);
loop {
tokio::select! {
_ = &mut sleep => return true,
maybe_message = outbound_rx.recv() => {
match maybe_message {
Some(message) => pending.push_back(message),
None => return false,
}
}
}
}
}
#[cfg(test)]
mod tests {
use secrecy::SecretString;
use super::{
strategy_config_from_optional, validate_strategy_thresholds, StreamClient,
LOCAL_STREAM_ENDPOINT, STREAM_ENDPOINT,
};
#[test]
fn stream_client_uses_production_endpoint_by_default() {
let client = StreamClient::new(SecretString::new("test-api-key".to_string()));
assert_eq!(client.endpoint(), STREAM_ENDPOINT);
}
#[test]
fn stream_client_uses_local_endpoint_when_enabled() {
let client =
StreamClient::new(SecretString::new("test-api-key".to_string())).with_local_mode(true);
assert_eq!(client.endpoint(), LOCAL_STREAM_ENDPOINT);
}
#[test]
fn stream_client_endpoint_override_takes_precedence() {
let client = StreamClient::new(SecretString::new("test-api-key".to_string()))
.with_local_mode(true)
.with_endpoint("wss://stream-dev.example/ws \n");
assert_eq!(client.endpoint(), "wss://stream-dev.example/ws");
}
#[test]
fn optional_strategy_builder_defaults_unset_values_to_zero() {
let strategy = strategy_config_from_optional(None, Some(1.5), None, None);
assert_eq!(strategy.target_profit_pct, 0.0);
assert_eq!(strategy.stop_loss_pct, 1.5);
}
#[test]
fn validation_accepts_target_only() {
let strategy = strategy_config_from_optional(Some(2.0), None, None, None);
assert!(validate_strategy_thresholds(&strategy, 0).is_ok());
}
#[test]
fn validation_accepts_stop_only() {
let strategy = strategy_config_from_optional(None, Some(1.0), None, None);
assert!(validate_strategy_thresholds(&strategy, 0).is_ok());
}
#[test]
fn validation_accepts_deadline_only() {
let strategy = strategy_config_from_optional(None, None, None, None);
assert!(validate_strategy_thresholds(&strategy, 45).is_ok());
}
#[test]
fn validation_rejects_when_all_thresholds_disabled() {
let strategy = strategy_config_from_optional(None, None, None, None);
assert!(validate_strategy_thresholds(&strategy, 0).is_err());
}
#[test]
fn validation_rejects_negative_values() {
let strategy = strategy_config_from_optional(Some(-1.0), None, None, None);
assert!(validate_strategy_thresholds(&strategy, 0).is_err());
}
}