use std::sync::Arc;
use crate::error::{Error, Result};
use crate::network::Network;
#[derive(Clone)]
pub struct Transport {
network: Arc<dyn Network>,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "kind")]
pub enum Message {
#[serde(rename = "execute")]
Execute(ExecuteMsg),
#[serde(rename = "unblock")]
Unblock(UnblockMsg),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ExecuteMsg {
pub data: ExecuteData,
}
impl ExecuteMsg {
pub fn task_id(&self) -> &str {
&self.data.task.id
}
pub fn version(&self) -> i64 {
self.data.task.version
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ExecuteData {
pub task: TaskRef,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TaskRef {
pub id: String,
#[serde(default)]
pub version: i64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct UnblockMsg {
pub data: UnblockData,
}
impl UnblockMsg {
pub fn promise(&self) -> &serde_json::Value {
&self.data.promise
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct UnblockData {
pub promise: serde_json::Value,
}
pub fn response_data(resp: &serde_json::Value) -> Result<&serde_json::Value> {
resp.get("data")
.ok_or_else(|| Error::DecodingError("response missing 'data' envelope field".into()))
}
pub fn response_status(resp: &serde_json::Value) -> Result<u64> {
resp.get("head")
.and_then(|h| h.get("status"))
.and_then(|s| s.as_u64())
.ok_or_else(|| Error::DecodingError("response missing 'head.status' envelope field".into()))
}
impl Transport {
pub fn new(network: Arc<dyn Network>) -> Self {
Self { network }
}
pub async fn send(&self, kind: &str, corr_id: &str, body: &str) -> Result<serde_json::Value> {
tracing::debug!(direction = "send_req", body = %body, "transport");
let resp_str = self.network.send(body.to_owned()).await?;
tracing::debug!(direction = "send_res", body = %resp_str, "transport");
let response: serde_json::Value = serde_json::from_str(&resp_str).map_err(|e| {
Error::DecodingError(format!("invalid response JSON: {e}, resp: {resp_str}"))
})?;
let resp_kind = response.get("kind").and_then(|k| k.as_str()).unwrap_or("");
if resp_kind != kind {
return Err(Error::ServerError {
code: 500,
message: format!(
"response kind mismatch: expected '{}', got '{}'",
kind, resp_kind
),
});
}
let resp_corr = response
.get("head")
.and_then(|h| h.get("corrId"))
.and_then(|c| c.as_str())
.unwrap_or("");
if resp_corr != corr_id {
return Err(Error::ServerError {
code: 500,
message: format!(
"response corrId mismatch: expected '{}', got '{}'",
corr_id, resp_corr
),
});
}
Ok(response)
}
pub async fn send_json(&self, request: serde_json::Value) -> Result<serde_json::Value> {
let kind = request.get("kind").and_then(|k| k.as_str()).unwrap_or("");
let corr_id = request
.get("head")
.and_then(|h| h.get("corrId"))
.and_then(|c| c.as_str())
.unwrap_or("");
let body = serde_json::to_string(&request)?;
self.send(kind, corr_id, &body).await
}
pub fn recv(&self, callback: Box<dyn Fn(Message) + Send + Sync>) {
self.network.recv(Box::new(move |raw: String| {
match serde_json::from_str::<Message>(&raw) {
Ok(msg) => {
tracing::debug!(direction = "recv", body = %raw, "transport");
callback(msg)
}
Err(e) => {
tracing::warn!(error = %e, raw = %raw, "failed to parse incoming message");
}
}
}));
}
pub fn network(&self) -> &Arc<dyn Network> {
&self.network
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::LocalNetwork;
#[tokio::test]
async fn transport_send_and_validate_envelope_format() {
let net = Arc::new(LocalNetwork::new(Some("test".into()), None));
let transport = Transport::new(net);
let body = serde_json::json!({
"kind": "promise.create",
"head": {
"corrId": "env123",
"version": "2025-01-15",
},
"data": {
"id": "p2",
"timeoutAt": i64::MAX,
"param": {},
"tags": {},
},
});
let body_str = serde_json::to_string(&body).unwrap();
let resp = transport
.send("promise.create", "env123", &body_str)
.await
.unwrap();
assert_eq!(resp["kind"], "promise.create");
assert_eq!(resp["head"]["corrId"], "env123");
assert_eq!(resp["data"]["promise"]["id"], "p2");
}
}