use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
str::FromStr,
};
use clap::Parser;
use convert_case::{Case, Casing};
use endpoint_gen::{
definitions::{Definition, EndpointSchemaElement, EnumElement, ErrorCodeSchema, GenService, StructElement},
docs::{self, Data},
error_codes::{build_error_code_catalog, validate_endpoint_error_codes, validate_reserved_enum_names},
rust,
};
use endpoint_libs::model::Type;
use eyre::*;
use ron::from_str;
use semver::{Version, VersionReq};
use serde::{Deserialize, Serialize};
use std::env;
use std::result::Result::Ok;
use walkdir::WalkDir;
#[derive(Parser, Debug)]
#[command(name = "endpoint-gen", version, about = "Generate endpoint documentation and code.")]
struct Cli {
#[arg(short, long)]
config_dir: Option<String>,
#[arg(short, long)]
output_dir: Option<String>,
}
fn main() -> Result<()> {
let args = Cli::parse();
let generation_root: PathBuf = {
if let Some(output_dir) = &args.output_dir {
PathBuf::from_str(output_dir)?
} else {
env::current_dir()?
}
};
let config_dir = {
if let Some(config_dir) = &args.config_dir {
PathBuf::from_str(config_dir)?
} else {
env::current_dir()?
}
};
let version_config = read_version_file(&config_dir.join("version.toml"))
.wrap_err("Error opening version.toml. Make sure it exists and is structured correctly")?;
check_compatibility(version_config)?;
let output_dir = generation_root.join("generated");
let input_objects = build_object_lists(config_dir)?;
let data = Data {
project_root: generation_root,
output_dir,
services: input_objects.services,
enums: input_objects.enums,
structs: input_objects.structs,
error_codes: input_objects.error_codes,
};
let docs_data = format_for_docs(&data);
docs::gen_services_docs(&docs_data)?;
docs::gen_md_docs(&docs_data)?;
rust::gen_model_rs(&data)?;
docs::gen_error_message_md(&data.project_root, &data.error_codes)?;
Ok(())
}
fn format_for_docs(data: &Data) -> Data {
fn camel_case_field(mut field: endpoint_libs::model::Field) -> endpoint_libs::model::Field {
field.name = field.name.to_case(Case::Camel);
field
}
let formatted_services = data
.services
.clone()
.into_iter()
.map(|mut gen_service| {
gen_service.endpoints = gen_service
.endpoints
.into_iter()
.map(|mut endpoint| {
if endpoint.config.snake_case_fields {
endpoint.schema.parameters =
endpoint.schema.parameters.into_iter().map(camel_case_field).collect();
endpoint.schema.returns = endpoint.schema.returns.into_iter().map(camel_case_field).collect();
endpoint.schema.errors = endpoint
.schema
.errors
.into_iter()
.map(|mut error| {
error.name = error.name.to_case(Case::Camel);
error.fields = error.fields.into_iter().map(camel_case_field).collect();
error
})
.collect();
}
endpoint
})
.collect();
gen_service
})
.collect();
let formatted_structs = data
.structs
.clone()
.into_iter()
.map(|mut struct_element| {
if struct_element.config.snake_case_fields {
struct_element.inner = match struct_element.inner {
Type::Struct { name, fields } => {
Type::struct_(name, fields.into_iter().map(camel_case_field).collect())
}
_ => unreachable!(),
};
struct_element
} else {
struct_element
}
})
.collect();
Data {
project_root: data.project_root.clone(),
output_dir: data.output_dir.clone(),
services: formatted_services,
enums: data.enums.clone(),
structs: formatted_structs,
error_codes: data.error_codes.clone(),
}
}
fn process_file(file_path: &Path) -> eyre::Result<Option<Definition>> {
match file_path.extension() {
Some(extension) if extension == "ron" => {
let file_string = std::fs::read_to_string(file_path)?;
let config_file: Config = from_str(&file_string)?;
Ok(Some(config_file.definition))
}
_ => Ok(None), }
}
fn process_input_files(dir: PathBuf) -> eyre::Result<Vec<Definition>> {
let root = dir.as_path();
let mut paths: Vec<PathBuf> = WalkDir::new(root)
.into_iter()
.filter_map(|e| e.ok()) .filter(|e| e.file_type().is_file()) .map(|e| e.into_path()) .collect();
paths.sort();
let mut rust_configs: Vec<Definition> = vec![];
let mut valid_config_files_counter = 0u32;
let mut config_errors = vec![];
for path in paths {
match process_file(path.as_path()) {
Ok(rust_config) => {
if let Some(config) = rust_config {
rust_configs.push(config);
valid_config_files_counter += 1;
}
}
Err(err) => match path.file_name() {
Some(name) if name.to_str().unwrap() == "version.toml" => (),
Some(_) => config_errors.push(format!("{path:?}: {err}")),
None => (),
},
}
}
if !config_errors.is_empty() {
bail!("Error processing RON config files:\n{}", config_errors.join("\n"));
}
if valid_config_files_counter == 0 {
bail!("No valid RON config files found in given path, aborting generation process");
}
Ok(rust_configs)
}
struct InputObjects {
services: Vec<GenService>,
enums: Vec<EnumElement>,
structs: Vec<StructElement>,
error_codes: Vec<ErrorCodeSchema>,
}
fn build_object_lists(dir: PathBuf) -> eyre::Result<InputObjects> {
let rust_configs = process_input_files(dir)?;
let mut service_schema_map: HashMap<(String, u16), Vec<EndpointSchemaElement>> = HashMap::new();
let mut services: Vec<GenService> = vec![];
let mut enums: Vec<EnumElement> = vec![];
let mut structs: Vec<StructElement> = vec![];
let mut custom_error_codes: Vec<ErrorCodeSchema> = vec![];
for config in rust_configs {
match config {
Definition::EndpointSchema(schema_definition) => service_schema_map
.entry((schema_definition.service_name, schema_definition.service_id))
.or_default()
.push(schema_definition.schema),
Definition::EndpointSchemaList(schema_list_definition) => service_schema_map
.entry((schema_list_definition.service_name, schema_list_definition.service_id))
.or_default()
.extend(schema_list_definition.endpoints.into_iter().map(|mut ele| {
if !ele.config.override_parent {
ele.config = schema_list_definition.config.clone();
}
ele
})),
Definition::Enum(enum_type) => enums.push(enum_type),
Definition::EnumList(enums_definition) => {
enums.extend(enums_definition.enum_elements.into_iter().map(|mut ele| {
if !ele.config.override_parent {
ele.config = enums_definition.config.clone();
}
ele
}))
}
Definition::ErrorCodeList(error_code_list) => custom_error_codes.extend(error_code_list.codes),
Definition::Struct(struct_element) => structs.push(struct_element),
Definition::StructList(structs_definition) => {
structs.extend(structs_definition.struct_elements.into_iter().map(|mut ele| {
if !ele.config.override_parent {
ele.config = structs_definition.config.clone();
}
ele
}))
}
}
}
if !service_schema_map.is_empty() {
for ((service_name, service_id), endpoint_schemas) in service_schema_map {
services.push(GenService::new(service_name, service_id, endpoint_schemas));
}
}
services.sort_by_key(|a| a.id);
services
.iter_mut()
.for_each(|service| service.endpoints.sort_by_key(|a| a.schema.code));
enums.sort();
structs.sort();
let error_codes = build_error_code_catalog(custom_error_codes)?;
validate_reserved_enum_names(&enums)?;
validate_endpoint_error_codes(&services, &error_codes)?;
Ok(InputObjects {
services,
enums,
structs,
error_codes,
})
}
#[derive(Deserialize, Serialize)]
struct Config {
definition: Definition,
}
#[derive(Debug, Deserialize)]
struct VersionConfig {
binary: BinaryVersion,
libs: LibsVersion,
}
#[derive(Debug, Deserialize)]
struct BinaryVersion {
version: String, }
#[derive(Debug, Deserialize)]
struct LibsVersion {
version: String, }
fn read_version_file(path: &Path) -> eyre::Result<VersionConfig> {
let content = fs::read_to_string(path)?;
let version_config: VersionConfig = toml::from_str(&content)?;
Ok(version_config)
}
fn check_compatibility(version_config: VersionConfig) -> eyre::Result<()> {
let current_crate_version = Version::parse(get_crate_version()).unwrap();
let binary_version_req = VersionReq::parse(&version_config.binary.version).unwrap();
let libs_version_requirement = env!("ENDPOINT_LIBS_REQUIREMENT");
let libs_version_req = VersionReq::parse(libs_version_requirement).unwrap();
let caller_libs_version = Version::parse(&version_config.libs.version).unwrap();
if !binary_version_req.matches(¤t_crate_version) {
Err(eyre!(
"Binary version constraint not satisfied. Version: {} is specified in version.toml. Current binary version is: {}",
&version_config.binary.version,
&get_crate_version()
))
} else if !libs_version_req.matches(&caller_libs_version) {
Err(eyre!(
"endpoint-libs version constraint not satisfied. Version: {} is specified in version.toml. This version of endpoint-gen requires: {}",
caller_libs_version,
libs_version_requirement
))
} else {
Ok(())
}
}
fn get_crate_version() -> &'static str {
env!("CARGO_PKG_VERSION")
}
#[cfg(test)]
mod tests {
use super::*;
use endpoint_gen::definitions::RustGenConfig;
use endpoint_libs::model::{EndpointErrorCodeRef, EndpointErrorSchema, EndpointSchema, Field};
#[test]
fn format_for_docs_camel_cases_endpoint_error_fields() {
let data = Data {
project_root: PathBuf::new(),
output_dir: PathBuf::new(),
services: vec![GenService::new(
"test_service".to_string(),
1,
vec![EndpointSchemaElement {
frontend_facing: true,
config: RustGenConfig {
snake_case_fields: true,
..Default::default()
},
schema: EndpointSchema::new(
"Login",
10001,
vec![Field::new("user_name", Type::String)],
vec![Field::new("access_token", Type::String)],
)
.with_errors(vec![EndpointErrorSchema {
name: "PasswordTooShort".to_string(),
code: EndpointErrorCodeRef::new("BadRequest"),
message: "Password too short".to_string(),
fields: vec![
Field::new("min_length", Type::Int32),
Field::new("actual_length", Type::Int32),
],
}]),
}],
)],
enums: vec![],
structs: vec![],
error_codes: vec![],
};
let docs = format_for_docs(&data);
let endpoint = &docs.services[0].endpoints[0].schema;
assert_eq!(endpoint.parameters[0].name, "userName");
assert_eq!(endpoint.returns[0].name, "accessToken");
assert_eq!(endpoint.errors[0].fields[0].name, "minLength");
assert_eq!(endpoint.errors[0].fields[1].name, "actualLength");
}
}