use heck::ToLowerCamelCase;
use specta::{
Format, Type, Types,
datatype::{DataType, Field, NamedReferenceType, Primitive, Reference, Struct},
};
use specta_serde::Phase;
use specta_typescript::{Error, Exporter as TsExporter, FrameworkExporter, Typescript, define};
use specta_util::Remapper;
use std::borrow::Cow;
use std::collections::BTreeMap;
use crate::TauRpcFunction;
const FRAMEWORK_HEADER: &str =
"// This file has been generated by TauRPC. Do not edit this file manually.";
static PACKAGE_JSON: &str = r#"
{
"name": ".taurpc",
"types": "index.ts"
}
"#;
static BOILERPLATE_TS_IMPORT: &str = r#"
import { createTauRPCProxy as createProxy, type InferCommandOutput, type TauRpcResult, type UnlistenFn } from 'taurpc'
"#;
static BOILERPLATE_TS_EXPORT: &str = r#"
// export const createTauRPCProxy = () => createProxy<Router>(ARGS_MAP)
export type { InferCommandOutput, TauRpcResult }
"#;
pub type ExportError = Error;
pub trait Exportable {
fn generate_types(
&self,
) -> (
Types,
BTreeMap<String, Vec<TauRpcFunction>>,
BTreeMap<String, String>,
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ErrorHandlingMode {
#[default]
Throw,
Result,
}
#[derive(Default)]
pub struct Exporter {
ts_config: Typescript,
specta_phases: bool,
error_handling: ErrorHandlingMode,
typed_error_impl: Cow<'static, str>,
}
impl Exporter {
pub fn new() -> Self {
Self {
specta_phases: true,
..Default::default()
}
}
pub fn ts_config(mut self, config: Typescript) -> Self {
self.ts_config = config;
self
}
pub fn specta_phases(mut self, enabled: bool) -> Self {
self.specta_phases = enabled;
self
}
pub fn error_handling(mut self, error_handling: ErrorHandlingMode) -> Self {
self.error_handling = error_handling;
self
}
pub fn typed_error_impl(mut self, runtime: impl Into<Cow<'static, str>>) -> Self {
self.typed_error_impl = runtime.into();
self
}
pub fn export(
self,
exportable: &impl Exportable,
path: impl AsRef<std::path::Path>,
) -> Result<(), Error> {
let (types, functions, args_map) = exportable.generate_types();
let format = SpectaFormat::new(self.specta_phases);
let format_clone = format.clone();
TsExporter::from(self.ts_config)
.framework_prelude(FRAMEWORK_HEADER)
.framework_runtime(move |mut exporter| {
let mut out = String::new();
out.push_str(BOILERPLATE_TS_IMPORT);
out.push_str(&exporter.render_types()?);
let parsed_args_map = generate_args_map(&args_map)?;
out.push_str(r#"const ARGS_MAP = "#);
out.push_str(
&serde_json::to_string(&parsed_args_map)
.map_err(|err| Error::framework("error stringify argument map", err))?,
);
out.push_str(";\n\n");
let result_map = generate_result_map(&functions, &exporter);
out.push_str(r#"const RESULT_MAP = "#);
out.push_str(
&serde_json::to_string(&result_map)
.map_err(|err| Error::framework("error stringify result map", err))?,
);
out.push_str(";\n\n");
out.push_str(
&generate_functions_router(
&functions,
&exporter,
&format_clone,
self.error_handling,
)
.map_err(|err| Error::framework("failed to generate router type", err))?,
);
if !self.typed_error_impl.is_empty() {
out.push_str(&self.typed_error_impl);
out.push_str("\n\n");
}
out.push_str("export const createTauRPCProxy = () => createProxy<Router>({\n");
out.push_str(" argsMap: ARGS_MAP,\n");
out.push_str(" resultMap: RESULT_MAP,\n");
out.push_str(" errorHandling: ");
out.push_str(match self.error_handling {
ErrorHandlingMode::Throw => "\"throw\"",
ErrorHandlingMode::Result => "\"result\"",
});
out.push_str(",\n");
if !self.typed_error_impl.is_empty() {
out.push_str(" typedError,\n");
}
out.push_str("})\n");
out.push_str(BOILERPLATE_TS_EXPORT);
Ok(out.into())
})
.export_to(path.as_ref(), &types, format)?;
if path
.as_ref()
.to_string_lossy()
.ends_with("node_modules\\.taurpc\\index.ts")
{
let package_json_path = path
.as_ref()
.parent()
.ok_or(Error::framework("", "Failed to create 'package.json' path"))?
.join("package.json");
std::fs::write(package_json_path, PACKAGE_JSON)
.map_err(|err| Error::framework("Failed to create 'package.json' file", err))?
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct SpectaFormat {
specta_phases_enabled: bool,
remapper: Remapper,
}
impl SpectaFormat {
fn new(specta_phases_enabled: bool) -> Self {
let mut remapper = Remapper::new();
let number = <specta_typescript::Number as Type>::definition(&mut Types::default());
remapper = remapper
.rule(DataType::Primitive(Primitive::usize), number.clone())
.rule(DataType::Primitive(Primitive::isize), number.clone())
.rule(DataType::Primitive(Primitive::u64), number.clone())
.rule(DataType::Primitive(Primitive::i64), number.clone())
.rule(DataType::Primitive(Primitive::u128), number.clone())
.rule(DataType::Primitive(Primitive::i128), number.clone())
.rule(
<specta_typescript::BigInt as Type>::definition(&mut Types::default()),
number,
);
Self {
specta_phases_enabled,
remapper,
}
}
}
impl Format for SpectaFormat {
fn map_types(&'_ self, types: &Types) -> Result<Cow<'_, Types>, specta::FormatError> {
let types = if self.specta_phases_enabled {
specta_serde::PhasesFormat.map_types(types)
} else {
specta_serde::Format.map_types(types)
}?;
Ok(Cow::Owned(self.remapper.remap_types(types.into_owned())))
}
fn map_type(
&'_ self,
types: &Types,
dt: &DataType,
) -> Result<Cow<'_, DataType>, specta::FormatError> {
let dt = if self.specta_phases_enabled {
specta_serde::PhasesFormat.map_type(types, dt)?
} else {
specta_serde::Format.map_type(types, dt)?
};
Ok(Cow::Owned(self.remapper.remap_dt(dt.into_owned())))
}
}
fn generate_functions_router(
functions: &BTreeMap<String, Vec<TauRpcFunction>>,
exporter: &FrameworkExporter,
format: &SpectaFormat,
error_handling: ErrorHandlingMode,
) -> Result<String, Error> {
let mut router = Struct::named();
for (path, path_functions) in functions {
let mut function_names_and_funcs: Vec<_> = path_functions
.iter()
.map(|f| (f.function.name(), f))
.collect();
function_names_and_funcs.sort_by(|a, b| a.0.cmp(b.0));
let mut path_router = Struct::named();
for (_, function) in function_names_and_funcs {
let (name, field) =
generate_function_field(function, exporter, format, error_handling)?;
path_router = path_router.field(name, field);
}
router = router.field(path.clone(), Field::new(path_router.build()));
}
let router_type = exporter.inline(&router.build())?;
Ok(format!("export type Router = {router_type};\n"))
}
fn generate_function_field(
function: &TauRpcFunction,
exporter: &FrameworkExporter,
format: &SpectaFormat,
error_handling: ErrorHandlingMode,
) -> Result<(String, Field), Error> {
let specta_fn = &function.function;
let args = specta_fn
.args()
.iter()
.map(|(name, typ)| {
let phase = if function.is_event {
Phase::Serialize
} else {
Phase::Deserialize
};
render_reference_dt_for_phase(typ, phase, exporter, format)
.map(|ty| format!("{}: {ty}", name.to_lower_camel_case()))
})
.collect::<Result<Vec<_>, _>>()?
.join(", ");
let return_ty = if function.is_event {
"void".to_string()
} else if let Some(result) = specta_fn.result() {
if let Some((dt_ok, dt_err)) = extract_std_result(result, exporter.types) {
let ok_str = render_reference_dt_for_phase(dt_ok, Phase::Serialize, exporter, format)?;
if error_handling == ErrorHandlingMode::Result {
let err_str =
render_reference_dt_for_phase(dt_err, Phase::Serialize, exporter, format)?;
format!("TauRpcResult<{}, {}>", ok_str, err_str)
} else {
ok_str
}
} else {
render_reference_dt_for_phase(result, Phase::Serialize, exporter, format)?
}
} else {
"void".to_string()
};
let name = specta_fn.name().split_once("_taurpc_fn__").unwrap().1;
let field_type = if function.is_event {
format!("{{ on: (listener: ({args}) => void) => Promise<UnlistenFn> }}")
} else {
format!("({args}) => Promise<{return_ty}>")
};
let mut field = Field::new(DataType::Reference(define(field_type).into()));
field.docs = specta_fn.docs.clone();
Ok((name.to_string(), field))
}
fn render_reference_dt_for_phase(
dt: &DataType,
phase: Phase,
exporter: &FrameworkExporter,
format: &SpectaFormat,
) -> Result<String, Error> {
if let DataType::Reference(Reference::Named(r)) = dt
&& let Some(ndt) = exporter.types.get(r)
&& ndt.name == "TAURI_CHANNEL"
&& ndt.module_path.starts_with("tauri::")
{
let generic = if let NamedReferenceType::Reference { generics, .. } = &r.inner {
if let Some((_, generic_dt)) = generics.first() {
render_reference_dt_for_phase(generic_dt, Phase::Serialize, exporter, format)?
} else {
"void".into()
}
} else {
"void".into()
};
return Ok(format!("(response: {generic}) => void"));
}
let dt1 = specta_serde::select_phase_datatype(dt, exporter.types, phase);
let dt = format.remapper.remap_dt(dt1);
render_reference_dt(&dt, exporter)
}
fn render_reference_dt(dt: &DataType, exporter: &FrameworkExporter) -> Result<String, Error> {
if let DataType::Reference(Reference::Named(r)) = dt
&& let Some(ndt) = exporter.types.get(r)
&& ndt.name == "TAURI_CHANNEL"
&& ndt.module_path.starts_with("tauri::")
{
let generic = if let NamedReferenceType::Reference { generics, .. } = &r.inner {
if let Some((_, dt)) = generics.first() {
match &dt {
DataType::Reference(r) => exporter.reference(r)?,
dt => exporter.inline(dt)?,
}
} else {
"void".into()
}
} else {
"void".into()
};
return Ok(format!("(response: {generic}) => void"));
}
match dt {
DataType::Reference(r) => exporter.reference(r),
dt => exporter.inline(dt),
}
}
fn extract_std_result<'a>(
dt: &'a DataType,
types: &'a Types,
) -> Option<(&'a DataType, &'a DataType)> {
if let DataType::Reference(Reference::Named(r)) = dt
&& let Some(ndt) = types.get(r)
&& ndt.name == "Result"
&& (ndt.module_path == "std::result" || ndt.module_path == "core::result")
&& let NamedReferenceType::Reference { generics, .. } = &r.inner
&& let [(_, ok), (_, err), ..] = generics.as_slice()
{
return Some((ok, err));
}
None
}
fn generate_result_map(
functions: &BTreeMap<String, Vec<TauRpcFunction>>,
exporter: &FrameworkExporter,
) -> BTreeMap<String, BTreeMap<String, bool>> {
let mut map = BTreeMap::new();
for (path, path_functions) in functions {
let mut result_fns = BTreeMap::new();
for function in path_functions {
let function_name = function
.function
.name()
.split_once("_taurpc_fn__")
.unwrap()
.1;
let result_type = if let Some(result) = function.function.result() {
extract_std_result(result, exporter.types).is_some()
} else {
false
};
result_fns.insert(function_name.to_string(), result_type);
}
map.insert(path.to_string(), result_fns);
}
map
}
fn generate_args_map(
args_map_json: &BTreeMap<String, String>,
) -> Result<BTreeMap<String, serde_json::Value>, Error> {
let mut parsed_args_map = std::collections::BTreeMap::new();
for (path, args) in args_map_json {
let parsed: serde_json::Value = serde_json::from_str(args)
.map_err(|err| Error::framework("error parsing argument map json", err))?;
parsed_args_map.insert(path.clone(), parsed);
}
Ok(parsed_args_map)
}
impl<R: tauri::Runtime> Exportable for crate::Router<R> {
fn generate_types(
&self,
) -> (
Types,
BTreeMap<String, Vec<TauRpcFunction>>,
BTreeMap<String, String>,
) {
(
self.types.clone(),
self.fns_map.clone(),
self.args_map_json.clone(),
)
}
}