zera-sdk 0.1.0

Rust SDK for ZERA transactions, validator APIs, and bridge workflows
Documentation
use std::collections::BTreeMap;
use std::time::Duration;

use async_trait::async_trait;
use prost::Message;
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
use url::Url;

use crate::error::{Result, ZeraError};
use crate::types::{ResolvedRpcEndpoint, RpcConfig};

const GRPC_WEB_CONTENT_TYPE: &str = "application/grpc-web+proto";
const GRPC_WEB_HEADER: &str = "x-grpc-web";

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TransportResponse {
    pub status: u16,
    pub headers: BTreeMap<String, String>,
    pub body: Vec<u8>,
}

impl TransportResponse {
    pub fn ok(body: Vec<u8>) -> Self {
        Self {
            status: 200,
            headers: BTreeMap::new(),
            body,
        }
    }
}

#[async_trait]
pub trait UnaryTransport: Clone + Send + Sync {
    async fn unary_bytes(&self, path: &str, framed_request: Vec<u8>) -> Result<TransportResponse>;
}

#[derive(Clone)]
pub struct GrpcWebTransport {
    client: reqwest::Client,
    config: RpcConfig,
    endpoint: ResolvedRpcEndpoint,
}

impl GrpcWebTransport {
    pub fn new(config: RpcConfig) -> Result<Self> {
        let endpoint = config.resolve_endpoint()?;
        let client = reqwest::Client::builder()
            .timeout(Duration::from_millis(config.timeout_ms))
            .build()?;

        Ok(Self {
            client,
            config,
            endpoint,
        })
    }

    pub fn endpoint(&self) -> &ResolvedRpcEndpoint {
        &self.endpoint
    }

    fn url_for(&self, path: &str) -> String {
        let path = if path.starts_with('/') {
            path.to_string()
        } else {
            format!("/{path}")
        };
        format!("{}{}", self.endpoint.base_url, path)
    }

    fn fallback_url(&self, original_url: &str) -> Result<Option<String>> {
        if !self.config.fallback_to_http {
            return Ok(None);
        }

        let mut url = Url::parse(original_url).map_err(|error| {
            ZeraError::Transport(format!(
                "Unable to parse transport URL \"{original_url}\": {error}"
            ))
        })?;
        if url.scheme() != "https" || url.host_str() != Some(self.endpoint.hostname.as_str()) {
            return Ok(None);
        }

        url.set_scheme("http")
            .map_err(|_| ZeraError::Transport("Unable to rewrite HTTPS URL to HTTP".to_string()))?;
        url.set_port(Some(self.config.fallback_port)).map_err(|_| {
            ZeraError::Transport(format!(
                "Unable to apply fallback port {} to transport URL",
                self.config.fallback_port
            ))
        })?;

        Ok(Some(url.to_string()))
    }

    async fn send_to(&self, url: String, framed_request: &[u8]) -> Result<TransportResponse> {
        let mut headers = HeaderMap::new();
        headers.insert(
            CONTENT_TYPE,
            HeaderValue::from_static(GRPC_WEB_CONTENT_TYPE),
        );
        headers.insert(ACCEPT, HeaderValue::from_static(GRPC_WEB_CONTENT_TYPE));
        headers.insert(GRPC_WEB_HEADER, HeaderValue::from_static("1"));

        let response = self
            .client
            .post(url)
            .headers(headers)
            .body(framed_request.to_vec())
            .send()
            .await?;
        let status = response.status().as_u16();
        let headers = normalize_headers(response.headers());
        let body = response.bytes().await?.to_vec();

        Ok(TransportResponse {
            status,
            headers,
            body,
        })
    }
}

#[async_trait]
impl UnaryTransport for GrpcWebTransport {
    async fn unary_bytes(&self, path: &str, framed_request: Vec<u8>) -> Result<TransportResponse> {
        let url = self.url_for(path);
        match self.send_to(url.clone(), &framed_request).await {
            Ok(response) => Ok(response),
            Err(primary_error) => {
                if let Some(fallback_url) = self.fallback_url(&url)? {
                    self.send_to(fallback_url, &framed_request)
                        .await
                        .map_err(|_| primary_error)
                } else {
                    Err(primary_error)
                }
            }
        }
    }
}

