use async_trait::async_trait;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::error;
use uuid::Uuid;
use crate::command::{
CommandContext, CommandEvent, CommandExecutor, CommandHandle, CommandRequest, CommandResult,
};
use crate::error::UbiquityError;
pub struct CloudCommandExecutor {
context: Arc<CommandContext>,
event_buffer_size: usize,
client: Client,
worker_url: String,
api_token: String,
namespace_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CloudExecuteRequest {
pub request: CommandRequest,
pub namespace_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CloudExecuteResponse {
pub durable_object_id: String,
pub websocket_url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
enum CloudWebSocketMessage {
Subscribe {
command_id: Uuid,
},
Event {
event: CommandEvent,
},
Cancel {
command_id: Uuid,
},
Status {
command_id: Uuid,
},
StatusResponse {
result: Option<CommandResult>,
},
Error {
message: String,
},
}
impl CloudCommandExecutor {
pub fn new(worker_url: String, api_token: String, namespace_id: String) -> Self {
Self {
context: Arc::new(CommandContext::new()),
event_buffer_size: 1024,
client: Client::builder()
.timeout(Duration::from_secs(30))
.build()
.unwrap(),
worker_url,
api_token,
namespace_id,
}
}
async fn create_durable_object(
&self,
request: CommandRequest,
) -> Result<CloudExecuteResponse, UbiquityError> {
let url = format!("{}/api/commands/execute", self.worker_url);
let cloud_request = CloudExecuteRequest {
request,
namespace_id: self.namespace_id.clone(),
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_token))
.json(&cloud_request)
.send()
.await
.map_err(|e| UbiquityError::Network(format!("Failed to create durable object: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(UbiquityError::CloudExecution(format!(
"Failed to create durable object: {} - {}",
status, body
)));
}
response
.json::<CloudExecuteResponse>()
.await
.map_err(|e| UbiquityError::Serialization(format!("Failed to parse response: {}", e)))
}
async fn connect_websocket(
&self,
websocket_url: &str,
command_id: Uuid,
event_tx: mpsc::Sender<CommandEvent>,
) -> Result<(), UbiquityError> {
use tokio_tungstenite::{connect_async, tungstenite::Message};
let (ws_stream, _) = connect_async(websocket_url)
.await
.map_err(|e| UbiquityError::Network(format!("Failed to connect WebSocket: {}", e)))?;
let (write, read) = ws_stream.split();
let (internal_tx, mut internal_rx) = mpsc::channel::<Message>(100);
let subscribe_msg = CloudWebSocketMessage::Subscribe { command_id };
let msg_text = serde_json::to_string(&subscribe_msg)
.map_err(|e| UbiquityError::Serialization(e.to_string()))?;
internal_tx
.send(Message::Text(msg_text))
.await
.map_err(|_| UbiquityError::Internal("Failed to send subscribe message".to_string()))?;
let write_task = tokio::spawn(async move {
use futures::SinkExt;
let mut write = write;
while let Some(msg) = internal_rx.recv().await {
if let Err(e) = write.send(msg).await {
error!("WebSocket write error: {}", e);
break;
}
}
});
let read_task = tokio::spawn(async move {
use futures::StreamExt;
let mut read = read;
while let Some(result) = read.next().await {
match result {
Ok(Message::Text(text)) => {
match serde_json::from_str::<CloudWebSocketMessage>(&text) {
Ok(CloudWebSocketMessage::Event { event }) => {
if event_tx.send(event).await.is_err() {
break;
}
}
Ok(CloudWebSocketMessage::Error { message }) => {
error!("Cloud execution error: {}", message);
let _ = event_tx
.send(CommandEvent::Failed {
command_id,
error: message,
duration_ms: 0,
timestamp: chrono::Utc::now(),
})
.await;
break;
}
_ => {}
}
}
Ok(Message::Close(_)) => break,
Err(e) => {
error!("WebSocket read error: {}", e);
break;
}
_ => {}
}
}
});
tokio::select! {
_ = write_task => {}
_ = read_task => {}
}
Ok(())
}
async fn execute_cloud(
request: CommandRequest,
event_tx: mpsc::Sender<CommandEvent>,
executor: CloudCommandExecutor,
) -> Result<(), UbiquityError> {
let command_id = request.id;
let response = executor.create_durable_object(request).await?;
executor
.connect_websocket(&response.websocket_url, command_id, event_tx)
.await
}
}
#[async_trait]
impl CommandExecutor for CloudCommandExecutor {
async fn execute(
&self,
request: CommandRequest,
) -> Result<Pin<Box<dyn Stream<Item = CommandEvent> + Send>>, UbiquityError> {
let (event_tx, event_rx) = mpsc::channel(self.event_buffer_size);
let (cancel_tx, _cancel_rx) = mpsc::channel(1);
let (status_tx, _status_rx) = mpsc::channel(1);
let command_id = request.id;
let handle = CommandHandle::new(command_id, cancel_tx, status_tx);
self.context.register(command_id, handle).await;
let executor = self.clone();
let context = self.context.clone();
tokio::spawn(async move {
let result = Self::execute_cloud(request, event_tx, executor).await;
context.unregister(&command_id).await;
if let Err(e) = result {
error!("Cloud command execution error: {}", e);
}
});
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(event_rx)))
}
async fn cancel(&self, command_id: Uuid) -> Result<(), UbiquityError> {
let url = format!("{}/api/commands/{}/cancel", self.worker_url, command_id);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_token))
.send()
.await
.map_err(|e| UbiquityError::Network(format!("Failed to cancel command: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(UbiquityError::CloudExecution(format!(
"Failed to cancel command: {} - {}",
status, body
)));
}
Ok(())
}
async fn status(&self, command_id: Uuid) -> Result<Option<CommandResult>, UbiquityError> {
let url = format!("{}/api/commands/{}/status", self.worker_url, command_id);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_token))
.send()
.await
.map_err(|e| UbiquityError::Network(format!("Failed to get command status: {}", e)))?;
if response.status() == 404 {
return Ok(None);
}
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(UbiquityError::CloudExecution(format!(
"Failed to get command status: {} - {}",
status, body
)));
}
let result = response
.json::<CommandResult>()
.await
.map_err(|e| UbiquityError::Serialization(format!("Failed to parse status: {}", e)))?;
Ok(Some(result))
}
}
impl Clone for CloudCommandExecutor {
fn clone(&self) -> Self {
Self {
context: self.context.clone(),
event_buffer_size: self.event_buffer_size,
client: self.client.clone(),
worker_url: self.worker_url.clone(),
api_token: self.api_token.clone(),
namespace_id: self.namespace_id.clone(),
}
}
}
#[cfg(feature = "cloudflare-worker")]
pub mod worker {
use super::*;
#[durable_object]
pub struct CommandDurableObject {
state: State,
env: Env,
websockets: Vec<WebSocket>,
command_result: Option<CommandResult>,
event_history: Vec<CommandEvent>,
}
#[durable_object]
impl DurableObject for CommandDurableObject {
fn new(state: State, env: Env) -> Self {
Self {
state,
env,
websockets: Vec::new(),
command_result: None,
event_history: Vec::new(),
}
}
async fn fetch(&mut self, req: Request) -> Result<Response> {
let path = req.path();
match path.as_str() {
"/execute" => self.handle_execute(req).await,
"/websocket" => self.handle_websocket(req).await,
"/cancel" => self.handle_cancel(req).await,
"/status" => self.handle_status(req).await,
_ => Response::error("Not Found", 404),
}
}
}
impl CommandDurableObject {
async fn handle_execute(&mut self, mut req: Request) -> Result<Response> {
let request: CommandRequest = req.json().await?;
let event_tx = self.create_event_broadcaster();
self.simulate_command_execution(request, event_tx).await;
Response::ok("Command execution started")
}
async fn handle_websocket(&mut self, req: Request) -> Result<Response> {
let pair = WebSocketPair::new()?;
let server = pair.server;
server.accept()?;
self.websockets.push(server);
Response::from_websocket(pair.client)
}
async fn handle_cancel(&mut self, _req: Request) -> Result<Response> {
let event = CommandEvent::Cancelled {
command_id: self.get_command_id()?,
duration_ms: 0,
timestamp: chrono::Utc::now(),
};
self.broadcast_event(event).await;
Response::ok("Command cancelled")
}
async fn handle_status(&mut self, _req: Request) -> Result<Response> {
match &self.command_result {
Some(result) => Response::ok(serde_json::to_string(result)?),
None => Response::error("Command not found", 404),
}
}
async fn simulate_command_execution(
&mut self,
request: CommandRequest,
event_tx: mpsc::Sender<CommandEvent>,
) {
let start = std::time::Instant::now();
let command_id = request.id;
let _ = event_tx.send(CommandEvent::Started {
command_id,
command: request.command.clone(),
args: request.args.clone(),
timestamp: chrono::Utc::now(),
}).await;
let _ = event_tx.send(CommandEvent::Stdout {
command_id,
data: format!("Executing: {} {}", request.command, request.args.join(" ")),
timestamp: chrono::Utc::now(),
}).await;
let duration_ms = start.elapsed().as_millis() as u64;
let _ = event_tx.send(CommandEvent::Completed {
command_id,
exit_code: 0,
duration_ms,
timestamp: chrono::Utc::now(),
}).await;
self.command_result = Some(CommandResult {
id: command_id,
exit_code: Some(0),
stdout: format!("Executed: {} {}", request.command, request.args.join(" ")),
stderr: String::new(),
duration_ms,
cancelled: false,
});
}
fn create_event_broadcaster(&self) -> mpsc::Sender<CommandEvent> {
let (tx, mut rx) = mpsc::channel(100);
let websockets = self.websockets.clone();
wasm_bindgen_futures::spawn_local(async move {
while let Some(event) = rx.recv().await {
let msg = CloudWebSocketMessage::Event { event };
let text = serde_json::to_string(&msg).unwrap();
for ws in &websockets {
let _ = ws.send_with_str(&text);
}
}
});
tx
}
async fn broadcast_event(&mut self, event: CommandEvent) {
self.event_history.push(event.clone());
let msg = CloudWebSocketMessage::Event { event };
let text = serde_json::to_string(&msg).unwrap();
self.websockets.retain(|ws| {
ws.send_with_str(&text).is_ok()
});
}
fn get_command_id(&self) -> Result<Uuid> {
self.command_result
.as_ref()
.map(|r| r.id)
.ok_or_else(|| Error::RustError("No command ID found".to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cloud_executor_creation() {
let executor = CloudCommandExecutor::new(
"https://example.workers.dev".to_string(),
"test-token".to_string(),
"test-namespace".to_string(),
);
assert_eq!(executor.worker_url, "https://example.workers.dev");
assert_eq!(executor.api_token, "test-token");
assert_eq!(executor.namespace_id, "test-namespace");
}
}