use heck::ToSnakeCase;
use std::fmt::Write;
use crate::ir::{
ApiKind, ApiSpec, EnumVariant, Field, HttpMethod, Operation, Param, ParamLocation, Protocol,
StreamingMode, TypeDef,
};
pub fn render(spec: &ApiSpec) -> String {
let mut out = String::new();
writeln!(
out,
"//! `{}` client crate, generated by `oxide-gen`.",
spec.display_name
)
.unwrap();
if let Some(desc) = &spec.description {
for line in desc.lines() {
writeln!(out, "//! {line}").unwrap();
}
}
writeln!(out, "//!").unwrap();
writeln!(out, "//! Spec kind: {}", spec.kind.slug()).unwrap();
writeln!(out, "//! Operations: {}", spec.operations.len()).unwrap();
writeln!(out, "//!").unwrap();
writeln!(
out,
"//! Do not edit by hand — re-run `oxide-gen` to regenerate."
)
.unwrap();
writeln!(out).unwrap();
writeln!(
out,
"#![allow(clippy::all, dead_code, unused_imports, unused_variables, unused_mut)]"
)
.unwrap();
writeln!(out).unwrap();
writeln!(out, "use serde::{{Deserialize, Serialize}};").unwrap();
writeln!(out).unwrap();
if spec.kind == ApiKind::Grpc {
writeln!(out, "pub mod proto {{").unwrap();
writeln!(out, " tonic::include_proto!(\"{}\");", spec.display_name).unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
for td in &spec.types {
writeln!(out, "pub use proto::{};", td.name()).unwrap();
}
writeln!(out).unwrap();
} else {
for td in &spec.types {
render_type(&mut out, td);
writeln!(out).unwrap();
}
}
render_client(&mut out, spec);
out
}
fn render_type(out: &mut String, td: &TypeDef) {
match td {
TypeDef::Struct {
name,
description,
fields,
} => {
emit_doc(out, description.as_deref(), "");
writeln!(out, "#[derive(Debug, Clone, Serialize, Deserialize)]").unwrap();
writeln!(out, "pub struct {name} {{").unwrap();
for field in fields {
render_field(out, field);
}
writeln!(out, "}}").unwrap();
}
TypeDef::Enum {
name,
description,
variants,
} => {
emit_doc(out, description.as_deref(), "");
writeln!(out, "#[derive(Debug, Clone, Serialize, Deserialize)]").unwrap();
writeln!(out, "pub enum {name} {{").unwrap();
for v in variants {
render_variant(out, v);
}
writeln!(out, "}}").unwrap();
}
TypeDef::Alias { name, target } => {
writeln!(out, "pub type {name} = {target};").unwrap();
}
}
}
fn render_field(out: &mut String, field: &Field) {
emit_doc(out, field.description.as_deref(), " ");
if let Some(rename) = &field.serde_rename {
writeln!(out, " #[serde(rename = \"{}\")]", escape(rename)).unwrap();
}
if field.optional {
writeln!(
out,
" #[serde(skip_serializing_if = \"Option::is_none\")]"
)
.unwrap();
}
writeln!(out, " pub {}: {},", field.name, field.rust_type).unwrap();
}
fn render_variant(out: &mut String, v: &EnumVariant) {
if let Some(rename) = &v.serde_rename {
writeln!(out, " #[serde(rename = \"{}\")]", escape(rename)).unwrap();
}
writeln!(out, " {},", v.name).unwrap();
}
fn render_client(out: &mut String, spec: &ApiSpec) {
let is_http = matches!(spec.kind, ApiKind::OpenApi | ApiKind::GraphQl);
let default_base = spec.base_url.clone().unwrap_or_else(|| match spec.kind {
ApiKind::OpenApi => "https://example.com".into(),
ApiKind::GraphQl => "https://example.com/graphql".into(),
ApiKind::Grpc => "http://localhost:50051".into(),
});
writeln!(out, "/// Asynchronous client for `{}`.", spec.display_name).unwrap();
writeln!(out, "#[derive(Debug, Clone)]").unwrap();
writeln!(out, "pub struct Client {{").unwrap();
writeln!(out, " pub base_url: String,").unwrap();
if is_http {
writeln!(out, " pub http: reqwest::Client,").unwrap();
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "impl Client {{").unwrap();
writeln!(out, " /// Build a client pointing at `base_url`.").unwrap();
writeln!(
out,
" pub fn new(base_url: impl Into<String>) -> Self {{"
)
.unwrap();
writeln!(out, " Self {{").unwrap();
writeln!(out, " base_url: base_url.into(),").unwrap();
if is_http {
writeln!(out, " http: reqwest::Client::new(),").unwrap();
}
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out).unwrap();
writeln!(
out,
" /// Build a client pointing at the default endpoint `{default_base}`."
)
.unwrap();
writeln!(out, " pub fn default_endpoint() -> Self {{").unwrap();
writeln!(out, " Self::new(\"{}\")", escape(&default_base)).unwrap();
writeln!(out, " }}").unwrap();
for op in &spec.operations {
writeln!(out).unwrap();
render_operation(out, spec, op);
}
writeln!(out, "}}").unwrap();
}
fn render_operation(out: &mut String, spec: &ApiSpec, op: &Operation) {
if let Some(desc) = &op.description {
for line in desc.lines() {
writeln!(out, " /// {line}").unwrap();
}
}
writeln!(out, " /// Endpoint: `{}`", op.endpoint).unwrap();
if op.streaming.is_streaming() {
writeln!(
out,
" /// Streaming mode: `{}` — this generated stub returns an error; wire your runtime (`tonic` / GraphQL subscriptions) to convert it into an `impl Stream`.",
op.streaming.label()
)
.unwrap();
}
let sig_params = op
.params
.iter()
.map(|p| {
let is_grpc_client_or_bidi_stream = (op.streaming == StreamingMode::ClientStream
|| op.streaming == StreamingMode::BidiStream)
&& op.protocol == Protocol::Grpc;
let ty = if is_grpc_client_or_bidi_stream {
format!(
"impl futures_util::Stream<Item = {}> + Send + 'static",
p.rust_type
)
} else if p.required {
p.rust_type.clone()
} else {
format!("Option<{}>", p.rust_type)
};
format!("{}: {}", p.name, ty)
})
.collect::<Vec<_>>()
.join(", ");
let prefix = if sig_params.is_empty() { "" } else { ", " };
let ret = if op.streaming.is_streaming() {
match op.protocol {
Protocol::GraphQl => {
format!(
"futures_util::stream::BoxStream<'static, anyhow::Result<{}>>",
op.return_type
)
}
Protocol::Grpc => {
format!("futures_util::stream::BoxStream<'static, std::result::Result<{}, tonic::Status>>", op.return_type)
}
_ => {
format!(
"futures_util::stream::BoxStream<'static, anyhow::Result<{}>>",
op.return_type
)
}
}
} else {
op.return_type.clone()
};
writeln!(
out,
" pub async fn {name}(&self{prefix}{sig_params}) -> anyhow::Result<{ret}> {{",
name = op.id,
ret = ret,
)
.unwrap();
if op.streaming.is_streaming() {
match op.protocol {
Protocol::GraphQl => render_graphql_subscription_body(out, op),
Protocol::Grpc => render_grpc_body(out, op),
_ => render_streaming_body(out, op),
}
} else {
match op.protocol {
Protocol::Rest => render_rest_body(out, spec, op),
Protocol::GraphQl => render_graphql_body(out, op),
Protocol::Grpc => render_grpc_body(out, op),
}
}
writeln!(out, " }}").unwrap();
}
fn render_streaming_body(out: &mut String, op: &Operation) {
writeln!(
out,
" // Streaming ({}) scaffold. Wire `tonic` server-streaming or a GraphQL subscription client to fulfil this method.",
op.streaming.label()
)
.unwrap();
for p in &op.params {
writeln!(out, " let _ = &{n};", n = p.name).unwrap();
}
writeln!(
out,
" anyhow::bail!(\"streaming operation `{}` ({}) not yet wired; regenerate after attaching a streaming runtime.\");",
op.original_id,
op.streaming.label()
)
.unwrap();
}
fn render_rest_body(out: &mut String, _spec: &ApiSpec, op: &Operation) {
let path = extract_path(&op.endpoint);
let path_params = path_params_in_order(&path, &op.params);
let mut url_template = path.clone();
for p in &path_params {
url_template = url_template.replace(
&format!("{{{}}}", p.original_name),
&format!("{{{}}}", p.name),
);
}
write!(
out,
" let url = format!(\"{{base}}{}\"",
url_template
)
.unwrap();
write!(out, ", base = self.base_url").unwrap();
for p in &path_params {
write!(out, ", {n} = {n}", n = p.name).unwrap();
}
writeln!(out, ");").unwrap();
let method_fn = op.http_method.reqwest_fn().unwrap_or("get");
writeln!(out, " let mut req = self.http.{method_fn}(&url);").unwrap();
for p in op
.params
.iter()
.filter(|p| p.location == ParamLocation::Query)
{
if p.required {
writeln!(
out,
" req = req.query(&[(\"{}\", {}.to_string())]);",
escape(&p.original_name),
p.name
)
.unwrap();
} else {
writeln!(
out,
" if let Some(v) = {n}.as_ref() {{ req = req.query(&[(\"{orig}\", v.to_string())]); }}",
n = p.name,
orig = escape(&p.original_name),
)
.unwrap();
}
}
for p in op
.params
.iter()
.filter(|p| p.location == ParamLocation::Header)
{
if p.required {
writeln!(
out,
" req = req.header(\"{}\", {}.to_string());",
escape(&p.original_name),
p.name
)
.unwrap();
} else {
writeln!(
out,
" if let Some(v) = {n}.as_ref() {{ req = req.header(\"{orig}\", v.to_string()); }}",
n = p.name,
orig = escape(&p.original_name),
)
.unwrap();
}
}
if let Some(body) = op.params.iter().find(|p| p.location == ParamLocation::Body) {
if body.required {
writeln!(out, " req = req.json(&{});", body.name).unwrap();
} else {
writeln!(
out,
" if let Some(v) = {n}.as_ref() {{ req = req.json(v); }}",
n = body.name
)
.unwrap();
}
}
writeln!(
out,
" let response = req.send().await?.error_for_status()?;"
)
.unwrap();
writeln!(
out,
" let parsed: {ret} = response.json().await?;",
ret = op.return_type
)
.unwrap();
writeln!(out, " Ok(parsed)").unwrap();
}
fn render_graphql_body(out: &mut String, op: &Operation) {
let var_decls = op
.params
.iter()
.filter(|p| p.location == ParamLocation::GraphQlVariable)
.map(|p| {
if p.required {
format!("vars.insert(\"{n}\".to_string(), serde_json::to_value(&{n})?);", n = p.name)
} else {
format!("if let Some(v) = {n}.as_ref() {{ vars.insert(\"{n}\".to_string(), serde_json::to_value(v)?); }}", n = p.name)
}
})
.collect::<Vec<_>>()
.join("\n ");
let op_word = match op.http_method {
HttpMethod::Post => "query", _ => "query",
};
let op_word = if op.endpoint.starts_with("POST") && op.id.starts_with("create") {
"mutation"
} else {
op_word
};
let arg_list = op
.params
.iter()
.filter(|p| p.location == ParamLocation::GraphQlVariable)
.map(|p| {
let ty = if p.required {
format!("{}!", p.rust_type)
} else {
p.rust_type.clone()
};
format!("${n}: {ty}", n = p.name)
})
.collect::<Vec<_>>()
.join(", ");
let arg_decl = if arg_list.is_empty() {
String::new()
} else {
format!("({})", arg_list)
};
let inner_args = op
.params
.iter()
.filter(|p| p.location == ParamLocation::GraphQlVariable)
.map(|p| format!("{n}: ${n}", n = p.name))
.collect::<Vec<_>>()
.join(", ");
let inner_args_block = if inner_args.is_empty() {
String::new()
} else {
format!("({inner_args})")
};
let query_str = format!(
"{op_word} {oid}{arg_decl} {{ {oid}{inner_args_block} }}",
oid = op.original_id,
);
writeln!(out, " let query = \"{}\";", escape(&query_str)).unwrap();
writeln!(
out,
" let mut vars: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();"
)
.unwrap();
if !var_decls.is_empty() {
writeln!(out, " {var_decls}").unwrap();
}
writeln!(
out,
" let body = serde_json::json!({{ \"query\": query, \"variables\": serde_json::Value::Object(vars) }});"
)
.unwrap();
writeln!(
out,
" let response = self.http.post(&self.base_url).json(&body).send().await?.error_for_status()?;"
)
.unwrap();
writeln!(
out,
" let envelope: serde_json::Value = response.json().await?;"
)
.unwrap();
writeln!(
out,
" let data = envelope.get(\"data\").and_then(|d| d.get(\"{}\")).cloned().unwrap_or(serde_json::Value::Null);",
op.original_id
)
.unwrap();
writeln!(
out,
" let parsed: {ret} = serde_json::from_value(data)?;",
ret = op.return_type
)
.unwrap();
writeln!(out, " Ok(parsed)").unwrap();
}
fn render_graphql_subscription_body(out: &mut String, op: &Operation) {
let var_decls = op
.params
.iter()
.filter(|p| p.location == ParamLocation::GraphQlVariable)
.map(|p| {
if p.required {
format!("vars.insert(\"{n}\".to_string(), serde_json::to_value(&{n})?);", n = p.name)
} else {
format!("if let Some(v) = {n}.as_ref() {{ vars.insert(\"{n}\".to_string(), serde_json::to_value(v)?); }}", n = p.name)
}
})
.collect::<Vec<_>>()
.join("\n ");
let arg_list = op
.params
.iter()
.filter(|p| p.location == ParamLocation::GraphQlVariable)
.map(|p| {
let ty = if p.required {
format!("{}!", p.rust_type)
} else {
p.rust_type.clone()
};
format!("${n}: {ty}", n = p.name)
})
.collect::<Vec<_>>()
.join(", ");
let arg_decl = if arg_list.is_empty() {
String::new()
} else {
format!("({})", arg_list)
};
let inner_args = op
.params
.iter()
.filter(|p| p.location == ParamLocation::GraphQlVariable)
.map(|p| format!("{n}: ${n}", n = p.name))
.collect::<Vec<_>>()
.join(", ");
let inner_args_block = if inner_args.is_empty() {
String::new()
} else {
format!("({inner_args})")
};
let query_str = format!(
"subscription {oid}{arg_decl} {{ {oid}{inner_args_block} }}",
oid = op.original_id,
);
writeln!(out, " use futures_util::{{SinkExt, StreamExt}};").unwrap();
writeln!(out, " let ws_url = self.base_url.replace(\"https://\", \"wss://\").replace(\"http://\", \"ws://\");").unwrap();
writeln!(
out,
" let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;"
)
.unwrap();
writeln!(
out,
" let (mut write, mut read) = ws_stream.split();"
)
.unwrap();
writeln!(out).unwrap();
writeln!(
out,
" let init_msg = tokio_tungstenite::tungstenite::Message::Text("
)
.unwrap();
writeln!(
out,
" r#\"{{\"type\":\"connection_init\"}}\"#.into()"
)
.unwrap();
writeln!(out, " );").unwrap();
writeln!(out, " write.send(init_msg).await?;").unwrap();
writeln!(out).unwrap();
writeln!(out, " if let Some(msg) = read.next().await {{").unwrap();
writeln!(out, " let msg = msg?;").unwrap();
writeln!(
out,
" if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {{"
)
.unwrap();
writeln!(
out,
" let ack: serde_json::Value = serde_json::from_str(&text)?;"
)
.unwrap();
writeln!(out, " if ack.get(\"type\").and_then(|t| t.as_str()) != Some(\"connection_ack\") {{").unwrap();
writeln!(
out,
" anyhow::bail!(\"Expected connection_ack, got: {{}}\", text);"
)
.unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }} else {{").unwrap();
writeln!(
out,
" anyhow::bail!(\"Expected connection_ack, got non-text message\");"
)
.unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }} else {{").unwrap();
writeln!(
out,
" anyhow::bail!(\"Connection closed during handshake\");"
)
.unwrap();
writeln!(out, " }}").unwrap();
writeln!(out).unwrap();
writeln!(out, " let query = \"{}\";", escape(&query_str)).unwrap();
writeln!(out, " let mut vars: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();").unwrap();
if !var_decls.is_empty() {
writeln!(out, " {var_decls}").unwrap();
}
writeln!(out, " let payload = serde_json::json!({{").unwrap();
writeln!(out, " \"query\": query,").unwrap();
writeln!(
out,
" \"variables\": serde_json::Value::Object(vars),"
)
.unwrap();
writeln!(out, " }});").unwrap();
writeln!(out, " let sub_msg = serde_json::json!({{").unwrap();
writeln!(out, " \"id\": \"sub_1\",").unwrap();
writeln!(out, " \"type\": \"subscribe\",").unwrap();
writeln!(out, " \"payload\": payload,").unwrap();
writeln!(out, " }});").unwrap();
writeln!(
out,
" let sub_text = serde_json::to_string(&sub_msg)?;"
)
.unwrap();
writeln!(out, " write.send(tokio_tungstenite::tungstenite::Message::Text(sub_text.into())).await?;").unwrap();
writeln!(out).unwrap();
writeln!(out, " let stream = futures_util::stream::unfold((write, read), |(mut write, mut read)| async move {{").unwrap();
writeln!(out, " loop {{").unwrap();
writeln!(out, " match read.next().await {{").unwrap();
writeln!(out, " Some(Ok(msg)) => {{").unwrap();
writeln!(out, " match msg {{").unwrap();
writeln!(
out,
" tokio_tungstenite::tungstenite::Message::Text(text) => {{"
)
.unwrap();
writeln!(out, " let val: serde_json::Value = match serde_json::from_str(&text) {{").unwrap();
writeln!(out, " Ok(v) => v,").unwrap();
writeln!(out, " Err(e) => return Some((Err(anyhow::anyhow!(e)), (write, read))),").unwrap();
writeln!(out, " }};").unwrap();
writeln!(out, " let msg_type = val.get(\"type\").and_then(|t| t.as_str());").unwrap();
writeln!(out, " match msg_type {{").unwrap();
writeln!(
out,
" Some(\"ping\") => {{"
)
.unwrap();
writeln!(out, " let pong_msg = tokio_tungstenite::tungstenite::Message::Text(").unwrap();
writeln!(
out,
" r#\"{{\"type\":\"pong\"}}\"#.into()"
)
.unwrap();
writeln!(out, " );").unwrap();
writeln!(
out,
" if let Err(e) = write.send(pong_msg).await {{"
)
.unwrap();
writeln!(out, " return Some((Err(anyhow::anyhow!(e)), (write, read)));").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(
out,
" Some(\"pong\") => {{}}"
)
.unwrap();
writeln!(
out,
" Some(\"next\") => {{"
)
.unwrap();
writeln!(
out,
" let data = val.get(\"payload\")"
)
.unwrap();
writeln!(
out,
" .and_then(|p| p.get(\"data\"))"
)
.unwrap();
writeln!(
out,
" .and_then(|d| d.get(\"{}\"))",
op.original_id
)
.unwrap();
writeln!(out, " .cloned()").unwrap();
writeln!(
out,
" .unwrap_or(serde_json::Value::Null);"
)
.unwrap();
writeln!(out, " let parsed = match serde_json::from_value::<{}>(data) {{", op.return_type).unwrap();
writeln!(
out,
" Ok(p) => p,"
)
.unwrap();
writeln!(out, " Err(e) => return Some((Err(anyhow::anyhow!(e)), (write, read))),").unwrap();
writeln!(out, " }};").unwrap();
writeln!(
out,
" return Some((Ok(parsed), (write, read)));"
)
.unwrap();
writeln!(out, " }}").unwrap();
writeln!(
out,
" Some(\"error\") => {{"
)
.unwrap();
writeln!(out, " let errors = val.get(\"payload\").cloned().unwrap_or(serde_json::Value::Null);").unwrap();
writeln!(out, " return Some((Err(anyhow::anyhow!(\"GraphQL subscription error: {{}}\", errors)), (write, read)));").unwrap();
writeln!(out, " }}").unwrap();
writeln!(
out,
" Some(\"complete\") => {{"
)
.unwrap();
writeln!(out, " return None;").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " _ => {{}}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(
out,
" tokio_tungstenite::tungstenite::Message::Close(_) => {{"
)
.unwrap();
writeln!(out, " return None;").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " _ => {{}}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " Some(Err(e)) => {{").unwrap();
writeln!(
out,
" return Some((Err(anyhow::anyhow!(e)), (write, read)));"
)
.unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " None => {{").unwrap();
writeln!(out, " return None;").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " }});").unwrap();
writeln!(out, " Ok(stream.boxed())").unwrap();
}
fn render_grpc_body(out: &mut String, op: &Operation) {
let parts: Vec<&str> = op.endpoint.split('/').collect();
if parts.len() < 3 {
writeln!(
out,
" anyhow::bail!(\"gRPC endpoint '{}' format invalid. Expected /Service/Method\");",
op.endpoint
)
.unwrap();
return;
}
let service_name = parts[1];
let method_name = parts[2];
let service_snake = service_name.to_snake_case();
let method_snake = method_name.to_snake_case();
writeln!(
out,
" let mut client = proto::{service_snake}_client::{service_name}Client::connect(self.base_url.clone()).await?;"
)
.unwrap();
if op.streaming == StreamingMode::ClientStream || op.streaming == StreamingMode::BidiStream {
writeln!(
out,
" let response = client.{method_snake}(request).await?;"
)
.unwrap();
} else {
writeln!(
out,
" let response = client.{method_snake}(tonic::Request::new(request)).await?;"
)
.unwrap();
}
if op.streaming.is_streaming() {
writeln!(out, " use futures_util::StreamExt;").unwrap();
writeln!(out, " Ok(response.into_inner().boxed())").unwrap();
} else {
writeln!(out, " Ok(response.into_inner())").unwrap();
}
}
fn extract_path(endpoint: &str) -> String {
let trimmed = endpoint.trim();
match trimmed.find(' ') {
Some(idx) => trimmed[idx + 1..].to_string(),
None => trimmed.to_string(),
}
}
fn path_params_in_order<'a>(path: &str, params: &'a [Param]) -> Vec<&'a Param> {
let mut out = Vec::new();
let bytes = path.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'{' {
if let Some(end) = path[i + 1..].find('}') {
let name = &path[i + 1..i + 1 + end];
if let Some(p) = params
.iter()
.find(|p| p.location == ParamLocation::Path && p.original_name == name)
{
out.push(p);
}
i += end + 2;
continue;
}
}
i += 1;
}
out
}
fn emit_doc(out: &mut String, text: Option<&str>, indent: &str) {
let Some(t) = text else { return };
for line in t.lines() {
writeln!(out, "{indent}/// {line}").unwrap();
}
}
fn escape(s: &str) -> String {
s.replace('\\', "\\\\").replace('"', "\\\"")
}