use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
Expr, Pat, Token,
parse::{Parse, ParseStream},
token::{Comma, FatArrow},
};
pub struct SelectInput {
pub cases: Vec<SelectCase>,
}
pub struct SelectCase {
pub recv: Option<(Pat, Expr)>,
pub send: Option<(Expr, Expr)>,
pub body: Expr,
}
impl Parse for SelectInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut cases = Vec::new();
while !input.is_empty() {
let case = parse_select_case(input)?;
cases.push(case);
if input.peek(Comma) {
input.parse::<Comma>()?;
}
}
Ok(SelectInput { cases })
}
}
fn parse_select_case(input: ParseStream) -> syn::Result<SelectCase> {
let lookahead = input.lookahead1();
if lookahead.peek(Token![default]) {
input.parse::<Token![default]>()?;
input.parse::<FatArrow>()?;
let body = input.parse::<Expr>()?;
return Ok(SelectCase {
recv: None,
send: None,
body,
});
}
let fork = input.fork();
if let Ok(_pat) = Pat::parse_single(&fork) {
if fork.peek(Token![<-]) {
let pat = Pat::parse_single(input)?;
input.parse::<Token![<-]>()?;
let channel = input.parse::<Expr>()?;
input.parse::<FatArrow>()?;
let body = input.parse::<Expr>()?;
return Ok(SelectCase {
recv: Some((pat, channel)),
send: None,
body,
});
}
}
let expr_span = input.span(); let expr = input.parse::<Expr>()?;
if let Expr::MethodCall(method) = &expr {
if method.method == "send" {
input.parse::<FatArrow>()?;
let body = input.parse::<Expr>()?;
let channel = *method.receiver.clone();
let value = method.args.first().cloned().ok_or_else(|| {
syn::Error::new(method.method.span(), "send() requires an argument")
})?;
return Ok(SelectCase {
recv: None,
send: Some((channel, value)),
body,
});
}
}
Err(syn::Error::new(
expr_span,
"Expected pattern <- channel or channel.send(value)",
))
}
pub fn parse_select(input_str: String) -> Result<TokenStream2, String> {
let parse_result = syn::parse_str::<SelectInput>(&input_str);
match parse_result {
Ok(select_input) => {
let has_default = select_input
.cases
.iter()
.any(|c| c.recv.is_none() && c.send.is_none());
Ok(generate_select_impl(select_input.cases, has_default))
}
Err(err) => Err(format!("Parse error: {}", err)),
}
}
fn generate_select_impl(cases: Vec<SelectCase>, has_default: bool) -> TokenStream2 {
if has_default {
generate_non_blocking_select(cases)
} else {
generate_blocking_select(cases)
}
}
fn generate_non_blocking_select(cases: Vec<SelectCase>) -> TokenStream2 {
let mut checks = Vec::new();
let mut default_body = None;
for case in cases {
match (case.recv, case.send) {
(Some((pat, chan)), None) => {
let body = case.body;
checks.push(quote! {
if let Ok(val) = #chan.try_recv() {
let #pat = val;
#body
return;
}
});
}
(None, Some((chan, val))) => {
let body = case.body;
checks.push(quote! {
if #chan.try_send(#val).is_ok() {
#body
return;
}
});
}
(None, None) => {
default_body = Some(case.body);
}
_ => unreachable!(),
}
}
quote! {
{
use ::gorust::channel::Selectable;
#(#checks)*
#default_body
}
}
}
fn generate_blocking_select(cases: Vec<SelectCase>) -> TokenStream2 {
let recv_tokens: Vec<_> = cases
.iter()
.enumerate()
.filter_map(|(i, case)| {
if let Some((_pat, chan)) = &case.recv {
let _body = &case.body;
Some(quote! {
{
let __tx = __result_tx.clone();
let __chan = #chan.clone();
let __case_id = #i;
::gorust::go(move || {
if let Some(__val) = __chan.recv() {
let _ = __tx.send((__case_id, Ok(__val)));
}
});
}
})
} else {
None
}
})
.collect();
let send_tokens: Vec<_> = cases
.iter()
.enumerate()
.filter_map(|(i, case)| {
if let Some((chan, val)) = &case.send {
let _body = &case.body;
Some(quote! {
{
let __tx = __result_tx.clone();
let __chan = #chan.clone();
let __val = #val.clone();
let __case_id = #i;
::gorust::go(move || {
if __chan.send(__val).is_ok() {
let _ = __tx.send((__case_id, Err(())));
}
});
}
})
} else {
None
}
})
.collect();
let branches: Vec<_> = cases
.iter()
.enumerate()
.map(|(i, case)| {
match (&case.recv, &case.send) {
(Some((pat, _)), None) => {
let body = &case.body;
quote! {
#i => {
if let Ok(__val) = __result_val {
let #pat = __val;
#body
}
}
}
}
(None, Some(_)) => {
let body = &case.body;
quote! {
#i => {
let _ = __result_val; #body
}
}
}
_ => quote! {}, }
})
.collect();
quote! {
{
use std::sync::mpsc::channel;
use ::gorust::channel::Selectable;
let (__result_tx, __result_rx): (std::sync::mpsc::Sender<(usize, Result<_, ()>)>,
std::sync::mpsc::Receiver<(usize, Result<_, ()>)>) = channel();
#(#recv_tokens)*
#(#send_tokens)*
if let Ok((__case_id, __result_val)) = __result_rx.recv() {
match __case_id {
#(#branches)*
_ => {}
}
}
drop(__result_tx);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_select_basic() {
let input = r#"
val <- ch1 => {
println!("Got: {}", val);
}
"#;
let result = parse_select(input.to_string());
assert!(result.is_ok());
}
#[test]
fn test_parse_select_with_send() {
let input = r#"
ch2.send(42) => {
println!("Sent!");
}
"#;
let result = parse_select(input.to_string());
assert!(result.is_ok());
}
#[test]
fn test_parse_select_with_default() {
let input = r#"
val <- ch1 => {
println!("Got: {}", val);
},
default => {
println!("No op");
}
"#;
let result = parse_select(input.to_string());
assert!(result.is_ok());
}
#[test]
fn test_parse_select_multiple_cases() {
let input = r#"
val1 <- ch1 => {
println!("From ch1: {}", val1);
},
val2 <- ch2 => {
println!("From ch2: {}", val2);
},
ch3.send(42) => {
println!("Sent to ch3");
}
"#;
let result = parse_select(input.to_string());
assert!(result.is_ok());
}
}