use std::collections::HashMap;
use tonic::transport::Channel;
use crate::error::{ClientError, ClientResult};
use stepflow_flow::workflow::Flow;
use stepflow_proto::{
CreateRunRequest, GetRunEventsRequest, GetRunItemsRequest, GetRunRequest, HealthCheckRequest,
ListRegisteredComponentsRequest, StoreFlowRequest,
components_service_client::ComponentsServiceClient, flows_service_client::FlowsServiceClient,
health_service_client::HealthServiceClient, runs_service_client::RunsServiceClient,
};
#[derive(Debug, Clone)]
pub struct RunStatus {
pub run_id: String,
pub status: i32,
pub outputs: Vec<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct ComponentInfo {
pub component: String,
pub description: Option<String>,
pub input_schema: Option<serde_json::Value>,
pub output_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct ListComponentsResult {
pub components: Vec<ComponentInfo>,
pub complete: bool,
pub failed_plugins: Vec<(String, String)>,
}
#[derive(Debug, Clone)]
pub struct FlowVariable {
pub description: Option<String>,
pub default_value: Option<serde_json::Value>,
pub required: bool,
pub schema: Option<serde_json::Value>,
pub env_var: Option<String>,
}
pub type StatusEventStream = tonic::codec::Streaming<stepflow_proto::StatusEvent>;
pub struct StepflowClient {
flows: FlowsServiceClient<Channel>,
runs: RunsServiceClient<Channel>,
health: HealthServiceClient<Channel>,
components: ComponentsServiceClient<Channel>,
}
impl StepflowClient {
pub async fn connect(url: impl Into<String>) -> ClientResult<Self> {
let url = url.into();
let channel = Channel::from_shared(url.clone())
.map_err(|e| ClientError::Connection {
url: url.clone(),
source: Box::new(e),
})?
.connect()
.await
.map_err(|e| ClientError::Connection {
url,
source: Box::new(e),
})?;
Ok(Self {
flows: FlowsServiceClient::new(channel.clone()),
runs: RunsServiceClient::new(channel.clone()),
health: HealthServiceClient::new(channel.clone()),
components: ComponentsServiceClient::new(channel),
})
}
pub async fn store_flow(&mut self, flow: &Flow) -> ClientResult<String> {
let flow_json = serde_json::to_value(flow)?;
let flow_value = json_to_proto_value(flow_json);
let flow_struct = match flow_value.kind {
Some(prost_wkt_types::value::Kind::StructValue(s)) => s,
_ => {
return Err(ClientError::InvalidResponse(
"Flow JSON must be an object".to_string(),
));
}
};
let request = StoreFlowRequest {
flow: Some(flow_struct),
dry_run: false,
};
let response = self.flows.store_flow(request).await?.into_inner();
Ok(response.flow_id)
}
pub async fn run(
&mut self,
flow_id: &str,
input: serde_json::Value,
) -> ClientResult<serde_json::Value> {
let input_proto = json_to_proto_value(input);
let request = CreateRunRequest {
flow_id: flow_id.to_string(),
input: vec![input_proto],
wait: true,
..Default::default()
};
let response = self.runs.create_run(request).await?.into_inner();
if let Some(item) = response.results.first() {
if let Some(output) = &item.output {
return Ok(proto_value_to_json(output));
}
if let Some(msg) = &item.error_message {
return Err(ClientError::InvalidResponse(format!("Run failed: {msg}")));
}
}
Err(ClientError::InvalidResponse(
"Run completed but returned no output".to_string(),
))
}
pub async fn submit(
&mut self,
flow_id: &str,
input: serde_json::Value,
) -> ClientResult<String> {
let input_proto = json_to_proto_value(input);
let request = CreateRunRequest {
flow_id: flow_id.to_string(),
input: vec![input_proto],
wait: false,
..Default::default()
};
let response = self.runs.create_run(request).await?.into_inner();
Ok(response.summary.map(|s| s.run_id).unwrap_or_default())
}
pub async fn get_run(&mut self, run_id: &str, wait: bool) -> ClientResult<RunStatus> {
let request = GetRunRequest {
run_id: run_id.to_string(),
wait,
timeout_secs: None,
};
let response = self.runs.get_run(request).await?.into_inner();
let summary = response.summary.unwrap_or_default();
Ok(RunStatus {
run_id: summary.run_id,
status: summary.status,
outputs: vec![],
})
}
pub async fn get_run_items(&mut self, run_id: &str) -> ClientResult<Vec<serde_json::Value>> {
let request = GetRunItemsRequest {
run_id: run_id.to_string(),
result_order: 0, };
let response = self.runs.get_run_items(request).await?.into_inner();
let mut outputs = Vec::with_capacity(response.results.len());
for item in &response.results {
if let Some(output) = &item.output {
outputs.push(proto_value_to_json(output));
} else if let Some(msg) = &item.error_message {
return Err(ClientError::InvalidResponse(format!(
"Run item failed: {msg}"
)));
} else {
outputs.push(serde_json::Value::Null);
}
}
Ok(outputs)
}
pub async fn list_components(
&mut self,
exclude_schemas: bool,
) -> ClientResult<ListComponentsResult> {
let request = ListRegisteredComponentsRequest { exclude_schemas };
let response = self
.components
.list_registered_components(request)
.await?
.into_inner();
let components = response
.components
.into_iter()
.map(|c| ComponentInfo {
component: c.component_id,
description: c.description,
input_schema: c.input_schema.map(proto_struct_to_json),
output_schema: c.output_schema.map(proto_struct_to_json),
})
.collect();
let failed_plugins = response
.failed_plugins
.into_iter()
.map(|e| (e.plugin, e.error))
.collect();
Ok(ListComponentsResult {
components,
complete: response.complete,
failed_plugins,
})
}
pub async fn status_events(
&mut self,
run_id: &str,
include_sub_runs: bool,
include_results: bool,
) -> ClientResult<StatusEventStream> {
let request = GetRunEventsRequest {
run_id: run_id.to_string(),
since: None,
event_types: vec![],
include_sub_runs,
include_results,
};
let stream = self.runs.get_run_events(request).await?.into_inner();
Ok(stream)
}
pub async fn get_flow_variables(
&mut self,
flow_id: &str,
) -> ClientResult<HashMap<String, FlowVariable>> {
use stepflow_proto::GetFlowVariablesRequest;
let request = GetFlowVariablesRequest {
flow_id: flow_id.to_string(),
};
let response = self.flows.get_flow_variables(request).await?.into_inner();
let variables = response
.variables
.into_iter()
.map(|(name, v)| {
(
name,
FlowVariable {
description: v.description,
default_value: v.default_value.as_ref().map(proto_value_to_json),
required: v.required,
schema: v.schema.map(proto_struct_to_json),
env_var: v.env_var,
},
)
})
.collect();
Ok(variables)
}
pub async fn is_healthy(&mut self) -> bool {
self.health
.health_check(HealthCheckRequest {})
.await
.is_ok()
}
}
pub(crate) fn json_to_proto_value(value: serde_json::Value) -> prost_wkt_types::Value {
use prost_wkt_types::value::Kind;
prost_wkt_types::Value {
kind: Some(match value {
serde_json::Value::Null => Kind::NullValue(0),
serde_json::Value::Bool(b) => Kind::BoolValue(b),
serde_json::Value::Number(n) => Kind::NumberValue(n.as_f64().unwrap_or(0.0)),
serde_json::Value::String(s) => Kind::StringValue(s),
serde_json::Value::Array(arr) => Kind::ListValue(prost_wkt_types::ListValue {
values: arr.into_iter().map(json_to_proto_value).collect(),
}),
serde_json::Value::Object(obj) => Kind::StructValue(prost_wkt_types::Struct {
fields: obj
.into_iter()
.map(|(k, v)| (k, json_to_proto_value(v)))
.collect(),
}),
}),
}
}
pub(crate) fn proto_value_to_json(value: &prost_wkt_types::Value) -> serde_json::Value {
use prost_wkt_types::value::Kind;
match &value.kind {
Some(Kind::NullValue(_)) | None => serde_json::Value::Null,
Some(Kind::BoolValue(b)) => serde_json::Value::Bool(*b),
Some(Kind::NumberValue(n)) => {
let n = *n;
if n.is_finite() && n.fract() == 0.0 {
let i = n as i64;
if i as f64 == n {
return serde_json::Value::Number(i.into());
}
}
serde_json::Number::from_f64(n)
.map(serde_json::Value::Number)
.unwrap_or(serde_json::Value::Null)
}
Some(Kind::StringValue(s)) => serde_json::Value::String(s.clone()),
Some(Kind::StructValue(s)) => {
let map = s
.fields
.iter()
.map(|(k, v)| (k.clone(), proto_value_to_json(v)))
.collect();
serde_json::Value::Object(map)
}
Some(Kind::ListValue(l)) => {
serde_json::Value::Array(l.values.iter().map(proto_value_to_json).collect())
}
}
}
fn proto_struct_to_json(s: prost_wkt_types::Struct) -> serde_json::Value {
let map = s
.fields
.into_iter()
.map(|(k, v)| (k, proto_value_to_json(&v)))
.collect();
serde_json::Value::Object(map)
}