use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use std::sync::Arc;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::{connect_async_tls_with_config, Connector};
use crate::config::Config;
use crate::output;
#[derive(Debug)]
struct AcceptAnyServerCert;
impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer<'_>,
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
_server_name: &rustls_pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
use rustls::SignatureScheme::*;
vec![
RSA_PKCS1_SHA256,
RSA_PKCS1_SHA384,
RSA_PKCS1_SHA512,
ECDSA_NISTP256_SHA256,
ECDSA_NISTP384_SHA384,
ECDSA_NISTP521_SHA512,
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
ED25519,
ED448,
]
}
}
fn build_connector(insecure_tls: bool) -> Result<Connector> {
let _ = rustls::crypto::ring::default_provider().install_default();
let config = if insecure_tls {
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(AcceptAnyServerCert))
.with_no_client_auth()
} else {
let mut roots = rustls::RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth()
};
Ok(Connector::Rustls(Arc::new(config)))
}
const EMBEDDINGS_QUERY: &str = r#"
subscription Watch($metaType: String, $metaId: String) {
embeddingChanges(metaType: $metaType, metaId: $metaId) {
metaType
metaId
key
state
model
embeddedAt
molecularHash
}
}
"#;
const DAG_QUERY: &str = r#"
subscription Watch($cellSlug: String) {
dagChanges(cellSlug: $cellSlug) {
eventType
molecularHash
status
height
cellSlug
bondHash
createdAt
bundle
bondType
}
}
"#;
pub async fn embeddings(
cfg: &Config,
meta_type: Option<String>,
meta_id: Option<String>,
) -> Result<()> {
let variables = json!({
"metaType": meta_type,
"metaId": meta_id,
});
run_subscription(cfg, EMBEDDINGS_QUERY, variables, "embeddingChanges").await
}
pub async fn dag(cfg: &Config, cell_slug: Option<String>) -> Result<()> {
let variables = json!({
"cellSlug": cell_slug,
});
run_subscription(cfg, DAG_QUERY, variables, "dagChanges").await
}
async fn run_subscription(
cfg: &Config,
query: &str,
variables: Value,
root_field: &str,
) -> Result<()> {
let ws_url = to_ws_url(&cfg.validator.url);
output::info(&format!("Connecting to {} …", ws_url));
let mut request = ws_url
.as_str()
.into_client_request()
.context("failed to build ws request")?;
request.headers_mut().insert(
"Sec-WebSocket-Protocol",
"graphql-transport-ws"
.parse()
.expect("static subprotocol string"),
);
let connector = Some(build_connector(cfg.validator.insecure_tls)?);
let (mut ws, _resp) =
connect_async_tls_with_config(request, None, false, connector)
.await
.context("WS connect failed — is the validator running and serving TLS correctly?")?;
ws.send(Message::Text(
serde_json::to_string(&json!({"type": "connection_init", "payload": {}}))
.expect("static json"),
))
.await
.context("failed to send connection_init")?;
let ack = tokio::time::timeout(std::time::Duration::from_secs(5), ws.next())
.await
.context("timed out waiting for connection_ack (5s)")?
.ok_or_else(|| anyhow::anyhow!("WS closed before sending connection_ack"))?
.context("error reading connection_ack")?;
match parse_text(&ack)? {
v if v.get("type").and_then(|t| t.as_str()) == Some("connection_ack") => {}
other => {
anyhow::bail!(
"expected connection_ack, got: {}",
serde_json::to_string(&other).unwrap_or_default()
);
}
}
let sub_id = "sub-1".to_string();
let sub_msg = json!({
"id": sub_id,
"type": "subscribe",
"payload": {
"query": query,
"variables": variables,
},
});
ws.send(Message::Text(
serde_json::to_string(&sub_msg).expect("static json"),
))
.await
.context("failed to send subscribe")?;
output::info(&format!(
"Subscribed to {}; streaming events (Ctrl-C to stop)…",
root_field
));
let stop = tokio::signal::ctrl_c();
tokio::pin!(stop);
loop {
tokio::select! {
biased;
_ = &mut stop => {
let _ = ws
.send(Message::Text(
serde_json::to_string(&json!({"id": sub_id, "type": "complete"}))
.expect("static json"),
))
.await;
let _ = ws.close(None).await;
output::info("\nSubscription closed.");
return Ok(());
}
msg = ws.next() => {
match msg {
Some(Ok(m)) => {
if let Err(e) = handle_message(m, root_field) {
output::warn(&format!("skipping malformed message: {e}"));
}
}
Some(Err(e)) => {
output::error(&format!("WS error: {e}"));
return Err(anyhow::anyhow!(e));
}
None => {
output::warn("Server closed the connection.");
return Ok(());
}
}
}
}
}
}
fn handle_message(msg: Message, root_field: &str) -> Result<()> {
let v = match msg {
Message::Text(t) => serde_json::from_str::<Value>(&t).context("json parse")?,
Message::Binary(_) => return Ok(()), Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => return Ok(()),
Message::Close(_) => {
output::warn("Server sent close frame.");
return Ok(());
}
};
match v.get("type").and_then(|t| t.as_str()) {
Some("next") => {
if let Some(event) = v
.pointer("/payload/data")
.and_then(|d| d.get(root_field))
{
println!("{}", serde_json::to_string(event).unwrap_or_default());
} else if let Some(errors) = v.pointer("/payload/errors") {
output::warn(&format!("server error: {}", errors));
}
}
Some("error") => {
output::error(&format!(
"subscription error: {}",
v.get("payload").map(|p| p.to_string()).unwrap_or_default()
));
}
Some("complete") => {
output::info("Server signalled subscription complete.");
}
Some("ping") => {
}
_ => {
}
}
Ok(())
}
fn parse_text(msg: &Message) -> Result<Value> {
match msg {
Message::Text(t) => serde_json::from_str::<Value>(t).context("json parse"),
_ => Err(anyhow::anyhow!("expected text frame, got {:?}", msg)),
}
}
fn to_ws_url(base: &str) -> String {
let trimmed = base.trim_end_matches('/');
let with_scheme = if let Some(rest) = trimmed.strip_prefix("https://") {
format!("wss://{}", rest)
} else if let Some(rest) = trimmed.strip_prefix("http://") {
format!("ws://{}", rest)
} else {
format!("wss://{}", trimmed)
};
format!("{}/graphql/ws", with_scheme)
}