use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::time::{timeout, Duration};
use tokio_tungstenite::{connect_async, tungstenite::http::Request};
use futures::{SinkExt, StreamExt};
pub struct OpenClawClient {
url: String,
origin: Option<String>,
}
#[derive(Serialize)]
struct ConnectRequest {
r#type: String,
id: String,
method: String,
params: ConnectParams,
}
#[derive(Serialize)]
struct ConnectParams {
#[serde(rename = "minProtocol")]
min_protocol: u8,
#[serde(rename = "maxProtocol")]
max_protocol: u8,
client: ClientInfo,
role: String,
scopes: Vec<String>,
auth: AuthInfo,
}
#[derive(Serialize)]
struct ClientInfo {
id: String,
version: String,
platform: String,
mode: String,
}
#[derive(Serialize)]
struct AuthInfo {
token: String,
}
#[derive(Deserialize)]
struct ConnectResponse {
r#type: String,
ok: bool,
payload: Option<HelloPayload>,
error: Option<ErrorPayload>,
}
#[derive(Deserialize)]
struct HelloPayload {
auth: Option<AuthResponse>,
}
#[derive(Deserialize)]
struct AuthResponse {
#[serde(rename = "deviceToken")]
device_token: Option<String>,
scopes: Vec<String>,
}
#[derive(Deserialize)]
struct ErrorPayload {
message: String,
}
impl OpenClawClient {
pub fn new(url: String, origin: Option<String>) -> Self {
Self { url, origin }
}
pub async fn test_cswsh(&self, token: Option<String>) -> Result<CswshResult> {
let request = if let Some(ref origin) = self.origin {
Request::builder()
.uri(&self.url)
.header("Origin", origin)
.body(())?
} else {
Request::builder()
.uri(&self.url)
.body(())?
};
let ws_result = timeout(
Duration::from_secs(10),
connect_async(request)
).await;
let (ws_stream, _) = match ws_result {
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
return Ok(CswshResult {
success: false,
origin_accepted: false,
device_token: None,
granted_scopes: vec![],
error: Some(format!("WebSocket connection failed: {}", e)),
});
}
Err(_) => {
return Ok(CswshResult {
success: false,
origin_accepted: false,
device_token: None,
granted_scopes: vec![],
error: Some("Connection timeout".to_string()),
});
}
};
let (mut write, mut read) = ws_stream.split();
let connect_req = ConnectRequest {
r#type: "req".to_string(),
id: uuid::Uuid::new_v4().to_string(),
method: "connect".to_string(),
params: ConnectParams {
min_protocol: 3,
max_protocol: 3,
client: ClientInfo {
id: "clawscan".to_string(),
version: "1.0.0".to_string(),
platform: "rust".to_string(),
mode: "operator".to_string(),
},
role: "operator".to_string(),
scopes: vec!["operator.read".to_string(), "operator.write".to_string()],
auth: AuthInfo {
token: token.unwrap_or_else(|| "test-token".to_string()),
},
},
};
let msg = tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(&connect_req)?
);
if let Err(e) = write.send(msg).await {
return Ok(CswshResult {
success: false,
origin_accepted: true, device_token: None,
granted_scopes: vec![],
error: Some(format!("Failed to send request: {}", e)),
});
}
let response_result = timeout(
Duration::from_secs(5),
read.next()
).await;
let response_msg = match response_result {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(e))) => {
return Ok(CswshResult {
success: false,
origin_accepted: true,
device_token: None,
granted_scopes: vec![],
error: Some(format!("WebSocket error: {}", e)),
});
}
Ok(None) => {
return Ok(CswshResult {
success: false,
origin_accepted: true,
device_token: None,
granted_scopes: vec![],
error: Some("Connection closed".to_string()),
});
}
Err(_) => {
return Ok(CswshResult {
success: false,
origin_accepted: true,
device_token: None,
granted_scopes: vec![],
error: Some("Response timeout".to_string()),
});
}
};
let response: ConnectResponse = match response_msg {
tokio_tungstenite::tungstenite::Message::Text(text) => {
serde_json::from_str(&text)?
}
_ => {
return Ok(CswshResult {
success: false,
origin_accepted: true,
device_token: None,
granted_scopes: vec![],
error: Some("Unexpected message type".to_string()),
});
}
};
if !response.ok {
let error_msg = response.error
.map(|e| e.message)
.unwrap_or_else(|| "Unknown error".to_string());
return Ok(CswshResult {
success: false,
origin_accepted: true,
device_token: None,
granted_scopes: vec![],
error: Some(error_msg),
});
}
let (device_token, granted_scopes) = if let Some(payload) = response.payload {
if let Some(auth) = payload.auth {
(auth.device_token, auth.scopes)
} else {
(None, vec![])
}
} else {
(None, vec![])
};
Ok(CswshResult {
success: true,
origin_accepted: true,
device_token,
granted_scopes,
error: None,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CswshResult {
pub success: bool,
pub origin_accepted: bool,
pub device_token: Option<String>,
pub granted_scopes: Vec<String>,
pub error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = OpenClawClient::new(
"ws://localhost:18789".to_string(),
Some("http://attacker.com".to_string()),
);
assert_eq!(client.url, "ws://localhost:18789");
assert_eq!(client.origin, Some("http://attacker.com".to_string()));
}
#[test]
fn test_client_without_origin() {
let client = OpenClawClient::new("ws://localhost:18789".to_string(), None);
assert_eq!(client.url, "ws://localhost:18789");
assert_eq!(client.origin, None);
}
#[test]
fn test_cswsh_result_equality() {
let result1 = CswshResult {
success: true,
origin_accepted: true,
device_token: Some("token123".to_string()),
granted_scopes: vec!["operator.read".to_string()],
error: None,
};
let result2 = CswshResult {
success: true,
origin_accepted: true,
device_token: Some("token123".to_string()),
granted_scopes: vec!["operator.read".to_string()],
error: None,
};
assert_eq!(result1, result2);
}
#[tokio::test]
async fn test_cswsh_connection_timeout() {
let client = OpenClawClient::new(
"ws://invalid-host-that-does-not-exist.local:18789".to_string(),
Some("http://attacker.com".to_string()),
);
let result = client.test_cswsh(None).await.unwrap();
assert!(!result.success);
assert!(result.error.is_some());
}
#[tokio::test]
async fn test_cswsh_result_structure() {
let result = CswshResult {
success: true,
origin_accepted: true,
device_token: Some("test-token-123".to_string()),
granted_scopes: vec!["operator.read".to_string(), "operator.write".to_string()],
error: None,
};
assert!(result.success);
assert!(result.origin_accepted);
assert_eq!(result.device_token.as_ref().unwrap(), "test-token-123");
assert_eq!(result.granted_scopes.len(), 2);
}
}