use std::collections::HashSet;
use cge::{Network, gene::GeneExtras, Activation};
use proc_macro2::TokenStream;
use quote::quote;
use crate::{recurrence, evaluator, macro_core::{Invocation, CgeType}, numeric_type::NumericType};
pub struct Synthesis {
pub recurrency_count: usize,
pub documentation: TokenStream,
pub persistence_field: TokenStream,
pub associated_constants: TokenStream,
pub persistence_methods: TokenStream,
pub evaluate_function: TokenStream,
}
fn load_network(cge_path: &str) -> Network {
let network = Network::load_from_file(&cge_path);
match network {
Ok(n) => n,
Err(e) => panic!("Failed to open CGE file ({})", e)
}
}
fn activation_path(activation: Activation, numeric_type: NumericType) -> TokenStream {
let numeric_type = numeric_type.token();
match activation {
Activation::Linear => quote! { const_cge::activations::#numeric_type::linear },
Activation::Threshold => quote! { const_cge::activations::#numeric_type::threshold },
Activation::Relu => quote! { const_cge::activations::#numeric_type::relu },
Activation::Sign => quote! { const_cge::activations::#numeric_type::sign },
Activation::Sigmoid => quote! { const_cge::activations::#numeric_type::sigmoid },
Activation::Tanh => quote! { const_cge::activations::#numeric_type::tanh },
Activation::SoftSign => quote! { const_cge::activations::#numeric_type::soft_sign },
Activation::BentIdentity => quote! { const_cge::activations::#numeric_type::bent_identity },
}
}
pub fn synthesize(invocation: &Invocation) -> Synthesis {
let mut network = match invocation.config.cge {
CgeType::File(ref path) => load_network(path),
CgeType::Direct(ref data) => {
Network::from_str(data)
.expect("I've inferred that you're trying to supply CGE data directly as a string, but I can't parse your input as CGE.")
},
CgeType::Module(_) => unreachable!()
};
let mut computations_list = vec![];
let mut computations_end = vec![];
let recurrence_table = recurrence::identify_recurrence(&network);
let size = network.size;
let activation = network.function.clone();
let activation_fn_path = activation_path(activation, invocation.config.numeric_type);
let recurrency_count = recurrence_table.len();
let input_count = {
let mut input_ids = HashSet::new();
network
.genome
.iter()
.filter_map(|g| match g.variant {
GeneExtras::Input(_) => Some(g.id),
_ => None
})
.for_each(|id| { input_ids.insert(id); });
input_ids.len()
};
let output_count = evaluator::evaluate(
&mut network,
0..size,
true,
false,
true,
&mut computations_list,
&mut computations_end,
&mut 0,
&recurrence_table,
invocation.config.numeric_type,
activation_fn_path
).expect("Corrupt CGE: network appears to have no outputs");
let numeric_token = invocation.config.numeric_type.token();
let numeric_bytes = invocation.config.numeric_type.size_of();
let (persistence_field, persistence_methods) = {
if recurrency_count == 0 {
(quote!(), quote!())
} else {
(
quote!(persistence: [#numeric_token; #recurrency_count],),
quote!(
pub fn with_recurrent_state(persistence: &[#numeric_token; #recurrency_count]) -> Self {
Self { persistence: *persistence }
}
pub fn set_recurrent_state(&mut self, persistence: &[#numeric_token; #recurrency_count]) {
self.persistence = persistence.clone();
}
pub fn recurrent_state(&self) -> &[#numeric_token; #recurrency_count] {
&self.persistence
}
pub fn recurrent_state_mut(&mut self) -> &mut [#numeric_token; #recurrency_count] {
&mut self.persistence
}
)
)
}
};
let documentation = {
let build_info = format!(
"{source_statement}- {recurrency_statement}",
source_statement = match invocation.config.cge {
CgeType::File(ref path) => format!("- Compiled from CGE file: `{}`\n", path),
CgeType::Direct(_) => "".into(),
CgeType::Module(_) => "".into()
},
recurrency_statement = if recurrency_count == 0 {
"No recurrency detected
- network is stateless (a ZST)
- `Self::evaluate` is static.".into()
} else {
format!(
"Network is recurrent (stateful)
- {state_count} persistent state{state_plural}: `{byte_count} byte{byte_plural}`
- `Self::evaluate` must take `&mut self`",
state_count = recurrency_count,
state_plural = if recurrency_count == 1 { "" } else { "s" },
byte_count = recurrency_count * numeric_bytes,
byte_plural = if recurrency_count * numeric_bytes == 1 { "" } else { "s" },
)
}
);
let input_declr = format!("let input = [{}];", {
if input_count <= 4 {
(0..input_count).map(|_| "0.").collect::<Vec<&str>>().join(", ")
} else {
format!("0.; {}", input_count)
}
});
let output_declr = format!("let mut output = [{}];", {
if output_count <= 4 {
(0..output_count).map(|_| "0.").collect::<Vec<&str>>().join(", ")
} else {
format!("0.; {}", output_count)
}
});
let network_declr = format!(
"let{mutability} network = Network::default(); // {comment}",
mutability = if recurrency_count == 0 { "" } else { " mut" },
comment = if recurrency_count == 0 { "no recurrency, zero-size type" } else { "recurrent state all zeros" },
);
quote! {
#[doc = #build_info]
#[doc = #input_declr]
#[doc = #output_declr]
#[doc = #network_declr]
}
};
let associated_constants = quote! {
pub const INPUT_COUNT: usize = #input_count;
pub const OUTPUT_COUNT: usize = #output_count;
pub const PERSISTENT_SIZE: usize = #recurrency_count;
};
let evaluate_function = {
let self_argument = if recurrency_count == 0 { quote!() } else { quote!(&mut self,) };
let numeric_comment = format!("- - how fast your target hardware can perform numeric (`{}`) operations", numeric_token);
quote! {
#[doc = #numeric_comment]
pub fn evaluate(#self_argument inputs: &[#numeric_token; #input_count], outputs: &mut [#numeric_token; #output_count]) {
#(#computations_list)*
#(#computations_end)*
}
}
};
Synthesis {
recurrency_count,
documentation,
persistence_field,
associated_constants,
persistence_methods,
evaluate_function,
}
}