pub async fn unary<Req, Res, T>(transport: &T, path: &str, request: &Req) -> Result<Res>
where
    Req: Message,
    Res: Message + Default,
    T: UnaryTransport,
{
    let framed_request = frame_grpc_web_message(request)?;
    let response = transport.unary_bytes(path, framed_request).await?;
    decode_grpc_web_response(&response)
}

pub fn frame_grpc_web_message<MessageType: Message>(message: &MessageType) -> Result<Vec<u8>> {
    let payload = message.encode_to_vec();
    let mut framed = Vec::with_capacity(payload.len() + 5);
    framed.push(0);
    framed.extend_from_slice(&(payload.len() as u32).to_be_bytes());
    framed.extend_from_slice(&payload);
    Ok(framed)
}

pub fn decode_grpc_web_response<MessageType: Message + Default>(
    response: &TransportResponse,
) -> Result<MessageType> {
    let (message_frame, trailers) = split_frames(&response.body)?;
    let grpc_status = trailers
        .get("grpc-status")
        .or_else(|| response.headers.get("grpc-status"))
        .cloned()
        .unwrap_or_else(|| "0".to_string());
    let grpc_message = trailers
        .get("grpc-message")
        .or_else(|| response.headers.get("grpc-message"))
        .cloned();

    if response.status >= 400 && grpc_status == "0" {
        return Err(ZeraError::Rpc(format!(
            "HTTP {} with no grpc-status trailer",
            response.status
        )));
    }

    if grpc_status != "0" {
        let detail = grpc_message.unwrap_or_else(|| "unknown gRPC error".to_string());
        return Err(ZeraError::Rpc(format!("[{}] {}", grpc_status, detail)));
    }

    if let Some(frame) = message_frame {
        return Ok(MessageType::decode(frame.as_slice())?);
    }

    Ok(MessageType::default())
}

type GrpcWebFrameSplit = (Option<Vec<u8>>, BTreeMap<String, String>);

fn split_frames(body: &[u8]) -> Result<GrpcWebFrameSplit> {
    let mut cursor = 0usize;
    let mut message_frame = None;
    let mut trailers = BTreeMap::new();

    while cursor < body.len() {
        if cursor + 5 > body.len() {
            return Err(ZeraError::Serialization(
                "Malformed gRPC-Web frame header".to_string(),
            ));
        }

        let flag = body[cursor];
        let length = u32::from_be_bytes([
            body[cursor + 1],
            body[cursor + 2],
            body[cursor + 3],
            body[cursor + 4],
        ]) as usize;
        cursor += 5;

        if cursor + length > body.len() {
            return Err(ZeraError::Serialization(
                "Malformed gRPC-Web frame length".to_string(),
            ));
        }

        let frame = &body[cursor..cursor + length];
        cursor += length;

        match flag {
            0 => message_frame = Some(frame.to_vec()),
            0x80 => {
                for (key, value) in parse_trailer_frame(frame) {
                    trailers.insert(key, value);
                }
            }
            _ => {
                return Err(ZeraError::Serialization(format!(
                    "Unsupported gRPC-Web frame flag: {flag}"
                )))
            }
        }
    }

    Ok((message_frame, trailers))
}

fn parse_trailer_frame(frame: &[u8]) -> BTreeMap<String, String> {
    let mut trailers = BTreeMap::new();
    let text = String::from_utf8_lossy(frame);
    for line in text.split("\r\n") {
        if let Some((key, value)) = line.split_once(':') {
            trailers.insert(key.trim().to_ascii_lowercase(), value.trim().to_string());
        }
    }
    trailers
}

fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
    headers
        .iter()
        .filter_map(|(key, value)| {
            value
                .to_str()
                .ok()
                .map(|value| (key.as_str().to_ascii_lowercase(), value.to_string()))
        })
        .collect()
}