use crate::convert::{ProtoPayload, ProtoRunId, ProtoWorkflowId, WireEnvelope};
use crate::error::ProtoWireError;
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoStartWorkflowRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(string, tag = "2")]
pub workflow_type: String,
#[prost(message, optional, tag = "3")]
pub input: Option<ProtoPayload>,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoStartWorkflowResponse {
#[prost(message, optional, tag = "1")]
pub workflow_id: Option<ProtoWorkflowId>,
#[prost(message, optional, tag = "2")]
pub run_id: Option<ProtoRunId>,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoSignalRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(message, optional, tag = "2")]
pub workflow_id: Option<ProtoWorkflowId>,
#[prost(message, optional, tag = "3")]
pub run_id: Option<ProtoRunId>,
#[prost(string, tag = "4")]
pub signal_name: String,
#[prost(message, optional, tag = "5")]
pub payload: Option<ProtoPayload>,
}
#[derive(Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoSignalResponse {}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoQueryRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(message, optional, tag = "2")]
pub workflow_id: Option<ProtoWorkflowId>,
#[prost(message, optional, tag = "3")]
pub run_id: Option<ProtoRunId>,
#[prost(string, tag = "4")]
pub query_name: String,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoQueryResponse {
#[prost(oneof = "proto_query_response::Outcome", tags = "1, 2")]
pub outcome: Option<proto_query_response::Outcome>,
}
pub mod proto_query_response {
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Oneof)]
pub enum Outcome {
#[prost(message, tag = "1")]
Result(super::ProtoPayload),
#[prost(message, tag = "2")]
Error(super::ProtoWireError),
}
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoCancelRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(message, optional, tag = "2")]
pub workflow_id: Option<ProtoWorkflowId>,
#[prost(message, optional, tag = "3")]
pub run_id: Option<ProtoRunId>,
#[prost(string, tag = "4")]
pub reason: String,
}
#[derive(Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoCancelResponse {}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoListWorkflowsRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(message, optional, tag = "2")]
pub filter: Option<WireEnvelope>,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoListWorkflowsResponse {
#[prost(message, repeated, tag = "1")]
pub summaries: Vec<WireEnvelope>,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoCountWorkflowsRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(message, optional, tag = "2")]
pub filter: Option<WireEnvelope>,
}
#[derive(Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoCountWorkflowsResponse {
#[prost(uint64, tag = "1")]
pub count: u64,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoDescribeWorkflowRequest {
#[prost(string, tag = "1")]
pub namespace: String,
#[prost(message, optional, tag = "2")]
pub workflow_id: Option<ProtoWorkflowId>,
#[prost(message, optional, tag = "3")]
pub run_id: Option<ProtoRunId>,
#[prost(bool, tag = "4")]
pub include_history: bool,
}
#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
pub struct ProtoDescribeWorkflowResponse {
#[prost(message, optional, tag = "1")]
pub summary: Option<WireEnvelope>,
#[prost(message, repeated, tag = "2")]
pub history: Vec<WireEnvelope>,
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use aion_core::SearchAttributeValue;
use aion_store::visibility::{ListWorkflowsFilter, SearchAttributePredicate};
use chrono::{DateTime, Utc};
use prost::Message;
use serde::de::DeserializeOwned;
use serde_json::json;
use super::{
ProtoCountWorkflowsRequest, ProtoCountWorkflowsResponse, ProtoListWorkflowsRequest,
ProtoListWorkflowsResponse, ProtoQueryRequest, ProtoQueryResponse,
ProtoStartWorkflowRequest, ProtoStartWorkflowResponse, proto_query_response,
};
use crate::convert::{
ProtoPayload, ProtoRunId, ProtoWorkflowId, decode_core_value, encode_core_value,
};
use crate::error::{ProtoWireError, WireError};
fn workflow_id() -> aion_core::WorkflowId {
aion_core::WorkflowId::new(uuid::Uuid::nil())
}
fn run_id() -> aion_core::RunId {
aion_core::RunId::new(uuid::Uuid::nil())
}
fn payload(label: &str) -> Result<ProtoPayload, aion_core::PayloadError> {
Ok(ProtoPayload::from(aion_core::Payload::from_json(
&json!({ "label": label }),
)?))
}
fn recorded_at() -> Result<DateTime<Utc>, chrono::ParseError> {
Ok(DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")?.with_timezone(&Utc))
}
fn assert_json_round_trip<T>(value: &T) -> Result<(), serde_json::Error>
where
T: Clone + PartialEq + serde::Serialize + DeserializeOwned,
{
let encoded = serde_json::to_string(value)?;
let decoded = serde_json::from_str::<T>(&encoded)?;
assert!(decoded == *value);
Ok(())
}
fn assert_proto_round_trip<T>(value: &T) -> Result<(), Box<dyn std::error::Error>>
where
T: Clone + PartialEq + Message + Default,
{
let mut bytes = Vec::new();
value.encode(&mut bytes)?;
let decoded = T::decode(bytes.as_slice())?;
assert!(decoded == *value);
Ok(())
}
#[test]
fn start_workflow_round_trips_json_and_proto() -> Result<(), Box<dyn std::error::Error>> {
let request = ProtoStartWorkflowRequest {
namespace: String::from("tenant-a"),
workflow_type: String::from("checkout"),
input: Some(payload("input")?),
};
let response = ProtoStartWorkflowResponse {
workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
run_id: Some(ProtoRunId::from(run_id())),
};
assert_json_round_trip(&request)?;
assert_proto_round_trip(&request)?;
assert_json_round_trip(&response)?;
assert_proto_round_trip(&response)?;
Ok(())
}
#[test]
fn list_workflows_round_trips_json_and_proto() -> Result<(), Box<dyn std::error::Error>> {
let filter = ListWorkflowsFilter {
workflow_type: Some(String::from("checkout")),
status: Some(aion_core::WorkflowStatus::Running),
search_attributes: vec![SearchAttributePredicate::Equals {
name: String::from("customer_id"),
value: SearchAttributeValue::String(String::from("12345")),
}],
limit: Some(10),
offset: Some(5),
..ListWorkflowsFilter::default()
};
let summary = aion_store::visibility::WorkflowSummary {
workflow_id: workflow_id(),
run_id: run_id(),
workflow_type: String::from("checkout"),
status: aion_core::WorkflowStatus::Running,
start_time: recorded_at()?,
close_time: None,
search_attributes: HashMap::from([(
String::from("customer_id"),
SearchAttributeValue::String(String::from("12345")),
)]),
};
let filter_envelope = encode_core_value("tenant-a", Some(String::from("r1")), &filter)?;
let summary_envelope = encode_core_value("tenant-a", None, &summary)?;
let request = ProtoListWorkflowsRequest {
namespace: String::from("tenant-a"),
filter: Some(filter_envelope.clone()),
};
let response = ProtoListWorkflowsResponse {
summaries: vec![summary_envelope.clone()],
};
let count_request = ProtoCountWorkflowsRequest {
namespace: String::from("tenant-a"),
filter: Some(filter_envelope.clone()),
};
let count_response = ProtoCountWorkflowsResponse { count: 1 };
assert_json_round_trip(&request)?;
assert_proto_round_trip(&request)?;
assert_json_round_trip(&response)?;
assert_proto_round_trip(&response)?;
assert_json_round_trip(&count_request)?;
assert_proto_round_trip(&count_request)?;
assert_json_round_trip(&count_response)?;
assert_proto_round_trip(&count_response)?;
assert_eq!(
decode_core_value::<ListWorkflowsFilter>(&filter_envelope)?,
filter
);
assert_eq!(
decode_core_value::<aion_store::visibility::WorkflowSummary>(&summary_envelope)?,
summary
);
Ok(())
}
#[test]
fn query_round_trips_json_and_proto() -> Result<(), Box<dyn std::error::Error>> {
let request = ProtoQueryRequest {
namespace: String::from("tenant-a"),
workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
run_id: Some(ProtoRunId::from(run_id())),
query_name: String::from("state"),
};
let result_response = ProtoQueryResponse {
outcome: Some(proto_query_response::Outcome::Result(payload("result")?)),
};
let error_response = ProtoQueryResponse {
outcome: Some(proto_query_response::Outcome::Error(ProtoWireError::from(
WireError::unknown_query("state query is not registered"),
))),
};
assert_json_round_trip(&request)?;
assert_proto_round_trip(&request)?;
assert_json_round_trip(&result_response)?;
assert_proto_round_trip(&result_response)?;
assert_json_round_trip(&error_response)?;
assert_proto_round_trip(&error_response)?;
Ok(())
}
}