use prelude::*;
use super::common::*;
use std::iter;
use std::collections::BTreeMap;
use idents;
use utils;
use tyhandlers::{Direction, ModelTypeSystem};
use model;
use methodinfo::ComMethodInfo;
extern crate proc_macro;
#[derive(Default)]
struct InterfaceOutput {
iid_arms : Vec<TokenStream>,
method_impls : BTreeMap< String, MethodImpl >,
}
struct MethodImpl {
info : ComMethodInfo,
impls : BTreeMap< ModelTypeSystem, TokenStream >,
}
impl MethodImpl {
pub fn new( mi : ComMethodInfo ) -> Self {
MethodImpl {
info: mi,
impls : Default::default(),
}
}
}
pub fn expand_com_interface(
attr_tokens: TokenStreamNightly,
item_tokens: TokenStreamNightly,
) -> Result<TokenStreamNightly, model::ParseError>
{
let mut output = vec![];
let itf = model::ComInterface::parse(
&lib_name(),
attr_tokens.into(),
&item_tokens.to_string() )?;
let itf_ident = itf.name();
let mut itf_output = InterfaceOutput::default();
for ( &ts, itf_variant ) in itf.variants() {
process_itf_variant(
&itf, ts, itf_variant,
&mut output, &mut itf_output );
}
if itf.item_type() == utils::InterfaceType::Trait {
let mut impls = vec![];
for ( _, method ) in itf_output.method_impls.iter() {
let mut impl_branches = vec![];
for ( ts, method_ts_impl ) in method.impls.iter() {
let ts_tokens = ts.as_typesystem_tokens();
impl_branches.push( quote!(
if let Some( comptr ) = ComItf::maybe_ptr( self, #ts_tokens ) {
#method_ts_impl
}
) );
}
let impl_args = method.info.args.iter().map( |ca| {
let name = &ca.name;
let ty = &ca.ty;
quote!( #name : #ty )
} );
let unsafety = if method.info.is_unsafe { quote!( unsafe ) } else { quote!() };
let self_arg = &method.info.rust_self_arg;
let method_rust_ident = &method.info.display_name;
let return_ty = &method.info.rust_return_ty;
impls.push( quote!(
#unsafety fn #method_rust_ident(
#self_arg, #( #impl_args ),*
) -> #return_ty {
#[allow(unused_imports)]
use ::intercom::ComInto;
#[allow(unused_imports)]
use ::intercom::ErrorValue;
#( #impl_branches )*
< #return_ty as ::intercom::ErrorValue >::from_com_error(
::intercom::ComError::E_POINTER.into() )
}
) );
}
let unsafety = if itf.is_unsafe() { quote!( unsafe ) } else { quote!() };
output.push( quote!(
#unsafety impl #itf_ident for ::intercom::ComItf< #itf_ident > {
#( #impls )*
}
) );
}
let iid_arms = itf_output.iid_arms;
let ( deref_impl, deref_ret ) = if itf.item_type() == utils::InterfaceType::Trait {
(
quote!( com_itf ),
quote!( &( #itf_ident + 'static ) )
)
} else {
(
quote!(
let some_iunk : &::intercom::ComItf<::intercom::IUnknown> = com_itf.as_ref();
let iunknown_iid = ::intercom::IUnknown::iid(
::intercom::TypeSystem::Automation )
.expect( "IUnknown must have Automation IID" );
let primary_iunk = some_iunk.query_interface( iunknown_iid )
.expect( "All types must implement IUnknown" );
let combox : *mut ::intercom::ComBox< #itf_ident > =
primary_iunk as *mut ::intercom::ComBox< #itf_ident >;
unsafe {
::intercom::ComBox::release( combox );
use std::ops::Deref;
(*combox).deref()
}
),
quote!( & #itf_ident )
)
};
output.push( quote!(
impl ::intercom::ComInterface for #itf_ident {
#[doc = "Returns the IID of the requested interface."]
fn iid( ts : ::intercom::TypeSystem ) -> Option< &'static ::intercom::IID > {
match ts {
#( #iid_arms ),*
}
}
fn deref(
com_itf : &::intercom::ComItf< #itf_ident >
) -> #deref_ret {
#deref_impl
}
}
) );
Ok( tokens_to_tokenstream( item_tokens, output ) )
}
fn process_itf_variant(
itf : &model::ComInterface,
ts : ModelTypeSystem,
itf_variant : &model::ComInterfaceVariant,
output : &mut Vec<TokenStream>,
itf_output : &mut InterfaceOutput,
) {
let itf_ident = itf.name();
let visibility = itf.visibility();
let iid_ident = idents::iid( itf_variant.unique_name() );
let vtable_ident = idents::vtable_struct( itf_variant.unique_name() );
let iid_tokens = utils::get_guid_tokens( itf_variant.iid() );
let iid_doc = format!( "`{}` interface ID.", itf_ident );
output.push( quote!(
#[doc = #iid_doc]
#[allow(non_upper_case_globals)]
#visibility const #iid_ident : ::intercom::IID = #iid_tokens;
) );
let ts_match = ts.as_typesystem_tokens();
itf_output.iid_arms.push( quote!( #ts_match => Some( & #iid_ident ) ) );
let mut vtbl_fields = vec![];
if let Some( ref base ) = *itf.base_interface() {
let vtbl = match base.to_string().as_ref() {
"IUnknown" => quote!( ::intercom::IUnknownVtbl ),
_ => { let vtbl = idents::vtable_struct( &base ); quote!( #vtbl ) }
};
vtbl_fields.push( quote!( pub __base : #vtbl ) );
}
let calling_convention = get_calling_convetion();
for method_info in itf_variant.methods() {
let method_ident = &method_info.unique_name;
let in_out_args = method_info.raw_com_args()
.into_iter()
.map( |com_arg| {
let name = &com_arg.name;
let com_ty = &com_arg.handler.com_ty( com_arg.dir );
let dir = match com_arg.dir {
Direction::In => quote!(),
Direction::Out | Direction::Retval => quote!( *mut )
};
quote!( #name : #dir #com_ty )
} );
let self_arg = quote!( self_vtable : ::intercom::RawComPtr );
let args = iter::once( self_arg ).chain( in_out_args );
let ret_ty = method_info.returnhandler.com_ty();
vtbl_fields.push( quote!(
pub #method_ident :
unsafe extern #calling_convention fn( #( #args ),* ) -> #ret_ty
) );
let method_name = method_info.display_name.to_string();
if ! itf_output.method_impls.contains_key( &method_name ) {
itf_output.method_impls.insert(
method_name.clone(),
MethodImpl::new( method_info.clone() ) );
}
let method_impl = &mut itf_output.method_impls.get_mut( &method_name )
.expect( "We just ensured this exists three lines up... ;_;" );
method_impl.impls.insert(
itf_variant.type_system(),
rust_to_com_delegate( itf_variant, method_info, &vtable_ident ) );
}
output.push( quote!(
#[allow(non_camel_case_types)]
#[repr(C)]
#[doc(hidden)]
#visibility struct #vtable_ident { #( #vtbl_fields, )* }
) );
}
fn rust_to_com_delegate(
itf_variant : &model::ComInterfaceVariant,
method_info : &ComMethodInfo,
vtable_ident : &Ident,
) -> TokenStream {
let out_arg_declarations = method_info.returnhandler.com_out_args()
.iter()
.map( |ca| {
let ident = &ca.name;
let ty = &ca.handler.com_ty( Direction::Retval );
let default = ca.handler.default_value();
quote!( let mut #ident : #ty = #default; )
} ).collect::<Vec<_>>();
let ( temporaries, params ) : ( Vec<_>, Vec<_> ) = method_info.raw_com_args()
.into_iter()
.map( |com_arg| {
let name = com_arg.name;
match com_arg.dir {
Direction::In => {
let param = com_arg.handler.rust_to_com( &name, Direction::In );
( param.temporary, param.value )
},
Direction::Out | Direction::Retval
=> ( None, quote!( &mut #name ) ),
}
} )
.unzip();
let params = iter::once( quote!( comptr.ptr ) ).chain( params );
let return_ident = Ident::new( "__result", Span::call_site() );
let return_statement = method_info
.returnhandler
.com_to_rust_return( &return_ident );
let method_ident = &method_info.unique_name;
let return_ty = &method_info.rust_return_ty;
let iid_tokens = utils::get_guid_tokens( itf_variant.iid() );
quote!(
let vtbl = comptr.ptr as *const *const #vtable_ident;
#( #temporaries )*
#[allow(unused_unsafe)] let result : Result< #return_ty, ::intercom::ComError > = ( || unsafe {
#( #out_arg_declarations )*
let #return_ident = ((**vtbl).#method_ident)( #( #params ),* );
let INTERCOM_iid = #iid_tokens;
Ok( { #return_statement } )
} )();
return match result {
Ok( v ) => v,
Err( err ) => < #return_ty as ::intercom::ErrorValue >::from_com_error( err ),
};
)
}