use crate::error::{Error, Result};
use crate::sprite::Sprite;
use crate::types::{ExitStatus, Output};
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};
use tokio_tungstenite::tungstenite::Message;
const STREAM_STDIN: u8 = 0;
const STREAM_STDOUT: u8 = 1;
const STREAM_STDERR: u8 = 2;
const STREAM_EXIT: u8 = 3;
const STREAM_STDIN_EOF: u8 = 4;
pub struct Command {
sprite: Sprite,
program: String,
args: Vec<String>,
env: HashMap<String, String>,
dir: Option<String>,
tty: bool,
control_mode: bool,
max_run_after_disconnect: Option<u32>,
}
impl Command {
pub(crate) fn new(sprite: Sprite, program: impl Into<String>) -> Self {
Self {
sprite,
program: program.into(),
args: Vec::new(),
env: HashMap::new(),
dir: None,
tty: false,
control_mode: false,
max_run_after_disconnect: None,
}
}
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.args.extend(args.into_iter().map(|s| s.into()));
self
}
pub fn env(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
self.env.insert(key.into(), val.into());
self
}
pub fn envs<I, K, V>(mut self, vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
for (k, v) in vars {
self.env.insert(k.into(), v.into());
}
self
}
pub fn current_dir(mut self, dir: impl Into<String>) -> Self {
self.dir = Some(dir.into());
self
}
pub fn tty(mut self, enable: bool) -> Self {
self.tty = enable;
self
}
pub fn control_mode(mut self, enable: bool) -> Self {
self.control_mode = enable;
self
}
pub fn max_run_after_disconnect(mut self, seconds: u32) -> Self {
self.max_run_after_disconnect = Some(seconds);
self
}
fn build_ws_url(&self) -> Result<String> {
let base_url = self.sprite.client().base_url();
let ws_base = base_url
.replace("https://", "wss://")
.replace("http://", "ws://");
let mut url = format!(
"{}/v1/sprites/{}/exec?path={}",
ws_base,
self.sprite.name(),
urlencoding::encode(&self.program)
);
url.push_str(&format!("&cmd={}", urlencoding::encode(&self.program)));
for arg in &self.args {
url.push_str(&format!("&cmd={}", urlencoding::encode(arg)));
}
if self.tty {
url.push_str("&tty=true");
}
if self.control_mode {
url.push_str("&control=true");
}
if let Some(seconds) = self.max_run_after_disconnect {
url.push_str(&format!("&max_run_after_disconnect={seconds}"));
}
if let Some(ref dir) = self.dir {
url.push_str(&format!("&cwd={}", urlencoding::encode(dir)));
}
for (key, val) in &self.env {
url.push_str(&format!(
"&env={}={}",
urlencoding::encode(key),
urlencoding::encode(val)
));
}
Ok(url)
}
pub async fn output(&self) -> Result<Output> {
let url = self.build_ws_url()?;
let token = self.sprite.client().token();
let ws_key = {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before UNIX epoch")
.as_nanos();
base64_encode(&nanos.to_le_bytes()[..16])
};
let request = tokio_tungstenite::tungstenite::http::Request::builder()
.method("GET")
.uri(&url)
.header("Authorization", format!("Bearer {token}"))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", &ws_key)
.header("Host", extract_host(&url).unwrap_or("api.sprites.dev"))
.body(())
.map_err(|e| Error::InvalidResponse(e.to_string()))?;
let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (mut _write, mut read) = ws_stream.split();
let mut stdout = Vec::new();
let mut stderr = Vec::new();
let mut exit_code: Option<i32> = None;
while let Some(msg) = read.next().await {
match msg? {
Message::Binary(data) => {
if data.is_empty() {
continue;
}
if self.tty {
stdout.extend_from_slice(&data);
} else {
let stream_id = data[0];
let payload = &data[1..];
match stream_id {
1 => stdout.extend_from_slice(payload),
2 => stderr.extend_from_slice(payload),
3 => {
if let Ok(code_str) = std::str::from_utf8(payload) {
exit_code = code_str.trim().parse().ok();
}
}
_ => {}
}
}
}
Message::Text(text) => {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(code) = val.get("exit_code").and_then(|c| c.as_i64()) {
exit_code = Some(code as i32);
}
}
}
Message::Close(_) => break,
_ => {}
}
}
Ok(Output {
status: exit_code.unwrap_or(0),
stdout,
stderr,
})
}
pub async fn status(&self) -> Result<ExitStatus> {
let output = self.output().await?;
Ok(ExitStatus::new(output.status))
}
pub async fn combined_output(&self) -> Result<Output> {
let url = self.build_ws_url()?;
let token = self.sprite.client().token();
let ws_key = {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before UNIX epoch")
.as_nanos();
base64_encode(&nanos.to_le_bytes()[..16])
};
let request = tokio_tungstenite::tungstenite::http::Request::builder()
.method("GET")
.uri(&url)
.header("Authorization", format!("Bearer {token}"))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", &ws_key)
.header("Host", extract_host(&url).unwrap_or("api.sprites.dev"))
.body(())
.map_err(|e| Error::InvalidResponse(e.to_string()))?;
let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (mut _write, mut read) = ws_stream.split();
let mut combined = Vec::new();
let mut exit_code: Option<i32> = None;
while let Some(msg) = read.next().await {
match msg? {
Message::Binary(data) => {
if data.is_empty() {
continue;
}
if self.tty {
combined.extend_from_slice(&data);
} else {
let stream_id = data[0];
let payload = &data[1..];
match stream_id {
STREAM_STDOUT | STREAM_STDERR => combined.extend_from_slice(payload),
STREAM_EXIT => {
if let Ok(code_str) = std::str::from_utf8(payload) {
exit_code = code_str.trim().parse().ok();
}
}
_ => {}
}
}
}
Message::Text(text) => {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(code) = val.get("exit_code").and_then(|c| c.as_i64()) {
exit_code = Some(code as i32);
}
}
}
Message::Close(_) => break,
_ => {}
}
}
Ok(Output {
status: exit_code.unwrap_or(0),
stdout: combined,
stderr: Vec::new(), })
}
pub async fn spawn(self) -> Result<Child> {
let url = self.build_ws_url()?;
let token = self.sprite.client().token().to_string();
let tty = self.tty;
let ws_key = {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before UNIX epoch")
.as_nanos();
base64_encode(&nanos.to_le_bytes()[..16])
};
let request = tokio_tungstenite::tungstenite::http::Request::builder()
.method("GET")
.uri(&url)
.header("Authorization", format!("Bearer {token}"))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", &ws_key)
.header("Host", extract_host(&url).unwrap_or("api.sprites.dev"))
.body(())
.map_err(|e| Error::InvalidResponse(e.to_string()))?;
let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (mut ws_write, mut ws_read) = ws_stream.split();
let (control_tx, mut control_rx) = mpsc::channel::<ControlMessage>(32);
let (stdout_tx, stdout_rx) = mpsc::channel::<Vec<u8>>(32);
let (stderr_tx, stderr_rx) = mpsc::channel::<Vec<u8>>(32);
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
tokio::spawn(async move {
let mut exit_code: Option<i32> = None;
loop {
tokio::select! {
Some(msg) = control_rx.recv() => {
match msg {
ControlMessage::Stdin(data) => {
let mut frame = vec![STREAM_STDIN];
frame.extend(data);
if ws_write.send(Message::Binary(frame)).await.is_err() {
break;
}
}
ControlMessage::StdinClose => {
let frame = vec![STREAM_STDIN_EOF];
let _ = ws_write.send(Message::Binary(frame)).await;
}
ControlMessage::Resize { rows, cols } => {
let resize_msg = serde_json::json!({
"type": "resize",
"rows": rows,
"cols": cols
});
let _ = ws_write.send(Message::Text(resize_msg.to_string())).await;
}
ControlMessage::Kill => {
let _ = ws_write.close().await;
break;
}
}
}
Some(msg) = ws_read.next() => {
match msg {
Ok(Message::Binary(data)) => {
if data.is_empty() {
continue;
}
if tty {
let _ = stdout_tx.send(data.clone()).await;
} else {
let stream_id = data[0];
let payload = data[1..].to_vec();
match stream_id {
STREAM_STDOUT => {
let _ = stdout_tx.send(payload).await;
}
STREAM_STDERR => {
let _ = stderr_tx.send(payload).await;
}
STREAM_EXIT => {
if let Ok(code_str) = std::str::from_utf8(&payload) {
exit_code = code_str.trim().parse().ok();
}
}
_ => {}
}
}
}
Ok(Message::Text(text)) => {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(code) = val.get("exit_code").and_then(|c| c.as_i64()) {
exit_code = Some(code as i32);
}
}
}
Ok(Message::Close(_)) => break,
Err(_) => break,
_ => {}
}
}
else => break,
}
}
let _ = exit_tx.send(exit_code.unwrap_or(1));
});
Ok(Child {
stdin: Some(ChildStdin {
control_tx: control_tx.clone(),
}),
stdout: Some(ChildStdout {
rx: stdout_rx,
buffer: Vec::new(),
}),
stderr: Some(ChildStderr {
rx: stderr_rx,
buffer: Vec::new(),
}),
control_tx,
exit_rx: Some(exit_rx),
tty,
})
}
}
impl std::fmt::Debug for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Command")
.field("sprite", &self.sprite.name())
.field("program", &self.program)
.field("args", &self.args)
.field("dir", &self.dir)
.field("tty", &self.tty)
.finish()
}
}
mod urlencoding {
pub fn encode(s: &str) -> String {
let mut result = String::new();
for c in s.chars() {
match c {
'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
_ => {
for byte in c.to_string().as_bytes() {
result.push_str(&format!("%{byte:02X}"));
}
}
}
}
result
}
}
fn base64_encode(data: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
let mut i = 0;
while i < data.len() {
let b0 = data[i];
let b1 = if i + 1 < data.len() { data[i + 1] } else { 0 };
let b2 = if i + 2 < data.len() { data[i + 2] } else { 0 };
result.push(ALPHABET[(b0 >> 2) as usize] as char);
result.push(ALPHABET[((b0 & 0x03) << 4 | b1 >> 4) as usize] as char);
if i + 1 < data.len() {
result.push(ALPHABET[((b1 & 0x0f) << 2 | b2 >> 6) as usize] as char);
} else {
result.push('=');
}
if i + 2 < data.len() {
result.push(ALPHABET[(b2 & 0x3f) as usize] as char);
} else {
result.push('=');
}
i += 3;
}
result
}
fn extract_host(url: &str) -> Option<&str> {
let without_scheme = url.strip_prefix("wss://").or_else(|| url.strip_prefix("ws://"))?;
without_scheme.split('/').next()
}
pub struct Child {
stdin: Option<ChildStdin>,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
control_tx: mpsc::Sender<ControlMessage>,
exit_rx: Option<oneshot::Receiver<i32>>,
tty: bool,
}
enum ControlMessage {
Stdin(Vec<u8>),
StdinClose,
Resize { rows: u16, cols: u16 },
Kill,
}
impl Child {
pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
self.stdin.as_mut()
}
pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
self.stdout.as_mut()
}
pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
self.stderr.as_mut()
}
pub fn take_stdin(&mut self) -> Option<ChildStdin> {
self.stdin.take()
}
pub fn take_stdout(&mut self) -> Option<ChildStdout> {
self.stdout.take()
}
pub fn take_stderr(&mut self) -> Option<ChildStderr> {
self.stderr.take()
}
pub fn resize(&mut self, rows: u16, cols: u16) -> Result<()> {
if !self.tty {
return Err(Error::InvalidResponse(
"resize() only works in TTY mode".to_string(),
));
}
self.control_tx
.try_send(ControlMessage::Resize { rows, cols })
.map_err(|_| Error::InvalidResponse("Child process already exited".to_string()))
}
pub fn kill(&mut self) -> Result<()> {
self.control_tx
.try_send(ControlMessage::Kill)
.map_err(|_| Error::InvalidResponse("Child process already exited".to_string()))
}
pub async fn wait(&mut self) -> Result<ExitStatus> {
match self.exit_rx.take() {
Some(rx) => {
let code = rx.await.unwrap_or(1);
Ok(ExitStatus::new(code))
}
None => Err(Error::InvalidResponse(
"wait() already called or child not started".to_string(),
)),
}
}
pub fn is_tty(&self) -> bool {
self.tty
}
}
pub struct ChildStdin {
control_tx: mpsc::Sender<ControlMessage>,
}
impl ChildStdin {
pub async fn write(&self, data: &[u8]) -> Result<()> {
self.control_tx
.send(ControlMessage::Stdin(data.to_vec()))
.await
.map_err(|_| Error::InvalidResponse("Child process already exited".to_string()))
}
pub async fn write_all(&self, data: &[u8]) -> Result<()> {
self.write(data).await
}
pub async fn close(&self) -> Result<()> {
self.control_tx
.send(ControlMessage::StdinClose)
.await
.map_err(|_| Error::InvalidResponse("Child process already exited".to_string()))
}
}
pub struct ChildStdout {
rx: mpsc::Receiver<Vec<u8>>,
buffer: Vec<u8>,
}
impl ChildStdout {
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if !self.buffer.is_empty() {
let n = std::cmp::min(buf.len(), self.buffer.len());
buf[..n].copy_from_slice(&self.buffer[..n]);
self.buffer.drain(..n);
return Ok(n);
}
match self.rx.recv().await {
Some(data) => {
let n = std::cmp::min(buf.len(), data.len());
buf[..n].copy_from_slice(&data[..n]);
if n < data.len() {
self.buffer.extend_from_slice(&data[n..]);
}
Ok(n)
}
None => Ok(0), }
}
pub async fn read_to_end(&mut self) -> Result<Vec<u8>> {
let mut result = std::mem::take(&mut self.buffer);
while let Some(data) = self.rx.recv().await {
result.extend(data);
}
Ok(result)
}
pub async fn read_to_string(&mut self) -> Result<String> {
let data = self.read_to_end().await?;
Ok(String::from_utf8_lossy(&data).to_string())
}
}
pub struct ChildStderr {
rx: mpsc::Receiver<Vec<u8>>,
buffer: Vec<u8>,
}
impl ChildStderr {
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if !self.buffer.is_empty() {
let n = std::cmp::min(buf.len(), self.buffer.len());
buf[..n].copy_from_slice(&self.buffer[..n]);
self.buffer.drain(..n);
return Ok(n);
}
match self.rx.recv().await {
Some(data) => {
let n = std::cmp::min(buf.len(), data.len());
buf[..n].copy_from_slice(&data[..n]);
if n < data.len() {
self.buffer.extend_from_slice(&data[n..]);
}
Ok(n)
}
None => Ok(0), }
}
pub async fn read_to_end(&mut self) -> Result<Vec<u8>> {
let mut result = std::mem::take(&mut self.buffer);
while let Some(data) = self.rx.recv().await {
result.extend(data);
}
Ok(result)
}
pub async fn read_to_string(&mut self) -> Result<String> {
let data = self.read_to_end().await?;
Ok(String::from_utf8_lossy(&data).to_string())
}
#[doc(hidden)]
pub fn new_public(rx: mpsc::Receiver<Vec<u8>>) -> Self {
Self {
rx,
buffer: Vec::new(),
}
}
}
pub enum ControlMessagePublic {
Stdin(Vec<u8>),
Resize { rows: u16, cols: u16 },
Kill,
}
impl Child {
#[doc(hidden)]
pub fn new_public(
stdin: Option<ChildStdin>,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
control_tx: mpsc::Sender<ControlMessagePublic>,
exit_rx: Option<oneshot::Receiver<i32>>,
tty: bool,
) -> Self {
let (internal_tx, mut internal_rx) = mpsc::channel::<ControlMessage>(32);
let external_tx = control_tx;
tokio::spawn(async move {
while let Some(msg) = internal_rx.recv().await {
let public_msg = match msg {
ControlMessage::Stdin(data) => ControlMessagePublic::Stdin(data),
ControlMessage::StdinClose => continue, ControlMessage::Resize { rows, cols } => {
ControlMessagePublic::Resize { rows, cols }
}
ControlMessage::Kill => ControlMessagePublic::Kill,
};
if external_tx.send(public_msg).await.is_err() {
break;
}
}
});
Self {
stdin,
stdout,
stderr,
control_tx: internal_tx,
exit_rx,
tty,
}
}
}
impl ChildStdin {
#[doc(hidden)]
pub fn new_public(control_tx: mpsc::Sender<ControlMessagePublic>) -> Self {
let (internal_tx, mut internal_rx) = mpsc::channel::<ControlMessage>(32);
let external_tx = control_tx;
tokio::spawn(async move {
while let Some(msg) = internal_rx.recv().await {
let public_msg = match msg {
ControlMessage::Stdin(data) => ControlMessagePublic::Stdin(data),
ControlMessage::StdinClose => continue,
ControlMessage::Resize { rows, cols } => {
ControlMessagePublic::Resize { rows, cols }
}
ControlMessage::Kill => ControlMessagePublic::Kill,
};
if external_tx.send(public_msg).await.is_err() {
break;
}
}
});
Self {
control_tx: internal_tx,
}
}
}
impl ChildStdout {
#[doc(hidden)]
pub fn new_public(rx: mpsc::Receiver<Vec<u8>>) -> Self {
Self {
rx,
buffer: Vec::new(),
}
}
}
#[doc(hidden)]
pub fn base64_encode_public(data: &[u8]) -> String {
base64_encode(data)
}