use std::collections::VecDeque;
use std::time::Duration;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine as _;
use futures_util::{SinkExt, StreamExt};
use http::header::{HeaderName, HeaderValue};
use serde_json::Value;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use crate::config::KEEPALIVE_PING_INTERVAL_SECS;
use crate::error::{Error, Result};
pub struct ProcessSocket {
socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
pending_frames: VecDeque<Value>,
}
impl ProcessSocket {
pub async fn connect(base_url: &str, token: &str, path: &str) -> Result<Self> {
let mut request = ws_url(base_url, path)?.into_client_request()?;
request.headers_mut().insert(
HeaderName::from_static("authorization"),
HeaderValue::from_str(&format!("Bearer {token}"))?,
);
let (socket, _response) = connect_async(request).await?;
Ok(Self {
socket,
pending_frames: VecDeque::new(),
})
}
pub async fn send_json(&mut self, payload: &Value) -> Result<()> {
self.socket
.send(Message::Text(payload.to_string().into()))
.await?;
Ok(())
}
pub async fn send_stdin(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
self.send_json(&stdin_payload(data)).await?;
self.wait_for_ack("stdin_ack").await
}
pub async fn close_stdin(&mut self) -> Result<()> {
self.send_json(&close_stdin_payload()).await?;
self.wait_for_ack("close_stdin_ack").await
}
pub async fn send_ping(&mut self) -> Result<()> {
self.socket
.send(Message::Ping(b"watasu-sdk".to_vec().into()))
.await?;
Ok(())
}
pub async fn close(&mut self) -> Result<()> {
self.socket.close(None).await?;
Ok(())
}
pub async fn next_frame(&mut self) -> Result<Option<Value>> {
if let Some(frame) = self.pending_frames.pop_front() {
return Ok(Some(frame));
}
self.read_frame().await
}
async fn wait_for_ack(&mut self, ack_type: &str) -> Result<()> {
while let Some(frame) = self.read_frame().await? {
if frame.get("type").and_then(Value::as_str) == Some(ack_type) {
return Ok(());
}
self.pending_frames.push_back(frame);
}
Err(Error::Sandbox(format!(
"process websocket closed before {ack_type}"
)))
}
async fn read_frame(&mut self) -> Result<Option<Value>> {
let idle = Duration::from_secs((KEEPALIVE_PING_INTERVAL_SECS / 2).max(1));
loop {
match timeout(idle, self.socket.next()).await {
Ok(Some(Ok(message))) => match message {
Message::Text(text) => {
let frame: Value = serde_json::from_str(&text)?;
match frame.get("type").and_then(Value::as_str) {
Some("ready" | "pong") => continue,
Some("error") => {
let message = frame
.get("message")
.or_else(|| frame.get("code"))
.and_then(Value::as_str)
.unwrap_or("process error");
return Err(Error::Sandbox(message.to_string()));
}
_ => return Ok(Some(frame)),
}
}
Message::Binary(_) => {
return Err(Error::Sandbox(
"process websocket returned binary frame".into(),
))
}
Message::Close(_) => return Ok(None),
Message::Ping(payload) => self.socket.send(Message::Pong(payload)).await?,
Message::Pong(_) => continue,
Message::Frame(_) => continue,
},
Ok(Some(Err(error))) => return Err(error.into()),
Ok(None) => return Ok(None),
Err(_elapsed) => {
self.send_ping().await?;
continue;
}
}
}
}
pub fn keepalive_interval_secs(&self) -> u64 {
KEEPALIVE_PING_INTERVAL_SECS
}
}
pub fn encode_runtime_data(data: impl AsRef<[u8]>) -> String {
BASE64.encode(data)
}
pub(crate) fn stdin_payload(data: impl AsRef<[u8]>) -> Value {
serde_json::json!({
"type": "stdin",
"data": encode_runtime_data(data)
})
}
pub(crate) fn close_stdin_payload() -> Value {
serde_json::json!({"type": "close_stdin"})
}
pub fn decode_runtime_data_bytes(value: &str) -> Vec<u8> {
BASE64
.decode(value)
.unwrap_or_else(|_| value.as_bytes().to_vec())
}
pub fn decode_runtime_data(value: &str) -> String {
String::from_utf8_lossy(&decode_runtime_data_bytes(value)).into_owned()
}
fn ws_url(base_url: &str, path: &str) -> Result<String> {
let mut url = url::Url::parse(base_url)?;
let scheme = match url.scheme() {
"https" => "wss",
"http" => "ws",
other => other,
}
.to_string();
url.set_scheme(&scheme)
.map_err(|_| Error::Sandbox("invalid websocket scheme".into()))?;
url.set_path(path.split('?').next().unwrap_or(path));
if let Some(query) = path.split_once('?').map(|(_, query)| query) {
url.set_query(Some(query));
}
Ok(url.to_string())
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{close_stdin_payload, decode_runtime_data_bytes, stdin_payload};
#[test]
fn process_input_payloads_match_runtime_protocol() {
assert_eq!(
stdin_payload("hi\n"),
json!({"type": "stdin", "data": "aGkK"})
);
assert_eq!(close_stdin_payload(), json!({"type": "close_stdin"}));
}
#[test]
fn runtime_data_decoder_preserves_bytes() {
assert_eq!(
decode_runtime_data_bytes("AJ+Slg=="),
vec![0, 159, 146, 150]
);
}
}