use std::{collections::HashMap, sync::Arc, time::Duration};
use alloy::{
primitives::{Address, B256},
signers::local::PrivateKeySigner,
};
use async_trait::async_trait;
use eigensdk::{crypto_bls::BlsG1Point, types::operator::OperatorId};
use newton_prover_chainio::operator_rpc_auth::{sign_authenticated, DEFAULT_EXPIRY_SECS};
use newton_prover_core::{
rpc::state_commit::{
GetStateCommitProposalRequest, GetStateCommitProposalResponse, SignStateCommitRequest, SignStateCommitResponse,
StateCommitWire,
},
state_commit_registry::IStateRootCommittable::StateCommit,
};
use reqwest::{Client, ClientBuilder};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::{debug, instrument, warn};
use super::operator_client::{OperatorClientError, OperatorProposal, StateCommitOperatorClient};
const OPERATOR_CONTRACT_ERROR_CODE: i64 = -32616;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
pub struct HttpStateCommitOperatorClient {
http: Client,
sockets: Arc<RwLock<HashMap<OperatorId, String>>>,
timeout: Duration,
signer: PrivateKeySigner,
chain_id: u64,
task_manager: Address,
}
impl std::fmt::Debug for HttpStateCommitOperatorClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpStateCommitOperatorClient")
.field("timeout_ms", &self.timeout.as_millis())
.field("chain_id", &self.chain_id)
.field("task_manager", &self.task_manager)
.finish_non_exhaustive()
}
}
impl HttpStateCommitOperatorClient {
pub fn new(
sockets: HashMap<OperatorId, String>,
signer: PrivateKeySigner,
chain_id: u64,
task_manager: Address,
) -> Result<Self, reqwest::Error> {
Self::with_timeout(sockets, DEFAULT_TIMEOUT, signer, chain_id, task_manager)
}
pub fn with_timeout(
sockets: HashMap<OperatorId, String>,
timeout: Duration,
signer: PrivateKeySigner,
chain_id: u64,
task_manager: Address,
) -> Result<Self, reqwest::Error> {
let http = ClientBuilder::new()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(90))
.timeout(timeout + Duration::from_secs(1))
.tcp_keepalive(Some(Duration::from_secs(60)))
.tcp_nodelay(true)
.build()?;
Ok(Self {
http,
sockets: Arc::new(RwLock::new(sockets)),
timeout,
signer,
chain_id,
task_manager,
})
}
pub async fn set_sockets(&self, sockets: HashMap<OperatorId, String>) {
*self.sockets.write().await = sockets;
}
async fn url_for(&self, operator_id: &OperatorId) -> Result<String, OperatorClientError> {
let sockets = self.sockets.read().await;
let raw = sockets.get(operator_id).ok_or_else(|| OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: format!("no socket registered for operator {}", hex::encode(operator_id)),
})?;
Ok(format_socket_as_http_url(raw))
}
async fn rpc_call<P, R>(
&self,
operator_id: &OperatorId,
url: &str,
method: &str,
params: P,
) -> Result<R, OperatorClientError>
where
P: Serialize,
R: for<'de> Deserialize<'de>,
{
let envelope = JsonRpcRequest {
jsonrpc: "2.0",
method,
params,
id: 1,
};
let send_fut = self.http.post(url).json(&envelope).send();
let response = match tokio::time::timeout(self.timeout, send_fut).await {
Ok(Ok(r)) => r,
Ok(Err(e)) if e.is_timeout() => {
return Err(OperatorClientError::Timeout {
operator_id: hex::encode(operator_id),
timeout_ms: self.timeout.as_millis() as u64,
});
}
Ok(Err(e)) => {
return Err(OperatorClientError::Transport {
operator_id: hex::encode(operator_id),
source: Box::new(e),
});
}
Err(_elapsed) => {
return Err(OperatorClientError::Timeout {
operator_id: hex::encode(operator_id),
timeout_ms: self.timeout.as_millis() as u64,
});
}
};
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(OperatorClientError::Transport {
operator_id: hex::encode(operator_id),
source: format!("HTTP {status}: {body}").into(),
});
}
let envelope: JsonRpcResponse<R> = response.json().await.map_err(|e| {
OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: format!("response body parse: {e}"),
}
})?;
if let Some(err) = envelope.error {
return Err(OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: format!("rpc error code={} message={}", err.code, err.message),
});
}
envelope.result.ok_or_else(|| OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: "rpc envelope missing both `result` and `error`".to_string(),
})
}
}
fn format_socket_as_http_url(socket: &str) -> String {
if socket.starts_with("http://") || socket.starts_with("https://") {
socket.to_string()
} else {
format!("http://{socket}")
}
}
fn is_digest_disagreement_error(reason: &str) -> bool {
let code_match = reason.contains(&format!("code={OPERATOR_CONTRACT_ERROR_CODE}"));
let msg_match =
reason.contains("digest") || reason.contains("newStateRoot") || reason.contains("diverges from commit");
code_match && msg_match
}
#[async_trait]
impl StateCommitOperatorClient for HttpStateCommitOperatorClient {
#[instrument(skip(self), fields(operator_id = %hex::encode(operator_id), sequence_no))]
async fn get_state_commit_proposal(
&self,
operator_id: &OperatorId,
sequence_no: u64,
) -> Result<OperatorProposal, OperatorClientError> {
let url = self.url_for(operator_id).await?;
let req = GetStateCommitProposalRequest { sequence_no };
let envelope = sign_authenticated(
&self.signer,
"newt_getStateCommitProposal",
self.chain_id,
self.task_manager,
DEFAULT_EXPIRY_SECS,
req,
)
.map_err(|e| OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: format!("authenticate envelope: {e}"),
})?;
let resp: GetStateCommitProposalResponse = self
.rpc_call(operator_id, &url, "newt_getStateCommitProposal", envelope)
.await?;
debug!(
operator_id = %hex::encode(operator_id),
sequence_no,
new_state_root = %resp.new_state_root,
"state-commit prepare: received proposal"
);
Ok(OperatorProposal {
new_state_root: resp.new_state_root,
da_cert_hash: resp.da_cert_hash,
pcr0_commitment: resp.pcr0_commitment,
})
}
#[instrument(
skip(self, digest, commit),
fields(
operator_id = %hex::encode(operator_id),
digest = %digest,
sequence_no = commit.sequenceNo,
)
)]
async fn sign_state_commit(
&self,
operator_id: &OperatorId,
digest: B256,
commit: &StateCommit,
reference_timestamp: u32,
) -> Result<BlsG1Point, OperatorClientError> {
let url = self.url_for(operator_id).await?;
let sequence_no = commit.sequenceNo;
let req = SignStateCommitRequest {
digest,
commit: StateCommitWire::from(commit),
reference_timestamp,
};
let envelope = sign_authenticated(
&self.signer,
"newt_signStateCommit",
self.chain_id,
self.task_manager,
DEFAULT_EXPIRY_SECS,
req,
)
.map_err(|e| OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: format!("authenticate envelope: {e}"),
})?;
let resp: SignStateCommitResponse = match self
.rpc_call::<_, SignStateCommitResponse>(operator_id, &url, "newt_signStateCommit", envelope)
.await
{
Ok(r) => r,
Err(OperatorClientError::Malformed {
operator_id: id,
reason,
}) if is_digest_disagreement_error(&reason) => {
warn!(
operator_id = %id,
sequence_no,
reason = %reason,
"state-commit sign: operator refused (digest disagreement)"
);
return Err(OperatorClientError::DigestDisagreement {
operator_id: id,
sequence_no,
});
}
Err(e) => return Err(e),
};
if resp.signature_bytes.len() != 64 {
return Err(OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason: format!("expected 64-byte G1 encoding, got {}", resp.signature_bytes.len()),
});
}
let g1 = decode_g1_point(&resp.signature_bytes).map_err(|reason| OperatorClientError::Malformed {
operator_id: hex::encode(operator_id),
reason,
})?;
Ok(g1)
}
}
#[derive(Debug, Serialize)]
struct JsonRpcRequest<'a, P> {
jsonrpc: &'a str,
method: &'a str,
params: P,
id: u64,
}
#[derive(Debug, Deserialize)]
struct JsonRpcResponse<R> {
#[allow(dead_code)]
jsonrpc: Option<String>,
result: Option<R>,
error: Option<JsonRpcError>,
#[allow(dead_code)]
id: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct JsonRpcError {
code: i64,
message: String,
}
fn decode_g1_point(bytes: &[u8]) -> Result<BlsG1Point, String> {
use alloy::primitives::U256;
use ark_bn254::{Fq, G1Affine};
use ark_ff::PrimeField;
if bytes.len() != 64 {
return Err(format!("g1 point: want 64 bytes, got {}", bytes.len()));
}
let x = U256::from_be_slice(&bytes[0..32]);
let y = U256::from_be_slice(&bytes[32..64]);
let to_fq = |u: U256| -> Fq { Fq::from_be_bytes_mod_order(&u.to_be_bytes::<32>()) };
let affine = G1Affine::new_unchecked(to_fq(x), to_fq(y));
Ok(BlsG1Point::new(affine))
}
#[cfg(test)]
mod tests {
use super::*;
fn op_id(byte: u8) -> OperatorId {
OperatorId::repeat_byte(byte)
}
fn test_client(sockets: HashMap<OperatorId, String>) -> HttpStateCommitOperatorClient {
HttpStateCommitOperatorClient::new(sockets, PrivateKeySigner::random(), 1u64, Address::ZERO)
.expect("build client")
}
fn test_client_with_timeout(
sockets: HashMap<OperatorId, String>,
timeout: Duration,
) -> HttpStateCommitOperatorClient {
HttpStateCommitOperatorClient::with_timeout(sockets, timeout, PrivateKeySigner::random(), 1u64, Address::ZERO)
.expect("build client")
}
#[test]
fn format_socket_adds_http_prefix() {
assert_eq!(format_socket_as_http_url("127.0.0.1:9000"), "http://127.0.0.1:9000");
assert_eq!(format_socket_as_http_url("http://op:9000"), "http://op:9000");
assert_eq!(format_socket_as_http_url("https://op:9000"), "https://op:9000");
}
#[test]
fn digest_disagreement_pattern_matches() {
assert!(is_digest_disagreement_error(
"rpc error code=-32616 message=state-commit sign: digest mismatch — expected 0xaa got 0xbb"
));
assert!(is_digest_disagreement_error(
"rpc error code=-32616 message=state-commit sign: local root 0xaa diverges from commit.newStateRoot 0xbb — stale view, refusing to sign"
));
assert!(!is_digest_disagreement_error(
"rpc error code=-32603 message=internal error"
));
assert!(!is_digest_disagreement_error(
"rpc error code=-32616 message=pcr0_provider: enclave unreachable"
));
}
#[tokio::test]
async fn url_for_unknown_operator_is_malformed() {
let client = test_client(HashMap::new());
let err = client.url_for(&op_id(0x01)).await.expect_err("unknown operator");
assert!(matches!(
err,
OperatorClientError::Malformed { ref operator_id, ref reason }
if operator_id == &hex::encode(op_id(0x01))
&& reason.contains("no socket registered")
));
}
#[tokio::test]
async fn set_sockets_replaces_map() {
let client = test_client(HashMap::new());
let mut next = HashMap::new();
next.insert(op_id(0x42), "127.0.0.1:9000".to_string());
client.set_sockets(next).await;
let url = client.url_for(&op_id(0x42)).await.expect("registered");
assert_eq!(url, "http://127.0.0.1:9000");
}
#[tokio::test]
async fn timeout_classifies_as_timeout_error() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
let _accept_loop = tokio::spawn(async move {
while let Ok((_socket, _)) = listener.accept().await {
tokio::time::sleep(Duration::from_secs(60)).await;
}
});
let mut sockets = HashMap::new();
sockets.insert(op_id(0xab), addr.to_string());
let client = test_client_with_timeout(sockets, Duration::from_millis(150));
let err = client
.get_state_commit_proposal(&op_id(0xab), 1)
.await
.expect_err("expect timeout");
assert!(
matches!(err, OperatorClientError::Timeout { ref operator_id, timeout_ms }
if operator_id == &hex::encode(op_id(0xab)) && timeout_ms == 150),
"got {err:?}"
);
}
#[tokio::test]
async fn connection_refused_classifies_as_transport() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
drop(listener);
let mut sockets = HashMap::new();
sockets.insert(op_id(0x07), addr.to_string());
let client = test_client_with_timeout(sockets, Duration::from_secs(2));
let err = client
.get_state_commit_proposal(&op_id(0x07), 1)
.await
.expect_err("expect transport error");
assert!(
matches!(err, OperatorClientError::Transport { ref operator_id, .. }
if operator_id == &hex::encode(op_id(0x07))),
"got {err:?}"
);
}
}