use clang::{Entity, EntityKind, EntityVisitResult, TypeKind};
use inflector::Inflector;
#[derive(Clone, Copy, Debug, PartialEq)]
#[allow(dead_code)]
pub enum PointerStyle {
RawConst,
RawMut,
MutRef,
Ref,
OptionRef,
}
#[derive(Clone, Debug)]
pub struct CodegenConfig {
pub method_to_snake: bool,
pub param_to_snake: bool,
pub api_pointer_style: PointerStyle,
pub spi_pointer_style: PointerStyle,
}
impl Default for CodegenConfig {
fn default() -> Self {
Self {
method_to_snake: true,
param_to_snake: true,
api_pointer_style: PointerStyle::MutRef,
spi_pointer_style: PointerStyle::OptionRef,
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum ApiKind {
Md,
Trader,
}
impl ApiKind {
pub fn api_class(&self) -> &'static str {
match self {
ApiKind::Md => "CThostFtdcMdApi",
ApiKind::Trader => "CThostFtdcTraderApi",
}
}
pub fn spi_class(&self) -> &'static str {
match self {
ApiKind::Md => "CThostFtdcMdSpi",
ApiKind::Trader => "CThostFtdcTraderSpi",
}
}
pub fn api_name(&self) -> &'static str {
match self {
ApiKind::Md => "MdApi",
ApiKind::Trader => "TraderApi",
}
}
pub fn spi_name(&self) -> &'static str {
match self {
ApiKind::Md => "MdSpi",
ApiKind::Trader => "TraderSpi",
}
}
pub fn spi_mod(&self) -> &'static str {
match self {
ApiKind::Md => "mdspi",
ApiKind::Trader => "traderspi",
}
}
}
#[derive(Clone, Debug)]
pub struct MethodInfo {
pub cpp_name: String,
pub rust_name: String,
pub comment: String,
pub params: Vec<ParamInfo>,
pub return_type: ReturnType,
pub is_static: bool,
pub method_kind: MethodKind,
}
#[derive(Clone, Debug, PartialEq)]
pub enum MethodKind {
General,
RegisterSpi,
Release,
RegisterString,
Subscribe,
Static,
}
#[derive(Clone, Debug)]
pub struct ParamInfo {
#[allow(dead_code)]
pub cpp_name: String,
pub rust_name: String,
pub kind: ParamKind,
}
#[derive(Clone, Debug)]
pub enum ParamKind {
Int,
Bool,
Enum(String),
CharPtr,
StructPtr(String),
IncompleteArray,
ConstantArray(String),
}
#[derive(Clone, Debug, PartialEq)]
pub enum ReturnType {
Void,
Int,
CharPtr,
}
pub fn extract_methods(
entity: &Entity,
class_name: &str,
config: &CodegenConfig,
) -> Vec<MethodInfo> {
let mut methods = Vec::new();
entity.visit_children(|child, _parent| {
if child.get_kind() == EntityKind::ClassDecl {
if let Some(name) = child.get_name() {
if name == class_name {
for m in child.get_children() {
if m.get_kind() == EntityKind::Method && !m.is_static_method() {
methods.push(extract_method_info(&m, config));
}
}
}
}
}
if cfg!(feature = "sopt") {
EntityVisitResult::Recurse
} else {
EntityVisitResult::Continue
}
});
methods
}
fn extract_method_info(m: &Entity, config: &CodegenConfig) -> MethodInfo {
let cpp_name = m.get_name().unwrap();
let rust_name = to_rust_method_name(&cpp_name, config.method_to_snake);
let comment = m.get_comment().unwrap_or_default();
let is_static = m.is_static_method();
let params: Vec<ParamInfo> = m
.get_arguments()
.unwrap_or_default()
.iter()
.map(|a| extract_param_info(a, config))
.collect();
let return_type = classify_return_type(&m.get_result_type().unwrap().get_display_name());
let method_kind = classify_method_kind(&cpp_name, is_static, ¶ms);
MethodInfo {
cpp_name,
rust_name,
comment,
params,
return_type,
is_static,
method_kind,
}
}
fn extract_param_info(e: &Entity, config: &CodegenConfig) -> ParamInfo {
let cpp_name = e.get_name().unwrap_or_default();
let rust_name = to_rust_param_name(&cpp_name, config.param_to_snake);
let tp = e.get_type().unwrap();
let kind = classify_param_type(&tp);
ParamInfo {
cpp_name,
rust_name,
kind,
}
}
fn classify_param_type(tp: &clang::Type) -> ParamKind {
match tp.get_kind() {
TypeKind::Int => ParamKind::Int,
TypeKind::Bool => ParamKind::Bool,
TypeKind::Enum => ParamKind::Enum(tp.get_display_name()),
TypeKind::Elaborated => {
let canonical = tp.get_canonical_type();
match canonical.get_kind() {
TypeKind::Int => ParamKind::Int,
TypeKind::Bool => ParamKind::Bool,
TypeKind::Enum => ParamKind::Enum(tp.get_display_name()),
TypeKind::ConstantArray => ParamKind::ConstantArray(tp.get_display_name()),
TypeKind::Pointer => classify_pointer_type(&canonical),
_ => ParamKind::StructPtr(tp.get_display_name()),
}
}
TypeKind::Pointer => classify_pointer_type(tp),
TypeKind::IncompleteArray => ParamKind::IncompleteArray,
other => panic!("未处理的参数类型: {:?} ({})", other, tp.get_display_name()),
}
}
fn classify_pointer_type(tp: &clang::Type) -> ParamKind {
let pointee = tp.get_pointee_type().unwrap();
match pointee.get_kind() {
TypeKind::CharS => ParamKind::CharPtr,
TypeKind::Pointer | TypeKind::Elaborated | TypeKind::Record => {
ParamKind::StructPtr(pointee.get_display_name())
}
other => panic!(
"未处理的指针目标类型: {:?} ({})",
other,
pointee.get_display_name()
),
}
}
fn classify_return_type(display_name: &str) -> ReturnType {
match display_name {
"void" => ReturnType::Void,
"int" => ReturnType::Int,
"const char *" => ReturnType::CharPtr,
other => panic!("未处理的返回类型: {}", other),
}
}
fn classify_method_kind(cpp_name: &str, is_static: bool, params: &[ParamInfo]) -> MethodKind {
if is_static {
return MethodKind::Static;
}
match cpp_name {
"RegisterSpi" => MethodKind::RegisterSpi,
"Release" => MethodKind::Release,
"RegisterFront" | "RegisterNameServer" => MethodKind::RegisterString,
name if name.contains("Subscribe") => {
if params
.iter()
.any(|p| matches!(p.kind, ParamKind::IncompleteArray))
{
MethodKind::Subscribe
} else {
MethodKind::General
}
}
_ => MethodKind::General,
}
}
pub fn to_rust_method_name(cpp_name: &str, snake_case: bool) -> String {
if snake_case {
cpp_name.replace("UnSub", "Unsub").to_snake_case()
} else {
cpp_name.to_string()
}
}
pub fn to_rust_param_name(cpp_name: &str, snake_case: bool) -> String {
if snake_case {
cpp_name.to_snake_case()
} else {
cpp_name.to_string()
}
}