use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashSet;
use std::fs;
use std::process::Command;
use std::str::FromStr;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::visit::{self, Visit};
use syn::{
AttrStyle, Fields, File, GenericArgument, Ident, ItemStruct, Macro, Path, PathArguments, Token,
Type, TypePath,
};
type StructFieldInfo = Vec<(Ident, Option<String>, Type)>;
type WrappedStructInfo = Vec<(Ident, Ident, StructFieldInfo)>;
struct Visitor {
info_wrap: WrappedStructInfo,
pyany_to_operation: Vec<TokenStream>,
operation_to_pyobject: Vec<TokenStream>,
}
impl Visitor {
pub fn new() -> Self {
Self {
info_wrap: Vec::new(),
pyany_to_operation: Vec::new(),
operation_to_pyobject: Vec::new(),
}
}
}
#[derive(Debug)]
struct DeriveMacroArguments(HashSet<String>);
impl DeriveMacroArguments {
pub fn _contains(&self, st: &str) -> bool {
self.0.contains(st)
}
}
impl Parse for DeriveMacroArguments {
fn parse(input: ParseStream) -> syn::parse::Result<Self> {
let arguments = Punctuated::<Path, Token![,]>::parse_terminated(input)?;
Ok(Self(
arguments
.into_iter()
.map(|p| match p.get_ident() {
Some(id) => id.to_string(),
_ => p
.segments
.last()
.expect("Last path segment can not be accessed")
.ident
.to_string(),
})
.collect(),
))
}
}
impl<'ast> Visit<'ast> for Visitor {
fn visit_item_struct(&mut self, itemstruct: &'ast ItemStruct) {
for att in itemstruct.attrs.clone() {
let path = att.path.get_ident().map(|id| id.to_string());
if att.style == AttrStyle::Outer && path == Some("wrap".to_string()) {
let wrapper_ident = format_ident!("{}Wrapper", itemstruct.ident);
let field_information = extract_fields_with_types(itemstruct.fields.clone());
self.info_wrap
.push((itemstruct.ident.clone(), wrapper_ident, field_information));
}
}
visit::visit_item_struct(self, itemstruct);
}
fn visit_macro(&mut self, i: &'ast Macro) {
let id = match i.path.clone().get_ident() {
Some(id) => Some(id.clone()),
_ => i.path.segments.last().map(|segment| segment.ident.clone()),
};
if let Some(ident) = id {
if ident.to_string().as_str() == "insert_pyany_to_operation" {
self.pyany_to_operation.push(i.tokens.clone())
}
if ident.to_string().as_str() == "insert_operation_to_pyobject" {
self.operation_to_pyobject.push(i.tokens.clone())
}
}
visit::visit_macro(self, i);
}
}
const SOURCE_FILES: &[&str] = &[
"src/operations/single_qubit_gate_operations.rs",
"src/operations/pragma_operations.rs",
"src/operations/two_qubit_gate_operations.rs",
"src/operations/measurement_operations.rs",
"src/operations/define_operations.rs",
];
fn main() {
let mut vis = Visitor::new();
for source_location in SOURCE_FILES {
let source = fs::read_to_string(source_location).expect("Unable to open source file");
let code = proc_macro2::TokenStream::from_str(&source).expect("Could not lex code");
let syntax_tree: File = syn::parse2(code).unwrap();
vis.visit_file(&syntax_tree);
}
let pyany_to_operation_quotes =
vis.info_wrap
.clone()
.into_iter()
.map(|(ident, _wrapper_ident, field_information)| {
let ident_string = ident.to_string();
let arguments = field_information.iter().map(|(id, _, _)| {
quote! {#id}
});
let field_quotes = field_information.iter().map(|(ident, type_string, ty)| {
let pyobject_name = format_ident!("{}_pyobject", ident);
let ident_string = ident.to_string();
match type_string {
Some(type_str) =>
match type_str.as_str(){
"CalculatorFloat" => {quote!{
let #pyobject_name = op
.call_method0(#ident_string)
.map_err(|_| QoqoError::ConversionError)?;
let #ident = convert_into_calculator_float(#pyobject_name).map_err(|_|
QoqoError::ConversionError)?;
}},
"Circuit" => {quote!{
let #pyobject_name = op
.call_method0(#ident_string)
.map_err(|_| QoqoError::ConversionError)?;
let #ident = convert_into_circuit(#pyobject_name).map_err(|_|
QoqoError::ConversionError)?;
}},
"Option<Circuit>" => {quote!{
let #pyobject_name = op
.call_method0(#ident_string)
.map_err(|_| QoqoError::ConversionError)?;
let tmp: Option<&PyAny> = #pyobject_name.extract().map_err(|_|
QoqoError::ConversionError)?;
let #ident = match tmp{
Some(cw) => Some(convert_into_circuit(cw)
.map_err(|_| QoqoError::ConversionError)?),
_ => None
};
}},
_ => {
quote!{
let #pyobject_name = op
.call_method0(#ident_string)
.map_err(|_| QoqoError::ConversionError)?;
let #ident: #ty = #pyobject_name.extract()
.map_err(|_| QoqoError::ConversionError)?;
}}
},
None => {
quote!{
let #pyobject_name = op
.call_method0(#ident_string)
.map_err(|_| QoqoError::ConversionError)?;
let #ident: #ty = #pyobject_name.extract()
.map_err(|_| QoqoError::ConversionError)?;
}
}
}
}
);
quote! {#ident_string => {
#(#field_quotes)*
Ok(#ident::new(#(#arguments),*).into())
}
}
});
let operation_to_pyobject_quotes =
vis.info_wrap
.into_iter()
.map(|(ident, wrapper_ident, _field_information)| {
quote! {
Operation::#ident(internal) => {
let pyref: Py<#wrapper_ident> =
Py::new(py, #wrapper_ident { internal }).unwrap();
let pyobject: PyObject = pyref.to_object(py);
Ok(pyobject)
}
}
});
let operation_to_pyobject_injected_quotes: Vec<TokenStream> = vis.operation_to_pyobject;
let pyany_to_operation_injected_quotes: Vec<TokenStream> = vis.pyany_to_operation;
let final_quote = quote! {
use crate::operations::*;
use crate::QoqoError;
use crate::convert_into_circuit;
use qoqo_calculator::CalculatorFloat;
use qoqo_calculator_pyo3::convert_into_calculator_float;
use pyo3::conversion::ToPyObject;
use pyo3::prelude::*;
use roqoqo::operations::*;
use std::collections::HashMap;
use ndarray::{Array1, Array};
use num_complex::Complex64;
pub fn convert_operation_to_pyobject(operation: Operation) -> PyResult<PyObject> {
let gil = Python::acquire_gil();
let py = gil.python();
match operation {
#(#operation_to_pyobject_quotes),*
#(#operation_to_pyobject_injected_quotes),*
}
}
pub fn convert_pyany_to_operation(op: &PyAny) -> Result<Operation, QoqoError> {
let hqslang_pyobject = op
.call_method0("hqslang")
.map_err(|_| QoqoError::ConversionError)?;
let hqslang: String = String::extract(hqslang_pyobject)
.map_err(|_| QoqoError::ConversionError)?;
match hqslang.as_str() {
#(#pyany_to_operation_quotes),*
#(#pyany_to_operation_injected_quotes),*
_ => Err(QoqoError::ConversionError),
}
}
};
let final_str = format!("{}", final_quote);
fs::write(
"src/operations/_auto_generated_operation_conversion.rs",
final_str,
)
.expect("Could not write to file");
let _unused_output = Command::new("rustfmt")
.arg("src/operations/_auto_generated_operation_conversion.rs")
.output();
}
fn extract_fields_with_types(input_fields: Fields) -> Vec<(Ident, Option<String>, Type)> {
let fields = match input_fields {
Fields::Named(fields) => fields,
_ => panic!("Trait can only be derived on structs with named fields"),
};
fields.named.into_iter().map(|f| {
let id = f
.ident
.expect("Operate can only be derived on structs with named fields");
let ty = f.ty;
let type_path =match &ty {
Type::Path(TypePath{path:p,..}) => p,
_ => panic!("Trait only supports fields with normal types of form path (e.g. CalculatorFloat, qoqo_calculator::CalculatorFloat)")
};
let mut type_string = match type_path.get_ident(){
Some(ident_path) => Some(ident_path.to_string()),
_ => type_path
.segments
.last().map(|segment|{segment.ident.to_string()})
};
if let Some(ref x) = type_string{
if x.as_str() == "Option"{
let inner_type = match &type_path.segments.iter().next().unwrap().arguments{
PathArguments::AngleBracketed(angle_argumnets) => match angle_argumnets.args.iter().next().unwrap() {
GenericArgument::Type(Type::Path(TypePath{path:innerty,..})) => match innerty.get_ident(){
Some(ident_path) => Some(ident_path.to_string()),
_ =>innerty
.segments
.last().map(|segment|{segment.ident.to_string()})
},
_ => panic!("Expected GenericArgument")
},
_ => panic!("Expected AngleBracketed")
};
if let Some(s) = inner_type { if s.as_str() == "Circuit"{
type_string = Some("Option<Circuit>".to_string())
}}}
}
(id, type_string, ty)
}).collect()
}