use std::fmt::Write as FmtWrite;
use crate::{DomainInfo, PredicateInfo, SymbolTable};
pub struct RustCodegen {
module_name: String,
derive_common: bool,
include_docs: bool,
}
impl RustCodegen {
pub fn new(module_name: impl Into<String>) -> Self {
Self {
module_name: module_name.into(),
derive_common: true,
include_docs: true,
}
}
pub fn with_common_derives(mut self, enable: bool) -> Self {
self.derive_common = enable;
self
}
pub fn with_docs(mut self, enable: bool) -> Self {
self.include_docs = enable;
self
}
pub fn generate(&self, table: &SymbolTable) -> String {
let mut code = String::new();
writeln!(code, "//! Generated from TensorLogic schema.")
.expect("writing to String is infallible");
writeln!(code, "//! Module: {}", self.module_name)
.expect("writing to String is infallible");
writeln!(code, "//!").expect("writing to String is infallible");
writeln!(code, "//! This code was automatically generated.")
.expect("writing to String is infallible");
writeln!(code, "//! DO NOT EDIT MANUALLY.").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, "#![allow(dead_code)]").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, "// ============================================")
.expect("writing to String is infallible");
writeln!(code, "// Domain Types").expect("writing to String is infallible");
writeln!(code, "// ============================================")
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
for domain in table.domains.values() {
self.generate_domain(&mut code, domain);
writeln!(code).expect("writing to String is infallible");
}
writeln!(code, "// ============================================")
.expect("writing to String is infallible");
writeln!(code, "// Predicate Types").expect("writing to String is infallible");
writeln!(code, "// ============================================")
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
for predicate in table.predicates.values() {
self.generate_predicate(&mut code, predicate, table);
writeln!(code).expect("writing to String is infallible");
}
writeln!(code, "// ============================================")
.expect("writing to String is infallible");
writeln!(code, "// Schema Metadata").expect("writing to String is infallible");
writeln!(code, "// ============================================")
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
self.generate_schema_metadata(&mut code, table);
code
}
fn generate_domain(&self, code: &mut String, domain: &DomainInfo) {
if self.include_docs {
if let Some(ref desc) = domain.description {
writeln!(code, "/// {}", desc).expect("writing to String is infallible");
} else {
writeln!(code, "/// Domain: {}", domain.name)
.expect("writing to String is infallible");
}
writeln!(code, "///").expect("writing to String is infallible");
writeln!(code, "/// Cardinality: {}", domain.cardinality)
.expect("writing to String is infallible");
}
if self.derive_common {
writeln!(code, "#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]")
.expect("writing to String is infallible");
}
let type_name = Self::to_type_name(&domain.name);
writeln!(code, "pub struct {}(pub usize);", type_name)
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
writeln!(
code,
" /// Maximum valid ID for this domain (exclusive)."
)
.expect("writing to String is infallible");
writeln!(
code,
" pub const CARDINALITY: usize = {};",
domain.cardinality
)
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, " /// Create a new {} instance.", type_name)
.expect("writing to String is infallible");
writeln!(code, " ///").expect("writing to String is infallible");
writeln!(code, " /// # Panics").expect("writing to String is infallible");
writeln!(code, " ///").expect("writing to String is infallible");
writeln!(code, " /// Panics if `id >= {}`.", domain.cardinality)
.expect("writing to String is infallible");
writeln!(code, " pub fn new(id: usize) -> Self {{")
.expect("writing to String is infallible");
writeln!(code, " assert!(id < Self::CARDINALITY, \"ID {{}} exceeds cardinality {{}}\", id, Self::CARDINALITY);", ).expect("writing to String is infallible");
writeln!(code, " Self(id)").expect("writing to String is infallible");
writeln!(code, " }}").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(
code,
" /// Create a new {} instance without bounds checking.",
type_name
)
.expect("writing to String is infallible");
writeln!(code, " ///").expect("writing to String is infallible");
writeln!(code, " /// # Safety").expect("writing to String is infallible");
writeln!(code, " ///").expect("writing to String is infallible");
writeln!(
code,
" /// Caller must ensure `id < {}`.",
domain.cardinality
)
.expect("writing to String is infallible");
writeln!(
code,
" pub unsafe fn new_unchecked(id: usize) -> Self {{"
)
.expect("writing to String is infallible");
writeln!(code, " Self(id)").expect("writing to String is infallible");
writeln!(code, " }}").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, " /// Get the underlying ID.").expect("writing to String is infallible");
writeln!(code, " pub fn id(&self) -> usize {{")
.expect("writing to String is infallible");
writeln!(code, " self.0").expect("writing to String is infallible");
writeln!(code, " }}").expect("writing to String is infallible");
writeln!(code, "}}").expect("writing to String is infallible");
}
fn generate_predicate(
&self,
code: &mut String,
predicate: &PredicateInfo,
_table: &SymbolTable,
) {
if self.include_docs {
if let Some(ref desc) = predicate.description {
writeln!(code, "/// {}", desc).expect("writing to String is infallible");
} else {
writeln!(code, "/// Predicate: {}", predicate.name)
.expect("writing to String is infallible");
}
writeln!(code, "///").expect("writing to String is infallible");
writeln!(code, "/// Arity: {}", predicate.arg_domains.len())
.expect("writing to String is infallible");
if let Some(ref constraints) = predicate.constraints {
if !constraints.properties.is_empty() {
writeln!(code, "///").expect("writing to String is infallible");
writeln!(code, "/// Properties:").expect("writing to String is infallible");
for prop in &constraints.properties {
writeln!(code, "/// - {:?}", prop)
.expect("writing to String is infallible");
}
}
}
}
if self.derive_common {
writeln!(code, "#[derive(Clone, Debug, PartialEq, Eq, Hash)]")
.expect("writing to String is infallible");
}
let type_name = Self::to_type_name(&predicate.name);
if predicate.arg_domains.is_empty() {
writeln!(code, "pub struct {};", type_name).expect("writing to String is infallible");
} else if predicate.arg_domains.len() == 1 {
let domain_type = Self::to_type_name(&predicate.arg_domains[0]);
writeln!(code, "pub struct {}(pub {});", type_name, domain_type)
.expect("writing to String is infallible");
} else {
write!(code, "pub struct {}(", type_name).expect("writing to String is infallible");
for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
if i > 0 {
write!(code, ", ").expect("writing to String is infallible");
}
write!(code, "pub {}", Self::to_type_name(domain_name))
.expect("writing to String is infallible");
}
writeln!(code, ");").expect("writing to String is infallible");
}
writeln!(code).expect("writing to String is infallible");
writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
if !predicate.arg_domains.is_empty() {
writeln!(code, " /// Create a new {} instance.", type_name)
.expect("writing to String is infallible");
write!(code, " pub fn new(").expect("writing to String is infallible");
for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
if i > 0 {
write!(code, ", ").expect("writing to String is infallible");
}
write!(code, "arg{}: {}", i, Self::to_type_name(domain_name))
.expect("writing to String is infallible");
}
writeln!(code, ") -> Self {{").expect("writing to String is infallible");
if predicate.arg_domains.len() == 1 {
writeln!(code, " Self(arg0)").expect("writing to String is infallible");
} else {
write!(code, " Self(").expect("writing to String is infallible");
for i in 0..predicate.arg_domains.len() {
if i > 0 {
write!(code, ", ").expect("writing to String is infallible");
}
write!(code, "arg{}", i).expect("writing to String is infallible");
}
writeln!(code, ")").expect("writing to String is infallible");
}
writeln!(code, " }}").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
writeln!(code, " /// Get argument {}.", i)
.expect("writing to String is infallible");
writeln!(
code,
" pub fn arg{}(&self) -> {} {{",
i,
Self::to_type_name(domain_name)
)
.expect("writing to String is infallible");
if predicate.arg_domains.len() == 1 {
writeln!(code, " self.0").expect("writing to String is infallible");
} else {
writeln!(code, " self.{}", i).expect("writing to String is infallible");
}
writeln!(code, " }}").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
}
}
writeln!(code, "}}").expect("writing to String is infallible");
}
fn generate_schema_metadata(&self, code: &mut String, table: &SymbolTable) {
writeln!(code, "/// Schema metadata and statistics.")
.expect("writing to String is infallible");
writeln!(code, "pub struct SchemaMetadata;").expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, "impl SchemaMetadata {{").expect("writing to String is infallible");
writeln!(code, " /// Number of domains in the schema.")
.expect("writing to String is infallible");
writeln!(
code,
" pub const DOMAIN_COUNT: usize = {};",
table.domains.len()
)
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, " /// Number of predicates in the schema.")
.expect("writing to String is infallible");
writeln!(
code,
" pub const PREDICATE_COUNT: usize = {};",
table.predicates.len()
)
.expect("writing to String is infallible");
writeln!(code).expect("writing to String is infallible");
writeln!(code, " /// Total cardinality across all domains.")
.expect("writing to String is infallible");
let total_card: usize = table.domains.values().map(|d| d.cardinality).sum();
writeln!(
code,
" pub const TOTAL_CARDINALITY: usize = {};",
total_card
)
.expect("writing to String is infallible");
writeln!(code, "}}").expect("writing to String is infallible");
}
pub(super) fn to_type_name(name: &str) -> String {
name.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect()
}
}